jingwwu commited on
Commit
295118d
·
1 Parent(s): 16c6dad

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/comparision.png filter=lfs diff=lfs merge=lfs -text
37
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,92 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ pipeline_tag: text-to-image
4
+ library_name: transformers
5
+ ---
6
+
7
+ ## NextStep-1.1
8
+
9
+ [Homepage](https://stepfun.ai/research/en/nextstep-1) 
10
+ | [GitHub](https://github.com/stepfun-ai/NextStep-1) 
11
+ | [Paper](https://arxiv.org/abs/2508.10711) 
12
+
13
+ We introduce **NextStep-1.1**, a new model represents a significant leap forward in the NextStep series. This version effectively resolves the visualization failures seen in **NextStep-1** and substantially elevates image quality through extended training and a Flow-based Reinforcement Learning (RL) post-training paradigm.
14
+
15
+ <div align='center'>
16
+ <img src="assets/comparision.png" class="interpolation-image" alt="arch." width="100%" />
17
+ </div>
18
+
19
+ ## What's New in 1.1?
20
+
21
+ NextStep-1.1 is not just a fine-tune; it is a re-engineered version focused on stability and high-fidelity output. Key improvements include:
22
+
23
+ - RL Enhanced Visual Fidelity: Significant improvement in image texture and a substantial reduction in visual artifacts via RL, ensuring much cleaner and more professional outputs.
24
+
25
+ - Technical Stability: Solves numerical instability inherent in the RL of autoregressive flow-based models.
26
+
27
+ ## Environment Setup
28
+
29
+ To avoid potential errors when loading and running your models, we recommend using the following settings:
30
+
31
+ ```shell
32
+ conda create -n nextstep python=3.11 -y
33
+ conda activate nextstep
34
+
35
+ pip install uv # optional
36
+
37
+ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/stepfun-ai/NextStep-1.1-Pretrain && cd NextStep-1.1-Pretrain
38
+ uv pip install -r requirements.txt
39
+
40
+ hf download stepfun-ai/NextStep-1.1-Pretrain "vae/checkpoint.pt" --local-dir ./
41
+ ```
42
+
43
+ ## Usage
44
+
45
+ ```python
46
+ import torch
47
+ from transformers import AutoTokenizer, AutoModel
48
+ from models.gen_pipeline import NextStepPipeline
49
+
50
+ HF_HUB = "stepfun-ai/NextStep-1.1-Pretrain"
51
+
52
+ # load model and tokenizer
53
+ tokenizer = AutoTokenizer.from_pretrained(HF_HUB, local_files_only=True, trust_remote_code=True)
54
+ model = AutoModel.from_pretrained(HF_HUB, local_files_only=True, trust_remote_code=True)
55
+ pipeline = NextStepPipeline(tokenizer=tokenizer, model=model).to(device="cuda", dtype=torch.bfloat16)
56
+
57
+ # set prompts
58
+ positive_prompt = ""
59
+ negative_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry."
60
+ example_prompt = "A REALISTIC PHOTOGRAPH OF A WALL WITH \"TOWARD AUTOREGRESSIVE IMAGE GENERATION WITH CONTINUOUS TOKENS AT SCALE\" PROMINENTLY DISPLAYED"
61
+
62
+ # generate image from text
63
+ IMG_SIZE = 512
64
+ image = pipeline.generate_image(
65
+ example_prompt,
66
+ hw=(IMG_SIZE, IMG_SIZE),
67
+ num_images_per_caption=1,
68
+ positive_prompt=positive_prompt,
69
+ negative_prompt=negative_prompt,
70
+ cfg=7.5,
71
+ cfg_img=1.0,
72
+ cfg_schedule="constant",
73
+ use_norm=False,
74
+ num_sampling_steps=28,
75
+ timesteps_shift=1.0,
76
+ seed=3407,
77
+ )[0]
78
+ image.save("./assets/output.jpg")
79
+ ```
80
+
81
+ ## Citation
82
+
83
+ If you find NextStep useful for your research and applications, please consider starring this repository and citing:
84
+
85
+ ```bibtex
86
+ @article{nextstepteam2025nextstep1,
87
+ title={NextStep-1: Toward Autoregressive Image Generation with Continuous Tokens at Scale},
88
+ author={NextStep Team and Chunrui Han and Guopeng Li and Jingwei Wu and Quan Sun and Yan Cai and Yuang Peng and Zheng Ge and Deyu Zhou and Haomiao Tang and Hongyu Zhou and Kenkun Liu and Ailin Huang and Bin Wang and Changxin Miao and Deshan Sun and En Yu and Fukun Yin and Gang Yu and Hao Nie and Haoran Lv and Hanpeng Hu and Jia Wang and Jian Zhou and Jianjian Sun and Kaijun Tan and Kang An and Kangheng Lin and Liang Zhao and Mei Chen and Peng Xing and Rui Wang and Shiyu Liu and Shutao Xia and Tianhao You and Wei Ji and Xianfang Zeng and Xin Han and Xuelin Zhang and Yana Wei and Yanming Xu and Yimin Jiang and Yingming Wang and Yu Zhou and Yucheng Han and Ziyang Meng and Binxing Jiao and Daxin Jiang and Xiangyu Zhang and Yibo Zhu},
89
+ journal={arXiv preprint arXiv:2508.10711},
90
+ year={2025}
91
+ }
92
+ ```
added_tokens.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</tool_call>": 151658,
3
+ "<tool_call>": 151657,
4
+ "<|begin_of_image|>": 151667,
5
+ "<|begin_of_prompt_refinement|>": 151670,
6
+ "<|begin_of_thinking|>": 151672,
7
+ "<|box_end|>": 151649,
8
+ "<|box_start|>": 151648,
9
+ "<|end_of_image|>": 151668,
10
+ "<|end_of_prompt_refinement|>": 151671,
11
+ "<|end_of_thinking|>": 151673,
12
+ "<|beginoftext|>": 151674,
13
+ "<|endoftext|>": 151643,
14
+ "<|file_sep|>": 151664,
15
+ "<|fim_middle|>": 151660,
16
+ "<|fim_pad|>": 151662,
17
+ "<|fim_prefix|>": 151659,
18
+ "<|fim_suffix|>": 151661,
19
+ "<|im_end|>": 151645,
20
+ "<|im_start|>": 151644,
21
+ "<|image_area|>": 151666,
22
+ "<|image_pad|>": 151655,
23
+ "<|image_placeholder|>": 151669,
24
+ "<|object_ref_end|>": 151647,
25
+ "<|object_ref_start|>": 151646,
26
+ "<|quad_end|>": 151651,
27
+ "<|quad_start|>": 151650,
28
+ "<|repo_name|>": 151663,
29
+ "<|video_pad|>": 151656,
30
+ "<|vision_end|>": 151653,
31
+ "<|vision_pad|>": 151654,
32
+ "<|vision_start|>": 151652,
33
+ "[PAD]": 151665
34
+ }
assets/comparision.png ADDED

Git LFS Details

  • SHA256: c03496181fccd0cb84da7554c305cdeaf7f7f4e4af41b73fe97f36b7626504dd
  • Pointer size: 133 Bytes
  • Size of remote file: 16.5 MB
config.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_attn_implementation_autoset": true,
3
+ "architectures": [
4
+ "LlamaForCausalLM"
5
+ ],
6
+ "auto_map":{
7
+ "AutoConfig": "models/config.NextStepConfig",
8
+ "AutoModel": "models/nextstep_model.NextStep"
9
+ },
10
+ "attention_bias": true,
11
+ "attention_dropout": 0.0,
12
+ "base_image_grid_size": 64,
13
+ "boi": 151667,
14
+ "bos_token_id": 151643,
15
+ "create_kwargs": {
16
+ "snr_type": "lognorm"
17
+ },
18
+ "eoi": 151668,
19
+ "eos_token_id": 151643,
20
+ "genloss_batch_mul": 4,
21
+ "genloss_depth": 12,
22
+ "genloss_net_arch": "mlp",
23
+ "genloss_num_sampling_steps": "100",
24
+ "genloss_type": "transport",
25
+ "genloss_width": 1536,
26
+ "head_dim": 128,
27
+ "hidden_act": "silu",
28
+ "hidden_size": 5120,
29
+ "image_decoder_arch": "Trans_E",
30
+ "image_encoder_name": null,
31
+ "image_feature_layer": -2,
32
+ "image_loss_weight": 1.0,
33
+ "image_placeholder_id": 151669,
34
+ "image_size": 64,
35
+ "initializer_range": 0.02,
36
+ "intermediate_size": 13824,
37
+ "lm_loss_weight": 0.01,
38
+ "max_position_embeddings": 131072,
39
+ "max_window_layers": 48,
40
+ "mlp_bias": false,
41
+ "model_type": "nextstep",
42
+ "noise_strength": 0.0,
43
+ "num_attention_heads": 40,
44
+ "num_channels": 16,
45
+ "num_hidden_layers": 48,
46
+ "num_key_value_heads": 8,
47
+ "o_attention_bias": false,
48
+ "pad_token_id_added": 151665,
49
+ "patch_size": 2,
50
+ "pretraining_tp": 1,
51
+ "rms_norm_eps": 1e-05,
52
+ "rope_scaling": null,
53
+ "rope_theta": 1000000.0,
54
+ "sliding_window": 131072,
55
+ "tie_word_embeddings": false,
56
+ "torch_dtype": "bfloat16",
57
+ "transformers_version": "4.55.0",
58
+ "use_2d_rope": false,
59
+ "use_cache": true,
60
+ "use_gen_pos_embed": false,
61
+ "use_mlp_before_lm_head": false,
62
+ "use_sliding_window": false,
63
+ "use_token_length_weight": false,
64
+ "vae_name_or_path": "vae/",
65
+ "vocab_size": 152064
66
+ }
model.safetensors.index.json ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 29907628160
4
+ },
5
+ "weight_map": {
6
+ "embed_tokens.weight": "pytorch-model-00004-of-00004.safetensors",
7
+ "image_head.net.cond_embed.bias": "pytorch-model-00003-of-00004.safetensors",
8
+ "image_head.net.cond_embed.weight": "pytorch-model-00003-of-00004.safetensors",
9
+ "image_head.net.final_layer.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
10
+ "image_head.net.final_layer.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
11
+ "image_head.net.final_layer.linear.bias": "pytorch-model-00003-of-00004.safetensors",
12
+ "image_head.net.final_layer.linear.weight": "pytorch-model-00003-of-00004.safetensors",
13
+ "image_head.net.input_proj.bias": "pytorch-model-00003-of-00004.safetensors",
14
+ "image_head.net.input_proj.weight": "pytorch-model-00003-of-00004.safetensors",
15
+ "image_head.net.res_blocks.0.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
16
+ "image_head.net.res_blocks.0.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
17
+ "image_head.net.res_blocks.0.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
18
+ "image_head.net.res_blocks.0.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
19
+ "image_head.net.res_blocks.0.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
20
+ "image_head.net.res_blocks.0.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
21
+ "image_head.net.res_blocks.0.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
22
+ "image_head.net.res_blocks.0.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
23
+ "image_head.net.res_blocks.1.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
24
+ "image_head.net.res_blocks.1.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
25
+ "image_head.net.res_blocks.1.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
26
+ "image_head.net.res_blocks.1.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
27
+ "image_head.net.res_blocks.1.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
28
+ "image_head.net.res_blocks.1.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
29
+ "image_head.net.res_blocks.1.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
30
+ "image_head.net.res_blocks.1.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
31
+ "image_head.net.res_blocks.10.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
32
+ "image_head.net.res_blocks.10.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
33
+ "image_head.net.res_blocks.10.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
34
+ "image_head.net.res_blocks.10.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
35
+ "image_head.net.res_blocks.10.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
36
+ "image_head.net.res_blocks.10.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
37
+ "image_head.net.res_blocks.10.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
38
+ "image_head.net.res_blocks.10.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
39
+ "image_head.net.res_blocks.11.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
40
+ "image_head.net.res_blocks.11.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
41
+ "image_head.net.res_blocks.11.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
42
+ "image_head.net.res_blocks.11.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
43
+ "image_head.net.res_blocks.11.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
44
+ "image_head.net.res_blocks.11.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
45
+ "image_head.net.res_blocks.11.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
46
+ "image_head.net.res_blocks.11.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
47
+ "image_head.net.res_blocks.2.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
48
+ "image_head.net.res_blocks.2.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
49
+ "image_head.net.res_blocks.2.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
50
+ "image_head.net.res_blocks.2.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
51
+ "image_head.net.res_blocks.2.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
52
+ "image_head.net.res_blocks.2.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
53
+ "image_head.net.res_blocks.2.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
54
+ "image_head.net.res_blocks.2.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
55
+ "image_head.net.res_blocks.3.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
56
+ "image_head.net.res_blocks.3.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
57
+ "image_head.net.res_blocks.3.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
58
+ "image_head.net.res_blocks.3.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
59
+ "image_head.net.res_blocks.3.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
60
+ "image_head.net.res_blocks.3.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
61
+ "image_head.net.res_blocks.3.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
62
+ "image_head.net.res_blocks.3.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
63
+ "image_head.net.res_blocks.4.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
64
+ "image_head.net.res_blocks.4.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
65
+ "image_head.net.res_blocks.4.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
66
+ "image_head.net.res_blocks.4.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
67
+ "image_head.net.res_blocks.4.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
68
+ "image_head.net.res_blocks.4.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
69
+ "image_head.net.res_blocks.4.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
70
+ "image_head.net.res_blocks.4.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
71
+ "image_head.net.res_blocks.5.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
72
+ "image_head.net.res_blocks.5.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
73
+ "image_head.net.res_blocks.5.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
74
+ "image_head.net.res_blocks.5.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
75
+ "image_head.net.res_blocks.5.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
76
+ "image_head.net.res_blocks.5.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
77
+ "image_head.net.res_blocks.5.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
78
+ "image_head.net.res_blocks.5.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
79
+ "image_head.net.res_blocks.6.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
80
+ "image_head.net.res_blocks.6.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
81
+ "image_head.net.res_blocks.6.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
82
+ "image_head.net.res_blocks.6.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
83
+ "image_head.net.res_blocks.6.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
84
+ "image_head.net.res_blocks.6.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
85
+ "image_head.net.res_blocks.6.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
86
+ "image_head.net.res_blocks.6.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
87
+ "image_head.net.res_blocks.7.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
88
+ "image_head.net.res_blocks.7.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
89
+ "image_head.net.res_blocks.7.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
90
+ "image_head.net.res_blocks.7.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
91
+ "image_head.net.res_blocks.7.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
92
+ "image_head.net.res_blocks.7.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
93
+ "image_head.net.res_blocks.7.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
94
+ "image_head.net.res_blocks.7.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
95
+ "image_head.net.res_blocks.8.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
96
+ "image_head.net.res_blocks.8.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
97
+ "image_head.net.res_blocks.8.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
98
+ "image_head.net.res_blocks.8.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
99
+ "image_head.net.res_blocks.8.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
100
+ "image_head.net.res_blocks.8.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
101
+ "image_head.net.res_blocks.8.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
102
+ "image_head.net.res_blocks.8.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
103
+ "image_head.net.res_blocks.9.adaLN_modulation.1.bias": "pytorch-model-00003-of-00004.safetensors",
104
+ "image_head.net.res_blocks.9.adaLN_modulation.1.weight": "pytorch-model-00003-of-00004.safetensors",
105
+ "image_head.net.res_blocks.9.in_ln.bias": "pytorch-model-00003-of-00004.safetensors",
106
+ "image_head.net.res_blocks.9.in_ln.weight": "pytorch-model-00003-of-00004.safetensors",
107
+ "image_head.net.res_blocks.9.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
108
+ "image_head.net.res_blocks.9.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
109
+ "image_head.net.res_blocks.9.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
110
+ "image_head.net.res_blocks.9.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
111
+ "image_head.net.time_embed.mlp.0.bias": "pytorch-model-00003-of-00004.safetensors",
112
+ "image_head.net.time_embed.mlp.0.weight": "pytorch-model-00003-of-00004.safetensors",
113
+ "image_head.net.time_embed.mlp.2.bias": "pytorch-model-00003-of-00004.safetensors",
114
+ "image_head.net.time_embed.mlp.2.weight": "pytorch-model-00003-of-00004.safetensors",
115
+ "image_in_projector.bias": "pytorch-model-00003-of-00004.safetensors",
116
+ "image_in_projector.weight": "pytorch-model-00003-of-00004.safetensors",
117
+ "image_out_projector.bias": "pytorch-model-00003-of-00004.safetensors",
118
+ "image_out_projector.weight": "pytorch-model-00003-of-00004.safetensors",
119
+ "layers.0.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
120
+ "layers.0.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
121
+ "layers.0.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
122
+ "layers.0.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
123
+ "layers.0.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
124
+ "layers.0.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
125
+ "layers.0.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
126
+ "layers.0.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
127
+ "layers.0.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
128
+ "layers.0.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
129
+ "layers.0.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
130
+ "layers.0.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
131
+ "layers.1.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
132
+ "layers.1.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
133
+ "layers.1.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
134
+ "layers.1.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
135
+ "layers.1.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
136
+ "layers.1.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
137
+ "layers.1.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
138
+ "layers.1.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
139
+ "layers.1.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
140
+ "layers.1.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
141
+ "layers.1.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
142
+ "layers.1.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
143
+ "layers.10.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
144
+ "layers.10.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
145
+ "layers.10.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
146
+ "layers.10.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
147
+ "layers.10.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
148
+ "layers.10.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
149
+ "layers.10.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
150
+ "layers.10.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
151
+ "layers.10.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
152
+ "layers.10.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
153
+ "layers.10.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
154
+ "layers.10.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
155
+ "layers.11.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
156
+ "layers.11.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
157
+ "layers.11.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
158
+ "layers.11.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
159
+ "layers.11.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
160
+ "layers.11.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
161
+ "layers.11.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
162
+ "layers.11.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
163
+ "layers.11.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
164
+ "layers.11.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
165
+ "layers.11.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
166
+ "layers.11.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
167
+ "layers.12.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
168
+ "layers.12.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
169
+ "layers.12.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
170
+ "layers.12.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
171
+ "layers.12.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
172
+ "layers.12.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
173
+ "layers.12.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
174
+ "layers.12.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
175
+ "layers.12.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
176
+ "layers.12.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
177
+ "layers.12.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
178
+ "layers.12.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
179
+ "layers.13.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
180
+ "layers.13.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
181
+ "layers.13.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
182
+ "layers.13.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
183
+ "layers.13.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
184
+ "layers.13.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
185
+ "layers.13.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
186
+ "layers.13.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
187
+ "layers.13.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
188
+ "layers.13.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
189
+ "layers.13.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
190
+ "layers.13.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
191
+ "layers.14.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
192
+ "layers.14.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
193
+ "layers.14.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
194
+ "layers.14.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
195
+ "layers.14.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
196
+ "layers.14.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
197
+ "layers.14.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
198
+ "layers.14.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
199
+ "layers.14.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
200
+ "layers.14.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
201
+ "layers.14.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
202
+ "layers.14.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
203
+ "layers.15.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
204
+ "layers.15.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
205
+ "layers.15.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
206
+ "layers.15.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
207
+ "layers.15.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
208
+ "layers.15.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
209
+ "layers.15.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
210
+ "layers.15.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
211
+ "layers.15.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
212
+ "layers.15.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
213
+ "layers.15.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
214
+ "layers.15.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
215
+ "layers.16.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
216
+ "layers.16.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
217
+ "layers.16.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
218
+ "layers.16.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
219
+ "layers.16.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
220
+ "layers.16.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
221
+ "layers.16.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
222
+ "layers.16.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
223
+ "layers.16.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
224
+ "layers.16.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
225
+ "layers.16.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
226
+ "layers.16.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
227
+ "layers.17.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
228
+ "layers.17.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
229
+ "layers.17.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
230
+ "layers.17.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
231
+ "layers.17.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
232
+ "layers.17.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
233
+ "layers.17.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
234
+ "layers.17.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
235
+ "layers.17.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
236
+ "layers.17.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
237
+ "layers.17.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
238
+ "layers.17.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
239
+ "layers.18.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
240
+ "layers.18.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
241
+ "layers.18.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
242
+ "layers.18.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
243
+ "layers.18.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
244
+ "layers.18.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
245
+ "layers.18.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
246
+ "layers.18.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
247
+ "layers.18.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
248
+ "layers.18.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
249
+ "layers.18.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
250
+ "layers.18.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
251
+ "layers.19.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
252
+ "layers.19.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
253
+ "layers.19.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
254
+ "layers.19.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
255
+ "layers.19.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
256
+ "layers.19.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
257
+ "layers.19.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
258
+ "layers.19.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
259
+ "layers.19.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
260
+ "layers.19.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
261
+ "layers.19.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
262
+ "layers.19.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
263
+ "layers.2.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
264
+ "layers.2.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
265
+ "layers.2.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
266
+ "layers.2.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
267
+ "layers.2.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
268
+ "layers.2.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
269
+ "layers.2.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
270
+ "layers.2.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
271
+ "layers.2.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
272
+ "layers.2.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
273
+ "layers.2.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
274
+ "layers.2.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
275
+ "layers.20.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
276
+ "layers.20.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
277
+ "layers.20.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
278
+ "layers.20.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
279
+ "layers.20.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
280
+ "layers.20.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
281
+ "layers.20.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
282
+ "layers.20.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
283
+ "layers.20.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
284
+ "layers.20.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
285
+ "layers.20.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
286
+ "layers.20.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
287
+ "layers.21.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
288
+ "layers.21.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
289
+ "layers.21.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
290
+ "layers.21.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
291
+ "layers.21.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
292
+ "layers.21.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
293
+ "layers.21.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
294
+ "layers.21.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
295
+ "layers.21.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
296
+ "layers.21.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
297
+ "layers.21.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
298
+ "layers.21.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
299
+ "layers.22.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
300
+ "layers.22.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
301
+ "layers.22.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
302
+ "layers.22.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
303
+ "layers.22.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
304
+ "layers.22.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
305
+ "layers.22.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
306
+ "layers.22.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
307
+ "layers.22.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
308
+ "layers.22.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
309
+ "layers.22.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
310
+ "layers.22.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
311
+ "layers.23.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
312
+ "layers.23.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
313
+ "layers.23.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
314
+ "layers.23.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
315
+ "layers.23.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
316
+ "layers.23.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
317
+ "layers.23.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
318
+ "layers.23.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
319
+ "layers.23.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
320
+ "layers.23.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
321
+ "layers.23.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
322
+ "layers.23.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
323
+ "layers.24.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
324
+ "layers.24.mlp.down_proj.weight": "pytorch-model-00001-of-00004.safetensors",
325
+ "layers.24.mlp.gate_proj.weight": "pytorch-model-00001-of-00004.safetensors",
326
+ "layers.24.mlp.up_proj.weight": "pytorch-model-00001-of-00004.safetensors",
327
+ "layers.24.post_attention_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
328
+ "layers.24.self_attn.k_proj.bias": "pytorch-model-00001-of-00004.safetensors",
329
+ "layers.24.self_attn.k_proj.weight": "pytorch-model-00001-of-00004.safetensors",
330
+ "layers.24.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
331
+ "layers.24.self_attn.q_proj.bias": "pytorch-model-00001-of-00004.safetensors",
332
+ "layers.24.self_attn.q_proj.weight": "pytorch-model-00001-of-00004.safetensors",
333
+ "layers.24.self_attn.v_proj.bias": "pytorch-model-00001-of-00004.safetensors",
334
+ "layers.24.self_attn.v_proj.weight": "pytorch-model-00001-of-00004.safetensors",
335
+ "layers.25.input_layernorm.weight": "pytorch-model-00001-of-00004.safetensors",
336
+ "layers.25.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
337
+ "layers.25.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
338
+ "layers.25.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
339
+ "layers.25.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
340
+ "layers.25.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
341
+ "layers.25.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
342
+ "layers.25.self_attn.o_proj.weight": "pytorch-model-00001-of-00004.safetensors",
343
+ "layers.25.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
344
+ "layers.25.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
345
+ "layers.25.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
346
+ "layers.25.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
347
+ "layers.26.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
348
+ "layers.26.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
349
+ "layers.26.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
350
+ "layers.26.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
351
+ "layers.26.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
352
+ "layers.26.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
353
+ "layers.26.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
354
+ "layers.26.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
355
+ "layers.26.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
356
+ "layers.26.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
357
+ "layers.26.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
358
+ "layers.26.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
359
+ "layers.27.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
360
+ "layers.27.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
361
+ "layers.27.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
362
+ "layers.27.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
363
+ "layers.27.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
364
+ "layers.27.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
365
+ "layers.27.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
366
+ "layers.27.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
367
+ "layers.27.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
368
+ "layers.27.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
369
+ "layers.27.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
370
+ "layers.27.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
371
+ "layers.28.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
372
+ "layers.28.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
373
+ "layers.28.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
374
+ "layers.28.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
375
+ "layers.28.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
376
+ "layers.28.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
377
+ "layers.28.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
378
+ "layers.28.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
379
+ "layers.28.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
380
+ "layers.28.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
381
+ "layers.28.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
382
+ "layers.28.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
383
+ "layers.29.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
384
+ "layers.29.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
385
+ "layers.29.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
386
+ "layers.29.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
387
+ "layers.29.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
388
+ "layers.29.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
389
+ "layers.29.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
390
+ "layers.29.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
391
+ "layers.29.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
392
+ "layers.29.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
393
+ "layers.29.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
394
+ "layers.29.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
395
+ "layers.3.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
396
+ "layers.3.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
397
+ "layers.3.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
398
+ "layers.3.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
399
+ "layers.3.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
400
+ "layers.3.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
401
+ "layers.3.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
402
+ "layers.3.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
403
+ "layers.3.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
404
+ "layers.3.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
405
+ "layers.3.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
406
+ "layers.3.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
407
+ "layers.30.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
408
+ "layers.30.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
409
+ "layers.30.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
410
+ "layers.30.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
411
+ "layers.30.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
412
+ "layers.30.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
413
+ "layers.30.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
414
+ "layers.30.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
415
+ "layers.30.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
416
+ "layers.30.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
417
+ "layers.30.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
418
+ "layers.30.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
419
+ "layers.31.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
420
+ "layers.31.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
421
+ "layers.31.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
422
+ "layers.31.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
423
+ "layers.31.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
424
+ "layers.31.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
425
+ "layers.31.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
426
+ "layers.31.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
427
+ "layers.31.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
428
+ "layers.31.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
429
+ "layers.31.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
430
+ "layers.31.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
431
+ "layers.32.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
432
+ "layers.32.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
433
+ "layers.32.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
434
+ "layers.32.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
435
+ "layers.32.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
436
+ "layers.32.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
437
+ "layers.32.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
438
+ "layers.32.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
439
+ "layers.32.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
440
+ "layers.32.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
441
+ "layers.32.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
442
+ "layers.32.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
443
+ "layers.33.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
444
+ "layers.33.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
445
+ "layers.33.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
446
+ "layers.33.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
447
+ "layers.33.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
448
+ "layers.33.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
449
+ "layers.33.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
450
+ "layers.33.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
451
+ "layers.33.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
452
+ "layers.33.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
453
+ "layers.33.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
454
+ "layers.33.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
455
+ "layers.34.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
456
+ "layers.34.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
457
+ "layers.34.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
458
+ "layers.34.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
459
+ "layers.34.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
460
+ "layers.34.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
461
+ "layers.34.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
462
+ "layers.34.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
463
+ "layers.34.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
464
+ "layers.34.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
465
+ "layers.34.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
466
+ "layers.34.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
467
+ "layers.35.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
468
+ "layers.35.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
469
+ "layers.35.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
470
+ "layers.35.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
471
+ "layers.35.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
472
+ "layers.35.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
473
+ "layers.35.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
474
+ "layers.35.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
475
+ "layers.35.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
476
+ "layers.35.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
477
+ "layers.35.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
478
+ "layers.35.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
479
+ "layers.36.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
480
+ "layers.36.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
481
+ "layers.36.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
482
+ "layers.36.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
483
+ "layers.36.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
484
+ "layers.36.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
485
+ "layers.36.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
486
+ "layers.36.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
487
+ "layers.36.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
488
+ "layers.36.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
489
+ "layers.36.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
490
+ "layers.36.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
491
+ "layers.37.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
492
+ "layers.37.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
493
+ "layers.37.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
494
+ "layers.37.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
495
+ "layers.37.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
496
+ "layers.37.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
497
+ "layers.37.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
498
+ "layers.37.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
499
+ "layers.37.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
500
+ "layers.37.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
501
+ "layers.37.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
502
+ "layers.37.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
503
+ "layers.38.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
504
+ "layers.38.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
505
+ "layers.38.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
506
+ "layers.38.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
507
+ "layers.38.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
508
+ "layers.38.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
509
+ "layers.38.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
510
+ "layers.38.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
511
+ "layers.38.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
512
+ "layers.38.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
513
+ "layers.38.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
514
+ "layers.38.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
515
+ "layers.39.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
516
+ "layers.39.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
517
+ "layers.39.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
518
+ "layers.39.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
519
+ "layers.39.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
520
+ "layers.39.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
521
+ "layers.39.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
522
+ "layers.39.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
523
+ "layers.39.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
524
+ "layers.39.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
525
+ "layers.39.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
526
+ "layers.39.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
527
+ "layers.4.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
528
+ "layers.4.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
529
+ "layers.4.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
530
+ "layers.4.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
531
+ "layers.4.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
532
+ "layers.4.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
533
+ "layers.4.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
534
+ "layers.4.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
535
+ "layers.4.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
536
+ "layers.4.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
537
+ "layers.4.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
538
+ "layers.4.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
539
+ "layers.40.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
540
+ "layers.40.mlp.down_proj.weight": "pytorch-model-00002-of-00004.safetensors",
541
+ "layers.40.mlp.gate_proj.weight": "pytorch-model-00002-of-00004.safetensors",
542
+ "layers.40.mlp.up_proj.weight": "pytorch-model-00002-of-00004.safetensors",
543
+ "layers.40.post_attention_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
544
+ "layers.40.self_attn.k_proj.bias": "pytorch-model-00002-of-00004.safetensors",
545
+ "layers.40.self_attn.k_proj.weight": "pytorch-model-00002-of-00004.safetensors",
546
+ "layers.40.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
547
+ "layers.40.self_attn.q_proj.bias": "pytorch-model-00002-of-00004.safetensors",
548
+ "layers.40.self_attn.q_proj.weight": "pytorch-model-00002-of-00004.safetensors",
549
+ "layers.40.self_attn.v_proj.bias": "pytorch-model-00002-of-00004.safetensors",
550
+ "layers.40.self_attn.v_proj.weight": "pytorch-model-00002-of-00004.safetensors",
551
+ "layers.41.input_layernorm.weight": "pytorch-model-00002-of-00004.safetensors",
552
+ "layers.41.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
553
+ "layers.41.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
554
+ "layers.41.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
555
+ "layers.41.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
556
+ "layers.41.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
557
+ "layers.41.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
558
+ "layers.41.self_attn.o_proj.weight": "pytorch-model-00002-of-00004.safetensors",
559
+ "layers.41.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
560
+ "layers.41.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
561
+ "layers.41.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
562
+ "layers.41.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
563
+ "layers.42.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
564
+ "layers.42.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
565
+ "layers.42.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
566
+ "layers.42.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
567
+ "layers.42.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
568
+ "layers.42.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
569
+ "layers.42.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
570
+ "layers.42.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
571
+ "layers.42.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
572
+ "layers.42.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
573
+ "layers.42.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
574
+ "layers.42.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
575
+ "layers.43.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
576
+ "layers.43.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
577
+ "layers.43.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
578
+ "layers.43.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
579
+ "layers.43.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
580
+ "layers.43.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
581
+ "layers.43.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
582
+ "layers.43.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
583
+ "layers.43.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
584
+ "layers.43.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
585
+ "layers.43.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
586
+ "layers.43.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
587
+ "layers.44.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
588
+ "layers.44.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
589
+ "layers.44.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
590
+ "layers.44.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
591
+ "layers.44.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
592
+ "layers.44.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
593
+ "layers.44.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
594
+ "layers.44.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
595
+ "layers.44.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
596
+ "layers.44.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
597
+ "layers.44.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
598
+ "layers.44.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
599
+ "layers.45.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
600
+ "layers.45.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
601
+ "layers.45.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
602
+ "layers.45.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
603
+ "layers.45.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
604
+ "layers.45.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
605
+ "layers.45.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
606
+ "layers.45.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
607
+ "layers.45.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
608
+ "layers.45.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
609
+ "layers.45.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
610
+ "layers.45.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
611
+ "layers.46.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
612
+ "layers.46.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
613
+ "layers.46.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
614
+ "layers.46.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
615
+ "layers.46.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
616
+ "layers.46.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
617
+ "layers.46.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
618
+ "layers.46.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
619
+ "layers.46.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
620
+ "layers.46.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
621
+ "layers.46.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
622
+ "layers.46.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
623
+ "layers.47.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
624
+ "layers.47.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
625
+ "layers.47.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
626
+ "layers.47.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
627
+ "layers.47.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
628
+ "layers.47.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
629
+ "layers.47.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
630
+ "layers.47.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
631
+ "layers.47.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
632
+ "layers.47.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
633
+ "layers.47.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
634
+ "layers.47.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
635
+ "layers.5.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
636
+ "layers.5.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
637
+ "layers.5.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
638
+ "layers.5.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
639
+ "layers.5.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
640
+ "layers.5.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
641
+ "layers.5.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
642
+ "layers.5.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
643
+ "layers.5.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
644
+ "layers.5.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
645
+ "layers.5.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
646
+ "layers.5.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
647
+ "layers.6.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
648
+ "layers.6.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
649
+ "layers.6.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
650
+ "layers.6.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
651
+ "layers.6.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
652
+ "layers.6.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
653
+ "layers.6.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
654
+ "layers.6.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
655
+ "layers.6.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
656
+ "layers.6.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
657
+ "layers.6.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
658
+ "layers.6.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
659
+ "layers.7.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
660
+ "layers.7.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
661
+ "layers.7.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
662
+ "layers.7.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
663
+ "layers.7.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
664
+ "layers.7.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
665
+ "layers.7.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
666
+ "layers.7.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
667
+ "layers.7.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
668
+ "layers.7.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
669
+ "layers.7.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
670
+ "layers.7.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
671
+ "layers.8.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
672
+ "layers.8.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
673
+ "layers.8.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
674
+ "layers.8.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
675
+ "layers.8.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
676
+ "layers.8.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
677
+ "layers.8.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
678
+ "layers.8.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
679
+ "layers.8.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
680
+ "layers.8.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
681
+ "layers.8.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
682
+ "layers.8.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
683
+ "layers.9.input_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
684
+ "layers.9.mlp.down_proj.weight": "pytorch-model-00003-of-00004.safetensors",
685
+ "layers.9.mlp.gate_proj.weight": "pytorch-model-00003-of-00004.safetensors",
686
+ "layers.9.mlp.up_proj.weight": "pytorch-model-00003-of-00004.safetensors",
687
+ "layers.9.post_attention_layernorm.weight": "pytorch-model-00003-of-00004.safetensors",
688
+ "layers.9.self_attn.k_proj.bias": "pytorch-model-00003-of-00004.safetensors",
689
+ "layers.9.self_attn.k_proj.weight": "pytorch-model-00003-of-00004.safetensors",
690
+ "layers.9.self_attn.o_proj.weight": "pytorch-model-00003-of-00004.safetensors",
691
+ "layers.9.self_attn.q_proj.bias": "pytorch-model-00003-of-00004.safetensors",
692
+ "layers.9.self_attn.q_proj.weight": "pytorch-model-00003-of-00004.safetensors",
693
+ "layers.9.self_attn.v_proj.bias": "pytorch-model-00003-of-00004.safetensors",
694
+ "layers.9.self_attn.v_proj.weight": "pytorch-model-00003-of-00004.safetensors",
695
+ "lm_head.weight": "pytorch-model-00003-of-00004.safetensors",
696
+ "norm.weight": "pytorch-model-00003-of-00004.safetensors"
697
+ }
698
+ }
models/__pycache__/config.cpython-310.pyc ADDED
Binary file (1.47 kB). View file
 
models/__pycache__/gen_pipeline.cpython-310.pyc ADDED
Binary file (10.6 kB). View file
 
models/__pycache__/heads.cpython-310.pyc ADDED
Binary file (10.2 kB). View file
 
models/__pycache__/llama_model.cpython-310.pyc ADDED
Binary file (14 kB). View file
 
models/__pycache__/nextstep_model.cpython-310.pyc ADDED
Binary file (15.8 kB). View file
 
models/config.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.models.llama.configuration_llama import LlamaConfig
2
+
3
+ class NextStepConfig(LlamaConfig):
4
+
5
+ model_type = "nextstep"
6
+
7
+ def __init__(
8
+ self,
9
+ vae_name_or_path: str | None = None,
10
+ latent_size: int = 32,
11
+ latent_patch_size: int = 2,
12
+ latent_channels: int = 16,
13
+ boi: int | None = None,
14
+ eoi: int | None = None,
15
+ image_placeholder_id: int | None = None,
16
+ pad_token_id_added: int | None = None,
17
+ lm_loss_weight: float = 0.01,
18
+ im_loss_weight: float = 1.0,
19
+ fm_head_dim: int = 1536,
20
+ fm_head_layers: int = 12,
21
+ fm_head_batch_mul: int = 4,
22
+ o_attention_bias: bool | None = None,
23
+ **kwargs,
24
+ ):
25
+ super().__init__(**kwargs)
26
+
27
+ self.vae_name_or_path = vae_name_or_path
28
+
29
+ self.latent_size = latent_size
30
+ self.latent_patch_size = latent_patch_size
31
+ self.latent_channels = latent_channels
32
+
33
+ self.boi = boi
34
+ self.eoi = eoi
35
+ self.image_placeholder_id = image_placeholder_id
36
+ self.pad_token_id_added = pad_token_id_added
37
+
38
+ self.lm_loss_weight = lm_loss_weight
39
+ self.im_loss_weight = im_loss_weight
40
+
41
+ self.fm_head_dim = fm_head_dim
42
+ self.fm_head_layers = fm_head_layers
43
+ self.fm_head_batch_mul = fm_head_batch_mul
44
+
45
+ self.o_attention_bias = self.attention_bias if o_attention_bias is None else o_attention_bias
models/gen_pipeline.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import copy
3
+ from typing import Literal
4
+
5
+ from PIL import Image
6
+ from tqdm.auto import tqdm
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torchvision.transforms as transforms
11
+
12
+ from transformers import AutoTokenizer
13
+ from transformers.cache_utils import Cache, StaticCache
14
+
15
+ from models.nextstep_model import NextStep
16
+ from vae.nextstep_ae import AutoencoderKL
17
+ from utils.image_utils import to_pil
18
+ from utils.model_utils import layer_norm
19
+ from utils.compile_utils import compile_manager
20
+ from utils.misc import set_seed
21
+
22
+ DEFAULT_IMAGE_AREA_TOKEN = "<|image_area|>"
23
+
24
+ def hw2str(h: int, w: int) -> str:
25
+ return f"{h}*{w}"
26
+
27
+
28
+ class NextStepPipeline:
29
+ def __init__(
30
+ self,
31
+ model_name_or_path: str | None = None,
32
+ vae_name_or_path: str | None = None,
33
+ tokenizer: AutoTokenizer | None = None,
34
+ model: nn.Module | None = None,
35
+ vae: AutoencoderKL | None = None,
36
+ ):
37
+ if model is not None:
38
+ self.tokenizer = copy.deepcopy(tokenizer)
39
+ self.tokenizer.padding_side = "left"
40
+ self.model = model
41
+ elif model_name_or_path is not None:
42
+ self.tokenizer = AutoTokenizer.from_pretrained(
43
+ model_name_or_path,
44
+ local_files_only=True,
45
+ padding_side="left",
46
+ use_fast=True,
47
+ )
48
+ self.model: NextStep = NextStep.from_pretrained(model_name_or_path, local_files_only=True)
49
+ else:
50
+ raise ValueError("model or model_name_or_path is required")
51
+
52
+ self.tokenizer.add_eos_token = False
53
+ if vae_name_or_path is None:
54
+ vae_name_or_path = getattr(self.model.config, "vae_name_or_path", None)
55
+ if vae is not None:
56
+ self.vae = vae
57
+ elif vae_name_or_path is not None:
58
+ self.vae = AutoencoderKL.from_pretrained(vae_name_or_path)
59
+ else:
60
+ raise ValueError("vae or vae_name_or_path is required")
61
+
62
+ self.model.eval()
63
+ self.vae.eval()
64
+
65
+ vae_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
66
+ self.down_factor = vae_factor * self.model.config.latent_patch_size
67
+ self.shift_factor = getattr(self.vae.config, "shift_factor", 0.0)
68
+ self.scaling_factor = getattr(self.vae.config, "scaling_factor", 1.0)
69
+
70
+ self.boi = self.model.config.boi
71
+ self.eoi = self.model.config.eoi
72
+
73
+ self.image_placeholder_id = self.model.config.image_placeholder_id
74
+ self.pil2tensor = transforms.Compose(
75
+ [
76
+ transforms.ToTensor(),
77
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
78
+ ]
79
+ )
80
+ self.__device = self.model.device
81
+ self.__dtype = self.model.dtype
82
+ self.to(self.device, self.dtype)
83
+
84
+ @property
85
+ def device(self):
86
+ return self.__device
87
+
88
+ @property
89
+ def device_type(self):
90
+ if isinstance(self.__device, str):
91
+ return self.__device
92
+ return self.__device.type
93
+
94
+ @property
95
+ def dtype(self):
96
+ return self.__dtype
97
+
98
+ def to(self, device: str | None = None, dtype: torch.dtype | None = None):
99
+ if device is not None:
100
+ self.__device = device
101
+ if dtype is not None:
102
+ self.__dtype = dtype
103
+ self.model.to(self.__device, dtype=self.__dtype)
104
+ self.vae.to(self.__device, dtype=self.__dtype)
105
+ return self
106
+
107
+ def _image_str(self, hw: tuple[int, int] = (256, 256)):
108
+ latent_hw = (hw[0] // self.down_factor, hw[1] // self.down_factor)
109
+ image_ids = [self.boi] + [self.image_placeholder_id] * (latent_hw[0] * latent_hw[1]) + [self.eoi]
110
+ image_str = DEFAULT_IMAGE_AREA_TOKEN + hw2str(*latent_hw) + self.tokenizer.decode(image_ids)
111
+ return image_str
112
+
113
+ def _check_input(
114
+ self, captions: str | list[str], images: Image.Image | list[Image.Image] | None
115
+ ) -> tuple[list[str], list[Image.Image] | None]:
116
+ if not isinstance(captions, list):
117
+ captions = [captions]
118
+ if images is not None:
119
+ if not isinstance(images, list):
120
+ images = [images]
121
+ # Validate image count matches <image> tokens in captions
122
+ image_token_count = 0
123
+ for caption in captions:
124
+ num_image_token = len(re.findall(r"<image>", caption))
125
+ assert num_image_token == 1, f"Caption `{caption}` has {num_image_token} image tokens, but only 1 is allowed."
126
+ image_token_count += num_image_token
127
+ if image_token_count != len(images):
128
+ raise ValueError(
129
+ f"Number of images ({len(images)}) does not match number of image tokens ({image_token_count}).\n"
130
+ f"Captions: {captions}"
131
+ )
132
+ hws = [(image.size[1], image.size[0]) for image in images]
133
+ # Replace <image> tokens sequentially with corresponding image_str based on hw
134
+ processed_captions = []
135
+ image_idx = 0
136
+ for caption in captions:
137
+ # Process each caption
138
+ processed_caption = caption
139
+ num_image_tokens = processed_caption.count("<image>")
140
+ # Replace each <image> token in order
141
+ for _ in range(num_image_tokens):
142
+ processed_caption = processed_caption.replace("<image>", self._image_str(hws[image_idx]), 1)
143
+ image_idx += 1
144
+ processed_captions.append(processed_caption)
145
+ captions = processed_captions
146
+ return captions, images
147
+
148
+ def _build_captions(
149
+ self,
150
+ captions: str | list[str],
151
+ images: list[Image.Image] | None = None,
152
+ num_images_per_caption: int = 1,
153
+ positive_prompt: str | None = None,
154
+ negative_prompt: str | None = None,
155
+ cfg: float = 1.0,
156
+ cfg_img: float = 1.0,
157
+ ):
158
+ # 1. repeat captions and images
159
+ if not isinstance(captions, list):
160
+ captions = [captions]
161
+
162
+ captions = [caption for caption in captions for _ in range(num_images_per_caption)]
163
+ if images is not None:
164
+ images = [image for image in images for _ in range(num_images_per_caption)]
165
+
166
+ # 2. add positive prompt
167
+ if positive_prompt is not None and positive_prompt != "":
168
+ captions = [f"{caption} {positive_prompt}" for caption in captions]
169
+
170
+ # 3. add negative prompt
171
+ if negative_prompt is None:
172
+ negative_prompt = ""
173
+
174
+ num_samples = len(captions)
175
+ if cfg != 1.0 and cfg_img != 1.0: # use both image and text CFG
176
+ w, h = images[0].size
177
+ captions = (
178
+ captions + [self._image_str((h, w)) + negative_prompt] * num_samples
179
+ )
180
+ images = images + images
181
+ captions = captions + [negative_prompt] * num_samples
182
+ elif cfg != 1.0 and cfg_img == 1.0: # use text CFG
183
+ captions = captions + [negative_prompt] * num_samples
184
+ elif cfg == 1.0 and cfg_img == 1.0:
185
+ pass
186
+
187
+ return captions, images
188
+
189
+ def _add_prefix_ids(self, hw: tuple[int, int], input_ids: torch.Tensor, attention_mask: torch.Tensor):
190
+ prefix_str = DEFAULT_IMAGE_AREA_TOKEN + hw2str(hw[0] // self.down_factor, hw[1] // self.down_factor)
191
+ prefix_output = self.tokenizer(
192
+ prefix_str,
193
+ truncation=False,
194
+ add_special_tokens=True,
195
+ return_tensors="pt"
196
+ )
197
+ prefix_input_ids = prefix_output.input_ids.to(input_ids.device, dtype=input_ids.dtype)
198
+ prefix_attention_mask = prefix_output.attention_mask.to(attention_mask.device, dtype=attention_mask.dtype)
199
+ # remove bos token
200
+ if self.tokenizer.bos_token is not None:
201
+ prefix_input_ids = prefix_input_ids[:, 1:]
202
+ prefix_attention_mask = prefix_attention_mask[:, 1:]
203
+ # add boi token
204
+ prefix_input_ids = torch.cat(
205
+ [
206
+ prefix_input_ids,
207
+ prefix_input_ids.new_tensor([self.model.config.boi]).unsqueeze(0),
208
+ ],
209
+ dim=1,
210
+ )
211
+ prefix_attention_mask = torch.cat(
212
+ [
213
+ prefix_attention_mask,
214
+ prefix_attention_mask.new_ones((prefix_attention_mask.shape[0], 1)),
215
+ ],
216
+ dim=1,
217
+ )
218
+ bsz = input_ids.shape[0]
219
+ input_ids = torch.cat([input_ids, prefix_input_ids.expand(bsz, -1)], dim=1)
220
+ attention_mask = torch.cat([attention_mask, prefix_attention_mask.expand(bsz, -1)], dim=1)
221
+
222
+ return input_ids, attention_mask
223
+
224
+ @torch.no_grad()
225
+ def decoding(
226
+ self,
227
+ c: torch.Tensor,
228
+ attention_mask: torch.Tensor,
229
+ past_key_values: Cache,
230
+ max_new_len: int,
231
+ num_images_per_caption: int,
232
+ use_norm: bool = False,
233
+ cfg: float = 1.0,
234
+ cfg_img: float = 1.0,
235
+ cfg_schedule: Literal["linear", "constant"] = "constant",
236
+ timesteps_shift: float = 1.0,
237
+ num_sampling_steps: int = 20,
238
+ progress: bool = True,
239
+ hw: tuple[int, int] = (256, 256),
240
+ step: int = 0,
241
+ ):
242
+ indices = list(range(max_new_len))
243
+ indices = tqdm(indices, unit="tokens") if progress else indices
244
+ tokens = None
245
+ for step in indices:
246
+ # cfg schedule follow Muse
247
+ if cfg_schedule == "linear":
248
+ tokens_len = 0 if tokens is None else tokens.shape[1]
249
+ cfg_iter = max(cfg / 2, 1 + (cfg - 1) * tokens_len / max_new_len)
250
+ cfg_img_iter = max(cfg_img / 2, 1 + (cfg_img - 1) * tokens_len / max_new_len)
251
+ elif cfg_schedule == "constant":
252
+ cfg_iter = cfg
253
+ cfg_img_iter = cfg_img
254
+ else:
255
+ raise NotImplementedError
256
+
257
+ c = self.model.image_out_projector(c)
258
+ token_sampled = self.model.image_head.sample(
259
+ c=c.squeeze(1),
260
+ cfg=cfg_iter,
261
+ cfg_img=cfg_img_iter,
262
+ timesteps_shift=timesteps_shift,
263
+ num_sampling_steps=num_sampling_steps,
264
+ noise_repeat=num_images_per_caption,
265
+ )
266
+
267
+ if use_norm:
268
+ token_sampled = layer_norm(token_sampled, normalized_shape=token_sampled.size()[1:])
269
+ if tokens is not None:
270
+ tokens = torch.cat([tokens, token_sampled.unsqueeze(1)], dim=1)
271
+ else:
272
+ tokens = token_sampled.unsqueeze(1)
273
+
274
+ cur_inputs_embeds = self.model.image_in_projector(tokens[:, -1:])
275
+ if cfg != 1.0 and cfg_img == 1.0:
276
+ cur_inputs_embeds = torch.cat([cur_inputs_embeds, cur_inputs_embeds], dim=0)
277
+ elif cfg != 1.0 and cfg_img != 1.0:
278
+ cur_inputs_embeds = torch.cat([cur_inputs_embeds, cur_inputs_embeds, cur_inputs_embeds], dim=0)
279
+
280
+ attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
281
+ outputs = self.model.forward_model(
282
+ inputs_embeds=cur_inputs_embeds,
283
+ attention_mask=attention_mask,
284
+ past_key_values=past_key_values,
285
+ use_cache=True,
286
+ )
287
+ past_key_values = outputs.past_key_values
288
+ c = outputs.last_hidden_state[:, -1:]
289
+ if self.model.config.use_gen_pos_embed:
290
+ c = c + self.model.gen_pos_embed_with_ar(hw[0], hw[1])[:, step + 1 : step + 2, :]
291
+
292
+ return tokens
293
+
294
+ @torch.no_grad()
295
+ def generate_image(
296
+ self,
297
+ captions: str | list[str],
298
+ images: list[Image.Image] | None = None,
299
+ num_images_per_caption: int = 1,
300
+ positive_prompt: str | None = None,
301
+ negative_prompt: str | None = None,
302
+ hw: tuple[int, int] = (256, 256),
303
+ use_norm: bool = False,
304
+ cfg: float = 1.0,
305
+ cfg_img: float = 1.0,
306
+ cfg_schedule: Literal["linear", "constant"] = "constant",
307
+ num_sampling_steps: int = 20,
308
+ timesteps_shift: float = 1.0,
309
+ seed: int = 42,
310
+ progress: bool = True,
311
+ ) -> list[Image.Image]:
312
+ # 0. set seed
313
+ if seed is not None:
314
+ set_seed(seed)
315
+
316
+ # 1. check input
317
+ captions, images = self._check_input(captions, images)
318
+
319
+ # 2. build captions
320
+ captions, images = self._build_captions(
321
+ captions, images, num_images_per_caption, positive_prompt, negative_prompt, cfg, cfg_img
322
+ )
323
+
324
+ # 3. encode images
325
+ # `images` must be processed by `process_images` before calling this function
326
+ latents = None
327
+ if images is not None:
328
+ pixel_values = [self.pil2tensor(image) for image in images]
329
+ pixel_values = torch.stack(pixel_values).to(self.device)
330
+ with compile_manager.compile_disabled():
331
+ posterior = self.vae.encode(pixel_values.to(self.vae.dtype)).latent_dist
332
+ latents = (posterior.sample() - self.shift_factor) * self.scaling_factor
333
+ captions = [self.tokenizer.bos_token + caption if self.tokenizer.bos_token is not None else caption for caption in captions]
334
+
335
+ # 4. tokenize caption & add prefix ids
336
+ output = self.tokenizer(
337
+ captions,
338
+ padding="longest",
339
+ truncation=False,
340
+ add_special_tokens=True,
341
+ return_tensors="pt",
342
+ padding_side="left"
343
+ )
344
+ input_ids = output.input_ids.to(self.device)
345
+ attention_mask = output.attention_mask.to(self.device)
346
+ input_ids, attention_mask = self._add_prefix_ids(hw, input_ids, attention_mask)
347
+
348
+ # 5. LLM prefill
349
+ max_new_len = (hw[0] // self.down_factor) * (hw[1] // self.down_factor)
350
+ max_cache_len = input_ids.shape[1] + max_new_len
351
+ past_key_values = StaticCache(
352
+ config=self.model.config,
353
+ max_batch_size=input_ids.shape[0],
354
+ max_cache_len=max_cache_len,
355
+ device=self.device,
356
+ dtype=self.dtype,
357
+ )
358
+ inputs_embeds = self.model.prepare_inputs_embeds(input_ids, latents)
359
+ with compile_manager.compile_disabled():
360
+ outputs = self.model.forward_model(
361
+ inputs_embeds=inputs_embeds,
362
+ attention_mask=attention_mask,
363
+ past_key_values=past_key_values,
364
+ use_cache=True,
365
+ )
366
+ past_key_values = outputs.past_key_values
367
+ c = outputs.last_hidden_state[:, -1:]
368
+ if self.model.config.use_gen_pos_embed:
369
+ c = c + self.model.gen_pos_embed_with_ar(hw[0], hw[1])[:, 0:1, :]
370
+
371
+ # 6. decoding
372
+ tokens = self.decoding(
373
+ c=c,
374
+ attention_mask=attention_mask,
375
+ past_key_values=past_key_values,
376
+ max_new_len=max_new_len,
377
+ num_images_per_caption=num_images_per_caption,
378
+ use_norm=use_norm,
379
+ cfg=cfg,
380
+ cfg_img=cfg_img,
381
+ cfg_schedule=cfg_schedule,
382
+ timesteps_shift=timesteps_shift,
383
+ num_sampling_steps=num_sampling_steps,
384
+ progress=progress,
385
+ hw=hw,
386
+ )
387
+
388
+ # 7. unpatchify
389
+ latents = self.model.unpatchify(tokens)
390
+ latents = (latents / self.scaling_factor) + self.shift_factor
391
+
392
+ # 8. decode latents
393
+ with compile_manager.compile_disabled():
394
+ sampled_images = self.vae.decode(latents.to(self.vae.dtype)).sample
395
+ sampled_images = sampled_images.detach().cpu().to(torch.float32)
396
+ pil_images = [to_pil(img) for img in sampled_images]
397
+
398
+ return pil_images
models/heads.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.utils.checkpoint import checkpoint
6
+
7
+ from transformers.activations import ACT2FN
8
+
9
+ from models.config import LlamaConfig
10
+ from utils.misc import LargeInt
11
+ from utils.model_utils import expand_t, randn_tensor
12
+ from utils.compile_utils import smart_compile
13
+
14
+
15
+ class LlamaMLP(nn.Module):
16
+ def __init__(self, config: LlamaConfig):
17
+ super().__init__()
18
+ self.config = config
19
+ self.hidden_size = config.hidden_size
20
+ self.intermediate_size = config.intermediate_size
21
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
22
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
23
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
24
+ self.act_fn = ACT2FN[config.hidden_act]
25
+
26
+ def forward(self, x):
27
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
28
+ return down_proj
29
+
30
+
31
+
32
+
33
+ def modulate(x, shift, scale=None):
34
+ if shift is None:
35
+ return x * (1 + scale)
36
+ return x * (1 + scale) + shift
37
+
38
+
39
+ class ResBlock(nn.Module):
40
+ def __init__(self, channels, mlp_ratio=1.0):
41
+ super().__init__()
42
+ self.channels = channels
43
+ self.intermediate_size = int(channels * mlp_ratio)
44
+
45
+ self.in_ln = nn.LayerNorm(self.channels, eps=1e-6)
46
+ self.mlp = nn.Sequential(
47
+ nn.Linear(self.channels, self.intermediate_size),
48
+ nn.SiLU(),
49
+ nn.Linear(self.intermediate_size, self.channels),
50
+ )
51
+
52
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True))
53
+
54
+ def forward(self, x, y):
55
+ shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
56
+ h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
57
+ h = self.mlp(h)
58
+ return x + gate_mlp * h
59
+
60
+
61
+ class FinalLayer(nn.Module):
62
+ def __init__(self, model_channels, out_channels):
63
+ super().__init__()
64
+ self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
65
+ self.linear = nn.Linear(model_channels, out_channels, bias=True)
66
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(model_channels, 2 * model_channels, bias=True))
67
+
68
+ def forward(self, x, c):
69
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
70
+ x = modulate(self.norm_final(x), shift, scale)
71
+ x = self.linear(x)
72
+ return x
73
+
74
+
75
+ class TimestepEmbedder(nn.Module):
76
+ """
77
+ Embeds scalar timesteps into vector representations.
78
+ """
79
+
80
+ def __init__(self, hidden_size, frequency_embedding_size=256):
81
+ super().__init__()
82
+ self.mlp = nn.Sequential(
83
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
84
+ nn.SiLU(),
85
+ nn.Linear(hidden_size, hidden_size, bias=True),
86
+ )
87
+ self.frequency_embedding_size = frequency_embedding_size
88
+
89
+ @staticmethod
90
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: float = 10000.0):
91
+ """
92
+ Create sinusoidal timestep embeddings.
93
+ :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional.
94
+ :param dim: the dimension of the output.
95
+ :param max_period: controls the minimum frequency of the embeddings.
96
+ :return: an (N, D) Tensor of positional embeddings.
97
+ """
98
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
99
+ half = dim // 2
100
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
101
+ device=t.device
102
+ )
103
+ args = t[:, None].float() * freqs[None]
104
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
105
+ if dim % 2:
106
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
107
+ return embedding
108
+
109
+ def forward(self, t):
110
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
111
+ t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
112
+ return t_emb
113
+
114
+
115
+ class SimpleMLPAdaLN(nn.Module):
116
+ def __init__(self, input_dim, cond_dim, dim=1536, layers=12, mlp_ratio=1.0):
117
+ super().__init__()
118
+ self.input_dim = input_dim
119
+ self.cond_dim = cond_dim
120
+ self.dim = dim
121
+ self.layers = layers
122
+ self.mlp_ratio = mlp_ratio
123
+
124
+ self.time_embed = TimestepEmbedder(dim)
125
+ self.cond_embed = nn.Linear(cond_dim, dim)
126
+ self.input_proj = nn.Linear(input_dim, dim)
127
+
128
+ res_blocks = []
129
+ for _ in range(layers):
130
+ res_blocks.append(ResBlock(dim, mlp_ratio))
131
+ self.res_blocks = nn.ModuleList(res_blocks)
132
+
133
+ self.final_layer = FinalLayer(dim, input_dim)
134
+
135
+ self.grad_checkpointing = False
136
+
137
+ self.initialize_weights()
138
+
139
+ def initialize_weights(self):
140
+ def _basic_init(module):
141
+ if isinstance(module, nn.Linear):
142
+ torch.nn.init.xavier_uniform_(module.weight)
143
+ if module.bias is not None:
144
+ nn.init.constant_(module.bias, 0)
145
+
146
+ self.apply(_basic_init)
147
+
148
+ # Initialize timestep embedding MLP
149
+ nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
150
+ nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
151
+
152
+ # Zero-out adaLN modulation layers
153
+ for block in self.res_blocks:
154
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
155
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
156
+
157
+ # Zero-out output layers
158
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
159
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
160
+ nn.init.constant_(self.final_layer.linear.weight, 0)
161
+ nn.init.constant_(self.final_layer.linear.bias, 0)
162
+
163
+ @smart_compile()
164
+ def forward(self, x, t, c):
165
+ """
166
+ x.shape = (bsz, input_dim)
167
+ t.shape = (bsz,)
168
+ c.shape = (bsz, cond_dim)
169
+ """
170
+
171
+ x = self.input_proj(x)
172
+ t = self.time_embed(t)
173
+ c = self.cond_embed(c)
174
+
175
+ y = t + c
176
+
177
+ for block in self.res_blocks:
178
+ if self.grad_checkpointing and self.training:
179
+ x = checkpoint(block, x, y, use_reentrant=True)
180
+ else:
181
+ x = block(x, y)
182
+
183
+ return self.final_layer(x, y)
184
+
185
+
186
+ class FlowMatchingHead(nn.Module):
187
+
188
+ def __init__(self, input_dim, cond_dim, dim=1536, layers=12, mlp_ratio=1.0):
189
+ super(FlowMatchingHead, self).__init__()
190
+ self.input_dim = input_dim
191
+ self.net = SimpleMLPAdaLN(input_dim=input_dim, cond_dim=cond_dim, dim=dim, layers=layers, mlp_ratio=mlp_ratio)
192
+
193
+ @property
194
+ def dtype(self):
195
+ return self.net.input_proj.weight.dtype
196
+
197
+ @property
198
+ def device(self):
199
+ return self.net.input_proj.weight.device
200
+
201
+ @property
202
+ def trainable_params(self) -> float:
203
+ n_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
204
+ return LargeInt(n_params)
205
+
206
+
207
+ def get_score_from_velocity(self, velocity, x, t):
208
+ """Wrapper function: transfrom velocity prediction model to score
209
+ Args:
210
+ velocity: [bsz, ...] shaped tensor; velocity model output
211
+ x: [bsz, ...] shaped tensor; x_t data point
212
+ t: [bsz,] time tensor
213
+ """
214
+ t = expand_t(t, x)
215
+ alpha_t, d_alpha_t = t, 1
216
+ sigma_t, d_sigma_t = 1 - t, -1
217
+ mean = x
218
+ reverse_alpha_ratio = alpha_t / d_alpha_t
219
+ var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
220
+ score = (reverse_alpha_ratio * velocity - mean) / var
221
+ return score
222
+
223
+ def get_velocity_from_cfg(self, velocity, cfg, cfg_img, cfg_mult):
224
+ if cfg_mult == 2:
225
+ cond_v, uncond_v = torch.chunk(velocity, 2, dim=0)
226
+ velocity = uncond_v + cfg * (cond_v - uncond_v)
227
+ elif cfg_mult == 3:
228
+ cond_v, uncond_v1, uncond_v2 = torch.chunk(velocity, 3, dim=0)
229
+ velocity = uncond_v2 + cfg_img * (uncond_v1 - uncond_v2) + cfg * (cond_v - uncond_v1)
230
+ return velocity
231
+
232
+ @smart_compile(options={"triton.cudagraphs": True}, fullgraph=True)
233
+ @torch.no_grad()
234
+ def sample(
235
+ self,
236
+ c: torch.Tensor,
237
+ cfg: float = 1.0,
238
+ cfg_img: float = 1.0,
239
+ timesteps_shift: float = 1.0,
240
+ num_sampling_steps: int = 20,
241
+ last_step_size: float = 0.0,
242
+ noise_repeat: int = 1,
243
+ ):
244
+ # """c.shape = (bsz, cond_dim)"""
245
+ cfg_mult = 1
246
+ if cfg > 1.0:
247
+ cfg_mult += 1
248
+ if cfg_img > 1.0:
249
+ cfg_mult += 1
250
+
251
+ noise = randn_tensor((c.shape[0] // cfg_mult, self.input_dim), noise_repeat, self.device)
252
+
253
+ mean_x = noise
254
+ x = noise
255
+ xs = []
256
+
257
+ t0, t1 = 0, 1
258
+ timesteps = torch.linspace(t0, t1, num_sampling_steps + 1, device=c.device)[:-1]
259
+ timesteps = timesteps / (timesteps_shift - (timesteps_shift - 1) * timesteps)
260
+ timesteps = torch.cat([timesteps, torch.ones(1, device=c.device)])
261
+ for ti, tj in zip(timesteps[:-1], timesteps[1:]):
262
+ dt = tj - ti
263
+
264
+ combined = torch.cat([x] * cfg_mult, dim=0)
265
+ velocity = self.net(combined.to(c.dtype), ti.expand(c.shape[0]).to(c), c)
266
+ velocity = velocity.to(torch.float32)
267
+
268
+ velocity = self.get_velocity_from_cfg(velocity, cfg, cfg_img, cfg_mult)
269
+ score = self.get_score_from_velocity(velocity, x, ti.expand(x.shape[0]).to(x))
270
+ drift = velocity + (1 - expand_t(ti.expand(x.shape[0]).to(x), x)) * score
271
+
272
+ w_cur = randn_tensor((c.shape[0] // cfg_mult, self.input_dim), noise_repeat, self.device)
273
+ dw = w_cur * torch.sqrt(dt)
274
+
275
+ mean_x = x + drift * dt
276
+ x = mean_x + torch.sqrt(2 * (1 - expand_t(ti.expand(x.shape[0]).to(x), x))) * dw
277
+ xs.append(x)
278
+
279
+
280
+ if len(xs) != num_sampling_steps:
281
+ raise ValueError(f"Samples ({len(xs)}) does not match the number of steps ({num_sampling_steps})")
282
+
283
+ return xs[-1].to(c.dtype)
models/llama_model.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ from loguru import logger
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from transformers.cache_utils import Cache, StaticCache
9
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
10
+ from transformers.utils import is_flash_attn_greater_or_equal_2_10
11
+ from transformers import ROPE_INIT_FUNCTIONS
12
+ from transformers.models.llama.configuration_llama import LlamaConfig
13
+
14
+ from models.heads import LlamaMLP
15
+ from utils.model_utils import apply_rotary_pos_emb, repeat_kv
16
+ from models.config import NextStepConfig
17
+
18
+
19
+ class LlamaRMSNorm(nn.Module):
20
+ """LlamaRMSNorm is equivalent to T5LayerNorm"""
21
+
22
+ def __init__(self, hidden_size, eps=1e-6):
23
+ super().__init__()
24
+ self.weight = nn.Parameter(torch.ones(hidden_size))
25
+ self.variance_epsilon = eps
26
+
27
+ def forward(self, hidden_states):
28
+ input_dtype = hidden_states.dtype
29
+ hidden_states = hidden_states.to(torch.float32)
30
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
31
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
32
+ return self.weight * hidden_states.to(input_dtype)
33
+
34
+ def extra_repr(self):
35
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
36
+
37
+
38
+ class LlamaRotaryEmbedding(nn.Module):
39
+ def __init__(self, device=None, config: Optional[LlamaConfig] = None):
40
+ super().__init__()
41
+ self.rope_type = "default"
42
+ self.config = config
43
+
44
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
45
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
46
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
47
+
48
+ @torch.no_grad()
49
+ def forward(self, x, position_ids):
50
+ # Core RoPE block
51
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
52
+ position_ids_expanded = position_ids[:, None, :].float()
53
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
54
+ device_type = x.device.type
55
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
56
+ with torch.autocast(device_type=device_type, enabled=False):
57
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
58
+ emb = torch.cat((freqs, freqs), dim=-1)
59
+ cos = emb.cos()
60
+ sin = emb.sin()
61
+
62
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
63
+ cos = cos * self.attention_scaling
64
+ sin = sin * self.attention_scaling
65
+
66
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
67
+
68
+
69
+ class LlamaAttention(nn.Module):
70
+ def __init__(self, config: NextStepConfig, layer_idx: Optional[int]):
71
+ super().__init__()
72
+ self.config = config
73
+ self.layer_idx = layer_idx
74
+
75
+ self.attention_dropout = config.attention_dropout
76
+ self.hidden_size = config.hidden_size
77
+ self.num_heads = config.num_attention_heads
78
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
79
+ self.num_key_value_heads = config.num_key_value_heads
80
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
81
+ self.max_position_embeddings = config.max_position_embeddings
82
+ self.rope_theta = config.rope_theta
83
+ self.is_causal = True
84
+
85
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
86
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
87
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
88
+ self.o_proj = nn.Linear(
89
+ self.num_heads * self.head_dim, self.hidden_size, bias=getattr(config, "o_attention_bias", config.attention_bias)
90
+ )
91
+ self._flash_attn_uses_top_left_mask = False
92
+
93
+ def forward_sdpa(
94
+ self,
95
+ hidden_states: torch.Tensor,
96
+ attention_mask: Optional[torch.Tensor] = None,
97
+ position_ids: Optional[torch.LongTensor] = None,
98
+ past_key_value: Optional[Cache] = None,
99
+ output_attentions: bool = False,
100
+ use_cache: bool = False,
101
+ cache_position: Optional[torch.LongTensor] = None,
102
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
103
+ **kwargs,
104
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
105
+ bsz, q_len, _ = hidden_states.size()
106
+
107
+ query_states = self.q_proj(hidden_states)
108
+ key_states = self.k_proj(hidden_states)
109
+ value_states = self.v_proj(hidden_states)
110
+
111
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
112
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
113
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
114
+
115
+ if position_embeddings is None:
116
+ logger.warning_once(
117
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
118
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
119
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
120
+ "removed and `position_embeddings` will be mandatory."
121
+ )
122
+ cos, sin = self.rotary_emb(value_states, position_ids)
123
+ else:
124
+ cos, sin = position_embeddings
125
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
126
+
127
+ if past_key_value is not None:
128
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
129
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
130
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
131
+
132
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
133
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
134
+
135
+ causal_mask = attention_mask
136
+ if attention_mask is not None:
137
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
138
+
139
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
140
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
141
+ if query_states.device.type == "cuda" and causal_mask is not None:
142
+ query_states = query_states.contiguous()
143
+ key_states = key_states.contiguous()
144
+ value_states = value_states.contiguous()
145
+
146
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
147
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
148
+ is_causal = True if causal_mask is None and q_len > 1 else False
149
+
150
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
151
+ query_states,
152
+ key_states,
153
+ value_states,
154
+ attn_mask=causal_mask,
155
+ dropout_p=self.attention_dropout if self.training else 0.0,
156
+ is_causal=is_causal,
157
+ )
158
+
159
+ attn_output = attn_output.transpose(1, 2).contiguous()
160
+ attn_output = attn_output.view(bsz, q_len, -1)
161
+
162
+ attn_output = self.o_proj(attn_output)
163
+
164
+ return attn_output, None, past_key_value
165
+
166
+ def forward_flash(
167
+ self,
168
+ hidden_states: torch.Tensor,
169
+ attention_mask: Optional[torch.LongTensor] = None,
170
+ position_ids: Optional[torch.LongTensor] = None,
171
+ past_key_value: Optional[Cache] = None,
172
+ output_attentions: bool = False,
173
+ use_cache: bool = False,
174
+ cache_position: Optional[torch.LongTensor] = None,
175
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
176
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
177
+ if isinstance(past_key_value, StaticCache):
178
+ raise ValueError(
179
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
180
+ "make sure to use `sdpa` in the mean time, and open an issue at GitHub - huggingface/transformers: 🤗 Transformers: the model-definition framework for state-of-the-a"
181
+ )
182
+
183
+ output_attentions = False
184
+
185
+ bsz, q_len, _ = hidden_states.size()
186
+
187
+ query_states = self.q_proj(hidden_states)
188
+ key_states = self.k_proj(hidden_states)
189
+ value_states = self.v_proj(hidden_states)
190
+
191
+ # Flash attention requires the input to have the shape
192
+ # batch_size x seq_length x head_dim x hidden_dim
193
+ # therefore we just need to keep the original shape
194
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
195
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
196
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
197
+
198
+ if position_embeddings is None:
199
+ logger.warning_once(
200
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
201
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
202
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
203
+ "removed and `position_embeddings` will be mandatory."
204
+ )
205
+ cos, sin = self.rotary_emb(value_states, position_ids)
206
+ else:
207
+ cos, sin = position_embeddings
208
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
209
+
210
+ if past_key_value is not None:
211
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
212
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
213
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
214
+
215
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
216
+ # to be able to avoid many of these transpose/reshape/view.
217
+ query_states = query_states.transpose(1, 2)
218
+ key_states = key_states.transpose(1, 2)
219
+ value_states = value_states.transpose(1, 2)
220
+
221
+ dropout_rate = self.attention_dropout if self.training else 0.0
222
+
223
+ input_dtype = query_states.dtype
224
+ if input_dtype == torch.float32:
225
+ if torch.is_autocast_enabled():
226
+ target_dtype = torch.get_autocast_gpu_dtype()
227
+ # Handle the case where the model is quantized
228
+ elif hasattr(self.config, "_pre_quantization_dtype"):
229
+ target_dtype = self.config._pre_quantization_dtype
230
+ else:
231
+ target_dtype = self.q_proj.weight.dtype
232
+
233
+ logger.warning_once(
234
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
235
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
236
+ f" {target_dtype}."
237
+ )
238
+
239
+ query_states = query_states.to(target_dtype)
240
+ key_states = key_states.to(target_dtype)
241
+ value_states = value_states.to(target_dtype)
242
+
243
+ attn_output = _flash_attention_forward(
244
+ query_states,
245
+ key_states,
246
+ value_states,
247
+ attention_mask,
248
+ q_len,
249
+ position_ids=position_ids,
250
+ dropout=dropout_rate,
251
+ sliding_window=getattr(self, "sliding_window", None),
252
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
253
+ is_causal=self.is_causal,
254
+ )
255
+
256
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
257
+ attn_output = self.o_proj(attn_output)
258
+
259
+ if not output_attentions:
260
+ attn_weights = None
261
+
262
+ return attn_output, attn_weights, past_key_value
263
+
264
+ def forward(
265
+ self,
266
+ hidden_states: torch.Tensor,
267
+ attention_mask: Optional[torch.Tensor] = None,
268
+ position_ids: Optional[torch.LongTensor] = None,
269
+ past_key_value: Optional[Cache] = None,
270
+ output_attentions: bool = False,
271
+ use_cache: bool = False,
272
+ cache_position: Optional[torch.LongTensor] = None,
273
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
274
+ **kwargs,
275
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
276
+ bsz, q_len, _ = hidden_states.size()
277
+
278
+ query_states = self.q_proj(hidden_states)
279
+ key_states = self.k_proj(hidden_states)
280
+ value_states = self.v_proj(hidden_states)
281
+
282
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
283
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
284
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
285
+
286
+ if position_embeddings is None:
287
+ logger.warning_once(
288
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
289
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
290
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
291
+ "removed and `position_embeddings` will be mandatory."
292
+ )
293
+ cos, sin = self.rotary_emb(value_states, position_ids)
294
+ else:
295
+ cos, sin = position_embeddings
296
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
297
+
298
+ if past_key_value is not None:
299
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
300
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
301
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
302
+
303
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
304
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
305
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
306
+
307
+ if attention_mask is not None: # no matter the length, we just slice it
308
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
309
+ attn_weights = attn_weights + causal_mask
310
+
311
+ # upcast attention to fp32
312
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
313
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
314
+ attn_output = torch.matmul(attn_weights, value_states)
315
+
316
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
317
+ raise ValueError(
318
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
319
+ f" {attn_output.size()}"
320
+ )
321
+
322
+ attn_output = attn_output.transpose(1, 2).contiguous()
323
+
324
+ attn_output = attn_output.reshape(bsz, q_len, -1)
325
+
326
+ attn_output = self.o_proj(attn_output)
327
+
328
+ if not output_attentions:
329
+ attn_weights = None
330
+
331
+ return attn_output, attn_weights, past_key_value
332
+
333
+
334
+ class LlamaFlashAttention2(LlamaAttention):
335
+ """
336
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
337
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
338
+ flash attention and deal with padding tokens in case the input contains any of them.
339
+ """
340
+
341
+ def __init__(self, *args, **kwargs):
342
+ super().__init__(*args, **kwargs)
343
+
344
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
345
+
346
+ def forward(
347
+ self,
348
+ hidden_states: torch.Tensor,
349
+ attention_mask: Optional[torch.LongTensor] = None,
350
+ past_key_value: Optional[Cache] = None,
351
+ output_attentions: bool = False,
352
+ use_cache: bool = False,
353
+ cache_position: Optional[torch.LongTensor] = None,
354
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
355
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
356
+ if isinstance(past_key_value, StaticCache):
357
+ raise ValueError(
358
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
359
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
360
+ )
361
+
362
+ output_attentions = False
363
+
364
+ bsz, q_len, _ = hidden_states.size()
365
+
366
+ query_states = self.q_proj(hidden_states)
367
+ key_states = self.k_proj(hidden_states)
368
+ value_states = self.v_proj(hidden_states)
369
+
370
+ # Flash attention requires the input to have the shape
371
+ # batch_size x seq_length x head_dim x hidden_dim
372
+ # therefore we just need to keep the original shape
373
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
374
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
375
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
376
+
377
+ cos, sin = position_embeddings
378
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
379
+
380
+ if past_key_value is not None:
381
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
382
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
383
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
384
+
385
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
386
+ # to be able to avoid many of these transpose/reshape/view.
387
+ query_states = query_states.transpose(1, 2)
388
+ key_states = key_states.transpose(1, 2)
389
+ value_states = value_states.transpose(1, 2)
390
+
391
+ dropout_rate = self.attention_dropout if self.training else 0.0
392
+
393
+ input_dtype = query_states.dtype
394
+ if input_dtype == torch.float32:
395
+ if torch.is_autocast_enabled():
396
+ target_dtype = torch.get_autocast_gpu_dtype()
397
+ # Handle the case where the model is quantized
398
+ elif hasattr(self.config, "_pre_quantization_dtype"):
399
+ target_dtype = self.config._pre_quantization_dtype
400
+ else:
401
+ target_dtype = self.q_proj.weight.dtype
402
+
403
+ logger.warning_once(
404
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
405
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
406
+ f" {target_dtype}."
407
+ )
408
+
409
+ query_states = query_states.to(target_dtype)
410
+ key_states = key_states.to(target_dtype)
411
+ value_states = value_states.to(target_dtype)
412
+
413
+ attn_output = _flash_attention_forward(
414
+ query_states,
415
+ key_states,
416
+ value_states,
417
+ attention_mask,
418
+ q_len,
419
+ position_ids=None,
420
+ dropout=dropout_rate,
421
+ sliding_window=getattr(self, "sliding_window", None),
422
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
423
+ is_causal=self.is_causal,
424
+ )
425
+
426
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
427
+ attn_output = self.o_proj(attn_output)
428
+
429
+ if not output_attentions:
430
+ attn_weights = None
431
+
432
+ return attn_output, attn_weights, past_key_value
433
+
434
+
435
+ class LlamaSdpaAttention(LlamaAttention):
436
+ """
437
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
438
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
439
+ SDPA API.
440
+ """
441
+
442
+ # Adapted from LlamaAttention.forward
443
+ def forward(
444
+ self,
445
+ hidden_states: torch.Tensor,
446
+ attention_mask: Optional[torch.Tensor] = None,
447
+ past_key_value: Optional[Cache] = None,
448
+ output_attentions: bool = False,
449
+ use_cache: bool = False,
450
+ cache_position: Optional[torch.LongTensor] = None,
451
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
452
+ **kwargs,
453
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
454
+
455
+ bsz, q_len, _ = hidden_states.size()
456
+
457
+ query_states = self.q_proj(hidden_states)
458
+ key_states = self.k_proj(hidden_states)
459
+ value_states = self.v_proj(hidden_states)
460
+
461
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
462
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
463
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
464
+
465
+ cos, sin = position_embeddings
466
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
467
+
468
+ if past_key_value is not None:
469
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
470
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
471
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
472
+
473
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
474
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
475
+
476
+ causal_mask = attention_mask
477
+ if attention_mask is not None:
478
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
479
+
480
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
481
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
482
+ if query_states.device.type == "cuda" and causal_mask is not None:
483
+ query_states = query_states.contiguous()
484
+ key_states = key_states.contiguous()
485
+ value_states = value_states.contiguous()
486
+
487
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
488
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
489
+ is_causal = True if causal_mask is None and q_len > 1 else False
490
+
491
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
492
+ query_states,
493
+ key_states,
494
+ value_states,
495
+ attn_mask=causal_mask,
496
+ dropout_p=self.attention_dropout if self.training else 0.0,
497
+ is_causal=is_causal,
498
+ )
499
+
500
+ attn_output = attn_output.transpose(1, 2).contiguous()
501
+ attn_output = attn_output.view(bsz, q_len, -1)
502
+
503
+ attn_output = self.o_proj(attn_output)
504
+
505
+ return attn_output, None, past_key_value
506
+
507
+
508
+ LLAMA_ATTENTION_CLASSES = {
509
+ "eager": LlamaAttention,
510
+ "flash_attention_2": LlamaFlashAttention2,
511
+ "sdpa": LlamaSdpaAttention,
512
+ }
513
+
514
+
515
+ class LlamaDecoderLayer(nn.Module):
516
+ def __init__(self, config: LlamaConfig, layer_idx: int):
517
+ super().__init__()
518
+ self.hidden_size = config.hidden_size
519
+
520
+ self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
521
+
522
+ self.mlp = LlamaMLP(config)
523
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
524
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
525
+
526
+ def forward(
527
+ self,
528
+ hidden_states: torch.Tensor,
529
+ attention_mask: Optional[torch.Tensor] = None,
530
+ past_key_value: Optional[Cache] = None,
531
+ output_attentions: Optional[bool] = False,
532
+ use_cache: Optional[bool] = False,
533
+ cache_position: Optional[torch.LongTensor] = None,
534
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
535
+ **kwargs,
536
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
537
+ residual = hidden_states
538
+
539
+ hidden_states = self.input_layernorm(hidden_states)
540
+
541
+ # Self Attention
542
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
543
+ hidden_states=hidden_states,
544
+ attention_mask=attention_mask,
545
+ past_key_value=past_key_value,
546
+ output_attentions=output_attentions,
547
+ use_cache=use_cache,
548
+ cache_position=cache_position,
549
+ position_embeddings=position_embeddings,
550
+ **kwargs,
551
+ )
552
+ hidden_states = residual + hidden_states
553
+
554
+ # Fully Connected
555
+ residual = hidden_states
556
+ hidden_states = self.post_attention_layernorm(hidden_states)
557
+ hidden_states = self.mlp(hidden_states)
558
+ hidden_states = residual + hidden_states
559
+
560
+ outputs = (hidden_states,)
561
+
562
+ if output_attentions:
563
+ outputs += (self_attn_weights,)
564
+
565
+ if use_cache:
566
+ outputs += (present_key_value,)
567
+
568
+ return outputs
models/nextstep_model.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import inspect
4
+ from loguru import logger
5
+ from dataclasses import dataclass
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import CrossEntropyLoss
10
+
11
+ from safetensors.torch import safe_open
12
+ from transformers.modeling_utils import PreTrainedModel
13
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
14
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
15
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
16
+
17
+ from models.config import NextStepConfig
18
+ from models.llama_model import LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding
19
+ from models.heads import FlowMatchingHead
20
+ from utils.misc import LargeInt
21
+ from utils.compile_utils import smart_compile
22
+ from utils.model_utils import get_2d_sincos_pos_embed
23
+
24
+
25
+ @dataclass
26
+ class NextStepOutputWithPast(CausalLMOutputWithPast):
27
+ lm_loss: torch.FloatTensor | None = None
28
+ im_loss: torch.FloatTensor | None = None
29
+
30
+
31
+ class NextStepPreTrainedModel(PreTrainedModel):
32
+ config_class = NextStepConfig
33
+ supports_gradient_checkpointing = True
34
+ _no_split_modules = ["LlamaDecoderLayer"]
35
+ _skip_keys_device_placement = ["past_key_values"]
36
+ _supports_flash_attn_2 = True
37
+ _supports_sdpa = True
38
+ _supports_cache_class = True
39
+ _supports_quantized_cache = True
40
+ _supports_static_cache = True
41
+
42
+ def _init_weights(self, module):
43
+ std = self.config.initializer_range
44
+ if isinstance(module, nn.Linear):
45
+ module.weight.data.normal_(mean=0.0, std=std)
46
+ if module.bias is not None:
47
+ module.bias.data.zero_()
48
+ elif isinstance(module, nn.Embedding):
49
+ module.weight.data.normal_(mean=0.0, std=std)
50
+ if module.padding_idx is not None:
51
+ module.weight.data[module.padding_idx].zero_()
52
+
53
+ @property
54
+ def trainable_params(self) -> float:
55
+ n_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
56
+ return LargeInt(n_params)
57
+
58
+
59
+ class NextStep(NextStepPreTrainedModel):
60
+
61
+ def __init__(self, config: NextStepConfig):
62
+ super().__init__(config)
63
+ self.padding_idx = config.pad_token_id
64
+ self.vocab_size = config.vocab_size
65
+
66
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
67
+
68
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
69
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
70
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
71
+
72
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
73
+
74
+ self.gradient_checkpointing = False
75
+
76
+ # Initialize weights and apply final processing
77
+ self.post_init()
78
+
79
+ token_dim = self.config.latent_channels * self.config.latent_patch_size**2
80
+
81
+ self.image_in_projector = nn.Linear(token_dim, config.hidden_size)
82
+ self.image_in_projector.weight.data.normal_(mean=0.0, std=config.initializer_range)
83
+ self.image_in_projector.bias.data.zero_()
84
+
85
+ self.image_out_projector = nn.Linear(config.hidden_size, config.hidden_size)
86
+ self.image_out_projector.weight.data.normal_(mean=0.0, std=config.initializer_range)
87
+ self.image_out_projector.bias.data.zero_()
88
+
89
+ self.image_head = FlowMatchingHead(
90
+ input_dim=token_dim,
91
+ cond_dim=config.hidden_size,
92
+ dim=config.fm_head_dim,
93
+ layers=config.fm_head_layers,
94
+ )
95
+
96
+ if config.use_gen_pos_embed:
97
+ self.init_gen_pos_embed()
98
+
99
+ def init_gen_pos_embed(self):
100
+ self.register_buffer(
101
+ "gen_pos_embed",
102
+ torch.from_numpy(
103
+ get_2d_sincos_pos_embed(
104
+ self.config.hidden_size, self.config.base_image_grid_size
105
+ )
106
+ ).float().unsqueeze(0),
107
+ )
108
+
109
+ def gen_pos_embed_with_ar(self, h, w):
110
+ bsz, hw, dim = self.gen_pos_embed.shape
111
+ gen_pos_embed = self.gen_pos_embed.reshape(bsz, int(hw**0.5), int(hw**0.5), dim)
112
+ gen_pos_embed = gen_pos_embed[:, :h, :w, :]
113
+ gen_pos_embed = gen_pos_embed.reshape(bsz, -1, dim)
114
+ return gen_pos_embed
115
+
116
+ @property
117
+ def image_size(self):
118
+ return self.config.image_size
119
+
120
+ @property
121
+ def image_patch_size(self):
122
+ return self.config.patch_size
123
+
124
+ @property
125
+ def image_grid_size(self):
126
+ return round(self.image_size / self.image_patch_size)
127
+
128
+ def get_input_embeddings(self):
129
+ return self.embed_tokens
130
+
131
+ def set_input_embeddings(self, value):
132
+ self.embed_tokens = value
133
+
134
+ def get_output_embeddings(self):
135
+ return self.lm_head
136
+
137
+ def set_output_embeddings(self, new_embeddings):
138
+ self.lm_head = new_embeddings
139
+
140
+ def load_lm_head(self, lm_head_dir: str | None = None):
141
+ index_json_file = os.path.join(lm_head_dir, "model.safetensors.index.json")
142
+ head_weight_name = "lm_head.weight" if not self.config.tie_word_embeddings else "model.embed_tokens.weight"
143
+ if os.path.exists(index_json_file):
144
+ with open(index_json_file, "r") as f:
145
+ index = json.load(f)
146
+ model_name = index["weight_map"][head_weight_name]
147
+ else:
148
+ model_name = "model.safetensors"
149
+ with safe_open(os.path.join(lm_head_dir, model_name), framework="pt") as f:
150
+ loaded_weight = f.get_tensor(head_weight_name)
151
+ loaded_weight = loaded_weight.to(dtype=self.lm_head.weight.dtype, device=self.lm_head.weight.device)
152
+ self.lm_head.weight.data.copy_(loaded_weight)
153
+
154
+ def patchify(self, img: torch.Tensor):
155
+ """
156
+ img: (bsz, C, H, W)
157
+ x: (bsz, H * W / patch_size**2, patch_size**2 * C)
158
+ """
159
+ bsz, c, h, w = img.shape
160
+ p = self.config.latent_patch_size
161
+ h_, w_ = h // p, w // p
162
+
163
+ img = img.reshape(bsz, c, h_, p, w_, p)
164
+ img = torch.einsum("nchpwq->nhwcpq", img)
165
+ x = img.reshape(bsz, h_ * w_, c * p**2)
166
+ return x
167
+
168
+ def unpatchify(self, x: torch.Tensor, h: int = None, w: int = None):
169
+ """
170
+ x: (bsz, H * W / patch_size**2, patch_size**2 * C)
171
+ img: (bsz, C, H, W)
172
+ """
173
+ bsz = x.shape[0]
174
+ p = self.config.latent_patch_size
175
+ c = self.config.latent_channels
176
+ if h is None and w is None:
177
+ h_ = w_ = int(x.shape[1] ** 0.5)
178
+ else:
179
+ h_, w_ = h, w
180
+ assert h_ * w_ == x.shape[1], f"Invalid sequence length {x.shape[1]}."
181
+
182
+ x = x.reshape(bsz, h_, w_, c, p, p)
183
+ x = torch.einsum("nhwcpq->nchpwq", x)
184
+ img = x.reshape(bsz, c, h_ * p, w_ * p)
185
+ return img
186
+
187
+ def prepare_inputs_embeds(self, input_ids: torch.LongTensor | None = None, latents: torch.FloatTensor | None = None):
188
+ if latents is None:
189
+ if not self.training:
190
+ return self.embed_tokens(input_ids)
191
+ else: # dummy forward for image pass, for the consistent shape of gradient.
192
+ raise NotImplementedError("Dummy forward for image pass is not implemented.")
193
+ else:
194
+ bs, seq_length = input_ids.shape
195
+ inputs_embeds = torch.zeros(
196
+ (bs, seq_length, self.config.hidden_size),
197
+ device=self.embed_tokens.weight.device,
198
+ dtype=self.embed_tokens.weight.dtype,
199
+ )
200
+ im_indices = input_ids == self.config.image_placeholder_id
201
+ lm_indices = ~im_indices
202
+
203
+ if isinstance(latents, list):
204
+ tokens = torch.cat([self.patchify(latent) for latent in latents], dim=1)
205
+ else:
206
+ tokens = self.patchify(latents)
207
+ # tokens = tokens.reshape(1, -1, tokens.shape[-1])
208
+
209
+ image_embeds = self.image_in_projector(tokens)
210
+ image_embeds = image_embeds.view(-1, self.config.hidden_size)
211
+
212
+ token_embeds = self.embed_tokens(input_ids[lm_indices])
213
+
214
+ inputs_embeds[im_indices] = image_embeds.to(inputs_embeds.dtype)
215
+ inputs_embeds[lm_indices] = token_embeds
216
+
217
+ return inputs_embeds
218
+
219
+ def _update_causal_mask(
220
+ self,
221
+ attention_mask: torch.Tensor,
222
+ input_tensor: torch.Tensor,
223
+ cache_position: torch.Tensor,
224
+ past_key_values: Cache,
225
+ output_attentions: bool,
226
+ ):
227
+ if self.config._attn_implementation == "flash_attention_2":
228
+ if attention_mask is not None and (attention_mask == 0.0).any():
229
+ return attention_mask
230
+ return None
231
+
232
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
233
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
234
+ # to infer the attention mask.
235
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
236
+ using_static_cache = isinstance(past_key_values, StaticCache)
237
+
238
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
239
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
240
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
241
+ attention_mask,
242
+ inputs_embeds=input_tensor,
243
+ past_key_values_length=past_seen_tokens,
244
+ is_training=self.training,
245
+ ):
246
+ return None
247
+
248
+ dtype, device = input_tensor.dtype, input_tensor.device
249
+ sequence_length = input_tensor.shape[1]
250
+ if using_static_cache:
251
+ target_length = past_key_values.get_max_cache_shape()
252
+ else:
253
+ target_length = (
254
+ attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1
255
+ )
256
+
257
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
258
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
259
+ attention_mask,
260
+ sequence_length=sequence_length,
261
+ target_length=target_length,
262
+ dtype=dtype,
263
+ device=device,
264
+ cache_position=cache_position,
265
+ batch_size=input_tensor.shape[0],
266
+ )
267
+
268
+ if (
269
+ self.config._attn_implementation == "sdpa"
270
+ and attention_mask is not None
271
+ and attention_mask.device.type == "cuda"
272
+ and not output_attentions
273
+ ):
274
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
275
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
276
+ # Details: https://github.com/pytorch/pytorch/issues/110213
277
+ min_dtype = torch.finfo(dtype).min
278
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
279
+
280
+ return causal_mask
281
+
282
+ @staticmethod
283
+ def _prepare_4d_causal_attention_mask_with_cache_position(
284
+ attention_mask: torch.Tensor,
285
+ sequence_length: int,
286
+ target_length: int,
287
+ dtype: torch.dtype,
288
+ device: torch.device,
289
+ cache_position: torch.Tensor,
290
+ batch_size: int,
291
+ **kwargs,
292
+ ):
293
+ """
294
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
295
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
296
+
297
+ Args:
298
+ attention_mask (`torch.Tensor`):
299
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
300
+ `(batch_size, 1, query_length, key_value_length)`.
301
+ sequence_length (`int`):
302
+ The sequence length being processed.
303
+ target_length (`int`):
304
+ The target length: when generating with static cache, the mask should be as long as the static cache,
305
+ to account for the 0 padding, the part of the cache that is not filled yet.
306
+ dtype (`torch.dtype`):
307
+ The dtype to use for the 4D attention mask.
308
+ device (`torch.device`):
309
+ The device to plcae the 4D attention mask on.
310
+ cache_position (`torch.Tensor`):
311
+ Indices depicting the position of the input sequence tokens in the sequence.
312
+ batch_size (`torch.Tensor`):
313
+ Batch size.
314
+ """
315
+ if attention_mask is not None and attention_mask.dim() == 4:
316
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
317
+ causal_mask = attention_mask
318
+ else:
319
+ min_dtype = torch.finfo(dtype).min
320
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
321
+ if sequence_length != 1:
322
+ causal_mask = torch.triu(causal_mask, diagonal=1)
323
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
324
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
325
+ if attention_mask is not None:
326
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
327
+ mask_length = attention_mask.shape[-1]
328
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
329
+ padding_mask = padding_mask == 0
330
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype)
331
+
332
+ return causal_mask
333
+
334
+ @smart_compile()
335
+ def forward_model(
336
+ self,
337
+ inputs_embeds: torch.FloatTensor | None = None,
338
+ attention_mask: torch.Tensor | None = None,
339
+ past_key_values: Cache | list[torch.FloatTensor] | None = None,
340
+ use_cache: bool | None = None,
341
+ output_attentions: bool | None = None,
342
+ output_hidden_states: bool | None = None,
343
+ cache_position: torch.LongTensor | None = None,
344
+ ) -> tuple | BaseModelOutputWithPast:
345
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
346
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
347
+
348
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
349
+ if self.gradient_checkpointing and self.training and use_cache:
350
+ use_cache = False
351
+
352
+ if use_cache and past_key_values is None:
353
+ past_key_values = DynamicCache()
354
+
355
+ if cache_position is None:
356
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
357
+ cache_position = torch.arange(
358
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
359
+ )
360
+ position_ids = cache_position.unsqueeze(0)
361
+
362
+ causal_mask = self._update_causal_mask(
363
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
364
+ )
365
+ hidden_states = inputs_embeds
366
+
367
+ # create position embeddings to be shared across the decoder layers
368
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
369
+
370
+ # decoder layers
371
+ all_hidden_states = () if output_hidden_states else None
372
+ all_self_attns = () if output_attentions else None
373
+
374
+ for decoder_layer in self.layers:
375
+ if output_hidden_states:
376
+ all_hidden_states += (hidden_states,)
377
+
378
+ if self.gradient_checkpointing and self.training:
379
+ layer_outputs = self._gradient_checkpointing_func(
380
+ decoder_layer.__call__,
381
+ hidden_states,
382
+ causal_mask,
383
+ past_key_values,
384
+ output_attentions,
385
+ use_cache,
386
+ cache_position,
387
+ position_embeddings,
388
+ )
389
+ else:
390
+ layer_outputs = decoder_layer(
391
+ hidden_states,
392
+ attention_mask=causal_mask,
393
+ past_key_value=past_key_values,
394
+ output_attentions=output_attentions,
395
+ use_cache=use_cache,
396
+ cache_position=cache_position,
397
+ position_embeddings=position_embeddings,
398
+ )
399
+
400
+ hidden_states = layer_outputs[0]
401
+
402
+ if output_attentions:
403
+ all_self_attns += (layer_outputs[1],)
404
+
405
+ hidden_states = self.norm(hidden_states)
406
+
407
+ # add hidden states from the last decoder layer
408
+ if output_hidden_states:
409
+ all_hidden_states += (hidden_states,)
410
+
411
+ return BaseModelOutputWithPast(
412
+ last_hidden_state=hidden_states,
413
+ past_key_values=past_key_values if use_cache else None,
414
+ hidden_states=all_hidden_states,
415
+ attentions=all_self_attns,
416
+ )
417
+
418
+
419
+
420
+ def prepare_inputs_for_generation(
421
+ self,
422
+ input_ids: torch.LongTensor,
423
+ past_key_values: Cache | None = None,
424
+ attention_mask: torch.LongTensor | None = None,
425
+ inputs_embeds: torch.FloatTensor | None = None,
426
+ cache_position: torch.LongTensor | None = None,
427
+ **kwargs,
428
+ ):
429
+ """
430
+ Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or
431
+ slicing inputs given the existing cache.
432
+
433
+ See the forward pass in the model documentation for expected arguments (different models might have different
434
+ requirements for e.g. `past_key_values`). This function should work as is for most LLMs.
435
+ """
436
+
437
+ # 1. Handle BC:
438
+ model_inputs = {}
439
+ # - some models don't have `Cache` support (which implies they don't expect `cache_position` in `forward`)
440
+ if self._supports_cache_class:
441
+ model_inputs["cache_position"] = cache_position
442
+ # - `cache_position` was not a mandatory input in `prepare_inputs_for_generation` for those models, and this
443
+ # function may be called outside of `generate`. Handle most use cases by creating `cache_position` on the fly
444
+ # (this alternative is not as robust as calling `generate` and letting it create `cache_position`)
445
+ elif cache_position is None:
446
+ past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
447
+ cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
448
+
449
+ # 2. Generic cache-dependent input preparation
450
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
451
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
452
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
453
+ # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case
454
+ if past_key_values is not None:
455
+ model_inputs["past_key_values"] = past_key_values
456
+ if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 or Exception 3
457
+ input_ids = input_ids[:, -cache_position.shape[0] :]
458
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
459
+ input_ids = input_ids[:, cache_position]
460
+
461
+ # 3. Prepare base model inputs
462
+ input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
463
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
464
+ if not self.config.is_encoder_decoder:
465
+ if inputs_embeds is not None and cache_position[0] == 0:
466
+ model_inputs[input_ids_key] = None
467
+ model_inputs["inputs_embeds"] = inputs_embeds
468
+ else:
469
+ # `clone` calls in this function ensure a consistent stride. See #32227
470
+ model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
471
+ model_inputs["inputs_embeds"] = None
472
+ else:
473
+ model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
474
+
475
+ # 4. Create missing `position_ids` on the fly
476
+ if (
477
+ attention_mask is not None
478
+ and kwargs.get("position_ids") is None
479
+ and "position_ids" in set(inspect.signature(self.forward).parameters.keys())
480
+ ):
481
+ position_ids = attention_mask.long().cumsum(-1) - 1
482
+ position_ids.masked_fill_(attention_mask == 0, 1)
483
+ kwargs["position_ids"] = position_ids # placed in kwargs for further processing (see below)
484
+
485
+ # 5. Slice model inputs if it's an input that should have the same length as `input_ids`
486
+ for model_input_name in ["position_ids", "token_type_ids"]:
487
+ model_input = kwargs.get(model_input_name)
488
+ if model_input is not None:
489
+ if past_key_values:
490
+ model_input = model_input[:, -input_ids.shape[1] :]
491
+ model_input = model_input.clone(memory_format=torch.contiguous_format)
492
+ model_inputs[model_input_name] = model_input
493
+
494
+ # 6. Create 4D attention mask is we are using a `StaticCache` (important for performant compiled forward pass)
495
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
496
+ if model_inputs["inputs_embeds"] is not None:
497
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
498
+ device = model_inputs["inputs_embeds"].device
499
+ else:
500
+ batch_size, sequence_length = model_inputs[input_ids_key].shape
501
+ device = model_inputs[input_ids_key].device
502
+
503
+ # Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
504
+ # the 4D causal mask exists, it should be present in the base model (XXXModel class).
505
+ base_model = getattr(self, self.base_model_prefix, None)
506
+ if base_model is None:
507
+ causal_mask_creation_function = getattr(self, "_prepare_4d_causal_attention_mask_with_cache_position", None)
508
+ else:
509
+ causal_mask_creation_function = getattr(
510
+ base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None
511
+ )
512
+ if causal_mask_creation_function is None:
513
+ logger.warning_once(
514
+ f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
515
+ "defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're "
516
+ "writing code, see Llama for an example implementation. If you're a user, please report this "
517
+ "issue on GitHub."
518
+ )
519
+ else:
520
+ attention_mask = causal_mask_creation_function(
521
+ attention_mask,
522
+ sequence_length=sequence_length,
523
+ target_length=past_key_values.get_max_cache_shape(),
524
+ dtype=self.dtype,
525
+ device=device,
526
+ cache_position=cache_position,
527
+ batch_size=batch_size,
528
+ config=self.config,
529
+ past_key_values=past_key_values,
530
+ )
531
+ if attention_mask is not None:
532
+ model_inputs["attention_mask"] = attention_mask
533
+
534
+ # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
535
+ for key, value in kwargs.items():
536
+ if key not in model_inputs:
537
+ model_inputs[key] = value
538
+
539
+ # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
540
+ model_inputs.pop("labels", None)
541
+ return model_inputs
542
+
543
+ @torch.no_grad()
544
+ def generate(self, inputs: torch.LongTensor = None, **kwargs):
545
+ input_ids = kwargs.pop("input_ids")
546
+ latents = kwargs.pop("latents", None)
547
+ inputs_embeds = self.prepare_inputs_embeds(input_ids, latents)
548
+ return super().generate(inputs=inputs, input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
549
+
550
+ def gradient_checkpointing_enable(self, **kwargs):
551
+ super().gradient_checkpointing_enable(**kwargs)
552
+
553
+ self.image_head.net.grad_checkpointing = True
pytorch-model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:909083ea5deb26b37d66d4966869b0c3310c06aa72a88974c293d0fe62a489b9
3
+ size 9962132680
pytorch-model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:630af8946d2326406253ede6e3cb143d935ba19595059ce56af86fd50442e2d3
3
+ size 9909693448
pytorch-model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47a71e20a18992b9008bae6fa523b69824c1e67b5656bbd8fa6b442bf7405c72
3
+ size 8478742432
pytorch-model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b5f5fe250f5cbce219aadb61c9a44903739e510a66184cc960bfce87175bc34
3
+ size 1557135464
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.34.0
2
+ einops==0.8.1
3
+ gradio==5.42.0
4
+ loguru==0.7.3
5
+ numpy==1.26.4
6
+ omegaconf==2.3.0
7
+ Pillow==11.0.0
8
+ Requests==2.32.4
9
+ safetensors==0.5.3
10
+ tabulate==0.9.0
11
+ torch==2.5.1
12
+ torchvision==0.20.1
13
+ tqdm==4.67.1
14
+ transformers==4.55.0
special_tokens_map.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|image_area|>",
4
+ "<|begin_of_image|>",
5
+ "<|end_of_image|>",
6
+ "<|image_placeholder|>",
7
+ "<|begin_of_prompt_refinement|>",
8
+ "<|end_of_prompt_refinement|>",
9
+ "<|begin_of_thinking|>",
10
+ "<|end_of_thinking|>",
11
+ "<|beginoftext|>"
12
+ ],
13
+ "eos_token": {
14
+ "content": "<|endoftext|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false
19
+ },
20
+ "pad_token": {
21
+ "content": "[PAD]",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false
26
+ }
27
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:310b48c809fba04c32e7f7cdac4d0fb1c00140d8914e0b0163307f64e5330a92
3
+ size 11423853
tokenizer_config.json ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "[PAD]",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": true
188
+ },
189
+ "151666": {
190
+ "content": "<|image_area|>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": true
196
+ },
197
+ "151667": {
198
+ "content": "<|begin_of_image|>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": true
204
+ },
205
+ "151668": {
206
+ "content": "<|end_of_image|>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": true
212
+ },
213
+ "151669": {
214
+ "content": "<|image_placeholder|>",
215
+ "lstrip": false,
216
+ "normalized": false,
217
+ "rstrip": false,
218
+ "single_word": false,
219
+ "special": true
220
+ },
221
+ "151670": {
222
+ "content": "<|begin_of_prompt_refinement|>",
223
+ "lstrip": false,
224
+ "normalized": false,
225
+ "rstrip": false,
226
+ "single_word": false,
227
+ "special": true
228
+ },
229
+ "151671": {
230
+ "content": "<|end_of_prompt_refinement|>",
231
+ "lstrip": false,
232
+ "normalized": false,
233
+ "rstrip": false,
234
+ "single_word": false,
235
+ "special": true
236
+ },
237
+ "151672": {
238
+ "content": "<|begin_of_thinking|>",
239
+ "lstrip": false,
240
+ "normalized": false,
241
+ "rstrip": false,
242
+ "single_word": false,
243
+ "special": true
244
+ },
245
+ "151673": {
246
+ "content": "<|end_of_thinking|>",
247
+ "lstrip": false,
248
+ "normalized": false,
249
+ "rstrip": false,
250
+ "single_word": false,
251
+ "special": true
252
+ },
253
+ "151674": {
254
+ "content": "<|beginoftext|>",
255
+ "lstrip": false,
256
+ "normalized": false,
257
+ "rstrip": false,
258
+ "single_word": false,
259
+ "special": true
260
+ }
261
+ },
262
+ "additional_special_tokens": [
263
+ "<|image_area|>",
264
+ "<|begin_of_image|>",
265
+ "<|end_of_image|>",
266
+ "<|image_placeholder|>",
267
+ "<|begin_of_prompt_refinement|>",
268
+ "<|end_of_prompt_refinement|>",
269
+ "<|begin_of_thinking|>",
270
+ "<|end_of_thinking|>",
271
+ "<|beginoftext|>"
272
+ ],
273
+ "bos_token": null,
274
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
275
+ "clean_up_tokenization_spaces": false,
276
+ "eos_token": "<|endoftext|>",
277
+ "errors": "replace",
278
+ "extra_special_tokens": {},
279
+ "model_max_length": 8192,
280
+ "pad_token": "[PAD]",
281
+ "padding_side": "right",
282
+ "split_special_tokens": false,
283
+ "tokenizer_class": "Qwen2Tokenizer",
284
+ "unk_token": null
285
+ }
utils/__pycache__/compile_utils.cpython-310.pyc ADDED
Binary file (2.72 kB). View file
 
utils/__pycache__/image_utils.cpython-310.pyc ADDED
Binary file (8.46 kB). View file
 
utils/__pycache__/misc.cpython-310.pyc ADDED
Binary file (2.03 kB). View file
 
utils/__pycache__/model_utils.cpython-310.pyc ADDED
Binary file (4.38 kB). View file
 
utils/aspect_ratio.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import PIL.Image
3
+
4
+ ANY_ASPECT_RATIO = (0, 0)
5
+
6
+ HW_ASPECT_RATIOS = [
7
+ (8, 32), # 256
8
+ (9, 28), # 252
9
+ (10, 25), # 250
10
+ (11, 23), # 253
11
+ (12, 21), # 252
12
+ (13, 19), # 247
13
+ (14, 18), # 252
14
+ (15, 17), # 255
15
+ (16, 16), # 256
16
+ (17, 15), # 255
17
+ (18, 14), # 252
18
+ (19, 13), # 247
19
+ (21, 12), # 252
20
+ (23, 11), # 253
21
+ (25, 10), # 250
22
+ (28, 9), # 252
23
+ (32, 8), # 256
24
+ ]
25
+
26
+
27
+ def get_ar_base(ars: list[tuple[int, int]] = HW_ASPECT_RATIOS):
28
+ sqrt_products = [round(np.sqrt(h * w)) for h, w in ars]
29
+ return round(np.mean(sqrt_products))
30
+
31
+
32
+ def ar2str(h: int, w: int) -> str:
33
+ return f"{h}*{w}"
34
+
35
+
36
+ def str2ar(s: str) -> tuple[int, int]:
37
+ return tuple(map(int, s.split("*")))
38
+
39
+ def center_crop_arr_with_buckets(pil_image, ars: list[tuple[int, int]] = HW_ASPECT_RATIOS, crop=True, buckets: list[int] = [256, 512, 768, 1024]):
40
+ """
41
+ Center crop the image to match the closest aspect ratio from the provided list.
42
+
43
+ Args:
44
+ pil_image: PIL Image to be cropped
45
+ image_size: Target size for the smaller dimension
46
+ ars: List of aspect ratios as (height, width) tuples
47
+
48
+ Returns:
49
+ PIL Image cropped to the closest aspect ratio
50
+ """
51
+ # ar_base = get_ar_base(ars)
52
+ # Get current image dimensions
53
+ width, height = pil_image.size
54
+
55
+ buckets = sorted(buckets, reverse=True)
56
+ image_size = buckets[-1]
57
+
58
+ for bucket in buckets:
59
+ if width * height >= bucket * bucket:
60
+ image_size = bucket
61
+ break
62
+
63
+ return center_crop_arr_with_ar(pil_image, image_size, ars, crop)
64
+
65
+ def center_crop_arr_with_ar(pil_image, image_size: int, ars: list[tuple[int, int]] = HW_ASPECT_RATIOS, crop=True):
66
+ """
67
+ Center crop the image to match the closest aspect ratio from the provided list.
68
+
69
+ Args:
70
+ pil_image: PIL Image to be cropped
71
+ image_sizes: Target size for the smaller dimension
72
+ ars: List of aspect ratios as (height, width) tuples
73
+
74
+ Returns:
75
+ PIL Image cropped to the closest aspect ratio
76
+ """
77
+
78
+ ar_base = get_ar_base(ars)
79
+ assert image_size % ar_base == 0, f"image_size must be divisible by {ar_base}"
80
+
81
+ # Get current image dimensions
82
+ width, height = pil_image.size
83
+
84
+ current_ar = height / width
85
+
86
+ # Find the closest aspect ratio
87
+ closest_ar_idx = np.argmin([abs(current_ar - (h / w)) for h, w in ars])
88
+ target_h, target_w = ars[closest_ar_idx]
89
+
90
+ if crop:
91
+ target_h, target_w = round(image_size / ar_base * target_h), round(image_size / ar_base * target_w)
92
+
93
+ # First, resize the image while maintaining aspect ratio to ensure the smaller dimension is at least the target size
94
+ scale = max(target_h / height, target_w / width)
95
+ new_height = round(height * scale)
96
+ new_width = round(width * scale)
97
+ pil_image = pil_image.resize((new_width, new_height), resample=PIL.Image.LANCZOS)
98
+
99
+ arr = np.array(pil_image)
100
+ # Then perform center crop to the target dimensions
101
+ crop_y = (new_height - target_h) // 2
102
+ crop_x = (new_width - target_w) // 2
103
+
104
+ return PIL.Image.fromarray(arr[crop_y : crop_y + target_h, crop_x : crop_x + target_w])
105
+ else:
106
+ scale = image_size // ar_base
107
+ return pil_image.resize((round(target_w * scale), round(target_h * scale)), resample=PIL.Image.LANCZOS)
utils/compile_utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import functools
3
+ import os
4
+ from typing import Callable, Dict, Optional
5
+
6
+ import torch
7
+
8
+ from loguru import logger
9
+
10
+ """
11
+ Usage:
12
+
13
+ 1. Control through environment variable (at startup):
14
+ export TORCH_COMPILE_ENABLE=true
15
+ python your_script.py
16
+
17
+ 2. Control through environment variable (disable):
18
+ export TORCH_COMPILE_ENABLE=false # or not set
19
+ python your_script.py
20
+
21
+ 3. Dynamically control in code:
22
+ compile_manager.set_compile_enabled(True) # enable
23
+ compile_manager.set_compile_enabled(False) # disable
24
+
25
+ 4. Select version at runtime:
26
+ # use the version configured
27
+ result = my_function(args)
28
+
29
+ # force use the original version
30
+ result = my_function.original(args)
31
+
32
+ # force use the compiled version
33
+ result = my_function.compiled(args)
34
+ """
35
+
36
+ # Global configuration: control whether to enable compile through environment variables
37
+ # Default set this env to true
38
+ ENABLE_TORCH_COMPILE = os.getenv("ENABLE_TORCH_COMPILE", "false").lower() == "true"
39
+
40
+
41
+ class CompileManager:
42
+ """Global controller for torch.compile"""
43
+
44
+ def __init__(self):
45
+ self.compile_enabled = ENABLE_TORCH_COMPILE
46
+ self.compiled_functions: Dict[str, Callable] = {}
47
+ self.original_functions: Dict[str, Callable] = {}
48
+
49
+ def set_compile_enabled(self, enabled: bool):
50
+ """Dynamic setting of whether to enable compile"""
51
+ self.compile_enabled = enabled
52
+
53
+ def get_compile_status(self):
54
+ """Get the current compile status"""
55
+ return self.compile_enabled
56
+
57
+ @contextlib.contextmanager
58
+ def compile_disabled(self):
59
+ """Temporarily disable compile within the context"""
60
+ original_status = self.compile_enabled
61
+ try:
62
+ self.compile_enabled = False
63
+ yield
64
+ finally:
65
+ self.compile_enabled = original_status
66
+
67
+
68
+ # global instance
69
+ compile_manager = CompileManager()
70
+
71
+
72
+ def smart_compile(func: Optional[Callable] = None, **compile_kwargs):
73
+ """
74
+ Smart compile decorator
75
+
76
+ Args:
77
+ func: The function to decorate
78
+ **compile_kwargs: Other compile parameters, see https://pytorch.org/docs/stable/generated/torch.compile.html
79
+ """
80
+
81
+ def decorator(fn: Callable) -> Callable:
82
+ # save the original function
83
+ original_func = fn
84
+ # Use qualified name to handle functions with same name in different classes
85
+ # Include module name to handle functions with same name in different files
86
+ func_name = f"{fn.__module__}.{fn.__qualname__}"
87
+ compile_manager.original_functions[func_name] = original_func
88
+
89
+ # if compile is disabled, return the original function
90
+ if not compile_manager.compile_enabled:
91
+ # add attributes to the original function for later access
92
+ original_func.original = original_func
93
+ original_func.compiled = original_func # point to itself
94
+ return original_func
95
+
96
+ # create the compiled function
97
+ try:
98
+ compiled_func = torch.compile(original_func, **compile_kwargs)
99
+ compile_manager.compiled_functions[func_name] = compiled_func
100
+ except Exception as e:
101
+ logger.warning(f"[WARNING] Failed to compile function {func_name}: {e}")
102
+ # if compile fails, revert to the original function
103
+ compiled_func = original_func
104
+
105
+ @functools.wraps(original_func)
106
+ def wrapper(*args, **kwargs):
107
+ # check whether to use the compiled version at runtime
108
+ if compile_manager.compile_enabled:
109
+ return compiled_func(*args, **kwargs)
110
+ else:
111
+ return original_func(*args, **kwargs)
112
+
113
+ # add attributes to the wrapper for later access
114
+ wrapper.original = original_func
115
+ wrapper.compiled = compiled_func
116
+
117
+ return wrapper
118
+
119
+ # support direct use of @smart_compile or @smart_compile(...)
120
+ if func is not None:
121
+ return decorator(func)
122
+ return decorator
utils/image_utils.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ from typing import Literal, TypeAlias
4
+
5
+ import numpy as np
6
+ import PIL.Image
7
+ import PIL.ImageOps
8
+ import requests
9
+ import torch
10
+
11
+ """
12
+ - pil: `PIL.Image.Image`, size (w, h), seamless conversion between `uint8`
13
+ - np: `np.ndarray`, shape (h, w, c), default `np.uint8`
14
+ - pt: `torch.Tensor`, shape (c, h, w), default `torch.uint8`
15
+ """
16
+ ImageType: TypeAlias = PIL.Image.Image | np.ndarray | torch.Tensor
17
+ ImageTypeStr: TypeAlias = Literal["pil", "np", "pt"]
18
+ ImageFormat: TypeAlias = Literal["JPEG", "PNG"]
19
+ DataFormat: TypeAlias = Literal["255", "01", "11"]
20
+
21
+
22
+ IMG_SUPPORT_MODE = ["L", "LA", "RGB", "RGBA", "CMYK", "P", "1"]
23
+ IMAGE_EXT_LOWER = ["png", "jpeg", "jpg", "webp"]
24
+ IMAGE_EXT = IMAGE_EXT_LOWER + [_ext.upper() for _ext in IMAGE_EXT_LOWER]
25
+
26
+
27
+ def check_image_type(image: ImageType):
28
+ if not (isinstance(image, PIL.Image.Image) or isinstance(image, np.ndarray) or isinstance(image, torch.Tensor)):
29
+ raise TypeError(f"`image` should be PIL Image, ndarray or Tensor. Got `{type(image)}`.")
30
+
31
+
32
+ def to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
33
+ # Automatically adjust the orientation of the image to match the direction it was taken.
34
+ image = PIL.ImageOps.exif_transpose(image)
35
+
36
+ if image.mode not in IMG_SUPPORT_MODE:
37
+ raise ValueError(f"Only support mode in `{IMG_SUPPORT_MODE}`, got `{image.mode}`")
38
+
39
+ if image.mode == "LA":
40
+ image = image.convert("RGBA")
41
+
42
+ # add white background for RGBA images, and convert to RGB
43
+ if image.mode == "RGBA":
44
+ background = PIL.Image.new("RGBA", image.size, "white")
45
+ image = PIL.Image.alpha_composite(background, image).convert("RGB")
46
+
47
+ # then convert to RGB
48
+ image = image.convert("RGB")
49
+
50
+ return image
51
+
52
+
53
+ def load_image(
54
+ image: str | os.PathLike | PIL.Image.Image | bytes,
55
+ *,
56
+ output_type: ImageTypeStr = "pil",
57
+ ) -> ImageType:
58
+ """
59
+ Loads `image` to a PIL Image, NumPy array or PyTorch tensor.
60
+
61
+ Args:
62
+ image (str | PIL.Image.Image): The path to image or PIL Image.
63
+ mode (ImageMode, optional): The mode to convert to. Defaults to None (no conversion).
64
+ The current version supports all possible conversions between "L", "RGB", "RGBA".
65
+ output_type (ImageTypeStr, optional): The type of the output image. Defaults to "pil".
66
+ The current version supports "pil", "np", "pt".
67
+
68
+ Returns:
69
+ ImageType: The loaded image in the given type.
70
+ """
71
+ timeout = 10
72
+ # Load the `image` into a PIL Image.
73
+ if isinstance(image, str) or isinstance(image, os.PathLike):
74
+ if image.startswith("http://") or image.startswith("https://"):
75
+ try:
76
+ image = PIL.Image.open(requests.get(image, stream=True, timeout=timeout).raw)
77
+ except requests.exceptions.Timeout:
78
+ raise ValueError(f"HTTP request timed out after {timeout} seconds")
79
+ elif os.path.isfile(image):
80
+ image = PIL.Image.open(image)
81
+ else:
82
+ raise ValueError(
83
+ f"Incorrect path or url, URLs must start with `http://`, `https://` or `s3+[profile]://`, and `{image}` is not a valid path."
84
+ )
85
+ elif isinstance(image, PIL.Image.Image):
86
+ image = image
87
+ elif isinstance(image, bytes):
88
+ image = PIL.Image.open(io.BytesIO(image))
89
+ else:
90
+ raise ValueError(f"`image` must be a path or PIL Image, got `{type(image)}`")
91
+
92
+ image = to_rgb(image)
93
+
94
+ if output_type == "pil":
95
+ image = image
96
+ elif output_type == "np":
97
+ image = to_np(image)
98
+ elif output_type == "pt":
99
+ image = to_pt(image)
100
+ else:
101
+ raise ValueError(f"`output_type` must be one of `{ImageTypeStr}`, got `{output_type}`")
102
+
103
+ return image
104
+
105
+
106
+ def to_pil(image: ImageType, image_mode: DataFormat | None = None) -> PIL.Image.Image:
107
+ """
108
+ Convert a NumPy array or a PyTorch tensor to a PIL image.
109
+ """
110
+ check_image_type(image)
111
+
112
+ if isinstance(image, PIL.Image.Image):
113
+ return image
114
+
115
+ elif isinstance(image, np.ndarray):
116
+ image = normalize_np(image, image_mode)
117
+
118
+ elif isinstance(image, torch.Tensor):
119
+ image = normalize_pt(image, image_mode)
120
+
121
+ image = image.cpu().permute(1, 2, 0).numpy()
122
+ assert image.dtype == np.uint8, f"Supposed to convert `torch.uint8` to `np.uint8`, but got `{image.dtype}`"
123
+
124
+ mode_map = {1: "L", 3: "RGB"}
125
+ mode = mode_map[image.shape[-1]]
126
+
127
+ if image.shape[-1] == 1:
128
+ image = image[:, :, 0]
129
+
130
+ return PIL.Image.fromarray(image, mode=mode)
131
+
132
+
133
+ def to_np(image: ImageType, image_mode: DataFormat | None = None) -> np.ndarray:
134
+ """
135
+ Convert a PIL image or a PyTorch tensor to a NumPy array.
136
+ """
137
+ check_image_type(image)
138
+
139
+ if isinstance(image, PIL.Image.Image):
140
+ image = np.array(image, np.uint8, copy=True)
141
+
142
+ if isinstance(image, np.ndarray):
143
+ image = normalize_np(image, image_mode)
144
+
145
+ elif isinstance(image, torch.Tensor):
146
+ image = normalize_pt(image, image_mode)
147
+
148
+ image = image.cpu().permute(1, 2, 0).numpy()
149
+ assert image.dtype == np.uint8, f"Supposed to convert `torch.uint8` to `np.uint8`, but got `{image.dtype}`"
150
+
151
+ return image
152
+
153
+
154
+ def to_pt(image: ImageType, image_mode: DataFormat | None = None) -> torch.Tensor:
155
+ """
156
+ Convert a PIL image or a NumPy array to a PyTorch tensor.
157
+ """
158
+ check_image_type(image)
159
+
160
+ if isinstance(image, torch.Tensor):
161
+ image = normalize_pt(image, image_mode)
162
+ return image
163
+
164
+ # convert PIL Image to NumPy array
165
+ if isinstance(image, PIL.Image.Image):
166
+ image = np.array(image, np.uint8, copy=True)
167
+
168
+ image = normalize_np(image, image_mode)
169
+
170
+ image = torch.from_numpy(image.transpose((2, 0, 1))).contiguous()
171
+ assert image.dtype == torch.uint8, f"Supposed to convert `np.uint8` to `torch.uint8`, but got `{image.dtype}`"
172
+ return image
173
+
174
+
175
+ def normalize_np(image: np.ndarray, image_mode: DataFormat | None = None) -> np.ndarray:
176
+ """
177
+ Normalize a NumPy array to the standard format of shape (h, w, c) and uint8.
178
+ """
179
+ if image.ndim not in {2, 3}:
180
+ raise ValueError(f"`image` should be 2 or 3 dimensions. Got {image.ndim} dimensions.")
181
+
182
+ elif image.ndim == 2:
183
+ # if 2D image, add channel dimension (HWC)
184
+ image = np.expand_dims(image, 2)
185
+
186
+ if image.shape[-1] not in {1, 3}:
187
+ raise ValueError(f"`image` should have 1 (`L`) or 3 (`RGB`) channels. Got {image.shape[-1]} channels.")
188
+
189
+ image = to_dataformat(image, image_mode=image_mode, mode="255")
190
+
191
+ return image
192
+
193
+
194
+ def normalize_pt(image: torch.Tensor, image_mode: DataFormat | None = None) -> torch.Tensor:
195
+ """
196
+ Normalize a PyTorch tensor to the standard format of shape (c, h, w) and uint8.
197
+ """
198
+ if image.ndimension() not in {2, 3}:
199
+ raise ValueError(f"`image` should be 2 or 3 dimensions. Got {image.ndimension()} dimensions.")
200
+
201
+ elif image.ndimension() == 2:
202
+ # if 2D image, add channel dimension (CHW)
203
+ image = image.unsqueeze(0)
204
+
205
+ # check number of channels
206
+ if image.shape[-3] not in {1, 3}:
207
+ raise ValueError(f"`image` should have 1 (`L`) or 3 (`RGB`) channels. Got {image.shape[-3]} channels.")
208
+
209
+ image = to_dataformat(image, image_mode=image_mode, mode="255")
210
+
211
+ return image
212
+
213
+
214
+ def to_dataformat(
215
+ image: ImageType,
216
+ *,
217
+ image_mode: DataFormat | None = None,
218
+ mode: DataFormat = "255",
219
+ ) -> np.ndarray | torch.Tensor:
220
+ check_image_type(image)
221
+
222
+ # convert PIL Image to NumPy array
223
+ if isinstance(image, PIL.Image.Image):
224
+ image = np.array(image, np.uint8, copy=True)
225
+ image_mode = "255"
226
+
227
+ # guess image mode
228
+ if image.dtype == np.uint8 or image.dtype == torch.uint8:
229
+ guess_image_mode = "255"
230
+ elif image.dtype == np.float32 or image.dtype == np.float16 or image.dtype == torch.float32 or image.dtype == torch.float16:
231
+ if image.min() < 0.0:
232
+ guess_image_mode = "11"
233
+ else:
234
+ guess_image_mode = "01"
235
+ else:
236
+ raise ValueError(f"Unsupported dtype `{image.dtype}`")
237
+
238
+ if image_mode is None:
239
+ image_mode = guess_image_mode
240
+ else:
241
+ if guess_image_mode != image_mode:
242
+ print(f"Guess image mode is `{guess_image_mode}`, but image mode is `{image_mode}`")
243
+
244
+ if isinstance(image, np.ndarray):
245
+ if image_mode == "255" and mode != "255":
246
+ np.clip((image.astype(np.float32) / 255), 0, 1, out=image)
247
+ if mode == "11":
248
+ np.clip((image * 2 - 1), -1, 1, out=image)
249
+
250
+ elif image_mode == "01" and mode != "01":
251
+ if mode == "255":
252
+ np.clip(image, 0, 1, out=image)
253
+ image = (image * 255).round().astype(np.uint8)
254
+ elif mode == "11":
255
+ np.clip((image * 2 - 1), -1, 1, out=image)
256
+
257
+ elif image_mode == "11" and mode != "11":
258
+ np.clip((image / 2 + 0.5), 0, 1, out=image)
259
+ if mode == "255":
260
+ image = (image * 255).round().astype(np.uint8)
261
+
262
+ elif isinstance(image, torch.Tensor):
263
+ if image_mode == "255" and mode != "255":
264
+ image = image.to(dtype=torch.float32).div(255).clamp(0, 1)
265
+ if mode == "11":
266
+ image = (image * 2 - 1).clamp(-1, 1)
267
+
268
+ elif image_mode == "01" and mode != "01":
269
+ if mode == "255":
270
+ image = image.clamp(0, 1)
271
+ image = (image * 255).round().to(dtype=torch.uint8)
272
+ elif mode == "11":
273
+ image = (image * 2 - 1).clamp(-1, 1)
274
+
275
+ elif image_mode == "11" and mode != "11":
276
+ image = (image / 2 + 0.5).clamp(0, 1)
277
+ if mode == "255":
278
+ image = image.mul(255).round().to(dtype=torch.uint8)
279
+
280
+ return image
281
+
282
+
283
+ def resize_image(pil_image, image_size):
284
+ while min(*pil_image.size) >= 2 * image_size:
285
+ pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=PIL.Image.BOX)
286
+
287
+ scale = image_size / min(*pil_image.size)
288
+ pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=PIL.Image.BICUBIC)
289
+ return pil_image
290
+
291
+
292
+ def center_crop_arr(pil_image, image_size, crop=True):
293
+ """
294
+ Center cropping implementation from ADM.
295
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
296
+ """
297
+ if crop:
298
+ pil_image = resize_image(pil_image, image_size)
299
+ arr = np.array(pil_image)
300
+ crop_y = (arr.shape[0] - image_size) // 2
301
+ crop_x = (arr.shape[1] - image_size) // 2
302
+ return PIL.Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
303
+ else:
304
+ # 将图像填充为正方形
305
+ width, height = pil_image.size
306
+ if width != height:
307
+ # 创建一个正方形画布,尺寸为较大的边长
308
+ max_dim = max(width, height)
309
+ padded_img = PIL.Image.new(pil_image.mode, (max_dim, max_dim), (0, 0, 0))
310
+ # 将原图居中粘贴到正方形画布上
311
+ padded_img.paste(pil_image, ((max_dim - width) // 2, (max_dim - height) // 2))
312
+ pil_image = padded_img
313
+ pil_image = resize_image(pil_image, image_size)
314
+ return pil_image
utils/misc.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import random
4
+
5
+ import torch
6
+
7
+
8
+ def set_seed(seed: int, rank: int = 0):
9
+ random.seed(seed + rank)
10
+ np.random.seed(seed + rank)
11
+ torch.manual_seed(seed + rank)
12
+ torch.cuda.manual_seed_all(seed + rank)
13
+ torch.backends.cudnn.deterministic = True
14
+ os.environ["PYTHONHASHSEED"] = str(seed + rank)
15
+
16
+ class LargeInt(int):
17
+ def __new__(cls, value):
18
+ if isinstance(value, str):
19
+ units = {"K": 1e3, "M": 1e6, "B": 1e9, "T": 1e12}
20
+ last_char = value[-1].upper()
21
+ if last_char in units:
22
+ num = float(value[:-1]) * units[last_char]
23
+ return super(LargeInt, cls).__new__(cls, int(num))
24
+ else:
25
+ return super(LargeInt, cls).__new__(cls, int(value))
26
+ else:
27
+ return super(LargeInt, cls).__new__(cls, value)
28
+
29
+ def __str__(self):
30
+ value = int(self)
31
+ if abs(value) < 1000:
32
+ return f"{value}"
33
+ for unit in ["", "K", "M", "B", "T"]:
34
+ if abs(value) < 1000:
35
+ return f"{value:.1f}{unit}"
36
+ value /= 1000
37
+ return f"{value:.1f}P" # P stands for Peta, or 10^15
38
+
39
+ def __repr__(self):
40
+ return f'"{self.__str__()}"' # Ensure repr also returns the string with quotes
41
+
42
+ def __json__(self):
43
+ return f'"{self.__str__()}"'
44
+
45
+ def __add__(self, other):
46
+ if isinstance(other, int):
47
+ return LargeInt(super().__add__(other))
48
+ return NotImplemented
49
+
50
+ def __radd__(self, other):
51
+ return self.__add__(other) # This ensures commutativity
utils/model_utils.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, pe_interpolation=1.0):
6
+ """
7
+ grid_size: int of the grid height and width
8
+ return:
9
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
10
+ """
11
+ grid_h = np.arange(grid_size, dtype=np.float32) / pe_interpolation
12
+ grid_w = np.arange(grid_size, dtype=np.float32) / pe_interpolation
13
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
14
+ grid = np.stack(grid, axis=0)
15
+
16
+ grid = grid.reshape([2, 1, grid_size, grid_size])
17
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
18
+ if cls_token and extra_tokens > 0:
19
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
20
+ return pos_embed
21
+
22
+
23
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
24
+ assert embed_dim % 2 == 0
25
+
26
+ # use half of dimensions to encode grid_h
27
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
28
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
29
+
30
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
31
+ return emb
32
+
33
+
34
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
35
+ """
36
+ embed_dim: output dimension for each position
37
+ pos: a list of positions to be encoded: size (M,)
38
+ out: (M, D)
39
+ """
40
+ assert embed_dim % 2 == 0
41
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
42
+ omega /= embed_dim / 2.0
43
+ omega = 1.0 / 10000**omega # (D/2,)
44
+
45
+ pos = pos.reshape(-1) # (M,)
46
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
47
+
48
+ emb_sin = np.sin(out) # (M, D/2)
49
+ emb_cos = np.cos(out) # (M, D/2)
50
+
51
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
52
+ return emb
53
+
54
+
55
+ def expand_t(t, x):
56
+ """Function to reshape time t to broadcastable dimension of x
57
+ Args:
58
+ t: [bsz,], time vector
59
+ x: [bsz,...], data point
60
+ """
61
+ dims = [1] * (len(x.size()) - 1)
62
+ t = t.view(t.size(0), *dims)
63
+ return t
64
+
65
+
66
+ def randn_tensor(shape, noise_repeat, device, dtype=torch.float32):
67
+ bsz = shape[0]
68
+ if bsz % noise_repeat != 0:
69
+ raise ValueError(f"Batch size ({bsz}) must be divisible by noise repeat ({noise_repeat})")
70
+ _shape = (noise_repeat,) + shape[1:]
71
+ _tensor = torch.randn(_shape, device=device, dtype=dtype).repeat(bsz // noise_repeat, 1)
72
+ return _tensor
73
+
74
+
75
+ def rotate_half(x):
76
+ """Rotates half the hidden dims of the input."""
77
+ x1 = x[..., : x.shape[-1] // 2]
78
+ x2 = x[..., x.shape[-1] // 2 :]
79
+ return torch.cat((-x2, x1), dim=-1)
80
+
81
+
82
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
83
+ cos = cos.unsqueeze(unsqueeze_dim)
84
+ sin = sin.unsqueeze(unsqueeze_dim)
85
+ q_embed = (q * cos) + (rotate_half(q) * sin)
86
+ k_embed = (k * cos) + (rotate_half(k) * sin)
87
+ return q_embed, k_embed
88
+
89
+
90
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
91
+ """
92
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
93
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
94
+ """
95
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
96
+ if n_rep == 1:
97
+ return hidden_states
98
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
99
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
100
+
101
+
102
+ def identity(input: torch.Tensor, *args, **kwargs) -> torch.Tensor:
103
+ return input
104
+
105
+
106
+ def rms_norm(
107
+ input: torch.Tensor,
108
+ normalized_shape: torch.Size,
109
+ eps: float = 1e-6,
110
+ ) -> torch.Tensor:
111
+ dtype = input.dtype
112
+ input = input.to(torch.float32)
113
+ variance = input.flatten(-len(normalized_shape)).pow(2).mean(dim=-1)[(...,) + (None,) * len(normalized_shape)]
114
+ input = input * torch.rsqrt(variance + eps)
115
+ return input.to(dtype)
116
+
117
+
118
+ def layer_norm(
119
+ input: torch.Tensor,
120
+ normalized_shape: torch.Size,
121
+ eps: float = 1e-6,
122
+ ) -> torch.Tensor:
123
+ dtype = input.dtype
124
+ input = input.to(torch.float32)
125
+ mean = input.flatten(-len(normalized_shape)).mean(dim=-1)[(...,) + (None,) * len(normalized_shape)]
126
+ variance = (input - mean).flatten(-len(normalized_shape)).pow(2).mean(dim=-1)[(...,) + (None,) * len(normalized_shape)]
127
+ input = (input - mean) * torch.rsqrt(variance + eps)
128
+ return input.to(dtype)
vae/__pycache__/nextstep_ae.cpython-310.pyc ADDED
Binary file (15.1 kB). View file
 
vae/checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99293255229a29297e2851858db3794497d1b0b09b20c308c1062636ea4bcdd9
3
+ size 335365010
vae/config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resolution": 256,
3
+ "in_channels": 3,
4
+ "ch": 128,
5
+ "out_ch": 3,
6
+ "ch_mult": [1, 2, 4, 4],
7
+ "num_res_blocks": 2,
8
+ "z_channels": 16,
9
+ "shift_factor": 0,
10
+ "scaling_factor": 1,
11
+ "deterministic": true,
12
+ "encoder_norm": true,
13
+ "psz": 1
14
+ }
vae/nextstep_ae.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import inspect
4
+ from dataclasses import dataclass, field, asdict
5
+ from loguru import logger
6
+ from omegaconf import OmegaConf
7
+ from tabulate import tabulate
8
+ from einops import rearrange
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch import Tensor
14
+ from torch.utils.checkpoint import checkpoint
15
+
16
+ from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
17
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
18
+
19
+ from utils.misc import LargeInt
20
+ from utils.model_utils import randn_tensor
21
+ from utils.compile_utils import smart_compile
22
+
23
+
24
+ @dataclass
25
+ class AutoEncoderParams:
26
+ resolution: int = 256
27
+ in_channels: int = 3
28
+ ch: int = 128
29
+ out_ch: int = 3
30
+ ch_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4])
31
+ num_res_blocks: int = 2
32
+ z_channels: int = 16
33
+ scaling_factor: float = 0.3611
34
+ shift_factor: float = 0.1159
35
+ deterministic: bool = False
36
+ encoder_norm: bool = False
37
+ psz: int | None = None
38
+
39
+
40
+ def swish(x: Tensor) -> Tensor:
41
+ return x * torch.sigmoid(x)
42
+
43
+
44
+ class AttnBlock(nn.Module):
45
+ def __init__(self, in_channels: int):
46
+ super().__init__()
47
+ self.in_channels = in_channels
48
+
49
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
50
+
51
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
52
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
53
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
54
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
55
+
56
+ def attention(self, h_: Tensor) -> Tensor:
57
+ h_ = self.norm(h_)
58
+ q = self.q(h_)
59
+ k = self.k(h_)
60
+ v = self.v(h_)
61
+
62
+ b, c, h, w = q.shape
63
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
64
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
65
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
66
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
67
+
68
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
69
+
70
+ def forward(self, x: Tensor) -> Tensor:
71
+ return x + self.proj_out(self.attention(x))
72
+
73
+
74
+ class ResnetBlock(nn.Module):
75
+ def __init__(self, in_channels: int, out_channels: int):
76
+ super().__init__()
77
+ self.in_channels = in_channels
78
+ out_channels = in_channels if out_channels is None else out_channels
79
+ self.out_channels = out_channels
80
+
81
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
82
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
83
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
84
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
85
+ if self.in_channels != self.out_channels:
86
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
87
+
88
+ def forward(self, x):
89
+ h = x
90
+ h = self.norm1(h)
91
+ h = swish(h)
92
+ h = self.conv1(h)
93
+
94
+ h = self.norm2(h)
95
+ h = swish(h)
96
+ h = self.conv2(h)
97
+
98
+ if self.in_channels != self.out_channels:
99
+ x = self.nin_shortcut(x)
100
+
101
+ return x + h
102
+
103
+
104
+ class Downsample(nn.Module):
105
+ def __init__(self, in_channels: int):
106
+ super().__init__()
107
+ # no asymmetric padding in torch conv, must do it ourselves
108
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
109
+
110
+ def forward(self, x: Tensor):
111
+ pad = (0, 1, 0, 1)
112
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
113
+ x = self.conv(x)
114
+ return x
115
+
116
+
117
+ class Upsample(nn.Module):
118
+ def __init__(self, in_channels: int):
119
+ super().__init__()
120
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
121
+
122
+ def forward(self, x: Tensor):
123
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
124
+ x = self.conv(x)
125
+ return x
126
+
127
+
128
+ class Encoder(nn.Module):
129
+ def __init__(
130
+ self,
131
+ resolution: int,
132
+ in_channels: int,
133
+ ch: int,
134
+ ch_mult: list[int],
135
+ num_res_blocks: int,
136
+ z_channels: int,
137
+ ):
138
+ super().__init__()
139
+ self.ch = ch
140
+ self.num_resolutions = len(ch_mult)
141
+ self.num_res_blocks = num_res_blocks
142
+ self.resolution = resolution
143
+ self.in_channels = in_channels
144
+ # downsampling
145
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
146
+
147
+ curr_res = resolution
148
+ in_ch_mult = (1,) + tuple(ch_mult)
149
+ self.in_ch_mult = in_ch_mult
150
+ self.down = nn.ModuleList()
151
+ block_in = self.ch
152
+ for i_level in range(self.num_resolutions):
153
+ block = nn.ModuleList()
154
+ attn = nn.ModuleList()
155
+ block_in = ch * in_ch_mult[i_level]
156
+ block_out = ch * ch_mult[i_level]
157
+ for _ in range(self.num_res_blocks):
158
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
159
+ block_in = block_out
160
+ down = nn.Module()
161
+ down.block = block
162
+ down.attn = attn
163
+ if i_level != self.num_resolutions - 1:
164
+ down.downsample = Downsample(block_in)
165
+ curr_res = curr_res // 2
166
+ self.down.append(down)
167
+
168
+ # middle
169
+ self.mid = nn.Module()
170
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
171
+ self.mid.attn_1 = AttnBlock(block_in)
172
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
173
+
174
+ # end
175
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
176
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
177
+
178
+ self.grad_checkpointing = False
179
+
180
+ @smart_compile()
181
+ def forward(self, x: Tensor) -> Tensor:
182
+ # downsampling
183
+ hs = [self.conv_in(x)]
184
+ for i_level in range(self.num_resolutions):
185
+ for i_block in range(self.num_res_blocks):
186
+ block_fn = self.down[i_level].block[i_block]
187
+ if self.grad_checkpointing:
188
+ h = checkpoint(block_fn, hs[-1])
189
+ else:
190
+ h = block_fn(hs[-1])
191
+ if len(self.down[i_level].attn) > 0:
192
+ attn_fn = self.down[i_level].attn[i_block]
193
+ if self.grad_checkpointing:
194
+ h = checkpoint(attn_fn, h)
195
+ else:
196
+ h = attn_fn(h)
197
+ hs.append(h)
198
+ if i_level != self.num_resolutions - 1:
199
+ hs.append(self.down[i_level].downsample(hs[-1]))
200
+
201
+ # middle
202
+ h = hs[-1]
203
+ h = self.mid.block_1(h)
204
+ h = self.mid.attn_1(h)
205
+ h = self.mid.block_2(h)
206
+ # end
207
+ h = self.norm_out(h)
208
+ h = swish(h)
209
+ h = self.conv_out(h)
210
+ return h
211
+
212
+
213
+ class Decoder(nn.Module):
214
+ def __init__(
215
+ self,
216
+ ch: int,
217
+ out_ch: int,
218
+ ch_mult: list[int],
219
+ num_res_blocks: int,
220
+ in_channels: int,
221
+ resolution: int,
222
+ z_channels: int,
223
+ ):
224
+ super().__init__()
225
+ self.ch = ch
226
+ self.num_resolutions = len(ch_mult)
227
+ self.num_res_blocks = num_res_blocks
228
+ self.resolution = resolution
229
+ self.in_channels = in_channels
230
+ self.ffactor = 2 ** (self.num_resolutions - 1)
231
+
232
+ # compute in_ch_mult, block_in and curr_res at lowest res
233
+ block_in = ch * ch_mult[self.num_resolutions - 1]
234
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
235
+ self.z_shape = (1, z_channels, curr_res, curr_res)
236
+
237
+ # z to block_in
238
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
239
+
240
+ # middle
241
+ self.mid = nn.Module()
242
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
243
+ self.mid.attn_1 = AttnBlock(block_in)
244
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
245
+
246
+ # upsampling
247
+ self.up = nn.ModuleList()
248
+ for i_level in reversed(range(self.num_resolutions)):
249
+ block = nn.ModuleList()
250
+ attn = nn.ModuleList()
251
+ block_out = ch * ch_mult[i_level]
252
+ for _ in range(self.num_res_blocks + 1):
253
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
254
+ block_in = block_out
255
+ up = nn.Module()
256
+ up.block = block
257
+ up.attn = attn
258
+ if i_level != 0:
259
+ up.upsample = Upsample(block_in)
260
+ curr_res = curr_res * 2
261
+ self.up.insert(0, up) # prepend to get consistent order
262
+
263
+ # end
264
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
265
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
266
+
267
+ self.grad_checkpointing = False
268
+
269
+ @smart_compile()
270
+ def forward(self, z: Tensor) -> Tensor:
271
+ # get dtype for proper tracing
272
+ upscale_dtype = next(self.up.parameters()).dtype
273
+
274
+ # z to block_in
275
+ h = self.conv_in(z)
276
+
277
+ # middle
278
+ h = self.mid.block_1(h)
279
+ h = self.mid.attn_1(h)
280
+ h = self.mid.block_2(h)
281
+
282
+ # cast to proper dtype
283
+ h = h.to(upscale_dtype)
284
+ # upsampling
285
+ for i_level in reversed(range(self.num_resolutions)):
286
+ for i_block in range(self.num_res_blocks + 1):
287
+ block_fn = self.up[i_level].block[i_block]
288
+ if self.grad_checkpointing:
289
+ h = checkpoint(block_fn, h)
290
+ else:
291
+ h = block_fn(h)
292
+ if len(self.up[i_level].attn) > 0:
293
+ attn_fn = self.up[i_level].attn[i_block]
294
+ if self.grad_checkpointing:
295
+ h = checkpoint(attn_fn, h)
296
+ else:
297
+ h = attn_fn(h)
298
+ if i_level != 0:
299
+ h = self.up[i_level].upsample(h)
300
+
301
+ # end
302
+ h = self.norm_out(h)
303
+ h = swish(h)
304
+ h = self.conv_out(h)
305
+ return h
306
+
307
+
308
+ def layer_norm_2d(input: torch.Tensor, normalized_shape: torch.Size, eps: float = 1e-6) -> torch.Tensor:
309
+ # input.shape = (bsz, c, h, w)
310
+ _input = input.permute(0, 2, 3, 1)
311
+ _input = F.layer_norm(_input, normalized_shape, None, None, eps)
312
+ _input = _input.permute(0, 3, 1, 2)
313
+ return _input
314
+
315
+
316
+ class AutoencoderKL(nn.Module):
317
+ def __init__(self, params: AutoEncoderParams):
318
+ super().__init__()
319
+ self.config = params
320
+ self.config = OmegaConf.create(asdict(self.config))
321
+ self.config.latent_channels = params.z_channels
322
+ self.config.block_out_channels = params.ch_mult
323
+
324
+ self.params = params
325
+ self.encoder = Encoder(
326
+ resolution=params.resolution,
327
+ in_channels=params.in_channels,
328
+ ch=params.ch,
329
+ ch_mult=params.ch_mult,
330
+ num_res_blocks=params.num_res_blocks,
331
+ z_channels=params.z_channels,
332
+ )
333
+ self.decoder = Decoder(
334
+ resolution=params.resolution,
335
+ in_channels=params.in_channels,
336
+ ch=params.ch,
337
+ out_ch=params.out_ch,
338
+ ch_mult=params.ch_mult,
339
+ num_res_blocks=params.num_res_blocks,
340
+ z_channels=params.z_channels,
341
+ )
342
+
343
+ self.encoder_norm = params.encoder_norm
344
+ self.psz = params.psz
345
+
346
+ self.apply(self._init_weights)
347
+
348
+ def _init_weights(self, module):
349
+ std = 0.02
350
+ if isinstance(module, (nn.Conv2d, nn.Linear)):
351
+ module.weight.data.normal_(mean=0.0, std=std)
352
+ if module.bias is not None:
353
+ module.bias.data.zero_()
354
+ elif isinstance(module, nn.GroupNorm):
355
+ if module.weight is not None:
356
+ module.weight.data.fill_(1.0)
357
+ if module.bias is not None:
358
+ module.bias.data.zero_()
359
+
360
+ def gradient_checkpointing_enable(self):
361
+ self.encoder.grad_checkpointing = True
362
+ self.decoder.grad_checkpointing = True
363
+
364
+ @property
365
+ def dtype(self):
366
+ return self.encoder.conv_in.weight.dtype
367
+
368
+ @property
369
+ def device(self):
370
+ return self.encoder.conv_in.weight.device
371
+
372
+ @property
373
+ def trainable_params(self) -> float:
374
+ n_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
375
+ return LargeInt(n_params)
376
+
377
+ @property
378
+ def params_info(self) -> str:
379
+ encoder_params = str(LargeInt(sum(p.numel() for p in self.encoder.parameters())))
380
+ decoder_params = str(LargeInt(sum(p.numel() for p in self.decoder.parameters())))
381
+ table = [["encoder", encoder_params], ["decoder", decoder_params]]
382
+ return tabulate(table, headers=["Module", "Params"], tablefmt="grid")
383
+
384
+ def get_last_layer(self):
385
+ return self.decoder.conv_out.weight
386
+
387
+ def patchify(self, img: torch.Tensor):
388
+ """
389
+ img: (bsz, C, H, W)
390
+ x: (bsz, patch_size**2 * C, H / patch_size, W / patch_size)
391
+ """
392
+ bsz, c, h, w = img.shape
393
+ p = self.psz
394
+ h_, w_ = h // p, w // p
395
+
396
+ img = img.reshape(bsz, c, h_, p, w_, p)
397
+ img = torch.einsum("nchpwq->ncpqhw", img)
398
+ x = img.reshape(bsz, c * p**2, h_, w_)
399
+ return x
400
+
401
+ def unpatchify(self, x: torch.Tensor):
402
+ """
403
+ x: (bsz, patch_size**2 * C, H / patch_size, W / patch_size)
404
+ img: (bsz, C, H, W)
405
+ """
406
+ bsz = x.shape[0]
407
+ p = self.psz
408
+ c = self.config.latent_channels
409
+ h_, w_ = x.shape[2], x.shape[3]
410
+
411
+ x = x.reshape(bsz, c, p, p, h_, w_)
412
+ x = torch.einsum("ncpqhw->nchpwq", x)
413
+ img = x.reshape(bsz, c, h_ * p, w_ * p)
414
+ return img
415
+
416
+ def encode(self, x: torch.Tensor, return_dict: bool = True):
417
+ moments = self.encoder(x)
418
+
419
+ mean, logvar = torch.chunk(moments, 2, dim=1)
420
+ if self.psz is not None:
421
+ mean = self.patchify(mean)
422
+
423
+ if self.encoder_norm:
424
+ mean = layer_norm_2d(mean, mean.size()[-1:])
425
+
426
+ if self.psz is not None:
427
+ mean = self.unpatchify(mean)
428
+
429
+ moments = torch.cat([mean, logvar], dim=1).contiguous()
430
+
431
+ posterior = DiagonalGaussianDistribution(moments, deterministic=self.params.deterministic)
432
+
433
+ if not return_dict:
434
+ return (posterior,)
435
+
436
+ return AutoencoderKLOutput(latent_dist=posterior)
437
+
438
+ def decode(self, z: torch.Tensor, return_dict: bool = True):
439
+ dec = self.decoder(z)
440
+
441
+ if not return_dict:
442
+ return (dec,)
443
+
444
+ return DecoderOutput(sample=dec)
445
+
446
+ def forward(self, input, sample_posterior=True, noise_strength=0.0):
447
+ posterior = self.encode(input).latent_dist
448
+ z = posterior.sample() if sample_posterior else posterior.mode()
449
+ if noise_strength > 0.0:
450
+ p = torch.distributions.Uniform(0, noise_strength)
451
+ z = z + p.sample((z.shape[0],)).reshape(-1, 1, 1, 1).to(z.device) * randn_tensor(
452
+ z.shape, device=z.device, dtype=z.dtype
453
+ )
454
+ dec = self.decode(z).sample
455
+ return dec, posterior
456
+
457
+ @classmethod
458
+ def from_pretrained(cls, model_path, **kwargs):
459
+ config_path = os.path.join(model_path, "config.json")
460
+ ckpt_path = os.path.join(model_path, "checkpoint.pt")
461
+
462
+ if not os.path.isdir(model_path) or not os.path.isfile(config_path) or not os.path.isfile(ckpt_path):
463
+ raise ValueError(
464
+ f"Invalid model path: {model_path}. The path should contain both config.json and checkpoint.pt files."
465
+ )
466
+
467
+ state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
468
+
469
+ with open(config_path, "r") as f:
470
+ config: dict = json.load(f)
471
+ config.update(kwargs)
472
+ kwargs = config
473
+
474
+ # Filter out kwargs that are not in AutoEncoderParams
475
+ # This ensures we only pass parameters that the model can accept
476
+ valid_kwargs = {}
477
+ param_signature = inspect.signature(AutoEncoderParams.__init__).parameters
478
+ for key, value in kwargs.items():
479
+ if key in param_signature:
480
+ valid_kwargs[key] = value
481
+ else:
482
+ logger.info(f"Ignoring parameter '{key}' as it's not defined in AutoEncoderParams")
483
+
484
+ params = AutoEncoderParams(**valid_kwargs)
485
+ model = cls(params)
486
+ try:
487
+ msg = model.load_state_dict(state_dict, strict=False)
488
+ logger.info(f"Loaded state_dict from {ckpt_path}")
489
+ logger.info(f"Missing keys:\n{msg.missing_keys}")
490
+ logger.info(f"Unexpected keys:\n{msg.unexpected_keys}")
491
+ except Exception as e:
492
+ logger.error(e)
493
+ logger.warning(f"Failed to load state_dict from {ckpt_path}, using random initialization")
494
+ return model
vocab.json ADDED
The diff for this file is too large to render. See raw diff