huangrh9 commited on
Commit
16ac099
·
verified ·
1 Parent(s): 52c0b51

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ tokenizer_config.json filter=lfs diff=lfs merge=lfs -text
added_tokens.json ADDED
The diff for this file is too large to render. See raw diff
 
aspect_ratio_utils.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+ import numpy as np
3
+ import math
4
+ import torch
5
+ from PIL import Image, ImageOps
6
+
7
+
8
+ RATIOS = [
9
+ (512, 512),
10
+ (384, 512),
11
+ (512, 384),
12
+ (384, 768),
13
+ (768, 384),
14
+ (384, 576),
15
+ (576, 384),
16
+ (320, 960),
17
+ (960, 320),
18
+ (256, 1024),
19
+ (1024, 256),
20
+ ]
21
+
22
+ RATIO_TYPES = [
23
+ 'ratio_h512_w512',
24
+ 'ratio_h384_w512',
25
+ 'ratio_h512_w384',
26
+ 'ratio_h384_w768',
27
+ 'ratio_h768_w384',
28
+ 'ratio_h384_w576',
29
+ 'ratio_h576_w384',
30
+ 'ratio_h320_w960',
31
+ 'ratio_h960_w320',
32
+ 'ratio_h256_w1024',
33
+ 'ratio_h1024_w256',
34
+ ]
35
+
36
+
37
+ def center_crop_and_resize(img, output_size=(256, 256)):
38
+ target_h, target_w = output_size
39
+ img_w, img_h = img.size
40
+
41
+ scale_w, scale_h = img_w / target_w, img_h / target_h
42
+ if scale_h > scale_w:
43
+ new_w, new_h = target_w, int(target_w / img_w * img_h)
44
+ else:
45
+ new_w, new_h = int(target_h / img_h * img_w), target_h
46
+
47
+ # Resize the image, keeping the aspect ratio
48
+ img = img.resize((new_w, new_h), Image.LANCZOS)
49
+
50
+ # Calculate the center cropping area
51
+ left = (new_w - target_w) // 2
52
+ top = (new_h - target_h) // 2
53
+ right = left + target_w
54
+ bottom = top + target_h
55
+
56
+ # Crop the extra part
57
+ img = img.crop((left, top, right, bottom))
58
+
59
+ return img
60
+
61
+ def resize_with_padding(img, output_size=(256, 256), fill_color=(0, 0, 0)):
62
+ target_height, target_width = output_size
63
+
64
+ # Step 1: Resize with aspect ratio preserved
65
+ original_width, original_height = img.size
66
+ ratio = min(target_width / original_width, target_height / original_height)
67
+ new_size = (int(original_width * ratio), int(original_height * ratio))
68
+ resized_image = img.resize(new_size, Image.LANCZOS)
69
+
70
+ # Step 2: Add padding to reach target size
71
+ delta_w = target_width - new_size[0]
72
+ delta_h = target_height - new_size[1]
73
+ padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
74
+ padded_image = ImageOps.expand(resized_image, padding, fill=fill_color)
75
+
76
+ return padded_image
77
+
78
+
79
+ def unpad_and_resize_back(padded_image, original_width, original_height):
80
+ """
81
+ Revert the padded+resized image back to original size.
82
+
83
+ Args:
84
+ padded_image (PIL.Image): Image after padding.
85
+ original_width (int): Original image width before resize & pad.
86
+ original_height (int): Original image height before resize & pad.
87
+
88
+ Returns:
89
+ PIL.Image: Image resized back to original resolution.
90
+ """
91
+ # Compute the scale factor used during the first resize
92
+ target_width, target_height = padded_image.size
93
+ ratio = min(target_width / original_width, target_height / original_height)
94
+ resized_w = int(original_width * ratio)
95
+ resized_h = int(original_height * ratio)
96
+
97
+ # Compute cropping box on padded image
98
+ left = (target_width - resized_w) // 2
99
+ upper = (target_height - resized_h) // 2
100
+ right = left + resized_w
101
+ lower = upper + resized_h
102
+
103
+ # Crop out the resized region (before padding)
104
+ cropped_image = padded_image.crop((left, upper, right, lower))
105
+
106
+ # Resize back to original resolution
107
+ recovered_image = cropped_image.resize((original_width, original_height), Image.LANCZOS)
108
+ return recovered_image
109
+
110
+
111
+ def calculate_ratio():
112
+ max_area = 512 * 512
113
+ ratios = [(2, 2), (3, 4), (4, 3), (2, 4), (4, 2), (1, 4), (4, 1), (2, 3), (3, 2), (1, 3), (3, 1)]
114
+ ratio_candicates = []
115
+ for ratio in ratios:
116
+ x = math.sqrt(max_area / ratio[0] / ratio[1])
117
+ x = round(x / 64) * 64
118
+ tmp = (x*ratio[0], x*ratio[1])
119
+ # print(ratio, x, tmp)
120
+ ratio_candicates.append(tmp)
121
+
122
+ print("ratio_candicates", ratio_candicates)
123
+ return ratio_candicates
124
+
125
+
126
+ class AspectRatioCrop(object):
127
+ """
128
+ Aspect Ratio Crop transform.
129
+ For a given image, find the corresponding aspect ratio and
130
+ resize / resize + crop to the corresponding base sizes
131
+
132
+ Args:
133
+ base_sizes: list[tuple], the base sizes of final output.
134
+ For example, [(512, 512), (512, 768), (768, 512)]
135
+
136
+ resize_and_crop: bool .If False, find the matched aspect ratio and resize to base size.
137
+ """
138
+
139
+ def __init__(self, base_sizes, crop_percent_thresh=0.2):
140
+ self.base_sizes = [(math.floor(h), math.floor(w)) for (h, w) in base_sizes]
141
+ self.aspect_ratios = [x[1] / x[0] for x in self.base_sizes] # w / h
142
+ self.crop_percent_thresh = crop_percent_thresh
143
+
144
+ def _find_size(self, w, h):
145
+ base_size_indexes = list(range(len(self.base_sizes)))
146
+ aspect_ratios = [self.aspect_ratios[i] for i in base_size_indexes]
147
+ aspect_ratio = w / h
148
+ ratio_diff = [abs(ratio - aspect_ratio) for ratio in aspect_ratios]
149
+ min_diff = np.min(ratio_diff)
150
+ match_diff_indexes = [j for j in range(len(ratio_diff)) if ratio_diff[j] == min_diff]
151
+ match_diff_indexes = sorted(match_diff_indexes, key=lambda x: (h-self.base_sizes[base_size_indexes[x]][0])**2
152
+ + (w-self.base_sizes[base_size_indexes[x]][1])**2) # pick the area most match one
153
+ corr_index = base_size_indexes[match_diff_indexes[0]]
154
+ return corr_index
155
+
156
+ def get_pred_target_w_h(self, w, h):
157
+ aspect_ratio = w / h
158
+ aspect_index = self._find_size(w, h)
159
+ pred_h, pred_w = self.base_sizes[aspect_index]
160
+
161
+ solutions = [
162
+ (pred_w, int(pred_w / aspect_ratio)),
163
+ (int(pred_h * aspect_ratio), pred_h),
164
+ ]
165
+ w_tar = None
166
+ h_tar = None
167
+ for solution in solutions:
168
+ w_s, h_s = solution
169
+ if w_s >= pred_w and h_s >= pred_h:
170
+ w_tar = w_s
171
+ h_tar = h_s
172
+
173
+ return pred_w, pred_h, w_tar, h_tar, aspect_index
174
+
175
+ def __call__(self, image, is_inference=False):
176
+ ## step 1: find the cloest aspect ratios
177
+ flag_matched = True
178
+ w, h = image.size
179
+ pred_w, pred_h, w_tar, h_tar, aspect_index = self.get_pred_target_w_h(w, h)
180
+
181
+ crop_percent = 1 - pred_w * pred_h / (w_tar * h_tar)
182
+ if self.crop_percent_thresh > 0 and crop_percent > self.crop_percent_thresh:
183
+ flag_matched = False # filter data
184
+
185
+ if not is_inference:
186
+ ## step 2: train: crop and resize
187
+ image = center_crop_and_resize(image, output_size=(pred_h, pred_w))
188
+ else:
189
+ ## step 2: inference: resize and padding
190
+ image = resize_with_padding(image, output_size=(pred_h, pred_w))
191
+
192
+ original_size = [h, w]
193
+ target_size = [pred_h, pred_w]
194
+
195
+ return image, original_size, target_size, flag_matched
196
+
config.json ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ILLUMEForConditionalGeneration"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_illume.ILLUMEConfig",
7
+ "AutoModel": "modeling_illume.ILLUMEForConditionalGeneration",
8
+ "AutoModelForCausalLM": "modeling_illume.ILLUMEForConditionalGeneration"
9
+ },
10
+ "ignore_index": -100,
11
+ "image_out_token_index": 282777,
12
+ "image_token_index": 282776,
13
+ "mm_projector_config": {
14
+ "hidden_size": 3584,
15
+ "mm_hidden_size": [
16
+ 3584,
17
+ 32
18
+ ],
19
+ "projector_cfg1": {
20
+ "mlp_depth": 2,
21
+ "type": "MLPProjector"
22
+ },
23
+ "projector_cfg2": {
24
+ "mlp_depth": 2,
25
+ "type": "MLPProjector"
26
+ },
27
+ "trainable": true,
28
+ "type": "MixedProjector"
29
+ },
30
+ "model_type": "illume",
31
+ "special_tokens_ids": {
32
+ "<end_of_image>": 151666,
33
+ "<end_of_level0>": 151669,
34
+ "<end_of_level1>": 151671,
35
+ "<end_of_line>": 151667,
36
+ "<start_of_image>": 151665,
37
+ "<start_of_level0>": 151668,
38
+ "<start_of_level1>": 151670
39
+ },
40
+ "text_config": {
41
+ "_name_or_path": "./Qwen2.5-7B-Instruct-with-vision-tokenizer-32k-96k-level2",
42
+ "architectures": [
43
+ "Qwen2ForCausalLM"
44
+ ],
45
+ "bos_token_id": 151643,
46
+ "eos_token_id": 151645,
47
+ "hidden_size": 3584,
48
+ "intermediate_size": 18944,
49
+ "model_type": "qwen2",
50
+ "num_attention_heads": 28,
51
+ "num_hidden_layers": 28,
52
+ "num_key_value_heads": 4,
53
+ "rope_theta": 1000000.0,
54
+ "torch_dtype": "bfloat16",
55
+ "vocab_size": 283175
56
+ },
57
+ "tie_word_embeddings": false,
58
+ "torch_dtype": "bfloat16",
59
+ "transformers_version": "4.44.2",
60
+ "vision_config": {
61
+ "_name_or_path": "./dualvitok/",
62
+ "architectures": [
63
+ "DualViTok"
64
+ ],
65
+ "auto_map": {
66
+ "AutoConfig": "configuration_dualvitok.DualViTokConfig",
67
+ "AutoModel": "modeling_dualvitok.DualViTok"
68
+ },
69
+ "model_type": "DualViTok",
70
+ "pixel_decoder": {
71
+ "attn_resolutions": [
72
+ 4
73
+ ],
74
+ "ch": 384,
75
+ "ch_mult": [
76
+ 1,
77
+ 1,
78
+ 2,
79
+ 2,
80
+ 4
81
+ ],
82
+ "codebook_size": 98304,
83
+ "embed_dim": 64,
84
+ "use_dc_up_down_blocks": true,
85
+ "z_channels": 64
86
+ },
87
+ "pixel_encoder": {
88
+ "attn_resolutions": [
89
+ 4
90
+ ],
91
+ "ch": 128,
92
+ "ch_mult": [
93
+ 1,
94
+ 1,
95
+ 2,
96
+ 2,
97
+ 4
98
+ ],
99
+ "codebook_size": 98304,
100
+ "embed_dim": 32,
101
+ "use_dc_up_down_blocks": true,
102
+ "z_channels": 32
103
+ },
104
+ "semantic_decoder": {
105
+ "pretrained_semantic_encoder": "Emova-ollm/qwen2vit600m",
106
+ "target_mlp": "norm"
107
+ },
108
+ "semantic_encoder": {
109
+ "pretrained_semantic_encoder": {
110
+ "_name_or_path": "Emova-ollm/qwen2vit600m",
111
+ "add_cross_attention": false,
112
+ "architectures": [
113
+ "Qwen2VisionTransformerPretrainedModel"
114
+ ],
115
+ "bad_words_ids": null,
116
+ "begin_suppress_tokens": null,
117
+ "bos_token_id": null,
118
+ "chunk_size_feed_forward": 0,
119
+ "cross_attention_hidden_size": null,
120
+ "decoder_start_token_id": null,
121
+ "depth": 32,
122
+ "diversity_penalty": 0.0,
123
+ "do_sample": false,
124
+ "early_stopping": false,
125
+ "embed_dim": 1280,
126
+ "encoder_no_repeat_ngram_size": 0,
127
+ "eos_token_id": null,
128
+ "exponential_decay_length_penalty": null,
129
+ "finetuning_task": null,
130
+ "forced_bos_token_id": null,
131
+ "forced_eos_token_id": null,
132
+ "hidden_act": "quick_gelu",
133
+ "hidden_size": 3584,
134
+ "id2label": {
135
+ "0": "LABEL_0",
136
+ "1": "LABEL_1"
137
+ },
138
+ "in_channels": 3,
139
+ "in_chans": 3,
140
+ "initializer_range": 0.02,
141
+ "is_decoder": false,
142
+ "is_encoder_decoder": false,
143
+ "label2id": {
144
+ "LABEL_0": 0,
145
+ "LABEL_1": 1
146
+ },
147
+ "length_penalty": 1.0,
148
+ "max_length": 20,
149
+ "min_length": 0,
150
+ "mlp_ratio": 4,
151
+ "model_type": "qwen2_vl",
152
+ "no_repeat_ngram_size": 0,
153
+ "num_beam_groups": 1,
154
+ "num_beams": 1,
155
+ "num_heads": 16,
156
+ "num_return_sequences": 1,
157
+ "output_attentions": false,
158
+ "output_hidden_states": false,
159
+ "output_scores": false,
160
+ "pad_token_id": null,
161
+ "patch_size": 14,
162
+ "prefix": null,
163
+ "problem_type": null,
164
+ "pruned_heads": {},
165
+ "remove_invalid_values": false,
166
+ "repetition_penalty": 1.0,
167
+ "return_dict": true,
168
+ "return_dict_in_generate": false,
169
+ "sep_token_id": null,
170
+ "spatial_merge_size": 2,
171
+ "spatial_patch_size": 14,
172
+ "suppress_tokens": null,
173
+ "task_specific_params": null,
174
+ "temperature": 1.0,
175
+ "temporal_patch_size": 2,
176
+ "tf_legacy_loss": false,
177
+ "tie_encoder_decoder": false,
178
+ "tie_word_embeddings": true,
179
+ "tokenizer_class": null,
180
+ "top_k": 50,
181
+ "top_p": 1.0,
182
+ "torch_dtype": "float32",
183
+ "torchscript": false,
184
+ "transformers_version": "4.44.2",
185
+ "typical_p": 1.0,
186
+ "use_bfloat16": false
187
+ }
188
+ },
189
+ "torch_dtype": "float16"
190
+ }
191
+ }
configuration_dualvitok.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ DualViTok model configuration """
2
+ import os
3
+ from typing import List, Union
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+ from transformers.utils import logging
7
+
8
+ from .configuration_movqgan import MoVQConfig
9
+
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+
14
+ class SemanticEncoderConfig(PretrainedConfig):
15
+ model_type = "DualViTokSemanticEncoder"
16
+
17
+ def __init__(
18
+ self,
19
+ pretrained_semantic_encoder='Emova-ollm/qwen2vit600m',
20
+ z_channels=32,
21
+ num_blocks=4,
22
+ embed_dim=1280,
23
+ out_layer='linear',
24
+ target_mlp='norm',
25
+ **kwargs
26
+ ):
27
+ super().__init__(**kwargs)
28
+ self.pretrained_semantic_encoder = pretrained_semantic_encoder
29
+ self.z_channels = z_channels
30
+ self.num_blocks = num_blocks
31
+ self.out_layer = out_layer
32
+ self.embed_dim = embed_dim
33
+ self.target_mlp = target_mlp
34
+
35
+
36
+ class SemanticDecoderConfig(PretrainedConfig):
37
+ model_type = "DualViTokSemanticDecoder"
38
+
39
+ def __init__(
40
+ self,
41
+ z_channels=32,
42
+ num_blocks=4,
43
+ embed_dim=1280,
44
+ out_layer='linear_norm',
45
+ out_channels=3584,
46
+ **kwargs
47
+ ):
48
+ super().__init__(**kwargs)
49
+ self.z_channels = z_channels
50
+ self.num_blocks = num_blocks
51
+ self.embed_dim = embed_dim
52
+ self.out_layer = out_layer
53
+ self.out_channels = out_channels
54
+
55
+
56
+ class DualViTokConfig(PretrainedConfig):
57
+ r"""
58
+ This is the configuration class to store the configuration of a [`DualViTok`]. It is used to instantiate an video movq
59
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
60
+ defaults will yield a configuration to the VQ model presented in paper.
61
+
62
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
63
+ documentation from [`PretrainedConfig`] for more information.
64
+
65
+
66
+ Args:
67
+ semantic_encoder (`dict`, *optional*):
68
+ Configuration dictionary for the semantic encoder. If `None`, defaults to `SemanticEncoderConfig()`.
69
+ The provided dictionary will be unpacked to initialize a `SemanticEncoderConfig` instance.
70
+ semantic_decoder (`dict`, *optional*):
71
+ Configuration dictionary for the semantic decoder. If `None`, defaults to `SemanticEncoderConfig()`.
72
+ The provided dictionary will be unpacked to initialize a `SemanticEncoderConfig` instance (note: uses `SemanticEncoderConfig` as per current implementation).
73
+ pixel_encoder (`dict`, *optional*):
74
+ Configuration dictionary for the pixel pathway's VQ-VAE model (e.g., `MoVQConfig`). If `None`, defaults to `MoVQConfig()`.
75
+ The provided dictionary will be unpacked to initialize a `MoVQConfig` instance, which defines both encoder and decoder for pixel-level features.
76
+ semantic_quantizer_type (`str`, *optional*, defaults to `'simvq'`):
77
+ Type of the quantizer for semantic tokens (e.g., `'simvq'`, `'ema_simvq'`).
78
+ pixel_quantizer_type (`str`, *optional*, defaults to `'simvq'`):
79
+ Type of the quantizer for pixel tokens (e.g., `'simvq'`, `'ema_simvq'`).
80
+ semantic_quantizer_codebook_size (`int`, *optional*, defaults to 32768):
81
+ Number of entries in the codebook for the semantic quantizer.
82
+ pixel_quantizer_codebook_size (`int`, *optional*, defaults to 98304):
83
+ Number of entries in the codebook for the pixel quantizer.
84
+ attn_implementation (`str`, *optional*, defaults to `'sdpa'`):
85
+ The attention implementation to use (e.g., `'sdpa'`, `'flash_attention_2'`, `'eager'`).
86
+ Can be `'sdpa'` (scaled dot product attention), `'flash_attention_2'` (if available and installed),
87
+ or `'eager'` (the default PyTorch attention implementation).
88
+
89
+ ```python
90
+ >>> from transformers import DualViTok, DualViTokConfig
91
+
92
+ >>> # Initializing a video VQ model of configuration
93
+ >>> configuration = DualViTokConfig()
94
+
95
+ >>> # Initializing a model from the VQ model style configuration
96
+ >>> model = DualViTok(configuration)
97
+
98
+ >>> # Accessing the model configuration
99
+ >>> configuration = model.config
100
+ ```"""
101
+
102
+ model_type = "DualViTok"
103
+
104
+ def __init__(
105
+ self,
106
+ semantic_encoder=None,
107
+ semantic_decoder=None,
108
+ pixel_encoder=None,
109
+ pixel_decoder=None,
110
+ semantic_quantizer_type='simvq',
111
+ pixel_quantizer_type='simvq',
112
+ semantic_quantizer_codebook_size=32768,
113
+ pixel_quantizer_codebook_size=98304,
114
+ attn_implementation='sdpa',
115
+ **kwargs,
116
+ ):
117
+ super().__init__(**kwargs)
118
+ if semantic_encoder is None:
119
+ self.semantic_encoder = SemanticEncoderConfig()
120
+ else:
121
+ self.semantic_encoder = SemanticEncoderConfig(**semantic_encoder)
122
+
123
+ if semantic_decoder is None:
124
+ self.semantic_decoder = SemanticDecoderConfig()
125
+ else:
126
+ self.semantic_decoder = SemanticDecoderConfig(**semantic_decoder)
127
+
128
+ self.semantic_quantizer_type = semantic_quantizer_type
129
+ self.pixel_quantizer_type = pixel_quantizer_type
130
+ self.semantic_quantizer_codebook_size = semantic_quantizer_codebook_size
131
+ self.pixel_quantizer_codebook_size = pixel_quantizer_codebook_size
132
+
133
+ if pixel_encoder is None:
134
+ self.pixel_encoder = MoVQConfig()
135
+ else:
136
+ self.pixel_encoder = MoVQConfig(**pixel_encoder)
137
+
138
+ self.pixel_decoder = self.pixel_encoder if pixel_decoder is None else MoVQConfig(**pixel_decoder)
139
+
140
+ self.attn_implementation = attn_implementation
141
+
142
+ @classmethod
143
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], attn_implementation='sdpa', **kwargs) -> "PretrainedConfig":
144
+ cls._set_token_in_kwargs(kwargs)
145
+
146
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
147
+
148
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
149
+ logger.warning(
150
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
151
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
152
+ )
153
+
154
+ return cls.from_dict(config_dict, attn_implementation=attn_implementation, **kwargs)
configuration_illume.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """ILLUME model configuration"""
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.utils import logging
5
+ from transformers.models.auto import CONFIG_MAPPING
6
+
7
+ # import the first three to make sure the last one recognize them.
8
+ from .modeling_rope_utils import rope_config_validation
9
+ from .configuration_movqgan import MoVQConfig
10
+ from .configuration_qwen2vit import Qwen2VLVisionConfig
11
+ from .configuration_dualvitok import DualViTokConfig
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+
16
+ class ILLUMEConfig(PretrainedConfig):
17
+ r"""
18
+ This is the configuration class to store the configuration of a [`ILLUMEForConditionalGeneration`]. It is used to instantiate an
19
+ ILLUME model according to the specified arguments, defining the model architecture.
20
+
21
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
22
+ documentation from [`PretrainedConfig`] for more information.
23
+
24
+ Args:
25
+ vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `DualViTokConfig`):
26
+ The config object or dictionary of the vision backbone.
27
+ mm_projector_config (`dict`, *optional*, defaults to `None`):
28
+ Configuration for the multimodal projector.
29
+ text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Qwen2Config`):
30
+ The config object or dictionary of the text backbone.
31
+ ignore_index (`int`, *optional*, defaults to -100):
32
+ The ignore index for the loss function.
33
+ image_token_index (`int`, *optional*, defaults to 32000):
34
+ The image token index to encode the image prompt.
35
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
36
+ Whether the model's input and output word embeddings should be tied.
37
+
38
+ Example:
39
+
40
+ ```python
41
+ >>> from transformers import ILLUMEForConditionalGeneration, ILLUMEConfig, CLIPVisionConfig, LlamaConfig
42
+
43
+ >>> # Initializing a CLIP-vision config
44
+ >>> vision_config = CLIPVisionConfig()
45
+
46
+ >>> # Initializing a Llama config
47
+ >>> text_config = LlamaConfig()
48
+
49
+ >>> # Initializing a ILLUME style configuration
50
+ >>> configuration = ILLUMEConfig(vision_config, text_config)
51
+
52
+ >>> # Initializing a model from the style configuration
53
+ >>> model = ILLUMEForConditionalGeneration(configuration)
54
+
55
+ >>> # Accessing the model configuration
56
+ >>> configuration = model.config
57
+ ```"""
58
+
59
+ model_type = "illume"
60
+ is_composition = False
61
+
62
+ def __init__(
63
+ self,
64
+ vision_config=None,
65
+ mm_projector_config=None,
66
+ text_config=None,
67
+ ignore_index=-100,
68
+ image_token_index=32000,
69
+ tie_word_embeddings=False,
70
+ **kwargs,
71
+ ):
72
+ self.ignore_index = ignore_index
73
+ self.image_token_index = image_token_index
74
+
75
+ if isinstance(vision_config, dict):
76
+ vision_config = DualViTokConfig(**vision_config)
77
+ elif vision_config is None:
78
+ vision_config = DualViTokConfig({
79
+ "semantic_encoder": {
80
+ "pretrained_semantic_encoder":
81
+ "Emova-ollm/qwen2vit600m",
82
+ "z_channels": 32,
83
+ "num_blocks": 4,
84
+ "out_layer": "linear",
85
+ "embed_dim": 1280,
86
+ "target_mlp": "norm"
87
+ },
88
+ "semantic_decoder": {
89
+ "z_channels": 32,
90
+ "num_blocks": 4,
91
+ "embed_dim": 1280,
92
+ "out_layer": "linear_norm",
93
+ "out_channels": 3584
94
+ },
95
+ "semantic_quantizer_type": "simvq",
96
+ "pixel_quantizer_type": "simvq",
97
+ "semantic_quantizer_codebook_size": 32768,
98
+ "pixel_quantizer_codebook_size": 98304,
99
+ "attn_implementation": "sdpa",
100
+ "pixel_encoder": {
101
+ "codebook_size": 98304,
102
+ "embed_dim": 32,
103
+ "z_channels": 32,
104
+ "double_z": False,
105
+ "in_channels": 3,
106
+ "out_channels": 3,
107
+ "ch": 128,
108
+ "ch_mult": [ 1, 1, 2, 2, 4 ],
109
+ "num_res_blocks": 2,
110
+ "attn_resolutions": [ 4 ],
111
+ "dropout": 0.0,
112
+ "use_dc_up_down_blocks": True
113
+ },
114
+ "pixel_decoder": {
115
+ "codebook_size": 98304,
116
+ "embed_dim": 64,
117
+ "z_channels": 64,
118
+ "double_z": False,
119
+ "in_channels": 3,
120
+ "out_channels": 3,
121
+ "ch": 384,
122
+ "ch_mult": [ 1, 1, 2, 2, 4 ],
123
+ "num_res_blocks": 2,
124
+ "attn_resolutions": [4],
125
+ "dropout": 0.0,
126
+ "use_dc_up_down_blocks": True
127
+ },
128
+ }
129
+ )
130
+
131
+ self.vision_config = vision_config
132
+ self.mm_projector_config = mm_projector_config
133
+ if isinstance(text_config, dict):
134
+ text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "qwen2"
135
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
136
+ elif text_config is None:
137
+ text_config = CONFIG_MAPPING["qwen2"]()
138
+
139
+ self.text_config = text_config
140
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
configuration_movqgan.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ MoVQ model configuration """
2
+
3
+ from typing import List
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+ from transformers.utils import logging
7
+
8
+
9
+ logger = logging.get_logger(__name__)
10
+
11
+
12
+ class MoVQConfig(PretrainedConfig):
13
+ r"""
14
+ This is the configuration class to store the configuration of a [`MoVQ`]. It is used to instantiate an video movq
15
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
16
+ defaults will yield a configuration to the VQ model presented in paper.
17
+
18
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
19
+ documentation from [`PretrainedConfig`] for more information.
20
+
21
+
22
+ Args:
23
+ codebook_size (`int`, *optional*, defaults to 32768):
24
+ Codebook size of the VQ model.
25
+ embed_dim (`int`, *optional*, defaults to 4):
26
+ Dimension of the quantized vector in codebook.
27
+ z_channels (`int`, *optional*, defaults to 4):
28
+ Dimension of the output channel of encoder and the input channel of decoder
29
+ double_z (`bool`, *optional*, defaults to False):
30
+ Whether double the output dim of the encoder.
31
+ in_channels (`int`, *optional*, defaults to 3):
32
+ Input channel of encoder.
33
+ out_channels (`int`, *optional*, defaults to 3):
34
+ Output channel of decoder.
35
+ ch (`int`, *optional*, defaults to 256):
36
+ Basic channel number of the intermediate blocks.
37
+ ch_mult (`List[int]`, *optional*, defaults to `[1, 2, 2, 4]`):
38
+ Channel scaling factor of the intermediate blocks.
39
+ num_res_blocks (`int`, *optional*, defaults to 2):
40
+ Residual block number in each stage.
41
+ attn_resolutions (`List[int]`, *optional*, defaults to 3):
42
+ Stage indices to apply attention.
43
+ dropout (`float`, *optional*, defaults to 0.0):
44
+ Dropout probability.
45
+ use_dc_up_down_blocks (`bool`, *optional*, defaults to `False`):
46
+ Whether to use the DC up-down blocks.
47
+
48
+ ```python
49
+ >>> from transformers import MoVQ, MoVQConfig
50
+
51
+ >>> # Initializing a video VQ model of configuration
52
+ >>> configuration = MoVQConfig()
53
+
54
+ >>> # Initializing a model from the VQ model style configuration
55
+ >>> model = MoVQModel(configuration)
56
+
57
+ >>> # Accessing the model configuration
58
+ >>> configuration = model.config
59
+ ```"""
60
+
61
+ model_type = "MoVQ"
62
+
63
+ def __init__(
64
+ self,
65
+ codebook_size: int = 32768,
66
+ embed_dim: int = 4,
67
+ z_channels: int = 4,
68
+ double_z: bool = False,
69
+ in_channels: int = 3,
70
+ out_channels: int = 3,
71
+ ch: int = 256,
72
+ ch_mult: List[int] = [1, 2, 2, 4],
73
+ num_res_blocks: int = 2,
74
+ attn_resolutions: List[int] = [3],
75
+ dropout: float = 0.0,
76
+ use_dc_up_down_blocks=False,
77
+ **kwargs,
78
+ ):
79
+ super().__init__(**kwargs)
80
+
81
+ self.codebook_size = codebook_size
82
+ self.embed_dim = embed_dim
83
+ self.z_channels = z_channels
84
+ self.double_z = double_z
85
+ self.in_channels = in_channels
86
+ self.out_channels = out_channels
87
+ self.ch = ch
88
+ self.ch_mult = ch_mult
89
+ self.num_res_blocks = num_res_blocks
90
+ self.attn_resolutions = attn_resolutions
91
+ self.dropout = dropout
92
+ self.use_dc_up_down_blocks = use_dc_up_down_blocks
configuration_qwen2vit.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Qwen2VL model configuration"""
16
+
17
+ import os
18
+ from typing import Union
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+ from .modeling_rope_utils import rope_config_validation
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class Qwen2VLVisionConfig(PretrainedConfig):
28
+ model_type = "qwen2_vl"
29
+
30
+ def __init__(
31
+ self,
32
+ depth=32,
33
+ embed_dim=1280,
34
+ hidden_size=3584,
35
+ hidden_act="quick_gelu",
36
+ mlp_ratio=4,
37
+ num_heads=16,
38
+ in_channels=3,
39
+ patch_size=14,
40
+ spatial_merge_size=2,
41
+ temporal_patch_size=2,
42
+ attn_implementation='eager',
43
+ init_weights=False,
44
+ **kwargs,
45
+ ):
46
+ super().__init__(**kwargs)
47
+
48
+ self.depth = depth
49
+ self.embed_dim = embed_dim
50
+ self.hidden_size = hidden_size
51
+ self.hidden_act = hidden_act
52
+ self.mlp_ratio = mlp_ratio
53
+ self.num_heads = num_heads
54
+ self.in_channels = in_channels
55
+ self.patch_size = patch_size
56
+ self.spatial_merge_size = spatial_merge_size
57
+ self.temporal_patch_size = temporal_patch_size
58
+ self.attn_implementation = attn_implementation if attn_implementation else 'eager'
59
+
60
+ self.init_weights = init_weights
61
+
62
+ @classmethod
63
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
64
+ cls._set_token_in_kwargs(kwargs)
65
+
66
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
67
+
68
+ # if config_dict.get("model_type") == "qwen2_vl":
69
+ # config_dict = config_dict["vision_config"]
70
+
71
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
72
+ logger.warning(
73
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
74
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
75
+ )
76
+
77
+ return cls.from_dict(config_dict, **kwargs)
78
+
79
+
80
+ class Qwen2VLConfig(PretrainedConfig):
81
+ r"""
82
+ This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a
83
+ Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
84
+ with the defaults will yield a similar configuration to that of
85
+ Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
86
+
87
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
88
+ documentation from [`PretrainedConfig`] for more information.
89
+
90
+
91
+ Args:
92
+ vocab_size (`int`, *optional*, defaults to 152064):
93
+ Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the
94
+ `inputs_ids` passed when calling [`Qwen2VLModel`]
95
+ hidden_size (`int`, *optional*, defaults to 8192):
96
+ Dimension of the hidden representations.
97
+ intermediate_size (`int`, *optional*, defaults to 29568):
98
+ Dimension of the MLP representations.
99
+ num_hidden_layers (`int`, *optional*, defaults to 80):
100
+ Number of hidden layers in the Transformer encoder.
101
+ num_attention_heads (`int`, *optional*, defaults to 64):
102
+ Number of attention heads for each attention layer in the Transformer encoder.
103
+ num_key_value_heads (`int`, *optional*, defaults to 8):
104
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
105
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
106
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
107
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
108
+ by meanpooling all the original heads within that group. For more details checkout [this
109
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
110
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
111
+ The non-linear activation function (function or string) in the decoder.
112
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
113
+ The maximum sequence length that this model might ever be used with.
114
+ initializer_range (`float`, *optional*, defaults to 0.02):
115
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
116
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
117
+ The epsilon used by the rms normalization layers.
118
+ use_cache (`bool`, *optional*, defaults to `True`):
119
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
120
+ relevant if `config.is_decoder=True`.
121
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
122
+ Whether the model's input and output word embeddings should be tied.
123
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
124
+ The base period of the RoPE embeddings.
125
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
126
+ Whether to use sliding window attention.
127
+ sliding_window (`int`, *optional*, defaults to 4096):
128
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
129
+ max_window_layers (`int`, *optional*, defaults to 80):
130
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
131
+ attention_dropout (`float`, *optional*, defaults to 0.0):
132
+ The dropout ratio for the attention probabilities.
133
+ vision_config (`Dict`, *optional*):
134
+ The config for the visual encoder initialization.
135
+ rope_scaling (`Dict`, *optional*):
136
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
137
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
138
+ accordingly.
139
+ Expected contents:
140
+ `rope_type` (`str`):
141
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
142
+ 'llama3'], with 'default' being the original RoPE implementation.
143
+ `factor` (`float`, *optional*):
144
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
145
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
146
+ original maximum pre-trained length.
147
+ `original_max_position_embeddings` (`int`, *optional*):
148
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
149
+ pretraining.
150
+ `attention_factor` (`float`, *optional*):
151
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
152
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
153
+ `factor` field to infer the suggested value.
154
+ `beta_fast` (`float`, *optional*):
155
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
156
+ ramp function. If unspecified, it defaults to 32.
157
+ `beta_slow` (`float`, *optional*):
158
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
159
+ ramp function. If unspecified, it defaults to 1.
160
+ `short_factor` (`List[float]`, *optional*):
161
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
162
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
163
+ size divided by the number of attention heads divided by 2
164
+ `long_factor` (`List[float]`, *optional*):
165
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
166
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
167
+ size divided by the number of attention heads divided by 2
168
+ `low_freq_factor` (`float`, *optional*):
169
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
170
+ `high_freq_factor` (`float`, *optional*):
171
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
172
+
173
+ ```python
174
+ >>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig
175
+
176
+ >>> # Initializing a Qwen2VL style configuration
177
+ >>> configuration = Qwen2VLConfig()
178
+
179
+ >>> # Initializing a model from the Qwen2-VL-7B style configuration
180
+ >>> model = Qwen2VLForConditionalGeneration(configuration)
181
+
182
+ >>> # Accessing the model configuration
183
+ >>> configuration = model.config
184
+ ```"""
185
+
186
+ model_type = "qwen2_vl"
187
+ keys_to_ignore_at_inference = ["past_key_values"]
188
+
189
+ def __init__(
190
+ self,
191
+ vocab_size=152064,
192
+ hidden_size=8192,
193
+ intermediate_size=29568,
194
+ num_hidden_layers=80,
195
+ num_attention_heads=64,
196
+ num_key_value_heads=8,
197
+ hidden_act="silu",
198
+ max_position_embeddings=32768,
199
+ initializer_range=0.02,
200
+ rms_norm_eps=1e-05,
201
+ use_cache=True,
202
+ tie_word_embeddings=False,
203
+ rope_theta=1000000.0,
204
+ use_sliding_window=False,
205
+ sliding_window=4096,
206
+ max_window_layers=80,
207
+ attention_dropout=0.0,
208
+ vision_config=None,
209
+ rope_scaling=None,
210
+ **kwargs,
211
+ ):
212
+ if isinstance(vision_config, dict):
213
+ self.vision_config = Qwen2VLVisionConfig(**vision_config)
214
+ elif vision_config is None:
215
+ self.vision_config = Qwen2VLVisionConfig()
216
+
217
+ self.vocab_size = vocab_size
218
+ self.max_position_embeddings = max_position_embeddings
219
+ self.hidden_size = hidden_size
220
+ self.intermediate_size = intermediate_size
221
+ self.num_hidden_layers = num_hidden_layers
222
+ self.num_attention_heads = num_attention_heads
223
+ self.use_sliding_window = use_sliding_window
224
+ self.sliding_window = sliding_window
225
+ self.max_window_layers = max_window_layers
226
+
227
+ # for backward compatibility
228
+ if num_key_value_heads is None:
229
+ num_key_value_heads = num_attention_heads
230
+
231
+ self.num_key_value_heads = num_key_value_heads
232
+ self.hidden_act = hidden_act
233
+ self.initializer_range = initializer_range
234
+ self.rms_norm_eps = rms_norm_eps
235
+ self.use_cache = use_cache
236
+ self.rope_theta = rope_theta
237
+ self.attention_dropout = attention_dropout
238
+ self.rope_scaling = rope_scaling
239
+
240
+ # Validate the correctness of rotary position embeddings parameters
241
+ # BC: if there is a 'type' field, move it to 'rope_type'.
242
+ # and change type from 'mrope' to 'default'
243
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
244
+ if self.rope_scaling["type"] == "mrope":
245
+ self.rope_scaling["type"] = "default"
246
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
247
+ rope_config_validation(self)
248
+
249
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 151643,
4
+ "eos_token_id": 151645,
5
+ "transformers_version": "4.44.2"
6
+ }
image_processing_dualvitok.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image processor class for DualViTok."""
2
+
3
+ from transformers.utils import TensorType, is_vision_available, logging
4
+
5
+ from .image_processing_movqgan import MoVQImageProcessor
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class DualViTokImageProcessor(MoVQImageProcessor):
11
+ r"""
12
+ Constructs a DualViTok image processor that dynamically resizes images based on the original images.
13
+ This image processor is based on MoVQImageProcessor with spatial_factor of 16.
14
+ """
15
+
16
+ model_input_names = ["pixel_values"]
17
+
18
+ def __init__(
19
+ self,
20
+ *args,
21
+ spatial_factor: int = 16,
22
+ **kwargs,
23
+ ) -> None:
24
+ super().__init__(*args, spatial_factor=spatial_factor, **kwargs)
image_processing_movqgan.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image processor class for MoVQ."""
2
+
3
+
4
+ import math
5
+ from typing import Dict, List, Optional, Union
6
+
7
+ import numpy as np
8
+
9
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
10
+ from transformers.image_transforms import (
11
+ convert_to_rgb,
12
+ resize,
13
+ to_channel_dimension_format,
14
+ )
15
+ from transformers.image_utils import (
16
+ IMAGENET_STANDARD_MEAN,
17
+ IMAGENET_STANDARD_STD,
18
+ ChannelDimension,
19
+ ImageInput,
20
+ PILImageResampling,
21
+ get_image_size,
22
+ infer_channel_dimension_format,
23
+ is_scaled_image,
24
+ make_list_of_images,
25
+ to_numpy_array,
26
+ valid_images,
27
+ validate_preprocess_arguments,
28
+ )
29
+ from transformers.utils import TensorType, is_vision_available, logging
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ if is_vision_available():
36
+ from PIL import Image
37
+
38
+
39
+ def smart_resize(
40
+ height: int, width: int, factor: int = 8, min_pixels: int = 512 * 512, max_pixels: int = 1024 * 1024
41
+ ):
42
+ """Rescales the image so that the following conditions are met:
43
+
44
+ 1. Both dimensions (height and width) are divisible by 'factor'.
45
+
46
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
47
+
48
+ 3. The aspect ratio of the image is maintained as closely as possible.
49
+
50
+ """
51
+ # if height < factor or width < factor:
52
+ # raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
53
+ # elif max(height, width) / min(height, width) > 5:
54
+ # raise ValueError(
55
+ # f"absolute aspect ratio must be smaller than 5, got {max(height, width) / min(height, width)}"
56
+ # )
57
+
58
+ h_bar = round(height / factor) * factor
59
+ w_bar = round(width / factor) * factor
60
+ if h_bar * w_bar > max_pixels:
61
+ beta = math.sqrt((height * width) / max_pixels)
62
+ h_bar = math.floor(height / beta / factor) * factor
63
+ w_bar = math.floor(width / beta / factor) * factor
64
+ elif h_bar * w_bar < min_pixels:
65
+ beta = math.sqrt(min_pixels / (height * width))
66
+ h_bar = math.ceil(height * beta / factor) * factor
67
+ w_bar = math.ceil(width * beta / factor) * factor
68
+
69
+ return max(h_bar, factor), max(w_bar, factor)
70
+
71
+
72
+ class MoVQImageProcessor(BaseImageProcessor):
73
+ r"""
74
+ Constructs a MoVQ image processor that dynamically resizes images based on the original images.
75
+
76
+ Args:
77
+ do_resize (`bool`, *optional*, defaults to `True`):
78
+ Whether to resize the image's (height, width) dimensions.
79
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
80
+ Resampling filter to use when resizing the image.
81
+ do_rescale (`bool`, *optional*, defaults to `True`):
82
+ Whether to rescale the image by the specified scale `rescale_factor`.
83
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
84
+ Scale factor to use if rescaling the image.
85
+ do_normalize (`bool`, *optional*, defaults to `True`):
86
+ Whether to normalize the image.
87
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
88
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
89
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
90
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
91
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
92
+ Whether to convert the image to RGB.
93
+ min_pixels (`int`, *optional*, defaults to `512 * 512`):
94
+ The min pixels of the image to resize the image.
95
+ max_pixels (`int`, *optional*, defaults to `1024 * 1024`):
96
+ The max pixels of the image to resize the image.
97
+ spatial_factor (`int`, *optional*, defautls to 8):
98
+ The spatial downsample factor the image will be downsampled in feature extracting phase
99
+ """
100
+
101
+ model_input_names = ["pixel_values"]
102
+
103
+ def __init__(
104
+ self,
105
+ do_resize: bool = True,
106
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
107
+ do_rescale: bool = True,
108
+ rescale_factor: Union[int, float] = 1 / 255,
109
+ do_normalize: bool = True,
110
+ image_mean: Optional[Union[float, List[float]]] = None,
111
+ image_std: Optional[Union[float, List[float]]] = None,
112
+ do_convert_rgb: bool = True,
113
+ min_pixels: int = 32 * 32,
114
+ max_pixels: int = 1024 * 1024,
115
+ spatial_factor: int = 8,
116
+ **kwargs,
117
+ ) -> None:
118
+ super().__init__(**kwargs)
119
+ self.do_resize = do_resize
120
+ self.resample = resample
121
+ self.do_rescale = do_rescale
122
+ self.rescale_factor = rescale_factor
123
+ self.do_normalize = do_normalize
124
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
125
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
126
+ self.min_pixels = min_pixels
127
+ self.max_pixels = max_pixels
128
+ self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
129
+ self.do_convert_rgb = do_convert_rgb
130
+ self.spatial_factor = spatial_factor
131
+
132
+ def _preprocess(
133
+ self,
134
+ images: ImageInput,
135
+ do_resize: Optional[bool] = None,
136
+ resample: PILImageResampling = None,
137
+ do_rescale: Optional[bool] = None,
138
+ rescale_factor: Optional[float] = None,
139
+ do_normalize: Optional[bool] = None,
140
+ image_mean: Optional[Union[float, List[float]]] = None,
141
+ image_std: Optional[Union[float, List[float]]] = None,
142
+ do_convert_rgb: Optional[bool] = None,
143
+ spatial_factor: Optional[int] = None,
144
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
145
+ output_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
146
+ ):
147
+ """
148
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
149
+
150
+ Args:
151
+ images (`ImageInput`):
152
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
153
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
154
+ Whether to resize the image.
155
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
156
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
157
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
158
+ Whether to rescale the image.
159
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
160
+ Scale factor to use if rescaling the image.
161
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
162
+ Whether to normalize the image.
163
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
164
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
165
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
166
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
167
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
168
+ Whether to convert the image to RGB.
169
+ spatial_factor (`int`, *optional*, defaults to `self.spatial_factor`):
170
+ The spatial downsample factor the image will be downsampled in feature extracting phase
171
+ input_data_format (`ChannelDimension` or `str`, *optional*):
172
+ The channel dimension format for the input image. Can be one of:
173
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
174
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
175
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
176
+ output_data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
177
+ The channel dimension format for the output image. Can be one of:
178
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
179
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
180
+ - Unset: Use the channel dimension format of the input image.
181
+ """
182
+ spatial_factor = spatial_factor if spatial_factor is not None else self.spatial_factor
183
+
184
+ images = make_list_of_images(images)
185
+ if do_convert_rgb:
186
+ images = [convert_to_rgb(image) for image in images]
187
+
188
+ # All transformations expect numpy arrays.
189
+ images = [to_numpy_array(image) for image in images]
190
+
191
+ if is_scaled_image(images[0]) and do_rescale:
192
+ logger.warning_once(
193
+ "It looks like you are trying to rescale already rescaled images. If the input"
194
+ "pixel_values.append()images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
195
+ )
196
+
197
+ if input_data_format is None:
198
+ # We assume that all images have the same channel dimension format.
199
+ input_data_format = infer_channel_dimension_format(images[0])
200
+
201
+ height, width = get_image_size(images[0], channel_dim=input_data_format)
202
+ resized_height, resized_width = height, width
203
+ processed_images = []
204
+ for image in images:
205
+ if do_resize:
206
+ resized_height, resized_width = smart_resize(
207
+ height,
208
+ width,
209
+ factor=spatial_factor,
210
+ min_pixels=self.min_pixels,
211
+ max_pixels=self.max_pixels,
212
+ )
213
+ image = resize(
214
+ image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
215
+ )
216
+
217
+ if do_rescale:
218
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
219
+
220
+ if do_normalize:
221
+ image = self.normalize(
222
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
223
+ )
224
+
225
+ image = to_channel_dimension_format(image, output_data_format, input_channel_dim=input_data_format)
226
+ processed_images.append(image)
227
+
228
+ image = np.array(processed_images)
229
+ return image
230
+
231
+ def preprocess(
232
+ self,
233
+ images: ImageInput,
234
+ do_resize: Optional[bool] = None,
235
+ resample: PILImageResampling = None,
236
+ do_rescale: Optional[bool] = None,
237
+ rescale_factor: Optional[float] = None,
238
+ do_normalize: Optional[bool] = None,
239
+ image_mean: Optional[Union[float, List[float]]] = None,
240
+ image_std: Optional[Union[float, List[float]]] = None,
241
+ do_convert_rgb: Optional[bool] = None,
242
+ spatial_factor: Optional[int] = None,
243
+ return_tensors: Optional[Union[str, TensorType]] = None,
244
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
245
+ output_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
246
+ ):
247
+ """
248
+ Args:
249
+ images (`ImageInput`):
250
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
251
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
252
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
253
+ Whether to resize the image.
254
+ resample (`int`, *optional*, defaults to `self.resample`):
255
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
256
+ has an effect if `do_resize` is set to `True`.
257
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
258
+ Whether to rescale the image.
259
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
260
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
261
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
262
+ Whether to normalize the image.
263
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
264
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
265
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
266
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`.
267
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
268
+ Whether to convert the image to RGB.
269
+ spatial_factor (`int`, *optional*, defaults to `self.spatial_factor`):
270
+ The spatial downsample factor the image will be downsampled in feature extracting phase
271
+ return_tensors (`str` or `TensorType`, *optional*):
272
+ The type of tensors to return. Can be one of:
273
+ - Unset: Return a list of `np.ndarray`.
274
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
275
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
276
+ input_data_format (`ChannelDimension` or `str`, *optional*):
277
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
278
+ from the input image. Can be one of:
279
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
280
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
281
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
282
+ output_data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
283
+ The channel dimension format for the output image. Can be one of:
284
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
285
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
286
+ - Unset: Use the channel dimension format of the input image.
287
+ """
288
+ do_resize = do_resize if do_resize is not None else self.do_resize
289
+ resample = resample if resample is not None else self.resample
290
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
291
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
292
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
293
+ image_mean = image_mean if image_mean is not None else self.image_mean
294
+ image_std = image_std if image_std is not None else self.image_std
295
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
296
+ spatial_factor = spatial_factor if spatial_factor is not None else self.spatial_factor
297
+
298
+ images = make_list_of_images(images)
299
+ if images is None or not valid_images(images):
300
+ raise ValueError(
301
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
302
+ "torch.Tensor, tf.Tensor or jax.ndarray."
303
+ )
304
+
305
+ validate_preprocess_arguments(
306
+ rescale_factor=rescale_factor,
307
+ do_normalize=do_normalize,
308
+ image_mean=image_mean,
309
+ image_std=image_std,
310
+ do_resize=do_resize,
311
+ size=self.size,
312
+ resample=resample,
313
+ )
314
+
315
+ pixel_values = []
316
+ for image in images:
317
+ norm_image = self._preprocess(
318
+ image,
319
+ do_resize=do_resize,
320
+ resample=resample,
321
+ do_rescale=do_rescale,
322
+ rescale_factor=rescale_factor,
323
+ do_normalize=do_normalize,
324
+ image_mean=image_mean,
325
+ image_std=image_std,
326
+ do_convert_rgb=do_convert_rgb,
327
+ spatial_factor=spatial_factor,
328
+ input_data_format=input_data_format,
329
+ output_data_format=output_data_format,
330
+ )
331
+ pixel_values.extend(norm_image)
332
+
333
+ pixel_values = np.array(pixel_values)
334
+ data = {"pixel_values": pixel_values}
335
+
336
+ return BatchFeature(data=data, tensor_type=return_tensors)
337
+
338
+ def postprocess(
339
+ self,
340
+ images: ImageInput,
341
+ do_rescale: Optional[bool] = None,
342
+ rescale_factor: Optional[float] = None,
343
+ do_normalize: Optional[bool] = None,
344
+ image_mean: Optional[Union[float, List[float]]] = None,
345
+ image_std: Optional[Union[float, List[float]]] = None,
346
+ return_tensors: Optional[Union[str, TensorType]] = "PIL.Image.Image",
347
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
348
+ ):
349
+ """
350
+ Postprocess an image or batch of images tensor. Postprocess is the reverse process of preprocess.
351
+ The parameters should be same as in preprocess.
352
+
353
+ Args:
354
+ images (`ImageInput`):
355
+ Image to postprocess. Expects a single or batch of images with pixel values ranging from -1 to 1.
356
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
357
+ Whether to rescale the image.
358
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
359
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
360
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
361
+ Whether to normalize the image.
362
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
363
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
364
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
365
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`.
366
+ return_tensors (`str` or `TensorType`, *optional*):
367
+ The type of tensors to return. Can be one of:
368
+ - Unset: Return a list of `np.ndarray`.
369
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
370
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
371
+ input_data_format (`ChannelDimension` or `str`, *optional*):
372
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
373
+ from the input image. Can be one of:
374
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
375
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
376
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
377
+ """
378
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
379
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
380
+ rescale_factor = 1 / rescale_factor
381
+
382
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
383
+ image_mean = image_mean if image_mean is not None else self.image_mean
384
+ image_std = image_std if image_std is not None else self.image_std
385
+ image_mean, image_std = self.inverse_meanstd(image_mean, image_std)
386
+
387
+ images = make_list_of_images(images)
388
+ if isinstance(images[0], Image.Image):
389
+ return images if len(images) > 1 else images[0]
390
+
391
+ if input_data_format is None:
392
+ # We assume that all images have the same channel dimension format.
393
+ input_data_format = infer_channel_dimension_format(images[0])
394
+
395
+ pixel_values = []
396
+ for image in images:
397
+ image = to_numpy_array(image)
398
+ if do_normalize:
399
+ image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
400
+
401
+ if do_rescale:
402
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
403
+ image = image.clip(0, 255).astype(np.uint8)
404
+
405
+ if do_normalize and do_rescale and return_tensors == "PIL.Image.Image":
406
+ image = to_channel_dimension_format(image, ChannelDimension.LAST, input_channel_dim=input_data_format)
407
+ pixel_values.append(Image.fromarray(image))
408
+ else:
409
+ pixel_values.extend(image)
410
+
411
+ data = {"pixel_values": pixel_values}
412
+ return_tensors = return_tensors if return_tensors != "PIL.Image.Image" else None
413
+
414
+ return BatchFeature(data=data, tensor_type=return_tensors)
415
+
416
+ def inverse_meanstd(self, image_mean, image_std):
417
+ image_mean = self.to_tuple(image_mean)
418
+ image_std = self.to_tuple(image_std)
419
+
420
+ rev_image_mean = tuple(-m / s for m, s in zip(image_mean, image_std))
421
+ rev_image_std = tuple(1 / s for s in image_std)
422
+
423
+ return rev_image_mean, rev_image_std
424
+
425
+ def to_tuple(self, value, dim=3):
426
+ if isinstance(value, (int, float)):
427
+ return (value,) * dim
428
+
429
+ return tuple(value)
image_utils.py ADDED
@@ -0,0 +1,812 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import base64
17
+ import os
18
+ from io import BytesIO
19
+ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import requests
23
+ from packaging import version
24
+
25
+
26
+ from transformers.utils import (
27
+ ExplicitEnum,
28
+ is_jax_tensor,
29
+ is_numpy_array,
30
+ is_tf_tensor,
31
+ is_torch_available,
32
+ is_torch_tensor,
33
+ is_torchvision_available,
34
+ is_vision_available,
35
+ logging,
36
+ requires_backends,
37
+ to_numpy,
38
+ )
39
+ from transformers.utils.constants import ( # noqa: F401
40
+ IMAGENET_DEFAULT_MEAN,
41
+ IMAGENET_DEFAULT_STD,
42
+ IMAGENET_STANDARD_MEAN,
43
+ IMAGENET_STANDARD_STD,
44
+ OPENAI_CLIP_MEAN,
45
+ OPENAI_CLIP_STD,
46
+ )
47
+
48
+
49
+ if is_vision_available():
50
+ import PIL.Image
51
+ import PIL.ImageOps
52
+
53
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
54
+ PILImageResampling = PIL.Image.Resampling
55
+ else:
56
+ PILImageResampling = PIL.Image
57
+
58
+ if is_torchvision_available():
59
+ from torchvision.transforms import InterpolationMode
60
+
61
+ pil_torch_interpolation_mapping = {
62
+ PILImageResampling.NEAREST: InterpolationMode.NEAREST,
63
+ PILImageResampling.BOX: InterpolationMode.BOX,
64
+ PILImageResampling.BILINEAR: InterpolationMode.BILINEAR,
65
+ PILImageResampling.HAMMING: InterpolationMode.HAMMING,
66
+ PILImageResampling.BICUBIC: InterpolationMode.BICUBIC,
67
+ PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
68
+ }
69
+
70
+
71
+ if TYPE_CHECKING:
72
+ if is_torch_available():
73
+ import torch
74
+
75
+
76
+ logger = logging.get_logger(__name__)
77
+
78
+
79
+ ImageInput = Union[
80
+ "PIL.Image.Image", np.ndarray, "torch.Tensor", List["PIL.Image.Image"], List[np.ndarray], List["torch.Tensor"]
81
+ ] # noqa
82
+
83
+
84
+ VideoInput = Union[
85
+ List["PIL.Image.Image"],
86
+ "np.ndarray",
87
+ "torch.Tensor",
88
+ List["np.ndarray"],
89
+ List["torch.Tensor"],
90
+ List[List["PIL.Image.Image"]],
91
+ List[List["np.ndarrray"]],
92
+ List[List["torch.Tensor"]],
93
+ ] # noqa
94
+
95
+
96
+ class ChannelDimension(ExplicitEnum):
97
+ FIRST = "channels_first"
98
+ LAST = "channels_last"
99
+
100
+
101
+ class AnnotationFormat(ExplicitEnum):
102
+ COCO_DETECTION = "coco_detection"
103
+ COCO_PANOPTIC = "coco_panoptic"
104
+
105
+
106
+ class AnnotionFormat(ExplicitEnum):
107
+ COCO_DETECTION = AnnotationFormat.COCO_DETECTION.value
108
+ COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value
109
+
110
+
111
+ AnnotationType = Dict[str, Union[int, str, List[Dict]]]
112
+
113
+
114
+ def is_pil_image(img):
115
+ return is_vision_available() and isinstance(img, PIL.Image.Image)
116
+
117
+
118
+ class ImageType(ExplicitEnum):
119
+ PIL = "pillow"
120
+ TORCH = "torch"
121
+ NUMPY = "numpy"
122
+ TENSORFLOW = "tensorflow"
123
+ JAX = "jax"
124
+
125
+
126
+ def get_image_type(image):
127
+ if is_pil_image(image):
128
+ return ImageType.PIL
129
+ if is_torch_tensor(image):
130
+ return ImageType.TORCH
131
+ if is_numpy_array(image):
132
+ return ImageType.NUMPY
133
+ if is_tf_tensor(image):
134
+ return ImageType.TENSORFLOW
135
+ if is_jax_tensor(image):
136
+ return ImageType.JAX
137
+ raise ValueError(f"Unrecognised image type {type(image)}")
138
+
139
+
140
+ def is_valid_image(img):
141
+ return is_pil_image(img) or is_numpy_array(img) or is_torch_tensor(img) or is_tf_tensor(img) or is_jax_tensor(img)
142
+
143
+
144
+ def valid_images(imgs):
145
+ # If we have an list of images, make sure every image is valid
146
+ if isinstance(imgs, (list, tuple)):
147
+ for img in imgs:
148
+ if not valid_images(img):
149
+ return False
150
+ # If not a list of tuple, we have been given a single image or batched tensor of images
151
+ elif not is_valid_image(imgs):
152
+ return False
153
+ return True
154
+
155
+
156
+ def is_batched(img):
157
+ if isinstance(img, (list, tuple)):
158
+ return is_valid_image(img[0])
159
+ return False
160
+
161
+
162
+ def is_scaled_image(image: np.ndarray) -> bool:
163
+ """
164
+ Checks to see whether the pixel values have already been rescaled to [0, 1].
165
+ """
166
+ if image.dtype == np.uint8:
167
+ return False
168
+
169
+ # It's possible the image has pixel values in [0, 255] but is of floating type
170
+ return np.min(image) >= 0 and np.max(image) <= 1
171
+
172
+
173
+ def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]:
174
+ """
175
+ Ensure that the input is a list of images. If the input is a single image, it is converted to a list of length 1.
176
+ If the input is a batch of images, it is converted to a list of images.
177
+
178
+ Args:
179
+ images (`ImageInput`):
180
+ Image of images to turn into a list of images.
181
+ expected_ndims (`int`, *optional*, defaults to 3):
182
+ Expected number of dimensions for a single input image. If the input image has a different number of
183
+ dimensions, an error is raised.
184
+ """
185
+ if is_batched(images):
186
+ return images
187
+
188
+ # Either the input is a single image, in which case we create a list of length 1
189
+ if isinstance(images, PIL.Image.Image):
190
+ # PIL images are never batched
191
+ return [images]
192
+
193
+ if is_valid_image(images):
194
+ if images.ndim == expected_ndims + 1:
195
+ # Batch of images
196
+ images = list(images)
197
+ elif images.ndim == expected_ndims:
198
+ # Single image
199
+ images = [images]
200
+ else:
201
+ raise ValueError(
202
+ f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got"
203
+ f" {images.ndim} dimensions."
204
+ )
205
+ return images
206
+ raise ValueError(
207
+ "Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or "
208
+ f"jax.ndarray, but got {type(images)}."
209
+ )
210
+
211
+
212
+ def to_numpy_array(img) -> np.ndarray:
213
+ if not is_valid_image(img):
214
+ raise ValueError(f"Invalid image type: {type(img)}")
215
+
216
+ if is_vision_available() and isinstance(img, PIL.Image.Image):
217
+ return np.array(img)
218
+ return to_numpy(img)
219
+
220
+
221
+ def infer_channel_dimension_format(
222
+ image: np.ndarray, num_channels: Optional[Union[int, Tuple[int, ...]]] = None
223
+ ) -> ChannelDimension:
224
+ """
225
+ Infers the channel dimension format of `image`.
226
+
227
+ Args:
228
+ image (`np.ndarray`):
229
+ The image to infer the channel dimension of.
230
+ num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
231
+ The number of channels of the image.
232
+
233
+ Returns:
234
+ The channel dimension of the image.
235
+ """
236
+ num_channels = num_channels if num_channels is not None else (1, 3)
237
+ num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels
238
+
239
+ if image.ndim == 3:
240
+ first_dim, last_dim = 0, 2
241
+ elif image.ndim == 4:
242
+ first_dim, last_dim = 1, 3
243
+ else:
244
+ raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
245
+
246
+ if image.shape[first_dim] in num_channels and image.shape[last_dim] in num_channels:
247
+ logger.warning(
248
+ f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension."
249
+ )
250
+ return ChannelDimension.FIRST
251
+ elif image.shape[first_dim] in num_channels:
252
+ return ChannelDimension.FIRST
253
+ elif image.shape[last_dim] in num_channels:
254
+ return ChannelDimension.LAST
255
+ raise ValueError("Unable to infer channel dimension format")
256
+
257
+
258
+ def get_channel_dimension_axis(
259
+ image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None
260
+ ) -> int:
261
+ """
262
+ Returns the channel dimension axis of the image.
263
+
264
+ Args:
265
+ image (`np.ndarray`):
266
+ The image to get the channel dimension axis of.
267
+ input_data_format (`ChannelDimension` or `str`, *optional*):
268
+ The channel dimension format of the image. If `None`, will infer the channel dimension from the image.
269
+
270
+ Returns:
271
+ The channel dimension axis of the image.
272
+ """
273
+ if input_data_format is None:
274
+ input_data_format = infer_channel_dimension_format(image)
275
+ if input_data_format == ChannelDimension.FIRST:
276
+ return image.ndim - 3
277
+ elif input_data_format == ChannelDimension.LAST:
278
+ return image.ndim - 1
279
+ raise ValueError(f"Unsupported data format: {input_data_format}")
280
+
281
+
282
+ def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]:
283
+ """
284
+ Returns the (height, width) dimensions of the image.
285
+
286
+ Args:
287
+ image (`np.ndarray`):
288
+ The image to get the dimensions of.
289
+ channel_dim (`ChannelDimension`, *optional*):
290
+ Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image.
291
+
292
+ Returns:
293
+ A tuple of the image's height and width.
294
+ """
295
+ if channel_dim is None:
296
+ channel_dim = infer_channel_dimension_format(image)
297
+
298
+ if channel_dim == ChannelDimension.FIRST:
299
+ return image.shape[-2], image.shape[-1]
300
+ elif channel_dim == ChannelDimension.LAST:
301
+ return image.shape[-3], image.shape[-2]
302
+ else:
303
+ raise ValueError(f"Unsupported data format: {channel_dim}")
304
+
305
+
306
+ def is_valid_annotation_coco_detection(annotation: Dict[str, Union[List, Tuple]]) -> bool:
307
+ if (
308
+ isinstance(annotation, dict)
309
+ and "image_id" in annotation
310
+ and "annotations" in annotation
311
+ and isinstance(annotation["annotations"], (list, tuple))
312
+ and (
313
+ # an image can have no annotations
314
+ len(annotation["annotations"]) == 0 or isinstance(annotation["annotations"][0], dict)
315
+ )
316
+ ):
317
+ return True
318
+ return False
319
+
320
+
321
+ def is_valid_annotation_coco_panoptic(annotation: Dict[str, Union[List, Tuple]]) -> bool:
322
+ if (
323
+ isinstance(annotation, dict)
324
+ and "image_id" in annotation
325
+ and "segments_info" in annotation
326
+ and "file_name" in annotation
327
+ and isinstance(annotation["segments_info"], (list, tuple))
328
+ and (
329
+ # an image can have no segments
330
+ len(annotation["segments_info"]) == 0 or isinstance(annotation["segments_info"][0], dict)
331
+ )
332
+ ):
333
+ return True
334
+ return False
335
+
336
+
337
+ def valid_coco_detection_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool:
338
+ return all(is_valid_annotation_coco_detection(ann) for ann in annotations)
339
+
340
+
341
+ def valid_coco_panoptic_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool:
342
+ return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations)
343
+
344
+
345
+ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = None) -> "PIL.Image.Image":
346
+ """
347
+ Loads `image` to a PIL Image.
348
+
349
+ Args:
350
+ image (`str` or `PIL.Image.Image`):
351
+ The image to convert to the PIL Image format.
352
+ timeout (`float`, *optional*):
353
+ The timeout value in seconds for the URL request.
354
+
355
+ Returns:
356
+ `PIL.Image.Image`: A PIL Image.
357
+ """
358
+ requires_backends(load_image, ["vision"])
359
+ if isinstance(image, str):
360
+ if image.startswith("http://") or image.startswith("https://"):
361
+ # We need to actually check for a real protocol, otherwise it's impossible to use a local file
362
+ # like http_huggingface_co.png
363
+ image = PIL.Image.open(BytesIO(requests.get(image, timeout=timeout).content))
364
+ elif os.path.isfile(image):
365
+ image = PIL.Image.open(image)
366
+ else:
367
+ if image.startswith("data:image/"):
368
+ image = image.split(",")[1]
369
+
370
+ # Try to load as base64
371
+ try:
372
+ b64 = base64.decodebytes(image.encode())
373
+ image = PIL.Image.open(BytesIO(b64))
374
+ except Exception as e:
375
+ raise ValueError(
376
+ f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
377
+ )
378
+ elif isinstance(image, PIL.Image.Image):
379
+ image = image
380
+ else:
381
+ raise TypeError(
382
+ "Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image."
383
+ )
384
+ image = PIL.ImageOps.exif_transpose(image)
385
+ image = image.convert("RGB")
386
+ return image
387
+
388
+
389
+ def validate_preprocess_arguments(
390
+ do_rescale: Optional[bool] = None,
391
+ rescale_factor: Optional[float] = None,
392
+ do_normalize: Optional[bool] = None,
393
+ image_mean: Optional[Union[float, List[float]]] = None,
394
+ image_std: Optional[Union[float, List[float]]] = None,
395
+ do_pad: Optional[bool] = None,
396
+ size_divisibility: Optional[int] = None,
397
+ do_center_crop: Optional[bool] = None,
398
+ crop_size: Optional[Dict[str, int]] = None,
399
+ do_resize: Optional[bool] = None,
400
+ size: Optional[Dict[str, int]] = None,
401
+ resample: Optional["PILImageResampling"] = None,
402
+ ):
403
+ """
404
+ Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method.
405
+ Raises `ValueError` if arguments incompatibility is caught.
406
+ Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`,
407
+ sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow
408
+ existing arguments when possible.
409
+
410
+ """
411
+ if do_rescale and rescale_factor is None:
412
+ raise ValueError("`rescale_factor` must be specified if `do_rescale` is `True`.")
413
+
414
+ if do_pad and size_divisibility is None:
415
+ # Here, size_divisor might be passed as the value of size
416
+ raise ValueError(
417
+ "Depending on the model, `size_divisibility`, `size_divisor`, `pad_size` or `size` must be specified if `do_pad` is `True`."
418
+ )
419
+
420
+ if do_normalize and (image_mean is None or image_std is None):
421
+ raise ValueError("`image_mean` and `image_std` must both be specified if `do_normalize` is `True`.")
422
+
423
+ if do_center_crop and crop_size is None:
424
+ raise ValueError("`crop_size` must be specified if `do_center_crop` is `True`.")
425
+
426
+ if do_resize and (size is None or resample is None):
427
+ raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
428
+
429
+
430
+ # In the future we can add a TF implementation here when we have TF models.
431
+ class ImageFeatureExtractionMixin:
432
+ """
433
+ Mixin that contain utilities for preparing image features.
434
+ """
435
+
436
+ def _ensure_format_supported(self, image):
437
+ if not isinstance(image, (PIL.Image.Image, np.ndarray)) and not is_torch_tensor(image):
438
+ raise ValueError(
439
+ f"Got type {type(image)} which is not supported, only `PIL.Image.Image`, `np.array` and "
440
+ "`torch.Tensor` are."
441
+ )
442
+
443
+ def to_pil_image(self, image, rescale=None):
444
+ """
445
+ Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
446
+ needed.
447
+
448
+ Args:
449
+ image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
450
+ The image to convert to the PIL Image format.
451
+ rescale (`bool`, *optional*):
452
+ Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will
453
+ default to `True` if the image type is a floating type, `False` otherwise.
454
+ """
455
+ self._ensure_format_supported(image)
456
+
457
+ if is_torch_tensor(image):
458
+ image = image.numpy()
459
+
460
+ if isinstance(image, np.ndarray):
461
+ if rescale is None:
462
+ # rescale default to the array being of floating type.
463
+ rescale = isinstance(image.flat[0], np.floating)
464
+ # If the channel as been moved to first dim, we put it back at the end.
465
+ if image.ndim == 3 and image.shape[0] in [1, 3]:
466
+ image = image.transpose(1, 2, 0)
467
+ if rescale:
468
+ image = image * 255
469
+ image = image.astype(np.uint8)
470
+ return PIL.Image.fromarray(image)
471
+ return image
472
+
473
+ def convert_rgb(self, image):
474
+ """
475
+ Converts `PIL.Image.Image` to RGB format.
476
+
477
+ Args:
478
+ image (`PIL.Image.Image`):
479
+ The image to convert.
480
+ """
481
+ self._ensure_format_supported(image)
482
+ if not isinstance(image, PIL.Image.Image):
483
+ return image
484
+
485
+ return image.convert("RGB")
486
+
487
+ def rescale(self, image: np.ndarray, scale: Union[float, int]) -> np.ndarray:
488
+ """
489
+ Rescale a numpy image by scale amount
490
+ """
491
+ self._ensure_format_supported(image)
492
+ return image * scale
493
+
494
+ def to_numpy_array(self, image, rescale=None, channel_first=True):
495
+ """
496
+ Converts `image` to a numpy array. Optionally rescales it and puts the channel dimension as the first
497
+ dimension.
498
+
499
+ Args:
500
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
501
+ The image to convert to a NumPy array.
502
+ rescale (`bool`, *optional*):
503
+ Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Will
504
+ default to `True` if the image is a PIL Image or an array/tensor of integers, `False` otherwise.
505
+ channel_first (`bool`, *optional*, defaults to `True`):
506
+ Whether or not to permute the dimensions of the image to put the channel dimension first.
507
+ """
508
+ self._ensure_format_supported(image)
509
+
510
+ if isinstance(image, PIL.Image.Image):
511
+ image = np.array(image)
512
+
513
+ if is_torch_tensor(image):
514
+ image = image.numpy()
515
+
516
+ rescale = isinstance(image.flat[0], np.integer) if rescale is None else rescale
517
+
518
+ if rescale:
519
+ image = self.rescale(image.astype(np.float32), 1 / 255.0)
520
+
521
+ if channel_first and image.ndim == 3:
522
+ image = image.transpose(2, 0, 1)
523
+
524
+ return image
525
+
526
+ def expand_dims(self, image):
527
+ """
528
+ Expands 2-dimensional `image` to 3 dimensions.
529
+
530
+ Args:
531
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
532
+ The image to expand.
533
+ """
534
+ self._ensure_format_supported(image)
535
+
536
+ # Do nothing if PIL image
537
+ if isinstance(image, PIL.Image.Image):
538
+ return image
539
+
540
+ if is_torch_tensor(image):
541
+ image = image.unsqueeze(0)
542
+ else:
543
+ image = np.expand_dims(image, axis=0)
544
+ return image
545
+
546
+ def normalize(self, image, mean, std, rescale=False):
547
+ """
548
+ Normalizes `image` with `mean` and `std`. Note that this will trigger a conversion of `image` to a NumPy array
549
+ if it's a PIL Image.
550
+
551
+ Args:
552
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
553
+ The image to normalize.
554
+ mean (`List[float]` or `np.ndarray` or `torch.Tensor`):
555
+ The mean (per channel) to use for normalization.
556
+ std (`List[float]` or `np.ndarray` or `torch.Tensor`):
557
+ The standard deviation (per channel) to use for normalization.
558
+ rescale (`bool`, *optional*, defaults to `False`):
559
+ Whether or not to rescale the image to be between 0 and 1. If a PIL image is provided, scaling will
560
+ happen automatically.
561
+ """
562
+ self._ensure_format_supported(image)
563
+
564
+ if isinstance(image, PIL.Image.Image):
565
+ image = self.to_numpy_array(image, rescale=True)
566
+ # If the input image is a PIL image, it automatically gets rescaled. If it's another
567
+ # type it may need rescaling.
568
+ elif rescale:
569
+ if isinstance(image, np.ndarray):
570
+ image = self.rescale(image.astype(np.float32), 1 / 255.0)
571
+ elif is_torch_tensor(image):
572
+ image = self.rescale(image.float(), 1 / 255.0)
573
+
574
+ if isinstance(image, np.ndarray):
575
+ if not isinstance(mean, np.ndarray):
576
+ mean = np.array(mean).astype(image.dtype)
577
+ if not isinstance(std, np.ndarray):
578
+ std = np.array(std).astype(image.dtype)
579
+ elif is_torch_tensor(image):
580
+ import torch
581
+
582
+ if not isinstance(mean, torch.Tensor):
583
+ if isinstance(mean, np.ndarray):
584
+ mean = torch.from_numpy(mean)
585
+ else:
586
+ mean = torch.tensor(mean)
587
+ if not isinstance(std, torch.Tensor):
588
+ if isinstance(std, np.ndarray):
589
+ std = torch.from_numpy(std)
590
+ else:
591
+ std = torch.tensor(std)
592
+
593
+ if image.ndim == 3 and image.shape[0] in [1, 3]:
594
+ return (image - mean[:, None, None]) / std[:, None, None]
595
+ else:
596
+ return (image - mean) / std
597
+
598
+ def resize(self, image, size, resample=None, default_to_square=True, max_size=None):
599
+ """
600
+ Resizes `image`. Enforces conversion of input to PIL.Image.
601
+
602
+ Args:
603
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
604
+ The image to resize.
605
+ size (`int` or `Tuple[int, int]`):
606
+ The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be
607
+ matched to this.
608
+
609
+ If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
610
+ `size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to
611
+ this number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
612
+ resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
613
+ The filter to user for resampling.
614
+ default_to_square (`bool`, *optional*, defaults to `True`):
615
+ How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a
616
+ square (`size`,`size`). If set to `False`, will replicate
617
+ [`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
618
+ with support for resizing only the smallest edge and providing an optional `max_size`.
619
+ max_size (`int`, *optional*, defaults to `None`):
620
+ The maximum allowed for the longer edge of the resized image: if the longer edge of the image is
621
+ greater than `max_size` after being resized according to `size`, then the image is resized again so
622
+ that the longer edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller
623
+ edge may be shorter than `size`. Only used if `default_to_square` is `False`.
624
+
625
+ Returns:
626
+ image: A resized `PIL.Image.Image`.
627
+ """
628
+ resample = resample if resample is not None else PILImageResampling.BILINEAR
629
+
630
+ self._ensure_format_supported(image)
631
+
632
+ if not isinstance(image, PIL.Image.Image):
633
+ image = self.to_pil_image(image)
634
+
635
+ if isinstance(size, list):
636
+ size = tuple(size)
637
+
638
+ if isinstance(size, int) or len(size) == 1:
639
+ if default_to_square:
640
+ size = (size, size) if isinstance(size, int) else (size[0], size[0])
641
+ else:
642
+ width, height = image.size
643
+ # specified size only for the smallest edge
644
+ short, long = (width, height) if width <= height else (height, width)
645
+ requested_new_short = size if isinstance(size, int) else size[0]
646
+
647
+ if short == requested_new_short:
648
+ return image
649
+
650
+ new_short, new_long = requested_new_short, int(requested_new_short * long / short)
651
+
652
+ if max_size is not None:
653
+ if max_size <= requested_new_short:
654
+ raise ValueError(
655
+ f"max_size = {max_size} must be strictly greater than the requested "
656
+ f"size for the smaller edge size = {size}"
657
+ )
658
+ if new_long > max_size:
659
+ new_short, new_long = int(max_size * new_short / new_long), max_size
660
+
661
+ size = (new_short, new_long) if width <= height else (new_long, new_short)
662
+
663
+ return image.resize(size, resample=resample)
664
+
665
+ def center_crop(self, image, size):
666
+ """
667
+ Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the
668
+ size given, it will be padded (so the returned result has the size asked).
669
+
670
+ Args:
671
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape (n_channels, height, width) or (height, width, n_channels)):
672
+ The image to resize.
673
+ size (`int` or `Tuple[int, int]`):
674
+ The size to which crop the image.
675
+
676
+ Returns:
677
+ new_image: A center cropped `PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape: (n_channels,
678
+ height, width).
679
+ """
680
+ self._ensure_format_supported(image)
681
+
682
+ if not isinstance(size, tuple):
683
+ size = (size, size)
684
+
685
+ # PIL Image.size is (width, height) but NumPy array and torch Tensors have (height, width)
686
+ if is_torch_tensor(image) or isinstance(image, np.ndarray):
687
+ if image.ndim == 2:
688
+ image = self.expand_dims(image)
689
+ image_shape = image.shape[1:] if image.shape[0] in [1, 3] else image.shape[:2]
690
+ else:
691
+ image_shape = (image.size[1], image.size[0])
692
+
693
+ top = (image_shape[0] - size[0]) // 2
694
+ bottom = top + size[0] # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
695
+ left = (image_shape[1] - size[1]) // 2
696
+ right = left + size[1] # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
697
+
698
+ # For PIL Images we have a method to crop directly.
699
+ if isinstance(image, PIL.Image.Image):
700
+ return image.crop((left, top, right, bottom))
701
+
702
+ # Check if image is in (n_channels, height, width) or (height, width, n_channels) format
703
+ channel_first = True if image.shape[0] in [1, 3] else False
704
+
705
+ # Transpose (height, width, n_channels) format images
706
+ if not channel_first:
707
+ if isinstance(image, np.ndarray):
708
+ image = image.transpose(2, 0, 1)
709
+ if is_torch_tensor(image):
710
+ image = image.permute(2, 0, 1)
711
+
712
+ # Check if cropped area is within image boundaries
713
+ if top >= 0 and bottom <= image_shape[0] and left >= 0 and right <= image_shape[1]:
714
+ return image[..., top:bottom, left:right]
715
+
716
+ # Otherwise, we may need to pad if the image is too small. Oh joy...
717
+ new_shape = image.shape[:-2] + (max(size[0], image_shape[0]), max(size[1], image_shape[1]))
718
+ if isinstance(image, np.ndarray):
719
+ new_image = np.zeros_like(image, shape=new_shape)
720
+ elif is_torch_tensor(image):
721
+ new_image = image.new_zeros(new_shape)
722
+
723
+ top_pad = (new_shape[-2] - image_shape[0]) // 2
724
+ bottom_pad = top_pad + image_shape[0]
725
+ left_pad = (new_shape[-1] - image_shape[1]) // 2
726
+ right_pad = left_pad + image_shape[1]
727
+ new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
728
+
729
+ top += top_pad
730
+ bottom += top_pad
731
+ left += left_pad
732
+ right += left_pad
733
+
734
+ new_image = new_image[
735
+ ..., max(0, top) : min(new_image.shape[-2], bottom), max(0, left) : min(new_image.shape[-1], right)
736
+ ]
737
+
738
+ return new_image
739
+
740
+ def flip_channel_order(self, image):
741
+ """
742
+ Flips the channel order of `image` from RGB to BGR, or vice versa. Note that this will trigger a conversion of
743
+ `image` to a NumPy array if it's a PIL Image.
744
+
745
+ Args:
746
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
747
+ The image whose color channels to flip. If `np.ndarray` or `torch.Tensor`, the channel dimension should
748
+ be first.
749
+ """
750
+ self._ensure_format_supported(image)
751
+
752
+ if isinstance(image, PIL.Image.Image):
753
+ image = self.to_numpy_array(image)
754
+
755
+ return image[::-1, :, :]
756
+
757
+ def rotate(self, image, angle, resample=None, expand=0, center=None, translate=None, fillcolor=None):
758
+ """
759
+ Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees
760
+ counter clockwise around its centre.
761
+
762
+ Args:
763
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
764
+ The image to rotate. If `np.ndarray` or `torch.Tensor`, will be converted to `PIL.Image.Image` before
765
+ rotating.
766
+
767
+ Returns:
768
+ image: A rotated `PIL.Image.Image`.
769
+ """
770
+ resample = resample if resample is not None else PIL.Image.NEAREST
771
+
772
+ self._ensure_format_supported(image)
773
+
774
+ if not isinstance(image, PIL.Image.Image):
775
+ image = self.to_pil_image(image)
776
+
777
+ return image.rotate(
778
+ angle, resample=resample, expand=expand, center=center, translate=translate, fillcolor=fillcolor
779
+ )
780
+
781
+
782
+ def validate_annotations(
783
+ annotation_format: AnnotationFormat,
784
+ supported_annotation_formats: Tuple[AnnotationFormat, ...],
785
+ annotations: List[Dict],
786
+ ) -> None:
787
+ if annotation_format not in supported_annotation_formats:
788
+ raise ValueError(f"Unsupported annotation format: {format} must be one of {supported_annotation_formats}")
789
+
790
+ if annotation_format is AnnotationFormat.COCO_DETECTION:
791
+ if not valid_coco_detection_annotations(annotations):
792
+ raise ValueError(
793
+ "Invalid COCO detection annotations. Annotations must a dict (single image) or list of dicts "
794
+ "(batch of images) with the following keys: `image_id` and `annotations`, with the latter "
795
+ "being a list of annotations in the COCO format."
796
+ )
797
+
798
+ if annotation_format is AnnotationFormat.COCO_PANOPTIC:
799
+ if not valid_coco_panoptic_annotations(annotations):
800
+ raise ValueError(
801
+ "Invalid COCO panoptic annotations. Annotations must a dict (single image) or list of dicts "
802
+ "(batch of images) with the following keys: `image_id`, `file_name` and `segments_info`, with "
803
+ "the latter being a list of annotations in the COCO format."
804
+ )
805
+
806
+
807
+ def validate_kwargs(valid_processor_keys: List[str], captured_kwargs: List[str]):
808
+ unused_keys = set(captured_kwargs).difference(set(valid_processor_keys))
809
+ if unused_keys:
810
+ unused_key_str = ", ".join(unused_keys)
811
+ # TODO raise a warning here instead of simply logging?
812
+ logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.")
inference_utils.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import random
6
+ import numpy as np
7
+ from transformers import LogitsProcessor, LogitsProcessorList, logging
8
+ from typing import Optional, List, Dict, Tuple, Union, Set
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+
13
+ def parse_interleaved_text_image(
14
+ full_output_text: str,
15
+ num_levels: int = 2,
16
+ image_placeholder: str = "<image>",
17
+ start_tag: str = "<start_of_image>",
18
+ end_tag: str = "<end_of_image>"
19
+ ) -> Tuple[str, List[List[List[int]]]]:
20
+ """
21
+ Parses text containing interleaved image token blocks.
22
+
23
+ Identifies blocks enclosed by start_tag and end_tag, extracts image tokens
24
+ (<|image_levelX_Y|>) within them, and replaces the blocks with a placeholder
25
+ in the output text.
26
+
27
+ Args:
28
+ full_output_text: The raw input string containing text and image blocks.
29
+ num_levels: The expected number of levels for image tokens (e.g., 2).
30
+ image_placeholder: The string to replace image blocks with in the output text.
31
+ start_tag: The exact string marking the beginning of an image block.
32
+ end_tag: The exact string marking the end of an image block.
33
+ eos_token: If provided, this token will be removed from the final text.
34
+
35
+ Returns:
36
+ A tuple containing:
37
+ - generated_text (str): The text with image blocks replaced by placeholders.
38
+ - all_image_indices (List[List[List[int]]]): A list where each element
39
+ represents one image. Each image element is a list containing lists
40
+ of token indices for each level.
41
+ Example for 2 images, 2 levels:
42
+ [
43
+ [[level0_indices_img1], [level1_indices_img1]], # Image 1
44
+ [[level0_indices_img2], [level1_indices_img2]] # Image 2
45
+ ]
46
+ """
47
+ all_image_indices: List[List[List[int]]] = []
48
+ processed_text_parts: List[str] = []
49
+ list_image_token_parts: List[str] = []
50
+ last_end: int = 0
51
+
52
+ # Escape start/end tags for regex safety if they contain special characters
53
+ escaped_start_tag = re.escape(start_tag)
54
+ escaped_end_tag = re.escape(end_tag)
55
+
56
+ # Pattern to find image blocks: start_tag ... end_tag (non-greedy)
57
+ image_block_pattern = rf'{escaped_start_tag}(.*?){escaped_end_tag}'
58
+ # Pattern to find individual image tokens within a block
59
+ token_pattern = r'<\|image_level(\d+)_(\d+)\|>'
60
+
61
+ # Find all image blocks (re.DOTALL allows '.' to match newlines)
62
+ for match in re.finditer(image_block_pattern, full_output_text, re.DOTALL):
63
+ # 1. Add text preceding this image block
64
+ processed_text_parts.append(full_output_text[last_end:match.start()])
65
+
66
+ # collect the image token ids.
67
+ list_image_token_parts.append(full_output_text[match.start(): match.end()])
68
+
69
+ # 2. Add the placeholder for the image
70
+ processed_text_parts.append(image_placeholder)
71
+
72
+ # 3. Process the content *within* the current image block
73
+ image_token_content = match.group(1) # Content between tags
74
+ parsed_level_indices = {} # {level: [indices]} for *this* image
75
+
76
+ # Find all image tokens within this block
77
+ for token_match in re.finditer(token_pattern, image_token_content):
78
+ try:
79
+ level = int(token_match.group(1))
80
+ index = int(token_match.group(2))
81
+ if level >= num_levels:
82
+ logger.warning(f"Parsed token level {level} >= num_levels {num_levels}. Ignoring token.")
83
+ continue
84
+ if level not in parsed_level_indices:
85
+ parsed_level_indices[level] = []
86
+ parsed_level_indices[level].append(index)
87
+ except (ValueError, IndexError):
88
+ logger.warning(f"Could not parse token: {token_match.group(0)}")
89
+ continue # Skip malformed tokens
90
+
91
+ # Structure the indices for the current image based on expected levels
92
+ current_image_indices = []
93
+ logger.debug(f"Processing Image Block. Found levels: {parsed_level_indices.keys()}")
94
+ for level in range(num_levels):
95
+ # Get indices for the level, default to empty list if level not found
96
+ indices = parsed_level_indices.get(level, [])
97
+ # Optional: Sort indices if order isn't guaranteed (usually is by finditer)
98
+ # indices.sort()
99
+ current_image_indices.append(indices)
100
+ logger.debug(f" Level {level} indices count: {len(indices)}")
101
+
102
+ all_image_indices.append(current_image_indices)
103
+ logger.info(f"Parsed Image {len(all_image_indices)}: Found indices for {len(current_image_indices)} levels.")
104
+
105
+ # 4. Update position for the next iteration
106
+ last_end = match.end()
107
+
108
+ # Add any remaining text after the last image block
109
+ processed_text_parts.append(full_output_text[last_end:])
110
+
111
+ # Join the text parts to form the final generated text
112
+ generated_text = "".join(processed_text_parts)
113
+
114
+ return generated_text, all_image_indices, list_image_token_parts
115
+
116
+
117
+ def calculate_image_token_num(h, w, downsample_rate_per_level=[28, 16]):
118
+ # Assuming RESOLUTION_MAPPING is accessible or hardcoded if needed
119
+ # For simplicity, let's assume direct calculation based on downsampling
120
+ # Replace with actual RESOLUTION_MAPPING logic if necessary
121
+ # Example: w1, h1 = RESOLUTION_MAPPING.get((w, h), (w, h)) # Get from mapping
122
+ w1, h1 = w, h # Placeholder if mapping not available/needed here
123
+ w1, h1 = w1 // downsample_rate_per_level[0], h1 // downsample_rate_per_level[0]
124
+ semantic_token_num = w1 * h1
125
+
126
+ w2, h2 = w // downsample_rate_per_level[1], h // downsample_rate_per_level[1]
127
+ pixel_token_num = w2 * h2
128
+ logger.info(f"Calculated token nums: semantic={semantic_token_num}, pixel={pixel_token_num} for target ({h},{w})")
129
+ # Estimate max_token_length (adjust based on special tokens in your format)
130
+ max_token_length = (h1 * (w1 + 1) + 2) + (h2 * (w2 + 1) + 2) + 2 + 2 + 1 + 1 + 50 # Add buffer
131
+ return [semantic_token_num, pixel_token_num], max_token_length, h1, w1, h2, w2
132
+
133
+
134
+ class InterleavedLogitsProcessor(LogitsProcessor):
135
+ """
136
+ Combines CFG, Dual VQ Image Token Structure Enforcement, and Dynamic Sampling
137
+ for interleaved text and image generation.
138
+
139
+ Includes refined masking during text generation to only allow text,
140
+ a specific resolution tag, and the start_of_image token.
141
+ """
142
+
143
+ def __init__(self,
144
+ # CFG parameters
145
+ guidance_scale=1.0,
146
+ uncond=None,
147
+ attention_mask=None,
148
+ model=None,
149
+ # DualVQ parameters
150
+ level0_range=None, level1_range=None,
151
+ num_level0_rows=None, num_level0_tokens=None,
152
+ num_level1_rows=None, num_level1_tokens=None,
153
+ special_tokens=None,
154
+ *,
155
+ # Dynamic Sampling parameters
156
+ default_temp=1.0, level0_temp=1.0, level1_temp=2.0,
157
+ default_top_k=2048, level0_top_k=2048, level1_top_k=2048 * 3,
158
+ default_top_p=0.8, level0_top_p=0.8, level1_top_p=1.0,
159
+ # General
160
+ images=None,
161
+ ):
162
+
163
+ # --- CFG ---
164
+ self.guidance_scale = guidance_scale
165
+ self.uncond = uncond
166
+ self.attention_mask = attention_mask
167
+ self.images = images
168
+ self.model = model
169
+ self.out = None
170
+
171
+ # --- DualVQ ---
172
+ self.level0_range = level0_range
173
+ self.level1_range = level1_range
174
+ self.num_level0_rows = num_level0_rows
175
+ self.num_level0_tokens = num_level0_tokens
176
+ self.num_level1_rows = num_level1_rows
177
+ self.num_level1_tokens = num_level1_tokens
178
+ self.special_tokens = special_tokens
179
+
180
+ # DualVQ State
181
+ self.generating_image = False
182
+ self.current_level = None
183
+ self.tokens_in_row = 0
184
+ self.rows_in_level = 0
185
+
186
+ # --- Dynamic Sampling ---
187
+ self.start_of_level0_token_id = special_tokens["start_of_level0"]
188
+ self.end_of_level0_token_id = special_tokens["end_of_level0"]
189
+ self.start_of_level1_token_id = special_tokens["start_of_level1"]
190
+ self.end_of_level1_token_id = special_tokens["end_of_level1"]
191
+ self.start_of_image_token_id = special_tokens["start_of_image"]
192
+ self.end_of_image_token_id = special_tokens["end_of_image"]
193
+
194
+ self.default_temp = default_temp
195
+ self.default_top_k = default_top_k
196
+ self.default_top_p = default_top_p
197
+ self.level0_temp = level0_temp
198
+ self.level0_top_k = level0_top_k
199
+ self.level0_top_p = level0_top_p
200
+ self.level1_temp = level1_temp
201
+ self.level1_top_k = level1_top_k
202
+ self.level1_top_p = level1_top_p
203
+
204
+ # Dynamic Sampling State
205
+ self.in_level0_mode = False
206
+ self.in_level1_mode = False
207
+
208
+ # --- Validation ---
209
+ if not self.special_tokens:
210
+ raise ValueError("special_tokens dictionary cannot be empty.")
211
+ # *** Updated required keys ***
212
+ required_keys = ["start_of_image", "end_of_image", "start_of_level0",
213
+ "end_of_level0", "start_of_level1", "end_of_level1",
214
+ "end_of_line", "end_of_text"]
215
+ for key in required_keys:
216
+ if key not in self.special_tokens:
217
+ raise ValueError(f"Missing required key in special_tokens: {key}")
218
+
219
+ def _apply_cfg(self, input_ids, scores):
220
+ """Applies Classifier-Free Guidance."""
221
+ scores = F.log_softmax(scores, dim=-1)
222
+ if self.guidance_scale == 1:
223
+ return scores
224
+
225
+ if self.out is None:
226
+ self.out = self.model(self.uncond,
227
+ attention_mask=self.attention_mask,
228
+ pixel_values=self.images)
229
+ else:
230
+ self.out = self.model(
231
+ input_ids[:, -1:],
232
+ use_cache=True,
233
+ past_key_values=self.out.past_key_values,
234
+ )
235
+ unconditional_logits = F.log_softmax(self.out.logits[:, -1, :], dim=-1)
236
+ out = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits
237
+ return out
238
+
239
+ def _apply_sampling(self, scores, temp, top_k, top_p):
240
+ """ Apply top-k, top-p, and temperature """
241
+ if temp > 0.0:
242
+ scores = scores / temp # Adjust temperature
243
+
244
+ # Top-K filtering
245
+ if top_k > 0:
246
+ top_k_values, _ = torch.topk(scores, min(top_k, scores.size(-1)))
247
+ scores[scores < top_k_values[:, -1].unsqueeze(-1)] = -float("Inf")
248
+
249
+ # Top-P filtering
250
+ if top_p < 1.0:
251
+ sorted_logits, sorted_indices = torch.sort(scores, descending=True)
252
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
253
+
254
+ # Only keep tokens with cumulative probabilities within top_p
255
+ sorted_indices_to_remove = cumulative_probs > top_p
256
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
257
+ sorted_indices_to_remove[:, 0] = False
258
+
259
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
260
+ scores[indices_to_remove] = -float("Inf")
261
+
262
+ return scores
263
+
264
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
265
+
266
+ # --- Step 0: Get last token and Vocab Size ---
267
+ last_token = None
268
+ if input_ids.shape[1] > 0:
269
+ last_token = input_ids[0, -1].item() # Assuming batch size 1
270
+
271
+ # --- Step 1: Update State & Apply Constraints ---
272
+ # State updates based on the *last generated* token
273
+ if last_token == self.start_of_image_token_id:
274
+ self.generating_image = True
275
+ self.current_level = None
276
+ self.tokens_in_row = 0
277
+ self.rows_in_level = 0
278
+ self.in_level0_mode = False
279
+ self.in_level1_mode = False
280
+ elif last_token == self.start_of_level0_token_id:
281
+ self.current_level = "level0"
282
+ self.tokens_in_row = 0
283
+ self.rows_in_level = 0
284
+ self.in_level0_mode = True
285
+ self.in_level1_mode = False
286
+ elif last_token == self.end_of_level0_token_id:
287
+ self.current_level = None
288
+ self.in_level0_mode = False
289
+ elif last_token == self.start_of_level1_token_id:
290
+ self.current_level = "level1"
291
+ self.tokens_in_row = 0
292
+ self.rows_in_level = 0
293
+ self.in_level0_mode = False
294
+ self.in_level1_mode = True
295
+ elif last_token == self.end_of_level1_token_id:
296
+ self.current_level = None
297
+ self.in_level1_mode = False
298
+ elif last_token == self.end_of_image_token_id:
299
+ self.generating_image = False
300
+ self.current_level = None
301
+ self.tokens_in_row = 0
302
+ self.rows_in_level = 0
303
+ self.in_level0_mode = False
304
+ self.in_level1_mode = False
305
+ elif last_token == self.special_tokens["end_of_line"] and self.generating_image:
306
+ self.tokens_in_row = 0
307
+ self.rows_in_level += 1
308
+ elif self.generating_image and self.current_level is not None:
309
+ if (self.current_level == "level0" and self.level0_range[0] <= last_token < self.level0_range[1]) or \
310
+ (self.current_level == "level1" and self.level1_range[0] <= last_token < self.level1_range[1]):
311
+ self.tokens_in_row += 1
312
+
313
+ # --- Step 2: Apply CFG ---
314
+ if self.generating_image:
315
+ scores = self._apply_cfg(input_ids, scores)
316
+ else:
317
+ if self.out:
318
+ self.out = None
319
+
320
+ # Apply constraints based on the *current* state (determining the *next* token)
321
+ mask = torch.zeros_like(scores, dtype=torch.bool) # True means ALLOWED
322
+
323
+ if self.generating_image:
324
+ # --- Image Generation Masking ---
325
+ if self.current_level == "level0":
326
+ if self.rows_in_level == self.num_level0_rows:
327
+ mask[:, self.special_tokens["end_of_level0"]] = True
328
+ elif self.tokens_in_row == self.num_level0_tokens:
329
+ mask[:, self.special_tokens["end_of_line"]] = True
330
+ else:
331
+ mask[:, self.level0_range[0]:self.level0_range[1]] = True
332
+ elif self.current_level == "level1":
333
+ if self.rows_in_level == self.num_level1_rows:
334
+ mask[:, self.special_tokens["end_of_level1"]] = True
335
+ elif self.tokens_in_row == self.num_level1_tokens:
336
+ mask[:, self.special_tokens["end_of_line"]] = True
337
+ else:
338
+ mask[:, self.level1_range[0]:self.level1_range[1]] = True
339
+ else: # Between structure tokens
340
+ if last_token == self.start_of_image_token_id:
341
+ mask[:, self.special_tokens["start_of_level0"]] = True
342
+ elif last_token == self.end_of_level0_token_id:
343
+ mask[:, self.special_tokens["start_of_level1"]] = True
344
+ elif last_token == self.end_of_level1_token_id:
345
+ mask[:, self.special_tokens["end_of_image"]] = True
346
+ elif last_token is None and input_ids.shape[1] == 0: # Very first token is image?
347
+ mask[:, self.start_of_image_token_id] = True
348
+ else: # Allow relevant structural tokens if needed
349
+ mask[:, self.special_tokens["start_of_level0"]] = True
350
+ mask[:, self.special_tokens["start_of_level1"]] = True
351
+ mask[:, self.special_tokens["end_of_image"]] = True
352
+
353
+ else:
354
+ # Allow *all* tokens by default...
355
+ mask[:, :] = True
356
+ # ...then specifically *disallow* image content and intermediate structure tokens
357
+ mask[:, self.level0_range[0]:self.level0_range[1]] = False
358
+ mask[:, self.level1_range[0]:self.level1_range[1]] = False
359
+ mask[:, self.special_tokens["start_of_level0"]] = False
360
+ mask[:, self.special_tokens["end_of_level0"]] = False
361
+ mask[:, self.special_tokens["start_of_level1"]] = False
362
+ mask[:, self.special_tokens["end_of_level1"]] = False
363
+ mask[:, self.special_tokens["end_of_line"]] = False # EOL only allowed within image context
364
+
365
+ # Ensure the specific allowed tokens for text phase are indeed allowed
366
+ # (This overrides any potential disallowing above if IDs overlap, e.g., if EOS was in image range)
367
+ mask[:, self.special_tokens["end_of_text"]] = True
368
+ mask[:, self.special_tokens["start_of_image"]] = True
369
+
370
+ # Apply the mask
371
+ scores[~mask] = -float("Inf")
372
+
373
+ # Handle edge case: If all tokens are masked
374
+ if not torch.any(scores > -float("Inf"), dim=-1).all():
375
+ print("WARN: All tokens masked, allowing EOS.")
376
+ # Allow EOS and potentially other safe tokens if needed
377
+ scores[:] = -float("Inf") # Reset all to -inf first
378
+ scores[:, self.special_tokens["end_of_text"]] = 0
379
+
380
+ # --- Step 3: Apply Dynamic Sampling ---
381
+ current_temp, current_top_k, current_top_p = self.default_temp, self.default_top_k, self.default_top_p
382
+ if self.in_level0_mode:
383
+ current_temp, current_top_k, current_top_p = self.level0_temp, self.level0_top_k, self.level0_top_p
384
+ elif self.in_level1_mode:
385
+ current_temp, current_top_k, current_top_p = self.level1_temp, self.level1_top_k, self.level1_top_p
386
+
387
+ scores = self._apply_sampling(scores, current_temp, current_top_k, current_top_p)
388
+
389
+ return scores
390
+
391
+
392
+ def replace_placeholder_with_list(
393
+ tensor_a: torch.Tensor,
394
+ tensor_b: torch.Tensor,
395
+ placeholder_value: Union[int, float]
396
+ ) -> torch.Tensor:
397
+ if tensor_a.dim() != 1:
398
+ raise ValueError("Input tensor_a must be 1-dimensional.")
399
+
400
+ indices = torch.where(tensor_a == placeholder_value)[0]
401
+
402
+ if len(indices) == 0:
403
+ # Placeholder not found, return the original tensor
404
+ print(
405
+ f"Warning: Placeholder value {placeholder_value} not found in the tensor. Returning original tensor.")
406
+ return tensor_a
407
+
408
+ # Get the index of the *first* occurrence
409
+ idx = indices[0].item()
410
+
411
+ result_tensor = torch.cat((tensor_a[:idx], tensor_b.to(tensor_a), tensor_a[idx + 1:]), dim=0)
412
+ return result_tensor
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
modeling_dualvitok.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import sys
5
+ import math
6
+ from typing import Optional, Tuple, Union, List, Callable
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.nn import Module
12
+
13
+ from einops import rearrange, repeat, pack, unpack
14
+ from einx import get_at
15
+
16
+ from torch.utils.checkpoint import checkpoint
17
+ from transformers import AutoImageProcessor
18
+ from transformers.modeling_utils import PreTrainedModel, get_parameter_device, get_parameter_dtype
19
+
20
+ from .configuration_dualvitok import DualViTokConfig
21
+ from .modeling_movqgan import MoVQModel, MoVQEncoder, MoVQDecoder, Decoder
22
+
23
+ from .configuration_qwen2vit import Qwen2VLVisionConfig
24
+ from .modeling_qwen2vit import Qwen2VisionTransformerPretrainedModel, \
25
+ VisionRotaryEmbedding, Qwen2VLBatchVisionBlock
26
+
27
+ try:
28
+ import xformers.ops as xops
29
+
30
+ is_xformers_available = True
31
+ except Exception as e:
32
+ is_xformers_available = False
33
+
34
+ if torch.__version__ > "2.1.2":
35
+ IS_SDPA_AVAILABLE = True
36
+ else:
37
+ IS_SDPA_AVAILABLE = False
38
+
39
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
40
+ sys.path.append(cur_dir)
41
+
42
+
43
+ # helper functions
44
+
45
+ def exists(v):
46
+ return v is not None
47
+
48
+
49
+ def identity(t):
50
+ return t
51
+
52
+
53
+ def default(v, d):
54
+ return v if exists(v) else d
55
+
56
+
57
+ def pack_one(t, pattern):
58
+ packed, packed_shape = pack([t], pattern)
59
+
60
+ def inverse(out, inv_pattern=None):
61
+ inv_pattern = default(inv_pattern, pattern)
62
+ out, = unpack(out, packed_shape, inv_pattern)
63
+ return out
64
+
65
+ return packed, inverse
66
+
67
+
68
+ # class
69
+
70
+
71
+ class SimVQ(Module):
72
+ def __init__(
73
+ self,
74
+ dim,
75
+ codebook_size,
76
+ codebook_transform: Module | None = None,
77
+ init_fn: Callable = identity,
78
+ channel_first=True,
79
+ input_to_quantize_commit_loss_weight=0.25,
80
+ commitment_weight=1.,
81
+ frozen_codebook_dim=None # frozen codebook dim could have different dimensions than projection
82
+ ):
83
+ super().__init__()
84
+ self.codebook_size = codebook_size
85
+ self.channel_first = channel_first
86
+
87
+ frozen_codebook_dim = default(frozen_codebook_dim, dim)
88
+ codebook = torch.randn(codebook_size, frozen_codebook_dim) * (frozen_codebook_dim ** -0.5)
89
+ codebook = init_fn(codebook)
90
+
91
+ # the codebook is actually implicit from a linear layer from frozen gaussian or uniform
92
+
93
+ if not exists(codebook_transform):
94
+ codebook_transform = nn.Linear(frozen_codebook_dim, dim, bias=False)
95
+
96
+ self.code_transform = codebook_transform
97
+
98
+ self.register_buffer('frozen_codebook', codebook)
99
+
100
+ # commit loss weighting - weighing input to quantize a bit less is crucial for it to work
101
+ self.input_to_quantize_commit_loss_weight = input_to_quantize_commit_loss_weight
102
+
103
+ # total commitment loss weight
104
+ self.commitment_weight = commitment_weight
105
+
106
+ @property
107
+ def codebook(self):
108
+ return self.code_transform(self.frozen_codebook)
109
+
110
+ def indices_to_codes(
111
+ self,
112
+ indices
113
+ ):
114
+ implicit_codebook = self.codebook
115
+
116
+ frozen_codes = get_at('[c] d, b ... -> b ... d', self.frozen_codebook, indices)
117
+ quantized = self.code_transform(frozen_codes)
118
+
119
+ if self.channel_first:
120
+ quantized = rearrange(quantized, 'b ... d -> b d ...')
121
+
122
+ return quantized
123
+
124
+ def forward(
125
+ self,
126
+ x
127
+ ):
128
+ if self.channel_first:
129
+ x = rearrange(x, 'b d ... -> b ... d')
130
+
131
+ x, inverse_pack = pack_one(x, 'b * d')
132
+
133
+ implicit_codebook = self.codebook
134
+
135
+ with torch.no_grad():
136
+ dist = torch.cdist(x, implicit_codebook)
137
+ indices = dist.argmin(dim=-1)
138
+
139
+ # select codes
140
+
141
+ quantized = get_at('[c] d, b n -> b n d', implicit_codebook, indices)
142
+
143
+ # commit loss and straight through, as was done in the paper
144
+
145
+ commit_loss = (
146
+ F.mse_loss(x.detach(), quantized) +
147
+ F.mse_loss(x, quantized.detach()) * self.input_to_quantize_commit_loss_weight
148
+ )
149
+
150
+ quantized = (quantized - x).detach() + x
151
+
152
+ quantized = inverse_pack(quantized)
153
+ indices = inverse_pack(indices, 'b *')
154
+
155
+ if self.channel_first:
156
+ quantized = rearrange(quantized, 'b ... d-> b d ...')
157
+
158
+ return quantized, commit_loss * self.commitment_weight, indices
159
+
160
+
161
+ def init_weights(m):
162
+ if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.LayerNorm):
163
+ if m.weight is not None:
164
+ nn.init.constant_(m.weight, 1)
165
+ if m.bias is not None:
166
+ nn.init.constant_(m.bias, 0)
167
+ elif isinstance(m, nn.Linear):
168
+ nn.init.xavier_uniform_(m.weight)
169
+ if m.bias is not None:
170
+ nn.init.constant_(m.bias, 0)
171
+ elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) \
172
+ or isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
173
+ w = m.weight.data
174
+ nn.init.xavier_uniform_(w)
175
+ if m.bias is not None:
176
+ nn.init.constant_(m.bias, 0)
177
+ elif isinstance(m, nn.Embedding):
178
+ nn.init.normal_(m.weight, mean=0, std=1)
179
+
180
+
181
+ class ScalingLayerForQwen2ViT:
182
+ def __init__(
183
+ self,
184
+ min_pixels: int = 56 * 56,
185
+ max_pixels: int = 28 * 28 * 1280,
186
+ patch_size: int = 14,
187
+ temporal_patch_size: int = 2,
188
+ merge_size: int = 2,
189
+ **kwargs,
190
+ ) -> None:
191
+ super().__init__(**kwargs)
192
+ OPENAI_CLIP_MEAN = torch.as_tensor([0.48145466, 0.4578275, 0.40821073])[None, :, None, None]
193
+ OPENAI_CLIP_STD = torch.as_tensor([0.26862954, 0.26130258, 0.27577711])[None, :, None, None]
194
+
195
+ self.image_mean = OPENAI_CLIP_MEAN
196
+ self.image_std = OPENAI_CLIP_STD
197
+ self.min_pixels = min_pixels
198
+ self.max_pixels = max_pixels
199
+ self.patch_size = patch_size
200
+ self.temporal_patch_size = temporal_patch_size
201
+ self.merge_size = merge_size
202
+ self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
203
+
204
+ def __call__(self, images):
205
+ if images.ndim == 4:
206
+ images = images.unsqueeze(1)
207
+ batch_size, temporal, channel, height, width = images.shape
208
+
209
+ factor = self.patch_size * self.merge_size
210
+
211
+ resized_height, resized_width = height // factor * factor, width // factor * factor
212
+
213
+ images = (images + 1) / 2 # rescale to [0, 1.]
214
+
215
+ images = torch.nn.functional.interpolate(
216
+ images.flatten(0, 1).float(),
217
+ size=(resized_height, resized_width),
218
+ mode='bicubic',
219
+ align_corners=False,
220
+ antialias=True
221
+ ).to(images.dtype)
222
+
223
+ images = images.clamp(0, 1) # rescale to [0, 1.]
224
+ images = ((images - self.image_mean.to(images)) / self.image_std.to(images))
225
+
226
+ images = rearrange(images, '(b t) c h w -> b t c h w', b=batch_size, t=temporal)
227
+ if temporal == 1:
228
+ images = images.repeat_interleave(self.temporal_patch_size, dim=1)
229
+ temporal = self.temporal_patch_size
230
+
231
+ grid_t = temporal // self.temporal_patch_size
232
+ grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
233
+
234
+ images = images.reshape(
235
+ batch_size * grid_t,
236
+ self.temporal_patch_size,
237
+ channel,
238
+ -1
239
+ )
240
+
241
+ images = rearrange(images, 'b p c n -> b n (c p)')
242
+ images = images.reshape(
243
+ batch_size * grid_t,
244
+ grid_h // self.merge_size,
245
+ self.merge_size,
246
+ self.patch_size,
247
+ grid_w // self.merge_size,
248
+ self.merge_size,
249
+ self.patch_size,
250
+ -1
251
+ )
252
+ images = rearrange(images, 'b h k s1 w l s2 n -> (b h w k l) (n s1 s2)')
253
+
254
+ return dict(image=images, image_grid_thw=torch.as_tensor([[grid_t, grid_h, grid_w] for _ in range(batch_size)]))
255
+
256
+
257
+ class SemanticEncoder(nn.Module):
258
+ def __init__(self,
259
+ semantic_encoder,
260
+ z_channels=4,
261
+ num_blocks=2,
262
+ embed_dim=1280,
263
+ proj_layer='linear',
264
+ attn_implementation='xformers',
265
+ target_mlp='identity',
266
+ ):
267
+ super().__init__()
268
+ self.embed_dim = embed_dim
269
+
270
+ if isinstance(semantic_encoder, str):
271
+ self.model = Qwen2VisionTransformerPretrainedModel.from_pretrained(
272
+ semantic_encoder,
273
+ attn_implementation=attn_implementation
274
+ )
275
+ elif isinstance(semantic_encoder, dict):
276
+ config = Qwen2VLVisionConfig(**semantic_encoder, attn_implementation=attn_implementation)
277
+ self.model = Qwen2VisionTransformerPretrainedModel(config)
278
+ else:
279
+ raise ValueError(f"Invalid semantic_encoder: {semantic_encoder}")
280
+ input_channels = self.model.config.hidden_size
281
+
282
+ for p in self.model.parameters():
283
+ p.requires_grad = False
284
+
285
+ self.proj_in = nn.Conv2d(input_channels, embed_dim, 1, 1) if input_channels != embed_dim else nn.Identity()
286
+
287
+ config = Qwen2VLVisionConfig(depth=num_blocks,
288
+ embed_dim=embed_dim, )
289
+ head_dim = config.embed_dim // config.num_heads
290
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
291
+
292
+ self.blocks = nn.ModuleList(
293
+ [Qwen2VLBatchVisionBlock(config, attn_implementation) for _ in range(num_blocks)]
294
+ )
295
+
296
+ if proj_layer == 'norm_linear':
297
+ self.proj_out = nn.Sequential(
298
+ nn.LayerNorm(embed_dim),
299
+ nn.Linear(
300
+ embed_dim,
301
+ z_channels,
302
+ )
303
+ )
304
+ elif proj_layer == 'linear':
305
+ self.proj_out = nn.Sequential(
306
+ nn.Linear(
307
+ embed_dim,
308
+ z_channels,
309
+ )
310
+ )
311
+ elif proj_layer == 'mlp':
312
+ self.proj_out = nn.Sequential(
313
+ nn.Linear(embed_dim, embed_dim),
314
+ nn.Tanh(),
315
+ nn.Linear(embed_dim, z_channels),
316
+ )
317
+ else:
318
+ raise RuntimeError(f"Wrong proj layer. Got {proj_layer}")
319
+
320
+ if target_mlp == 'identity':
321
+ self.target_mlp = nn.Sequential(
322
+ nn.Identity(),
323
+ )
324
+ elif target_mlp == 'norm':
325
+ self.target_mlp = nn.Sequential(
326
+ nn.LayerNorm(input_channels, eps=1e-6, elementwise_affine=False),
327
+ )
328
+ self.init_weight()
329
+
330
+ def init_weight(self):
331
+ self.proj_in.apply(init_weights)
332
+ self.blocks.apply(init_weights)
333
+ self.proj_out.apply(init_weights)
334
+ self.target_mlp.apply(init_weights)
335
+
336
+ def rot_pos_emb(self, grid_thw, max_seq_len):
337
+ pos_ids = torch.zeros((len(grid_thw), max_seq_len, 2), dtype=torch.long)
338
+ for idx, (t, h, w) in enumerate(grid_thw):
339
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
340
+ hpos_ids = hpos_ids.flatten()
341
+
342
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
343
+ wpos_ids = wpos_ids.flatten()
344
+
345
+ current_pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
346
+ pos_ids[idx, :current_pos_ids.shape[0]] = current_pos_ids
347
+ max_grid_size = grid_thw[:, 1:].max()
348
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
349
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(2)
350
+ return rotary_pos_emb
351
+
352
+ def forward(self, x, grid_thw):
353
+ x = self.model(x, grid_thw=grid_thw)
354
+
355
+ x = x_target = self.target_mlp(x)
356
+
357
+ x = F.linear(x,
358
+ self.proj_in.weight.view(self.proj_in.weight.shape[0], -1),
359
+ self.proj_in.bias)
360
+
361
+ new_grid_thw = torch.as_tensor([[t, h // 2, w // 2] for t, h, w in grid_thw])
362
+
363
+ seq_lens = [t_i * h_i * w_i for t_i, h_i, w_i in new_grid_thw]
364
+ max_seq_len = max(seq_lens)
365
+
366
+ x = rearrange(x, '(b h w) c -> b (h w) c', h=new_grid_thw[0, 1], w=new_grid_thw[0, 2])
367
+
368
+ rotary_pos_emb = self.rot_pos_emb(new_grid_thw, max_seq_len)
369
+
370
+ for blk in self.blocks:
371
+ x = blk(x, rotary_pos_emb=rotary_pos_emb)
372
+
373
+ x = self.proj_out(x) # [b, max_length, d]
374
+
375
+ t, h, w = new_grid_thw[0]
376
+ b = len(grid_thw)
377
+ x = rearrange(x, 'b (h w) c ->b c h w', b=b, h=h, w=w)
378
+ x_target = rearrange(x_target, '(b h w) c ->b c h w', b=b, h=h, w=w)
379
+ return x, x_target
380
+
381
+
382
+ class SemanticDecoder(nn.Module):
383
+ def __init__(self,
384
+ z_channels=4,
385
+ embed_dim=1280,
386
+ num_blocks=2,
387
+ output_channels=1280,
388
+ attn_implementation='xformers',
389
+ proj_layer='linear_norm'):
390
+ super().__init__()
391
+ self.proj_in = nn.Linear(z_channels, embed_dim)
392
+
393
+ self.output_channels = output_channels
394
+ config = Qwen2VLVisionConfig(depth=num_blocks, embed_dim=embed_dim)
395
+
396
+ self.blocks = nn.ModuleList(
397
+ [Qwen2VLBatchVisionBlock(config, attn_implementation) for _ in range(num_blocks)]
398
+ )
399
+ head_dim = config.embed_dim // config.num_heads
400
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
401
+
402
+ if proj_layer == 'norm_linear':
403
+ self.proj_out = nn.Sequential(
404
+ nn.LayerNorm(embed_dim),
405
+ nn.Linear(embed_dim, output_channels),
406
+ )
407
+ elif proj_layer == 'linear':
408
+ self.proj_out = nn.Sequential(
409
+ nn.Linear(embed_dim, output_channels)
410
+ )
411
+ elif proj_layer == 'mlp':
412
+ self.proj_out = nn.Sequential(
413
+ nn.Linear(embed_dim, embed_dim),
414
+ nn.Tanh(),
415
+ nn.Linear(embed_dim, output_channels),
416
+ )
417
+ elif proj_layer == 'linear_norm':
418
+ self.proj_out = nn.Sequential(
419
+ nn.Linear(embed_dim, output_channels),
420
+ nn.LayerNorm(output_channels),
421
+ )
422
+
423
+ self.apply(init_weights)
424
+
425
+ @property
426
+ def last_layer(self):
427
+ return self.proj_out[-1].weight
428
+
429
+ def rot_pos_emb(self, grid_thw, max_seq_len):
430
+ pos_ids = torch.zeros((len(grid_thw), max_seq_len, 2), dtype=torch.long)
431
+ for idx, (t, h, w) in enumerate(grid_thw):
432
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
433
+ hpos_ids = hpos_ids.flatten()
434
+
435
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
436
+ wpos_ids = wpos_ids.flatten()
437
+
438
+ current_pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
439
+ pos_ids[idx, :current_pos_ids.shape[0]] = current_pos_ids
440
+ max_grid_size = grid_thw[:, 1:].max()
441
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
442
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(2)
443
+ return rotary_pos_emb
444
+
445
+ def forward(self, z: torch.Tensor):
446
+ x = z
447
+
448
+ b, c, h, w = x.shape
449
+
450
+ x = rearrange(x, 'b c h w -> b (h w) c')
451
+
452
+ grid_thw = torch.as_tensor([[1, h, w] for _ in range(b)])
453
+ seq_lens = [t * h * w for t, h, w in grid_thw]
454
+ max_seq_len = max(seq_lens)
455
+
456
+ x = self.proj_in(x)
457
+
458
+ rotary_pos_emb = self.rot_pos_emb(grid_thw, max_seq_len)
459
+
460
+ for blk in self.blocks:
461
+ x = blk(x, rotary_pos_emb=rotary_pos_emb)
462
+
463
+ x = self.proj_out(x)
464
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
465
+ return x
466
+
467
+
468
+ class DualViTokPretrainModel(PreTrainedModel):
469
+ """
470
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
471
+ models.
472
+ """
473
+
474
+ config_class = DualViTokConfig
475
+ base_model_prefix = "dualvitok"
476
+ main_input_name = "pixel_values"
477
+ _no_split_modules = ["BatchQwen2VLVisionBlock", "MoVQResnetBlock", "MoVQAttnBlock", "MoVQResnetTemporalBlock"]
478
+ _supports_flash_attn_2 = True
479
+ _supports_sdpa = True
480
+ _supports_cache_class = True
481
+ _supports_static_cache = True
482
+
483
+ def _init_weights(self, module):
484
+ if isinstance(module, (nn.Conv2d, nn.Conv3d)):
485
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
486
+ # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
487
+ elif isinstance(module, nn.Linear):
488
+ nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
489
+ if module.bias is not None:
490
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
491
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
492
+ nn.init.uniform_(module.bias, -bound, bound)
493
+ elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
494
+ nn.init.constant_(module.weight, 1)
495
+ nn.init.constant_(module.bias, 0)
496
+
497
+
498
+ class DualViTok(DualViTokPretrainModel):
499
+ def __init__(self, config: DualViTokConfig):
500
+ super().__init__(config)
501
+ self.config = config
502
+
503
+ self._semantic_channel = config.semantic_encoder.z_channels
504
+ self._pixel_channel = config.pixel_encoder.z_channels
505
+
506
+ self.semantic_encoder = SemanticEncoder(
507
+ semantic_encoder=config.semantic_encoder.pretrained_semantic_encoder,
508
+ z_channels=config.semantic_encoder.z_channels,
509
+ num_blocks=config.semantic_encoder.num_blocks,
510
+ embed_dim=config.semantic_encoder.embed_dim,
511
+ proj_layer=config.semantic_encoder.out_layer,
512
+ attn_implementation=config.attn_implementation,
513
+ target_mlp=config.semantic_encoder.target_mlp, )
514
+ self.semantic_decoder = SemanticDecoder(
515
+ z_channels=config.semantic_decoder.z_channels,
516
+ embed_dim=config.semantic_decoder.embed_dim,
517
+ num_blocks=config.semantic_decoder.num_blocks,
518
+ output_channels=config.semantic_decoder.out_channels,
519
+ attn_implementation=config.attn_implementation,
520
+ proj_layer=config.semantic_decoder.out_layer,
521
+ )
522
+
523
+ if config.semantic_quantizer_type.lower() == 'simvq':
524
+ self.semantic_quantizer = SimVQ(
525
+ dim=config.semantic_encoder.z_channels,
526
+ codebook_size=config.semantic_quantizer_codebook_size,
527
+ )
528
+ elif config.semantic_quantizer_type.lower() == 'vq':
529
+ raise NotImplementedError
530
+ self.semantic_quantizer = VQ(
531
+ dim=config.semantic_encoder.z_channels,
532
+ codebook_size=config.semantic_quantizer_codebook_size,
533
+ )
534
+
535
+ self.pixel_encoder = MoVQEncoder(config.pixel_encoder)
536
+ self.pixel_quant_conv = nn.Conv2d(config.pixel_encoder.z_channels, config.pixel_encoder.embed_dim, 1)
537
+
538
+ if config.pixel_quantizer_type.lower() == 'simvq':
539
+ self.pixel_quantizer = SimVQ(
540
+ dim=config.pixel_encoder.z_channels,
541
+ codebook_size=config.pixel_quantizer_codebook_size,
542
+ )
543
+ elif config.pixel_quantizer_type.lower() == 'vq':
544
+ raise NotImplementedError
545
+ self.pixel_quantizer = VQ(
546
+ dim=config.pixel_encoder.z_channels,
547
+ codebook_size=config.pixel_quantizer_codebook_size,
548
+ )
549
+
550
+ self.pixel_post_quant_conv = nn.Conv2d(config.pixel_decoder.embed_dim,
551
+ config.pixel_decoder.z_channels, 1)
552
+
553
+ self.pixel_decoder = MoVQDecoder(config.pixel_decoder)
554
+
555
+ self.scaling_layer = ScalingLayerForQwen2ViT()
556
+
557
+ @property
558
+ def device(self):
559
+ return get_parameter_device(self)
560
+
561
+ @property
562
+ def dtype(self):
563
+ return get_parameter_dtype(self)
564
+
565
+ @property
566
+ def pixel_channel(self):
567
+ return self._pixel_channel
568
+
569
+ @property
570
+ def semantic_channel(self):
571
+ return self._semantic_channel
572
+
573
+ def encode(self, image: torch.FloatTensor):
574
+ scale_output = self.scaling_layer(image)
575
+ image, image_grid_thw, image_gen = scale_output['image'], scale_output['image_grid_thw'], image
576
+
577
+ h_semantic, target_semantic = self.semantic_encoder(image, image_grid_thw)
578
+ quant_semantic, emb_loss_semantic, info_semantic = self.semantic_quantizer(h_semantic.float())
579
+
580
+ h_pixel = self.pixel_encoder(image_gen)
581
+ h_pixel = self.pixel_quant_conv(h_pixel)
582
+
583
+ quant_pixel, emb_loss_pixel, info_pixel = self.pixel_quantizer(h_pixel.float())
584
+
585
+ return (quant_semantic, emb_loss_semantic, info_semantic, target_semantic), \
586
+ (quant_pixel, emb_loss_pixel, info_pixel)
587
+
588
+ def encode_code(self, *args, **kwargs):
589
+ (_, _, semantic_indices, _), \
590
+ (_, _, pixel_indices) = self.encode(*args, **kwargs)
591
+ return semantic_indices, pixel_indices
592
+
593
+ def indices_to_codes(self, semantic_indices, pixel_indices):
594
+ quant_semantic = self.semantic_quantizer.indices_to_codes(semantic_indices)
595
+ quant_pixel = self.pixel_quantizer.indices_to_codes(pixel_indices)
596
+ return quant_semantic, quant_pixel
597
+
598
+ def encode_semantic(self, image: torch.FloatTensor):
599
+ scale_output = self.scaling_layer(image)
600
+ image, image_grid_thw, image_gen = scale_output['image'], scale_output['image_grid_thw'], image
601
+
602
+ h_semantic, target_semantic = self.semantic_encoder(image, image_grid_thw)
603
+ quant_semantic, emb_loss_semantic, info_semantic = self.semantic_quantizer(h_semantic.float())
604
+ return quant_semantic, emb_loss_semantic, info_semantic, target_semantic
605
+
606
+ def merge_quants(self, quant_semantic: torch.Tensor, quant_pixel: torch.Tensor):
607
+ quant_semantic_resized = F.interpolate(
608
+ quant_semantic, quant_pixel.shape[-2:], mode='bicubic'
609
+ ).to(quant_semantic.dtype)
610
+ quant_semantic = quant_semantic_resized
611
+
612
+ quant = torch.cat([quant_semantic, quant_pixel], dim=1)
613
+
614
+ return quant
615
+
616
+ def decode(self, quant_semantic: torch.Tensor, quant_pixel: torch.Tensor, ):
617
+ quant = self.merge_quants(quant_semantic, quant_pixel)
618
+ quant2 = self.pixel_post_quant_conv(quant)
619
+ x = self.pixel_decoder(quant2, quant)
620
+ return x
621
+
622
+ def decode_code(self, semantic_indices, pixel_indices):
623
+ quant_semantic = self.semantic_quantizer.indices_to_codes(semantic_indices)
624
+ quant_pixel = self.pixel_quantizer.indices_to_codes(pixel_indices)
625
+ return self.decode(quant_semantic, quant_pixel)
626
+
627
+ def decode_semantic(self, x: List[torch.Tensor]):
628
+ return self.semantic_decoder(x)
629
+
630
+ def forward(self, pixel_values: torch.FloatTensor):
631
+ (quant_semantic, diff_semantic, _, target_semantic), \
632
+ (quant_pixel, diff_pixel, _) = self.encode(pixel_values)
633
+ dec = self.decode(quant_semantic, quant_pixel)
634
+ dec_semantic = self.decode_semantic(quant_semantic)
635
+ return (dec_semantic, diff_semantic, target_semantic), (dec, diff_pixel)
636
+
637
+ def build_sdxl_decoder(self, path='ILLUME-MLLM/dualvitok-sdxl-decoder',
638
+ image_processor=None,
639
+ torch_dtype=torch.float16,
640
+ add_watermarker=False,
641
+ device='cuda',
642
+ ):
643
+ from .sdxl_decoder_pipe import StableDiffusionXLDecoderPipeline
644
+
645
+ if image_processor is None:
646
+ image_processor = AutoImageProcessor.from_pretrained('ILLUME-MLLM/dualvitok', trust_remote_code=True)
647
+
648
+ return StableDiffusionXLDecoderPipeline.from_pretrained(path,
649
+ torch_dtype=torch_dtype,
650
+ add_watermarker=add_watermarker,
651
+ vq_model=self,
652
+ vq_image_processor=image_processor).to(device)
653
+
modeling_illume.py ADDED
@@ -0,0 +1,883 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch ILLUME model."""
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+ from functools import partial
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.utils.checkpoint
10
+ from torch import nn
11
+
12
+ from transformers import PreTrainedModel
13
+ from transformers.activations import ACT2FN
14
+ from transformers.cache_utils import Cache
15
+ from transformers.modeling_outputs import ModelOutput
16
+ from transformers.utils import (
17
+ add_start_docstrings,
18
+ add_start_docstrings_to_model_forward,
19
+ logging,
20
+ replace_return_docstrings,
21
+ )
22
+ from transformers.models.auto import AutoModel, AutoModelForCausalLM
23
+ from transformers import LogitsProcessorList
24
+
25
+ from .configuration_illume import ILLUMEConfig
26
+ from .modeling_qwen2vit import Qwen2VisionTransformerPretrainedModel
27
+ from .modeling_dualvitok import ScalingLayerForQwen2ViT, SemanticEncoder
28
+ from .modeling_movqgan import MoVQEncoder
29
+ from .inference_utils import InterleavedLogitsProcessor, \
30
+ parse_interleaved_text_image, calculate_image_token_num
31
+
32
+
33
+ from einops import rearrange
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+ _CONFIG_FOR_DOC = "ILLUMEConfig"
38
+
39
+ # Define common resolutions
40
+ DEFAULT_RESOLUTIONS = [
41
+ (256, 256), (512, 512), (384, 640), (640, 384), (512, 384),
42
+ (384, 512), (256, 384), (384, 256), (256, 512), (512, 256)
43
+ ]
44
+
45
+ # qwen2.5
46
+ special_tokens_ids = [151665, 151666, 151667, 151668, 151669, 151670, 151671]
47
+ start_token = 151672 + 32
48
+ level0_range = (start_token, start_token + 32768) # Level 0 token ID 范围
49
+ level1_range = (start_token + 32768, start_token + 32768 * 4) # Level 1 token ID 范围
50
+
51
+ special_tokens_dict = {
52
+ "start_of_image": 151665,
53
+ "end_of_image": 151666,
54
+ "start_of_level0": 151668,
55
+ "end_of_level0": 151669,
56
+ "start_of_level1": 151670,
57
+ "end_of_level1": 151671,
58
+ "end_of_line": 151667,
59
+ "end_of_text": 151645,
60
+ #
61
+ "level0_range": level0_range,
62
+ "level1_range": level1_range,
63
+ }
64
+
65
+
66
+ @dataclass
67
+ class ILLUMECausalLMOutputWithPast(ModelOutput):
68
+ """
69
+ Base class for ILLUME causal language model (or autoregressive) outputs.
70
+
71
+ Args:
72
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
73
+ Language modeling loss (for next-token prediction).
74
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
75
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
76
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
77
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
78
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
79
+
80
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
81
+ `past_key_values` input) to speed up sequential decoding.
82
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
83
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
84
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
85
+
86
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
87
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
88
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
89
+ sequence_length)`.
90
+
91
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
92
+ heads.
93
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
94
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
95
+ sequence_length, hidden_size)`.
96
+
97
+ image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
98
+ """
99
+
100
+ loss: Optional[torch.FloatTensor] = None
101
+ logits: torch.FloatTensor = None
102
+ past_key_values: Optional[List[torch.FloatTensor]] = None
103
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
104
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
105
+ image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
106
+
107
+
108
+ class MLPProjector(nn.Sequential):
109
+ # CAbstractor
110
+ def __init__(self, mlp_depth, hidden_size, mm_hidden_size):
111
+ super(MLPProjector, self).__init__()
112
+ modules = [nn.Linear(mm_hidden_size, hidden_size)]
113
+ for _ in range(1, mlp_depth):
114
+ modules.append(nn.GELU())
115
+ modules.append(nn.Linear(hidden_size, hidden_size))
116
+ super(MLPProjector, self).__init__(*modules)
117
+
118
+
119
+ class ILLUMEMultiModalProjector(nn.Sequential):
120
+ # CAbstractor
121
+ def __init__(self, config):
122
+ super(ILLUMEMultiModalProjector, self).__init__()
123
+ hidden_size = config.text_config.hidden_size
124
+ mm_hidden_size1, mm_hidden_size2 = config.mm_projector_config['mm_hidden_size']
125
+ self.projector_1 = MLPProjector(mlp_depth=config.mm_projector_config['projector_cfg1']['mlp_depth'],
126
+ mm_hidden_size=mm_hidden_size1, hidden_size=hidden_size)
127
+ self.projector_2 = MLPProjector(mlp_depth=config.mm_projector_config['projector_cfg2']['mlp_depth'],
128
+ mm_hidden_size=mm_hidden_size2, hidden_size=hidden_size)
129
+
130
+ def forward(self, image_features):
131
+ image_feature_1, image_feature_2 = image_features
132
+ image_feature_1 = self.projector_1(image_feature_1)
133
+ image_feature_2 = self.projector_2(image_feature_2)
134
+ image_features = torch.concat([image_feature_1, image_feature_2], dim=1)
135
+ return image_features
136
+
137
+
138
+ class ILLUMEDualVisionTower(nn.Module):
139
+ def __init__(self,
140
+ vision_config,
141
+ attn_implementation='sdpa',
142
+ ):
143
+ super().__init__()
144
+ self._config = vision_config
145
+ self.semantic_encoder = SemanticEncoder(
146
+ semantic_encoder=vision_config.semantic_encoder.pretrained_semantic_encoder,
147
+ z_channels=vision_config.semantic_encoder.z_channels,
148
+ num_blocks=vision_config.semantic_encoder.num_blocks,
149
+ embed_dim=vision_config.semantic_encoder.embed_dim,
150
+ proj_layer=vision_config.semantic_encoder.out_layer,
151
+ attn_implementation=attn_implementation,
152
+ target_mlp=vision_config.semantic_encoder.target_mlp, ).model
153
+ self.pixel_encoder = MoVQEncoder(vision_config.pixel_encoder)
154
+ self.scaling_layer = ScalingLayerForQwen2ViT()
155
+
156
+ def forward(self, images):
157
+ if isinstance(images, list) and all(x is not None and x.shape == images[0].shape for x in images):
158
+ images = torch.concat(images, dim=0)
159
+ images = images.to(device=self.device, dtype=self.dtype)
160
+ else:
161
+ images = [image.to(device=self.device, dtype=self.dtype) for image in images]
162
+
163
+ image_feature_shape_pixels, image_feature_shape_semantics = [], []
164
+ if isinstance(images, list): # anyres setting
165
+ h_pixels = []
166
+ for image in images:
167
+ if image.ndim == 3:
168
+ image = image.unsqueeze(0)
169
+ h_pixel = self.pixel_encoder(image)
170
+ b, c, h, w = h_pixel.shape
171
+ image_feature_shape_pixels.append((h, w))
172
+ h_pixel = rearrange(h_pixel, 'b c h w -> b (h w) c')
173
+ h_pixels.append(h_pixel)
174
+ h_pixels = torch.cat(h_pixels, dim=1)
175
+
176
+ h_semantics = []
177
+ for image in images:
178
+ if image.ndim == 3:
179
+ image = image.unsqueeze(0)
180
+ image = image.unsqueeze(dim=1)
181
+ scale_output = self.scaling_layer(image.clone())
182
+ image_2, image_grid_thw = scale_output['image'], scale_output['image_grid_thw']
183
+ image_feature_shape_semantics.append((int(image_grid_thw[0][1]) // 2, int(image_grid_thw[0][2] // 2)))
184
+ h_semantic = self.semantic_encoder(image_2, image_grid_thw)
185
+ h_semantics.append(h_semantic)
186
+ h_semantics = torch.cat(h_semantics, dim=0)
187
+ h_semantics = h_semantics.unsqueeze(dim=0)
188
+
189
+ image_feature_shapes = [[shape_semantic, shape_pixel] for shape_semantic, shape_pixel in
190
+ zip(image_feature_shape_semantics, image_feature_shape_pixels)]
191
+
192
+ else: # fixed res setting
193
+ assert images.ndim == 4
194
+ h_pixels = self.pixel_encoder(images)
195
+ b, c, h, w = h_pixels.shape
196
+ h_pixels = rearrange(h_pixels, 'b c h w -> (b h w) c')
197
+ h_pixels = h_pixels.unsqueeze(dim=0)
198
+
199
+ images = images.unsqueeze(dim=1)
200
+ scale_output = self.scaling_layer(images.clone())
201
+ images_2, images_grid_thw = scale_output['image'], scale_output['image_grid_thw']
202
+
203
+ h_semantics = self.semantic_encoder(images_2, images_grid_thw)
204
+ h_semantics = h_semantics.unsqueeze(dim=0)
205
+
206
+ shape_semantic = (int(images_grid_thw[0][1]) // 2, int(images_grid_thw[0][2] // 2))
207
+ shape_pixel = (h, w)
208
+ image_feature_shapes = [[shape_semantic, shape_pixel] for i in range(b)]
209
+
210
+ return [h_semantics, h_pixels], image_feature_shapes
211
+
212
+ @property
213
+ def dtype(self):
214
+ return self.semantic_encoder.dtype
215
+
216
+ @property
217
+ def device(self):
218
+ return self.semantic_encoder.device
219
+
220
+ @property
221
+ def config(self):
222
+ return self._config
223
+
224
+ @property
225
+ def hidden_size(self):
226
+ return self.config.hidden_size
227
+
228
+
229
+ class ILLUMEPreTrainedModel(PreTrainedModel):
230
+ config_class = ILLUMEConfig
231
+ base_model_prefix = "model"
232
+ supports_gradient_checkpointing = True
233
+ _no_split_modules = ["ILLUMEVisionAttention"]
234
+ _skip_keys_device_placement = "past_key_values"
235
+ _supports_flash_attn_2 = True
236
+ _supports_cache_class = True
237
+
238
+ def _init_weights(self, module):
239
+ std = (
240
+ self.config.initializer_range
241
+ if hasattr(self.config, "initializer_range")
242
+ else self.config.text_config.initializer_range
243
+ )
244
+
245
+ if hasattr(module, "class_embedding"):
246
+ module.class_embedding.data.normal_(mean=0.0, std=std)
247
+
248
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
249
+ module.weight.data.normal_(mean=0.0, std=std)
250
+ if module.bias is not None:
251
+ module.bias.data.zero_()
252
+ elif isinstance(module, nn.Embedding):
253
+ module.weight.data.normal_(mean=0.0, std=std)
254
+ if module.padding_idx is not None:
255
+ module.weight.data[module.padding_idx].zero_()
256
+
257
+ @property
258
+ def _supports_sdpa(self):
259
+ """
260
+ Retrieve language_model's attribute to check whether the model supports
261
+ SDPA or not.
262
+ """
263
+ return self.language_model._supports_sdpa
264
+
265
+
266
+ class ILLUMEForConditionalGeneration(ILLUMEPreTrainedModel):
267
+ def __init__(self, config: ILLUMEConfig, **kwargs):
268
+ super().__init__(config)
269
+ self.vision_tower = ILLUMEDualVisionTower(config.vision_config)
270
+ self.mm_projector = ILLUMEMultiModalProjector(config)
271
+
272
+ self.vocab_size = config.text_config.vocab_size
273
+ self.language_model = AutoModelForCausalLM.from_config(
274
+ config.text_config, attn_implementation=config._attn_implementation
275
+ )
276
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
277
+ self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
278
+ self.post_init()
279
+
280
+ @property
281
+ def padding_side(self):
282
+ return self._padding_side
283
+
284
+ @padding_side.setter
285
+ def padding_side(self, padding_side: str):
286
+ if padding_side not in ["left", "right"]:
287
+ raise ValueError(f"{padding_side} is not `left` or `right`.")
288
+ self._padding_side = padding_side
289
+
290
+ def get_input_embeddings(self):
291
+ return self.language_model.get_input_embeddings()
292
+
293
+ def set_input_embeddings(self, value):
294
+ self.language_model.set_input_embeddings(value)
295
+
296
+ def get_output_embeddings(self):
297
+ return self.language_model.get_output_embeddings()
298
+
299
+ def set_output_embeddings(self, new_embeddings):
300
+ self.language_model.set_output_embeddings(new_embeddings)
301
+
302
+ def set_decoder(self, decoder):
303
+ self.language_model.set_decoder(decoder)
304
+
305
+ def get_decoder(self):
306
+ return self.language_model.get_decoder()
307
+
308
+ def tie_weights(self):
309
+ return self.language_model.tie_weights()
310
+
311
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
312
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
313
+ # update vocab size
314
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
315
+ self.vocab_size = model_embeds.num_embeddings
316
+ return model_embeds
317
+
318
+ def _add_eol(self, x_feat, eol_feature):
319
+ h, w, C = x_feat.shape
320
+ eol_feature = eol_feature.unsqueeze(0).unsqueeze(0).expand(h, 1, C)
321
+ x_feat = torch.cat([x_feat, eol_feature], dim=1)
322
+ x_feat = x_feat.view(-1, C)
323
+ return x_feat
324
+
325
+ def _reformat_image_sequence(self, x, special_tokens_features, level):
326
+ # add end_of_line
327
+ x = self._add_eol(x, special_tokens_features[2])
328
+
329
+ # add soi, eoi, sol, eol
330
+ x = torch.cat([
331
+ special_tokens_features[3 + level * 2].unsqueeze(0),
332
+ x,
333
+ special_tokens_features[3 + level * 2 + 1].unsqueeze(0),
334
+ ], dim=0)
335
+ return x
336
+
337
+ def _merge_input_ids_with_image_features(
338
+ self,
339
+ image_features,
340
+ feature_lens,
341
+ inputs_embeds,
342
+ input_ids,
343
+ attention_mask,
344
+ position_ids=None,
345
+ labels=None,
346
+ image_token_index=None,
347
+ ignore_index=-100,
348
+ ):
349
+ image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index
350
+ ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index
351
+
352
+ with torch.no_grad():
353
+ num_images = feature_lens.size(0)
354
+ num_image_features, embed_dim = image_features.shape
355
+ if feature_lens.sum() != num_image_features:
356
+ raise ValueError(f"{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}")
357
+ batch_size = input_ids.shape[0]
358
+ _left_padding = torch.any(attention_mask[:, 0] == 0)
359
+ _right_padding = torch.any(attention_mask[:, -1] == 0)
360
+
361
+ left_padding = True if not self.training else False
362
+ if batch_size > 1 and not self.training:
363
+ if _left_padding and not _right_padding:
364
+ left_padding = True
365
+ elif not _left_padding and _right_padding:
366
+ left_padding = False
367
+ elif not _left_padding and not _right_padding:
368
+ # both side is 1, so cannot tell
369
+ left_padding = self.padding_side == "left"
370
+ else:
371
+ # invalid attention_mask
372
+ raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
373
+
374
+ # Whether to turn off right padding
375
+ # 1. Create a mask to know where special image tokens are
376
+ special_image_token_mask = input_ids == image_token_index
377
+ # special_image_token_mask: [bsz, seqlen]
378
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
379
+ # num_special_image_tokens: [bsz]
380
+ # Reserve for padding of num_images
381
+ total_num_special_image_tokens = torch.sum(special_image_token_mask)
382
+ if total_num_special_image_tokens != num_images:
383
+ raise ValueError(
384
+ f"Number of image tokens in input_ids ({total_num_special_image_tokens}) different from num_images ({num_images})."
385
+ )
386
+ # Compute the maximum embed dimension
387
+ # max_image_feature_lens is max_feature_lens per batch
388
+ feature_lens = feature_lens.to(input_ids.device)
389
+ feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0)
390
+ feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=input_ids.device)
391
+ embed_sequence_lengths = (
392
+ (attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum
393
+ )
394
+ max_embed_dim = embed_sequence_lengths.max()
395
+
396
+ batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1))
397
+ # 2. Compute the positions where text should be written
398
+ # Calculate new positions for text tokens in merged image-text sequence.
399
+ # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images` text tokens.
400
+ # `torch.cumsum` computes how each image token shifts subsequent text token positions.
401
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
402
+ # ! instead of special_image_token_mask * (num_image_patches - 1)
403
+ # special_image_token_mask * (num_feature_len - 1)
404
+ special_image_token_mask = special_image_token_mask.long()
405
+ special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1
406
+ new_token_positions = torch.cumsum((special_image_token_mask + 1), -1) - 1
407
+ if left_padding:
408
+ # shift right token positions so that they are ending at the same number
409
+ # the below here was incorrect? new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:]
410
+ new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:]
411
+
412
+ text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
413
+
414
+ # 3. Create the full embedding, already padded to the maximum position
415
+ final_embedding = torch.zeros(
416
+ batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
417
+ )
418
+ final_attention_mask = torch.zeros(
419
+ batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
420
+ )
421
+ final_input_ids = torch.full(
422
+ (batch_size, max_embed_dim), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device
423
+ )
424
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
425
+ # set the corresponding tensors into their correct target device.
426
+ target_device = inputs_embeds.device
427
+ batch_indices, non_image_indices, text_to_overwrite = (
428
+ batch_indices.to(target_device),
429
+ non_image_indices.to(target_device),
430
+ text_to_overwrite.to(target_device),
431
+ )
432
+ attention_mask = attention_mask.to(target_device)
433
+ input_ids = input_ids.to(target_device)
434
+
435
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
436
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
437
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
438
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
439
+ final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_image_indices]
440
+ final_labels = None
441
+ if labels is not None:
442
+ labels = labels.to(target_device)
443
+ final_labels = torch.full_like(final_attention_mask, ignore_index).to(torch.long)
444
+ final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
445
+
446
+ # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
447
+ with torch.no_grad():
448
+ image_to_overwrite = torch.full(
449
+ (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
450
+ )
451
+ image_to_overwrite[batch_indices, text_to_overwrite] = False
452
+ embed_indices = torch.arange(max_embed_dim).unsqueeze(0).to(target_device)
453
+ embed_indices = embed_indices.expand(batch_size, max_embed_dim)
454
+ embed_seq_lens = embed_sequence_lengths[:, None].to(target_device)
455
+
456
+ if left_padding:
457
+ # exclude padding on the left
458
+ max_embed_dim = max_embed_dim.to(target_device)
459
+ val = (max_embed_dim - embed_indices) <= embed_seq_lens
460
+ else:
461
+ # exclude padding on the right
462
+ val = embed_indices < embed_seq_lens
463
+ image_to_overwrite &= val
464
+
465
+ if image_to_overwrite.sum() != num_image_features:
466
+ raise ValueError(
467
+ f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. "
468
+ f"The number of image tokens is {torch.sum(special_image_token_mask)} while"
469
+ f" the number of image given to the model is {num_images}. "
470
+ f"This prevents correct indexing and breaks batch generation."
471
+ )
472
+ final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
473
+ final_attention_mask |= image_to_overwrite
474
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
475
+
476
+ return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids
477
+
478
+ def forward(
479
+ self,
480
+ input_ids: torch.LongTensor = None,
481
+ pixel_values: torch.FloatTensor = None,
482
+ attention_mask: Optional[torch.Tensor] = None,
483
+ position_ids: Optional[torch.LongTensor] = None,
484
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
485
+ inputs_embeds: Optional[torch.FloatTensor] = None,
486
+ labels: Optional[torch.LongTensor] = None,
487
+ use_cache: Optional[bool] = None,
488
+ output_attentions: Optional[bool] = None,
489
+ output_hidden_states: Optional[bool] = None,
490
+ return_dict: Optional[bool] = None,
491
+ ) -> Union[Tuple, ILLUMECausalLMOutputWithPast]:
492
+ r"""
493
+ Args:
494
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
495
+ Indices of input sequence tokens in the vocabulary.
496
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
497
+ [`PreTrainedTokenizer.__call__`] for details.
498
+
499
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
500
+ Pixel values of the image to be generated.
501
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
502
+ Mask to avoid performing attention on padding token indices.
503
+ Mask values selected in `[0, 1]`:
504
+ - 1 for tokens that are not masked,
505
+ - 0 for tokens that are masked.
506
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
507
+ Indices of positions of each input sequence tokens in the position embeddings.
508
+ Selected in the range `[0, config.max_position_embeddings - 1]`.
509
+
510
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
511
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
512
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
513
+ use_cache (`bool`, *optional*):
514
+ If `True`, past_key_values are returned and can be used to speed up decoding.
515
+ output_attentions (`bool`, *optional*):
516
+ Whether to return the attentions tensors of all attention layers.
517
+ output_hidden_states (`bool`, *optional*):
518
+ Whether to return the hidden states of all layers.
519
+ return_dict (`bool`, *optional*):
520
+ Whether to return a dictionary of outputs instead of a plain tuple.
521
+
522
+ Returns:
523
+ An instance of [`ILLUMECausalLMOutputWithPast`] if `return_dict=True`. Otherwise, it returns a tuple of tensors.
524
+
525
+ Example:
526
+
527
+ ```python
528
+ >>> import torch
529
+ >>> from PIL import Image
530
+ >>> import requests
531
+ >>> from transformers import AutoProcessor, ILLUMEForConditionalGeneration
532
+
533
+ >>> # Load model and processor
534
+ >>> # Specify `torch_dtype` for mixed-precision inference, e.g., torch.bfloat16 or torch.float16
535
+ >>> # Specify `attn_implementation="flash_attention_2"` if Flash Attention 2 is installed and supported, or "sdpa" for PyTorch SDPA.
536
+ >>> # `low_cpu_mem_usage=True` can help reduce CPU memory for large models.
537
+ >>> model = ILLUMEForConditionalGeneration.from_pretrained(
538
+ ... "illume-unified-mllm/illume_plus-qwen2_5-3b-hf",
539
+ ... torch_dtype=torch.bfloat16, # Optional: Or torch.float16. Adjust based on your hardware.
540
+ ... low_cpu_mem_usage=True, # Optional: Reduces CPU RAM during model loading.
541
+ ... attn_implementation="sdpa", # Optional: Use "flash_attention_2" if available for better performance.
542
+ ... trust_remote_code=True
543
+ ... ).eval()
544
+ >>> # To use GPU: model = model.to("cuda") # Ensure the model is on the correct device
545
+
546
+ >>> processor = AutoProcessor.from_pretrained(
547
+ ... "illume-unified-mllm/illume_plus-qwen2_5-3b-hf",
548
+ ... trust_remote_code=True
549
+ ... )
550
+
551
+ >>> # Prepare inputs: a text prompt and an image
552
+ >>> # The processor formats the input for the model, including applying the chat template.
553
+ >>> messages = [
554
+ ... {"role": "user", "content": [
555
+ ... {"type": "image"},
556
+ ... {"type": "text", "text": "What is shown in this image?"}
557
+ ... ]}
558
+ ... ]
559
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" # An example image URL
560
+ >>> image = Image.open(requests.get(url, stream=True).raw)
561
+
562
+ >>> inputs = processor(text=messages, images=[image], return_tensors="pt")
563
+ >>> # To use GPU: inputs = {k: v.to("cuda") for k, v in inputs.items()} # Move inputs to the same device as the model
564
+
565
+ >>> # Generate text based on the input
566
+ >>> gen_kwargs = {"max_new_tokens": 100, "do_sample": False} # Generation parameters
567
+ >>> with torch.no_grad(): # Disable gradient calculations for inference
568
+ ... outputs = model.generate(**inputs, **gen_kwargs)
569
+
570
+ >>> # Decode the generated tokens, removing the prompt
571
+ >>> input_token_len = inputs["input_ids"].shape[1]
572
+ >>> generated_ids = outputs[:, input_token_len:]
573
+ >>> response = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
574
+
575
+ >>> print(response)
576
+ ```"""
577
+
578
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
579
+ output_hidden_states = (
580
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
581
+ )
582
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
583
+
584
+ if inputs_embeds is None:
585
+ # 1. Extract the input embeddings
586
+ # In case image_token_index is not in the embeddings (extra token but embedding don't have it)
587
+ for_inputs_embeds_ids = input_ids.clone()
588
+ for_inputs_embeds_ids[(input_ids == self.config.image_token_index)] = 0
589
+ inputs_embeds = self.get_input_embeddings()(for_inputs_embeds_ids)
590
+
591
+ # 2. Merge text and images
592
+ if pixel_values is not None and input_ids.shape[1] != 1 and len(pixel_values) > 0:
593
+ image_features, image_feature_shapes = self.vision_tower(pixel_values)
594
+ image_features = self.mm_projector(image_features)
595
+
596
+ # reformat image sequence
597
+ # special_tokens_ids: <start_of_image>, <end_of_image>, <end_of_line>, <start_of_level0>, <end_of_level0>, <start_of_level1>, <end_of_level1>
598
+ special_tokens_ids = torch.Tensor([self.config.special_tokens_ids[key] for key in
599
+ ["<start_of_image>", "<end_of_image>", "<end_of_line>",
600
+ "<start_of_level0>",
601
+ "<end_of_level0>", "<start_of_level1>",
602
+ "<end_of_level1>"]]).long().to(image_features.device)
603
+
604
+ semantic_sizes = [h * w for (h, w), _ in image_feature_shapes]
605
+ pixel_sizes = [h * w for _, (h, w) in image_feature_shapes]
606
+
607
+ # English: Split the image features into semantic features and reshape them to (h, w, -1)
608
+ semantic_features = torch.split(image_features[:, :sum(semantic_sizes), :], semantic_sizes, dim=1)
609
+ h_semantics = [feat.view(h, w, -1) for feat, ((h, w), _) in
610
+ zip(semantic_features, image_feature_shapes)]
611
+
612
+ # English: Split the image features into pixel features and reshape them to (h, w, -1)
613
+ det_features = torch.split(
614
+ image_features[:, sum(semantic_sizes): sum(semantic_sizes) + sum(pixel_sizes), :], pixel_sizes,
615
+ dim=1)
616
+ h_pixels = [feat.view(h, w, -1) for feat, (_, (h, w)) in zip(det_features, image_feature_shapes)]
617
+
618
+ special_tokens_features = self.language_model.model.embed_tokens(special_tokens_ids)
619
+
620
+ image_features = []
621
+ feature_lens = []
622
+ for h_semantic, h_pixel in zip(h_semantics, h_pixels):
623
+ h_semantic = self._reformat_image_sequence(h_semantic, special_tokens_features.clone(), level=0)
624
+ h_pixel = self._reformat_image_sequence(h_pixel, special_tokens_features.clone(), level=1)
625
+
626
+ image_feature = torch.cat([special_tokens_features[0].unsqueeze(0), h_semantic, h_pixel,
627
+ special_tokens_features[1].unsqueeze(0)], dim=0)
628
+ image_features.append(image_feature)
629
+ feature_lens.append(image_feature.shape[0])
630
+
631
+ feature_lens = torch.as_tensor(feature_lens)
632
+ image_features = torch.cat(image_features, dim=0)
633
+
634
+ inputs_embeds = inputs_embeds.to(self.dtype)
635
+ inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features(
636
+ image_features,
637
+ feature_lens,
638
+ inputs_embeds,
639
+ input_ids,
640
+ attention_mask,
641
+ position_ids,
642
+ labels=labels,
643
+ )
644
+
645
+ # pixel_values is not None but is empty ---> text only cases
646
+ elif pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) == 0:
647
+ # there are no images
648
+ pass
649
+
650
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
651
+ # generation with cache
652
+ elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
653
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
654
+ # that are set to 0
655
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
656
+
657
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
658
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
659
+
660
+ # Get the target length
661
+ target_length = input_ids.shape[1]
662
+ past_length = first_layer_past_key_value.shape[-1]
663
+
664
+ extended_attention_mask = torch.ones(
665
+ (attention_mask.shape[0], past_length),
666
+ dtype=attention_mask.dtype,
667
+ device=attention_mask.device,
668
+ )
669
+
670
+ # Filter out only the tokens that can be un-attended, this can happen
671
+ # if one uses ILLUME + Fused modules where the cache on the
672
+ # first iteration is already big enough, or if one passes custom cache
673
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
674
+ new_batch_index = batch_index[valid_indices]
675
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
676
+
677
+ # Zero-out the places where we don't need to attend
678
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
679
+
680
+ attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
681
+
682
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
683
+
684
+ outputs = self.language_model(
685
+ attention_mask=attention_mask.to(inputs_embeds.device) if attention_mask is not None else attention_mask,
686
+ position_ids=position_ids,
687
+ past_key_values=past_key_values,
688
+ inputs_embeds=inputs_embeds,
689
+ use_cache=use_cache,
690
+ output_attentions=output_attentions,
691
+ output_hidden_states=output_hidden_states,
692
+ return_dict=return_dict,
693
+ )
694
+
695
+ logits = outputs[0]
696
+
697
+ loss = None
698
+ if labels is not None:
699
+ # Shift so that tokens < n predict n
700
+ if attention_mask is not None:
701
+ shift_attention_mask = attention_mask[..., 1:]
702
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
703
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
704
+ else:
705
+ shift_logits = logits[..., :-1, :].contiguous()
706
+ shift_labels = labels[..., 1:].contiguous()
707
+ # Flatten the tokens
708
+ loss_fct = nn.CrossEntropyLoss()
709
+ loss = loss_fct(
710
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
711
+ )
712
+
713
+ if not return_dict:
714
+ output = (logits,) + outputs[1:]
715
+ return (loss,) + output if loss is not None else output
716
+
717
+ return ILLUMECausalLMOutputWithPast(
718
+ loss=loss,
719
+ logits=logits,
720
+ past_key_values=outputs.past_key_values,
721
+ hidden_states=outputs.hidden_states,
722
+ attentions=outputs.attentions,
723
+ )
724
+
725
+ def prepare_inputs_for_generation(
726
+ self,
727
+ input_ids,
728
+ past_key_values=None,
729
+ inputs_embeds=None,
730
+ pixel_values=None,
731
+ attention_mask=None,
732
+ **kwargs,
733
+ ):
734
+ if past_key_values is not None:
735
+ if isinstance(past_key_values, Cache):
736
+ cache_length = past_key_values.get_seq_length()
737
+ past_length = past_key_values.seen_tokens
738
+ else:
739
+ cache_length = past_length = past_key_values[0][0].shape[2]
740
+
741
+ # Keep only the unprocessed tokens:
742
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
743
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
744
+ # input)
745
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
746
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
747
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
748
+ # input_ids based on the past_length.
749
+ elif past_length < input_ids.shape[1]:
750
+ input_ids = input_ids[:, past_length:]
751
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
752
+ elif self.config.image_token_index in input_ids:
753
+ input_ids = input_ids[:, input_ids.shape[1] - 1:]
754
+ # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
755
+ # older attention values, as their corresponding values are not part of the input.
756
+ if cache_length < past_length and attention_mask is not None:
757
+ attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]):]
758
+
759
+ position_ids = kwargs.get("position_ids", None)
760
+ if attention_mask is not None and position_ids is None:
761
+ # create position_ids on the fly for batch generation
762
+ position_ids = attention_mask.long().cumsum(-1) - 1
763
+ position_ids.masked_fill_(attention_mask == 0, 1)
764
+ if past_key_values:
765
+ position_ids = position_ids[:, -input_ids.shape[1]:]
766
+
767
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
768
+ if inputs_embeds is not None and past_key_values is None:
769
+ model_inputs = {"inputs_embeds": inputs_embeds}
770
+ else:
771
+ model_inputs = {"input_ids": input_ids}
772
+
773
+ model_inputs.update(
774
+ {
775
+ "position_ids": position_ids,
776
+ "past_key_values": past_key_values,
777
+ "use_cache": kwargs.get("use_cache"),
778
+ "attention_mask": attention_mask,
779
+ "pixel_values": pixel_values,
780
+ }
781
+ )
782
+ return model_inputs
783
+
784
+ def _reorder_cache(self, *args, **kwargs):
785
+ return self.language_model._reorder_cache(*args, **kwargs)
786
+
787
+ def prepare_logit_processor(self,
788
+ guidance_scale=2.0,
789
+ negative_prompt_ids=None,
790
+ negative_prompt_attention_mask=None,
791
+ resolution=None,
792
+ temperature=1.0, image_semantic_temperature=1.0, image_pixel_temperature=None,
793
+ top_k=128, image_semantic_top_k=2048, image_pixel_top_k=None,
794
+ top_p=1.0, image_semantic_top_p=1.0, image_pixel_top_p=None,
795
+ images=None,
796
+ ):
797
+ if resolution is not None:
798
+ token_nums, _, h1, w1, h2, w2 = calculate_image_token_num(*resolution)
799
+ else:
800
+ h1, w1, h2, w2 = 0, 0, 0, 0
801
+
802
+ if image_pixel_temperature is None:
803
+ image_pixel_temperature = image_semantic_temperature
804
+ if image_pixel_top_k is None:
805
+ image_pixel_top_k = image_semantic_top_k * 3
806
+ if image_pixel_top_p is None:
807
+ image_pixel_top_p = image_semantic_top_p
808
+
809
+ return InterleavedLogitsProcessor(
810
+ guidance_scale=guidance_scale,
811
+ uncond=negative_prompt_ids,
812
+ attention_mask=negative_prompt_attention_mask,
813
+ model=self,
814
+ # DualVQ parameters
815
+ level0_range=level0_range,
816
+ level1_range=level1_range,
817
+ num_level0_rows=h1, num_level0_tokens=w1,
818
+ num_level1_rows=h2, num_level1_tokens=w2,
819
+ special_tokens=special_tokens_dict,
820
+ # Dynamic Sampling parameters
821
+ default_temp=temperature, level0_temp=image_semantic_temperature, level1_temp=image_pixel_temperature,
822
+ default_top_k=top_k, level0_top_k=image_semantic_top_k, level1_top_k=image_pixel_top_k,
823
+ default_top_p=top_p, level0_top_p=image_semantic_top_p, level1_top_p=image_pixel_top_p,
824
+ images=images
825
+ )
826
+
827
+ def generate(
828
+ self,
829
+ *args,
830
+ temperature: float = 1.0, top_k: int = 128, top_p: float = 1.0,
831
+ pixel_values: Optional[torch.Tensor] = None,
832
+ # image generation or image editing hyperparameters.
833
+ guidance_scale=1.0, target_image_resolution=None,
834
+ image_semantic_temperature: float = 1.0,
835
+ image_semantic_top_k: int = 2048,
836
+ image_semantic_top_p: float = 1.0,
837
+ image_pixel_temperature: float = 1.0,
838
+ image_pixel_top_k: int = 2048 * 3,
839
+ image_pixel_top_p: float = 1.0,
840
+ negative_image_prompt_ids: Optional[torch.Tensor] = None,
841
+ negative_image_prompt_attention_mask: Optional[torch.Tensor] = None,
842
+ disable_logit_processor=False,
843
+ logits_processor=None,
844
+ **kwargs,
845
+ ):
846
+ if target_image_resolution is not None:
847
+ # check if target_image_resolution valied.
848
+ if not isinstance(target_image_resolution, tuple) or len(target_image_resolution) != 2:
849
+ raise ValueError("target_image_resolution must be a tuple of two integers.")
850
+ if not all(isinstance(dim, int) and dim > 0 for dim in target_image_resolution):
851
+ raise ValueError("target_image_resolution must contain positive integers.")
852
+
853
+ if target_image_resolution not in DEFAULT_RESOLUTIONS:
854
+ raise ValueError(
855
+ "target_image_resolution must be in one of the following ratios: " + str(DEFAULT_RESOLUTIONS))
856
+
857
+ if logits_processor is None:
858
+ logits_processor = LogitsProcessorList([])
859
+
860
+ if not disable_logit_processor:
861
+ illume_logit_processor = self.prepare_logit_processor(
862
+ negative_prompt_ids=negative_image_prompt_ids,
863
+ negative_prompt_attention_mask=negative_image_prompt_attention_mask,
864
+ temperature=temperature,
865
+ top_k=top_k,
866
+ top_p=top_p,
867
+ guidance_scale=guidance_scale,
868
+ resolution=target_image_resolution,
869
+ image_semantic_temperature=image_semantic_temperature,
870
+ image_pixel_temperature=image_pixel_temperature,
871
+ image_semantic_top_k=image_semantic_top_k,
872
+ image_pixel_top_k=image_pixel_top_k,
873
+ image_semantic_top_p=image_semantic_top_p,
874
+ image_pixel_top_p=image_pixel_top_p,
875
+ images=pixel_values,
876
+ )
877
+ logits_processor.append(illume_logit_processor)
878
+
879
+ return super(ILLUMEForConditionalGeneration, self).generate(
880
+ *args,
881
+ pixel_values=pixel_values,
882
+ logits_processor=logits_processor,
883
+ **kwargs)
modeling_movqgan.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ MoVQ model """
2
+
3
+ import math
4
+ from typing import Optional, Tuple, Union
5
+
6
+ import torch
7
+ from einops import rearrange, repeat
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from torch.utils.checkpoint import checkpoint
11
+ from transformers.modeling_utils import PreTrainedModel
12
+
13
+ from .configuration_movqgan import MoVQConfig
14
+
15
+ try:
16
+ import xformers.ops as xops
17
+
18
+ is_xformers_available = True
19
+ except Exception as e:
20
+ is_xformers_available = False
21
+
22
+ if torch.__version__ > "2.1.2":
23
+ IS_SDPA_AVAILABLE = True
24
+ else:
25
+ IS_SDPA_AVAILABLE = False
26
+
27
+
28
+ class MoVQActivation(nn.Module):
29
+
30
+ def __init__(self):
31
+ super().__init__()
32
+
33
+ def __call__(self, x: torch.Tensor):
34
+ return x * torch.sigmoid(x)
35
+
36
+
37
+ class MoVQUpsample(nn.Module):
38
+
39
+ def __init__(self, in_channels: int):
40
+ super().__init__()
41
+ self.conv = nn.Conv2d(
42
+ in_channels,
43
+ in_channels,
44
+ kernel_size=3,
45
+ stride=1,
46
+ padding=1,
47
+ )
48
+
49
+ def forward(self, x: torch.Tensor):
50
+ x = F.interpolate(x.float(), scale_factor=2.0, mode="nearest").to(x.dtype)
51
+ x = self.conv(x)
52
+ return x
53
+
54
+
55
+ class DCDownBlock2d(nn.Module):
56
+ def __init__(self, in_channels: int, out_channels: int = None, downsample: bool = True,
57
+ shortcut: bool = True) -> None:
58
+ super().__init__()
59
+ out_channels = out_channels if out_channels else in_channels
60
+
61
+ self.downsample = downsample
62
+ self.factor = 2
63
+ self.stride = 1 if downsample else 2
64
+ self.group_size = in_channels * self.factor ** 2 // out_channels
65
+ self.shortcut = shortcut
66
+
67
+ out_ratio = self.factor ** 2
68
+ if downsample:
69
+ assert out_channels % out_ratio == 0
70
+ out_channels = out_channels // out_ratio
71
+
72
+ self.conv = nn.Conv2d(
73
+ in_channels,
74
+ out_channels,
75
+ kernel_size=3,
76
+ stride=self.stride,
77
+ padding=1,
78
+ )
79
+
80
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
81
+ x = self.conv(hidden_states)
82
+ if self.downsample:
83
+ x = F.pixel_unshuffle(x, self.factor)
84
+
85
+ if self.shortcut:
86
+ y = F.pixel_unshuffle(hidden_states, self.factor)
87
+ y = y.unflatten(1, (-1, self.group_size))
88
+ y = y.mean(dim=2)
89
+ hidden_states = x + y
90
+ else:
91
+ hidden_states = x
92
+
93
+ return hidden_states # x + y
94
+
95
+
96
+ class DCUpBlock2d(nn.Module):
97
+ def __init__(
98
+ self,
99
+ in_channels: int,
100
+ out_channels: int = None,
101
+ interpolate: bool = False,
102
+ shortcut: bool = True,
103
+ interpolation_mode: str = "nearest",
104
+ ) -> None:
105
+ super().__init__()
106
+ out_channels = out_channels if out_channels else in_channels
107
+
108
+ self.interpolate = interpolate
109
+ self.interpolation_mode = interpolation_mode
110
+ self.shortcut = shortcut
111
+ self.factor = 2
112
+ self.repeats = out_channels * self.factor ** 2 // in_channels
113
+
114
+ out_ratio = self.factor ** 2
115
+
116
+ if not interpolate:
117
+ out_channels = out_channels * out_ratio
118
+
119
+ self.conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
120
+
121
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
122
+ if self.interpolate:
123
+ x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode)
124
+ x = self.conv(x)
125
+ else:
126
+ x = self.conv(hidden_states)
127
+ x = F.pixel_shuffle(x, self.factor)
128
+
129
+ if self.shortcut:
130
+ y = hidden_states.repeat_interleave(self.repeats, dim=1)
131
+ y = F.pixel_shuffle(y, self.factor)
132
+ hidden_states = x + y
133
+ else:
134
+ hidden_states = x
135
+
136
+ return hidden_states
137
+
138
+
139
+ class MoVQDownsample(nn.Module):
140
+
141
+ def __init__(self, in_channels: int):
142
+ super().__init__()
143
+ self.conv = nn.Conv2d(
144
+ in_channels,
145
+ in_channels,
146
+ kernel_size=3,
147
+ stride=2,
148
+ padding=0,
149
+ )
150
+
151
+ def forward(self, x: torch.Tensor):
152
+ pad = (0, 1, 0, 1)
153
+ x = F.pad(x, pad, mode="constant", value=0)
154
+ x = self.conv(x)
155
+ return x
156
+
157
+
158
+ class MoVQSpatialNorm(nn.Module):
159
+
160
+ def __init__(
161
+ self,
162
+ f_channels: int,
163
+ zq_channels: int,
164
+ norm_layer: nn.Module = nn.GroupNorm,
165
+ add_conv: bool = False,
166
+ num_groups: int = 32,
167
+ eps: float = 1e-6,
168
+ affine: bool = True,
169
+ ):
170
+ super().__init__()
171
+ self.norm_layer = norm_layer(
172
+ num_channels=f_channels,
173
+ num_groups=num_groups,
174
+ eps=eps,
175
+ affine=affine,
176
+ )
177
+
178
+ self.add_conv = add_conv
179
+ if self.add_conv:
180
+ self.conv = nn.Conv2d(
181
+ zq_channels,
182
+ zq_channels,
183
+ kernel_size=3,
184
+ stride=1,
185
+ padding=1,
186
+ )
187
+
188
+ self.conv_y = nn.Conv2d(
189
+ zq_channels,
190
+ f_channels,
191
+ kernel_size=1,
192
+ stride=1,
193
+ padding=0,
194
+ )
195
+ self.conv_b = nn.Conv2d(
196
+ zq_channels,
197
+ f_channels,
198
+ kernel_size=1,
199
+ stride=1,
200
+ padding=0,
201
+ )
202
+
203
+ def forward(self, x: torch.Tensor, zq: torch.Tensor):
204
+ zq = F.interpolate(zq.float(), size=x.shape[-2:], mode="nearest").to(zq.dtype)
205
+
206
+ if self.add_conv:
207
+ zq = self.conv(zq)
208
+
209
+ x = self.norm_layer(x)
210
+ x = x * self.conv_y(zq) + self.conv_b(zq)
211
+ return x
212
+
213
+
214
+ class MoVQResnetBlock(nn.Module):
215
+
216
+ def __init__(
217
+ self,
218
+ in_channels: int,
219
+ out_channels: Optional[int] = None,
220
+ conv_shortcut: bool = False,
221
+ dropout: float = 0.0,
222
+ zq_ch: Optional[int] = None,
223
+ add_conv: bool = False,
224
+ ):
225
+ super().__init__()
226
+ self.in_channels = in_channels
227
+ out_channels = in_channels if out_channels is None else out_channels
228
+ self.out_channels = out_channels
229
+ self.use_conv_shortcut = conv_shortcut
230
+ self.zq_ch = zq_ch
231
+
232
+ if zq_ch is None:
233
+ norm_kwargs = dict(num_groups=32, eps=1e-6, affine=True)
234
+ self.norm1 = nn.GroupNorm(num_channels=in_channels, **norm_kwargs)
235
+ self.norm2 = nn.GroupNorm(num_channels=out_channels, **norm_kwargs)
236
+ else:
237
+ self.norm1 = MoVQSpatialNorm(in_channels, zq_ch, add_conv=add_conv)
238
+ self.norm2 = MoVQSpatialNorm(out_channels, zq_ch, add_conv=add_conv)
239
+
240
+ self.conv1 = nn.Conv2d(
241
+ in_channels,
242
+ out_channels,
243
+ kernel_size=3,
244
+ stride=1,
245
+ padding=1,
246
+ )
247
+
248
+ self.dropout = nn.Dropout(dropout)
249
+ self.conv2 = nn.Conv2d(
250
+ out_channels,
251
+ out_channels,
252
+ kernel_size=3,
253
+ stride=1,
254
+ padding=1,
255
+ )
256
+
257
+ self.act = MoVQActivation()
258
+
259
+ if self.in_channels != self.out_channels:
260
+ if self.use_conv_shortcut:
261
+ self.conv_shortcut = nn.Conv2d(
262
+ in_channels,
263
+ out_channels,
264
+ kernel_size=3,
265
+ stride=1,
266
+ padding=1,
267
+ )
268
+ else:
269
+ self.nin_shortcut = nn.Conv2d(
270
+ in_channels,
271
+ out_channels,
272
+ kernel_size=1,
273
+ stride=1,
274
+ padding=0,
275
+ )
276
+
277
+ def forward(self, x: torch.Tensor, zq: Optional[torch.Tensor] = None):
278
+ norm_args = tuple() if self.zq_ch is None else (zq,)
279
+
280
+ h = self.norm1(x, *norm_args)
281
+ h = self.act(h)
282
+ h = self.conv1(h)
283
+
284
+ h = self.norm2(h, *norm_args)
285
+ h = self.act(h)
286
+ h = self.dropout(h)
287
+ h = self.conv2(h)
288
+
289
+ if self.in_channels != self.out_channels:
290
+ if self.use_conv_shortcut:
291
+ x = self.conv_shortcut(x)
292
+ else:
293
+ x = self.nin_shortcut(x)
294
+
295
+ return x + h
296
+
297
+
298
+ class MoVQAttnBlock(nn.Module):
299
+
300
+ def __init__(
301
+ self,
302
+ in_channels: int,
303
+ zq_ch: Optional[int] = None,
304
+ add_conv: bool = False,
305
+ num_heads=1,
306
+ ):
307
+ super().__init__()
308
+ self.in_channels = in_channels
309
+ self.zq_ch = zq_ch
310
+ self.num_heads = num_heads
311
+
312
+ if zq_ch is None:
313
+ norm_kwargs = dict(num_groups=32, eps=1e-6, affine=True)
314
+ self.norm = nn.GroupNorm(num_channels=in_channels, **norm_kwargs)
315
+ else:
316
+ self.norm = MoVQSpatialNorm(in_channels, zq_ch, add_conv=add_conv)
317
+
318
+ self.q = nn.Conv2d(
319
+ in_channels,
320
+ in_channels,
321
+ kernel_size=1,
322
+ stride=1,
323
+ padding=0,
324
+ )
325
+ self.k = nn.Conv2d(
326
+ in_channels,
327
+ in_channels,
328
+ kernel_size=1,
329
+ stride=1,
330
+ padding=0,
331
+ )
332
+ self.v = nn.Conv2d(
333
+ in_channels,
334
+ in_channels,
335
+ kernel_size=1,
336
+ stride=1,
337
+ padding=0,
338
+ )
339
+ self.proj_out = nn.Conv2d(
340
+ in_channels,
341
+ in_channels,
342
+ kernel_size=1,
343
+ stride=1,
344
+ padding=0,
345
+ )
346
+
347
+ def forward(self, x: torch.Tensor, zq: Optional[torch.Tensor] = None):
348
+ # x: [b, c1, h1, w1]
349
+ # zq: [b, c2, h2, w2]
350
+ # attention_mask: [b, 1, h3, w3]
351
+ norm_args = tuple() if self.zq_ch is None else (zq,)
352
+
353
+ # if context is not None:
354
+ # context = F.interpolate(context.float(), size=x.shape[-2:], mode="nearest").to(context.dtype)
355
+ # x = x + self.conv_context(context)
356
+
357
+ nx = self.norm(x, *norm_args)
358
+ q = self.q(nx)
359
+ k = self.k(nx)
360
+ v = self.v(nx)
361
+
362
+ b, c, h, w = q.shape
363
+ if is_xformers_available:
364
+ # If xformers is available, create attn_bias for xops.memory_efficient_attention.
365
+ attn_bias = None
366
+
367
+ v = xops.memory_efficient_attention(
368
+ rearrange(q, 'b (n c) h w -> b (h w) n c', n=self.num_heads).contiguous(),
369
+ rearrange(k, 'b (n c) h w -> b (h w) n c', n=self.num_heads).contiguous(),
370
+ rearrange(v, 'b (n c) h w -> b (h w) n c', n=self.num_heads).contiguous(),
371
+ scale=1.0 / math.sqrt(c // self.num_heads),
372
+ attn_bias=attn_bias,
373
+ )
374
+ v = rearrange(v, 'b (h w) n c -> b (n c) h w', h=h, w=w).contiguous()
375
+ elif IS_SDPA_AVAILABLE:
376
+ # compute attention
377
+ q = rearrange(q, 'b (n c) h w -> b n (h w) c', n=self.num_heads).contiguous()
378
+ k = rearrange(k, 'b (n c) h w -> b n (h w) c', n=self.num_heads).contiguous()
379
+ v = rearrange(v, 'b (n c) h w -> b n (h w) c', n=self.num_heads).contiguous()
380
+
381
+ attn_bias = None
382
+
383
+ v = F.scaled_dot_product_attention(q, k, v, attn_bias, dropout_p=0.0)
384
+ v = v.transpose(1, 2)
385
+ v = rearrange(v, 'b (h w) n c -> b (n c) h w', h=h, w=w)
386
+ else:
387
+ # compute attention
388
+ q = rearrange(q, 'b (n c) h w -> b n c (h w)', n=self.num_heads).contiguous()
389
+ k = rearrange(k, 'b (n c) h w -> b n c (h w)', n=self.num_heads).contiguous()
390
+ v = rearrange(v, 'b (n c) h w -> b n c (h w)', n=self.num_heads).contiguous()
391
+
392
+ # score = torch.bmm(q.permute(0, 2, 1), k)
393
+ score = torch.einsum('b n c k, b n c l -> b n k l', q, k)
394
+ score = score / math.sqrt(c // self.num_heads)
395
+
396
+ score = F.softmax(score, dim=2)
397
+
398
+ # attend to values
399
+ # v = v.reshape(b, c, h * w)
400
+ # v = torch.bmm(v, score.permute(0, 2, 1))
401
+ v = torch.einsum('b n c l, b n k l -> b n c k', v, score)
402
+ v = v.reshape(b, c, h, w)
403
+
404
+ v = self.proj_out(v)
405
+
406
+ return x + v
407
+
408
+
409
+ class MoVQVectorQuantizer(nn.Module):
410
+
411
+ def __init__(self, config: MoVQConfig):
412
+ super().__init__()
413
+ self.embedding = nn.Embedding(config.codebook_size, config.embed_dim)
414
+ self.embedding.weight.data.uniform_(-1.0 / config.codebook_size, 1.0 / config.codebook_size)
415
+
416
+ def forward(self, x: torch.Tensor):
417
+ # b t c h w -> b t h w c
418
+ b, t, c, h, w = x.shape
419
+ x = x.permute(0, 1, 3, 4, 2).contiguous()
420
+ x_flattened = x.view(-1, c)
421
+
422
+ codebook = self.embedding.weight
423
+
424
+ d = torch.sum(x_flattened ** 2, dim=1, keepdim=True) + \
425
+ torch.sum(codebook ** 2, dim=1) - 2 * \
426
+ torch.einsum('bd,dn->bn', x_flattened, codebook.permute(1, 0))
427
+
428
+ indices = torch.argmin(d, dim=1)
429
+ indices = indices.view(b, t, h, w)
430
+ return indices
431
+
432
+
433
+ class MoVQPretrainedModel(PreTrainedModel):
434
+ """
435
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
436
+ models.
437
+ """
438
+
439
+ config_class = MoVQConfig
440
+ base_model_prefix = "movq"
441
+ main_input_name = "pixel_values"
442
+ _no_split_modules = ["MoVQResnetBlock", "MoVQAttnBlock"]
443
+
444
+ def _init_weights(self, module):
445
+ if isinstance(module, (nn.Conv2d, nn.Conv3d)):
446
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
447
+ # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
448
+ elif isinstance(module, nn.Linear):
449
+ nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
450
+ if module.bias is not None:
451
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
452
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
453
+ nn.init.uniform_(module.bias, -bound, bound)
454
+ elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
455
+ nn.init.constant_(module.weight, 1)
456
+ nn.init.constant_(module.bias, 0)
457
+
458
+
459
+ class MoVQEncoder(nn.Module):
460
+ def __init__(self, config: MoVQConfig):
461
+ super().__init__()
462
+ self.config = config
463
+ self.ch = config.ch
464
+ self.num_resolutions = len(config.ch_mult)
465
+ self.num_res_blocks = config.num_res_blocks
466
+ self.in_channels = config.in_channels
467
+
468
+ # downsampling
469
+ self.conv_in = nn.Conv2d(
470
+ self.in_channels,
471
+ self.ch,
472
+ kernel_size=3,
473
+ stride=1,
474
+ padding=1
475
+ )
476
+
477
+ in_ch_mult = (1,) + tuple(config.ch_mult)
478
+ self.down = nn.ModuleList()
479
+ for i_level in range(self.num_resolutions):
480
+ block = nn.ModuleList()
481
+ attn = nn.ModuleList()
482
+ block_in = config.ch * in_ch_mult[i_level]
483
+ block_out = config.ch * config.ch_mult[i_level]
484
+ for i_block in range(self.num_res_blocks):
485
+ block.append(
486
+ MoVQResnetBlock(
487
+ in_channels=block_in,
488
+ out_channels=block_out,
489
+ dropout=config.dropout,
490
+ )
491
+ )
492
+ block_in = block_out
493
+ if i_level in config.attn_resolutions:
494
+ attn.append(MoVQAttnBlock(block_in))
495
+
496
+ down = nn.Module()
497
+ down.block = block
498
+ down.attn = attn
499
+ if i_level != self.num_resolutions - 1:
500
+ if config.use_dc_up_down_blocks:
501
+ down.downsample = DCDownBlock2d(block_in)
502
+ else:
503
+ down.downsample = MoVQDownsample(block_in)
504
+
505
+ self.down.append(down)
506
+
507
+ # middle
508
+ self.mid = nn.Module()
509
+ self.mid.block_1 = MoVQResnetBlock(
510
+ in_channels=block_in,
511
+ out_channels=block_in,
512
+ dropout=config.dropout,
513
+ )
514
+ self.mid.attn_1 = MoVQAttnBlock(block_in)
515
+ self.mid.block_2 = MoVQResnetBlock(
516
+ in_channels=block_in,
517
+ out_channels=block_in,
518
+ dropout=config.dropout,
519
+ )
520
+
521
+ # end
522
+
523
+ self.norm_out = nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True)
524
+
525
+ self.act = MoVQActivation()
526
+
527
+ out_z_channels = 2 * config.z_channels if config.double_z else config.z_channels
528
+ self.conv_out = nn.Conv2d(
529
+ block_in,
530
+ out_z_channels,
531
+ kernel_size=3,
532
+ stride=1,
533
+ padding=1,
534
+ )
535
+
536
+ self.out_shortcut_average_group_size = block_in // out_z_channels
537
+
538
+ def forward(self, x: torch.Tensor):
539
+
540
+ # downsampling
541
+ h = self.conv_in(x)
542
+ for i_level in range(self.num_resolutions):
543
+ for i_block in range(self.num_res_blocks):
544
+ h = self.down[i_level].block[i_block](h)
545
+ if len(self.down[i_level].attn) > 0:
546
+ h = self.down[i_level].attn[i_block](h)
547
+
548
+ if i_level != self.num_resolutions - 1:
549
+ h = self.down[i_level].downsample(h)
550
+
551
+ h = self.mid.block_1(h)
552
+ h = self.mid.attn_1(h)
553
+ h = self.mid.block_2(h)
554
+
555
+ # end
556
+ h = self.norm_out(h)
557
+ h = self.act(h)
558
+
559
+ if self.config.use_dc_up_down_blocks:
560
+ x = h.unflatten(1, (-1, self.out_shortcut_average_group_size))
561
+ x = x.mean(dim=2)
562
+ h = self.conv_out(h) + x
563
+ else:
564
+ h = self.conv_out(h)
565
+ return h
566
+
567
+
568
+ class MoVQDecoder(nn.Module):
569
+ def __init__(self, config: MoVQConfig):
570
+ super().__init__()
571
+ self.config = config
572
+ self.ch = config.ch
573
+ self.num_resolutions = len(config.ch_mult)
574
+ self.num_res_blocks = config.num_res_blocks
575
+
576
+ in_ch_mult = (1,) + tuple(config.ch_mult)
577
+ zq_ch = config.embed_dim
578
+
579
+ block_in = config.ch * config.ch_mult[-1]
580
+
581
+ self.in_shortcut_repeats = block_in // config.embed_dim
582
+
583
+ self.conv_in = nn.Conv2d(
584
+ config.z_channels,
585
+ block_in,
586
+ kernel_size=3,
587
+ stride=1,
588
+ padding=1,
589
+ )
590
+
591
+ # middle
592
+ self.mid = nn.Module()
593
+ self.mid.block_1 = MoVQResnetBlock(
594
+ in_channels=block_in,
595
+ out_channels=block_in,
596
+ dropout=config.dropout,
597
+ zq_ch=zq_ch,
598
+ )
599
+ self.mid.attn_1 = MoVQAttnBlock(block_in, zq_ch)
600
+ self.mid.block_2 = MoVQResnetBlock(
601
+ in_channels=block_in,
602
+ out_channels=block_in,
603
+ dropout=config.dropout,
604
+ zq_ch=zq_ch,
605
+ )
606
+
607
+ # upsampling
608
+ self.up = nn.ModuleList()
609
+ for i_level in reversed(range(self.num_resolutions)):
610
+ block = nn.ModuleList()
611
+ attn = nn.ModuleList()
612
+ block_out = config.ch * config.ch_mult[i_level]
613
+ for i_block in range(self.num_res_blocks + 1):
614
+ block.append(
615
+ MoVQResnetBlock(
616
+ in_channels=block_in,
617
+ out_channels=block_out,
618
+ dropout=config.dropout,
619
+ zq_ch=zq_ch,
620
+ )
621
+ )
622
+ block_in = block_out
623
+ if i_level in config.attn_resolutions:
624
+ attn.append(MoVQAttnBlock(block_in, zq_ch))
625
+
626
+ up = nn.Module()
627
+ up.block = block
628
+ up.attn = attn
629
+ if i_level != 0:
630
+ if config.use_dc_up_down_blocks:
631
+ up.upsample = DCUpBlock2d(block_in)
632
+ else:
633
+ up.upsample = MoVQUpsample(block_in)
634
+
635
+ self.up.insert(0, up)
636
+
637
+ self.act = MoVQActivation()
638
+
639
+ self.norm_out = MoVQSpatialNorm(block_in, zq_ch)
640
+ self.conv_out = nn.Conv2d(
641
+ block_in,
642
+ config.out_channels,
643
+ kernel_size=3,
644
+ stride=1,
645
+ padding=1,
646
+ )
647
+
648
+ @property
649
+ def last_layer(self):
650
+ return self.conv_out.weight
651
+
652
+ def forward(self, z: torch.Tensor, zq: torch.Tensor):
653
+ h = z
654
+
655
+ if self.config.use_dc_up_down_blocks:
656
+ h = h.repeat_interleave(self.in_shortcut_repeats, dim=1)
657
+ h = self.conv_in(z) + h
658
+ else:
659
+ h = self.conv_in(h)
660
+
661
+ # middle
662
+ h = self.mid.block_1(h, zq)
663
+ h = self.mid.attn_1(h, zq)
664
+ h = self.mid.block_2(h, zq)
665
+
666
+ # upsampling
667
+ for i_level in reversed(range(self.num_resolutions)):
668
+ for i_block in range(self.num_res_blocks + 1):
669
+ h = self.up[i_level].block[i_block](h, zq)
670
+ if len(self.up[i_level].attn) > 0:
671
+ h = self.up[i_level].attn[i_block](h, zq)
672
+
673
+ if i_level != 0:
674
+ h = self.up[i_level].upsample(h)
675
+
676
+ h = self.norm_out(h, zq)
677
+ h = self.act(h)
678
+ h = self.conv_out(h)
679
+
680
+ return h
681
+
682
+
683
+ class Decoder(nn.Module):
684
+ def __init__(self, config: MoVQConfig):
685
+ super().__init__()
686
+ self.config = config
687
+ self.ch = config.ch
688
+ self.num_resolutions = len(config.ch_mult)
689
+ self.num_res_blocks = config.num_res_blocks
690
+
691
+ in_ch_mult = (1,) + tuple(config.ch_mult)
692
+
693
+ block_in = config.ch * config.ch_mult[-1]
694
+
695
+ self.conv_in = nn.Conv2d(
696
+ config.z_channels,
697
+ block_in,
698
+ kernel_size=3,
699
+ stride=1,
700
+ padding=1,
701
+ )
702
+
703
+ # middle
704
+ self.mid = nn.Module()
705
+ self.mid.block_1 = MoVQResnetBlock(
706
+ in_channels=block_in,
707
+ out_channels=block_in,
708
+ dropout=config.dropout,
709
+ )
710
+ self.mid.attn_1 = MoVQAttnBlock(block_in)
711
+ self.mid.block_2 = MoVQResnetBlock(
712
+ in_channels=block_in,
713
+ out_channels=block_in,
714
+ dropout=config.dropout,
715
+ )
716
+
717
+ # upsampling
718
+ self.up = nn.ModuleList()
719
+ for i_level in reversed(range(self.num_resolutions)):
720
+ block = nn.ModuleList()
721
+ attn = nn.ModuleList()
722
+ block_out = config.ch * config.ch_mult[i_level]
723
+ for i_block in range(self.num_res_blocks + 1):
724
+ block.append(
725
+ MoVQResnetBlock(
726
+ in_channels=block_in,
727
+ out_channels=block_out,
728
+ dropout=config.dropout,
729
+ )
730
+ )
731
+ block_in = block_out
732
+ if i_level in config.attn_resolutions:
733
+ attn.append(MoVQAttnBlock(block_in))
734
+
735
+ up = nn.Module()
736
+ up.block = block
737
+ up.attn = attn
738
+ if i_level != 0:
739
+ up.upsample = MoVQUpsample(block_in)
740
+
741
+ self.up.insert(0, up)
742
+
743
+ self.act = MoVQActivation()
744
+
745
+ norm_kwargs = dict(num_groups=32, eps=1e-6, affine=True)
746
+ self.norm_out = nn.GroupNorm(num_channels=block_in, **norm_kwargs)
747
+ self.conv_out = nn.Conv2d(
748
+ block_in,
749
+ config.out_channels,
750
+ kernel_size=3,
751
+ stride=1,
752
+ padding=1,
753
+ )
754
+
755
+ @property
756
+ def last_layer(self):
757
+ return self.conv_out.weight
758
+
759
+ def forward(self, z: torch.Tensor, zq: torch.Tensor):
760
+ h = z
761
+ h = self.conv_in(h)
762
+
763
+ # middle
764
+ h = self.mid.block_1(h)
765
+ h = self.mid.attn_1(h)
766
+ h = self.mid.block_2(h)
767
+
768
+ # upsampling
769
+ for i_level in reversed(range(self.num_resolutions)):
770
+ for i_block in range(self.num_res_blocks + 1):
771
+ h = self.up[i_level].block[i_block](h)
772
+ if len(self.up[i_level].attn) > 0:
773
+ h = self.up[i_level].attn[i_block](h)
774
+
775
+ if i_level != 0:
776
+ h = self.up[i_level].upsample(h)
777
+
778
+ h = self.norm_out(h)
779
+ h = self.act(h)
780
+ h = self.conv_out(h)
781
+
782
+ return h
783
+
784
+
785
+ class MoVQModel(MoVQPretrainedModel):
786
+
787
+ def __init__(self, config):
788
+ super().__init__(config)
789
+ self.config = config
790
+
791
+ self.encoder = MoVQEncoder(config)
792
+ self.decoder = MoVQDecoder(config)
793
+ self.quantize = MoVQVectorQuantizer(config)
794
+
795
+ self.quant_conv = nn.Conv2d(config.z_channels, config.embed_dim, 1)
796
+ self.post_quant_conv = nn.Conv2d(config.embed_dim, config.z_channels, 1)
797
+
798
+ self.spatial_scale_factor = 2 ** (len(config.ch_mult) - 1)
799
+
800
+ self.post_init()
801
+
802
+ def encode(self, x: torch.Tensor):
803
+ h = self.encoder(x)
804
+ h = self.quant_conv(h)
805
+ codes = self.quantize(h)
806
+ return codes
807
+
808
+ def decode(self, x: torch.Tensor):
809
+ quant = self.quantize.embedding(x.flatten())
810
+ b, h, w, c = quant.shape
811
+ quant = quant.view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
812
+ quant2 = self.post_quant_conv(quant)
813
+ image = self.decoder(quant2, quant)
814
+ image = image.reshape(
815
+ b,
816
+ self.config.out_channels,
817
+ h * self.spatial_scale_factor,
818
+ w * self.spatial_scale_factor,
819
+ )
820
+ return image
821
+
822
+ @property
823
+ def device(self):
824
+ return next(self.parameters()).device
825
+
826
+ @property
827
+ def dtype(self):
828
+ return next(self.parameters()).dtype
modeling_qwen2vit.py ADDED
@@ -0,0 +1,841 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """PyTorch Qwen2-VL model."""
21
+
22
+ import math
23
+ from dataclasses import dataclass
24
+ from typing import Any, Dict, List, Optional, Tuple, Union
25
+
26
+ from torch import Tensor
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+ import torch.utils.checkpoint
32
+
33
+ from transformers.activations import ACT2FN
34
+ from transformers.cache_utils import Cache, StaticCache
35
+ from transformers.modeling_attn_mask_utils import (
36
+ AttentionMaskConverter,
37
+ )
38
+ from transformers.modeling_outputs import (
39
+ BaseModelOutputWithPast,
40
+ ModelOutput,
41
+ )
42
+ from transformers.modeling_utils import PreTrainedModel
43
+ from transformers.utils import (
44
+ add_start_docstrings,
45
+ add_start_docstrings_to_model_forward,
46
+ is_torch_npu_available,
47
+ is_flash_attn_2_available,
48
+ is_flash_attn_greater_or_equal_2_10,
49
+ logging,
50
+ replace_return_docstrings,
51
+ )
52
+ from .configuration_qwen2vit import Qwen2VLConfig, Qwen2VLVisionConfig
53
+ from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
54
+
55
+ from einops import rearrange
56
+
57
+ logger = logging.get_logger(__name__)
58
+
59
+ _CONFIG_FOR_DOC = "Qwen2VLConfig"
60
+
61
+ try:
62
+ import xformers.ops as xops
63
+
64
+ is_xformers_available = True
65
+ except Exception as e:
66
+ is_xformers_available = False
67
+
68
+ if is_flash_attn_2_available():
69
+ from flash_attn import flash_attn_varlen_func
70
+
71
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
72
+ else:
73
+ flash_attn_varlen_func = None
74
+
75
+
76
+ def init_weights(m):
77
+ if isinstance(m, nn.Linear):
78
+ # we use xavier_uniform following official JAX ViT:
79
+ torch.nn.init.xavier_uniform_(m.weight)
80
+ if m.bias is not None:
81
+ nn.init.constant_(m.bias, 0)
82
+ elif isinstance(m, nn.nn.LayerNorm):
83
+ nn.init.constant_(m.bias, 0)
84
+ nn.init.constant_(m.weight, 1.0)
85
+ elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
86
+ w = m.weight.data
87
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
88
+
89
+
90
+ @dataclass
91
+ class Qwen2VLCausalLMOutputWithPast(ModelOutput):
92
+ """
93
+ Base class for Qwen2VL causal language model (or autoregressive) outputs.
94
+
95
+ Args:
96
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
97
+ Language modeling loss (for next-token prediction).
98
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
99
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
100
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
101
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
102
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
103
+
104
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
105
+ `past_key_values` input) to speed up sequential decoding.
106
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
107
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
108
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
109
+
110
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
111
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
112
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
113
+ sequence_length)`.
114
+
115
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
116
+ heads.
117
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
118
+ The rope index difference between sequence length and multimodal rope.
119
+ """
120
+
121
+ loss: Optional[torch.FloatTensor] = None
122
+ logits: torch.FloatTensor = None
123
+ past_key_values: Optional[List[torch.FloatTensor]] = None
124
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
125
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
126
+ rope_deltas: Optional[torch.LongTensor] = None
127
+
128
+
129
+ class Qwen2VLRotaryEmbedding(nn.Module):
130
+ def __init__(
131
+ self,
132
+ dim=None,
133
+ max_position_embeddings=2048,
134
+ base=10000,
135
+ device=None,
136
+ scaling_factor=1.0,
137
+ rope_type="default",
138
+ config: Optional[Qwen2VLConfig] = None,
139
+ ):
140
+ super().__init__()
141
+ # TODO (joao): remove the `if` below, only used for BC
142
+ self.rope_kwargs = {}
143
+ if config is None:
144
+ logger.warning_once(
145
+ "`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the "
146
+ "`config` argument. All other arguments will be removed in v4.46"
147
+ )
148
+ self.rope_kwargs = {
149
+ "rope_type": rope_type,
150
+ "factor": scaling_factor,
151
+ "dim": dim,
152
+ "base": base,
153
+ "max_position_embeddings": max_position_embeddings,
154
+ }
155
+ self.rope_type = rope_type
156
+ self.max_seq_len_cached = max_position_embeddings
157
+ self.original_max_seq_len = max_position_embeddings
158
+ else:
159
+ # BC: "rope_type" was originally "type"
160
+ if config.rope_scaling is not None:
161
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
162
+ else:
163
+ self.rope_type = "default"
164
+ self.max_seq_len_cached = config.max_position_embeddings
165
+ self.original_max_seq_len = config.max_position_embeddings
166
+
167
+ self.config = config
168
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
169
+
170
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
171
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
172
+ self.original_inv_freq = self.inv_freq
173
+
174
+ def _dynamic_frequency_update(self, position_ids, device):
175
+ """
176
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
177
+ 1 - growing beyond the cached sequence length (allow scaling)
178
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179
+ """
180
+ seq_len = torch.max(position_ids) + 1
181
+ if seq_len > self.max_seq_len_cached: # growth
182
+ inv_freq, self.attention_scaling = self.rope_init_fn(
183
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
184
+ )
185
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
186
+ self.max_seq_len_cached = seq_len
187
+
188
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
189
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
190
+ self.max_seq_len_cached = self.original_max_seq_len
191
+
192
+ @torch.no_grad()
193
+ def forward(self, x, position_ids):
194
+ if "dynamic" in self.rope_type:
195
+ self._dynamic_frequency_update(position_ids, device=x.device)
196
+
197
+ # Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for thw grids
198
+ # So we expand the inv_freq to shape (3, ...)
199
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
200
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
201
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
202
+ device_type = x.device.type
203
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
204
+ with torch.autocast(device_type=device_type, enabled=False):
205
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
206
+ emb = torch.cat((freqs, freqs), dim=-1)
207
+ cos = emb.cos()
208
+ sin = emb.sin()
209
+
210
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
211
+ cos = cos * self.attention_scaling
212
+ sin = sin * self.attention_scaling
213
+
214
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
215
+
216
+
217
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
218
+ def rotate_half(x):
219
+ """Rotates half the hidden dims of the input."""
220
+ x1 = x[..., : x.shape[-1] // 2]
221
+ x2 = x[..., x.shape[-1] // 2:]
222
+ return torch.cat((-x2, x1), dim=-1)
223
+
224
+
225
+ def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
226
+ orig_dtype = tensor.dtype
227
+ tensor = tensor.float()
228
+ cos = freqs.cos()
229
+ sin = freqs.sin()
230
+ cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
231
+ sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
232
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
233
+ output = output.to(orig_dtype)
234
+ return output
235
+
236
+
237
+ def apply_rotary_pos_emb_vision_batch(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
238
+ orig_dtype = tensor.dtype
239
+ tensor = tensor.float()
240
+ cos = freqs.cos()
241
+ sin = freqs.sin()
242
+ cos = cos.repeat(1, 1, 1, 2).float()
243
+ sin = sin.repeat(1, 1, 1, 2).float()
244
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
245
+ output = output.to(orig_dtype)
246
+ return output
247
+
248
+
249
+ class VisionRotaryEmbedding(nn.Module):
250
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
251
+ super().__init__()
252
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
253
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
254
+
255
+ def forward(self, seqlen: int, scale_factor: float = 1.0) -> torch.Tensor:
256
+ # 使用 scale_factor 动态调整 inv_freq
257
+ scaled_inv_freq = self.inv_freq * scale_factor
258
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
259
+ freqs = torch.outer(seq, scaled_inv_freq)
260
+ return freqs
261
+
262
+
263
+ class PatchEmbed(nn.Module):
264
+ def __init__(
265
+ self,
266
+ patch_size: int = 14,
267
+ temporal_patch_size: int = 2,
268
+ in_channels: int = 3,
269
+ embed_dim: int = 1152,
270
+ ) -> None:
271
+ super().__init__()
272
+ self.patch_size = patch_size
273
+ self.temporal_patch_size = temporal_patch_size
274
+ self.in_channels = in_channels
275
+ self.embed_dim = embed_dim
276
+
277
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
278
+ self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
279
+
280
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
281
+ target_dtype = self.proj.weight.dtype
282
+ if is_torch_npu_available():
283
+ # if True:
284
+ hidden_states = F.linear(hidden_states, self.proj.weight.view(self.embed_dim, -1))
285
+ else:
286
+ hidden_states = hidden_states.view(
287
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
288
+ )
289
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
290
+ return hidden_states
291
+
292
+
293
+ class PatchMerger(nn.Module):
294
+ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
295
+ super().__init__()
296
+ self.hidden_size = context_dim * (spatial_merge_size ** 2)
297
+ self.ln_q = nn.LayerNorm(context_dim, eps=1e-6)
298
+ self.mlp = nn.Sequential(
299
+ nn.Linear(self.hidden_size, self.hidden_size),
300
+ nn.GELU(),
301
+ nn.Linear(self.hidden_size, dim),
302
+ )
303
+
304
+ def forward(self, x: torch.Tensor, grid_thw) -> torch.Tensor:
305
+ x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
306
+ return x
307
+
308
+
309
+ class VisionMlp(nn.Module):
310
+ def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
311
+ super().__init__()
312
+ self.fc1 = nn.Linear(dim, hidden_dim)
313
+ self.act = ACT2FN[hidden_act]
314
+ self.fc2 = nn.Linear(hidden_dim, dim)
315
+
316
+ def forward(self, x) -> torch.Tensor:
317
+ return self.fc2(self.act(self.fc1(x)))
318
+
319
+
320
+ class VisionAttention(nn.Module):
321
+ def __init__(self, dim: int, num_heads: int = 16, ) -> None:
322
+ super().__init__()
323
+ self.num_heads = num_heads
324
+ self.head_dim = dim // num_heads
325
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
326
+ self.proj = nn.Linear(dim, dim)
327
+
328
+ def forward(
329
+ self,
330
+ hidden_states: torch.Tensor,
331
+ cu_seqlens: torch.Tensor,
332
+ rotary_pos_emb: torch.Tensor = None
333
+ ) -> torch.Tensor:
334
+ seq_length = hidden_states.shape[0]
335
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
336
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
337
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
338
+
339
+ attention_mask = torch.full(
340
+ [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
341
+ )
342
+ for i in range(1, len(cu_seqlens)):
343
+ attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = 0
344
+
345
+ q = q.transpose(0, 1)
346
+ k = k.transpose(0, 1)
347
+ v = v.transpose(0, 1)
348
+ attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
349
+ attn_weights = attn_weights + attention_mask
350
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
351
+ attn_output = torch.matmul(attn_weights, v)
352
+ attn_output = attn_output.transpose(0, 1)
353
+ attn_output = attn_output.reshape(seq_length, -1)
354
+ attn_output = self.proj(attn_output)
355
+ return attn_output
356
+
357
+
358
+ class BatchVisionAttention(nn.Module):
359
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
360
+ super().__init__()
361
+ self.num_heads = num_heads
362
+ self.head_dim = dim // num_heads
363
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
364
+ self.proj = nn.Linear(dim, dim)
365
+
366
+ def forward(
367
+ self,
368
+ hidden_states: torch.Tensor, # [batch_size, seq_len, dim]
369
+ attention_mask: torch.Tensor, # [batch_size, 1, 1, seq_len]
370
+ rotary_pos_emb: torch.Tensor = None # [batch_size, seq_len, head_dim//2]
371
+ ) -> torch.Tensor:
372
+ batch_size, seq_len, _ = hidden_states.shape
373
+
374
+ q, k, v = self.qkv(hidden_states).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim).permute(2, 0,
375
+ 3, 1,
376
+ 4).unbind(
377
+ 0)
378
+ # [batch_size, num_heads, seq_len, head_dim]
379
+
380
+ if rotary_pos_emb is not None:
381
+ rotary_pos_emb = rotary_pos_emb.unsqueeze(1) # [batch_size, 1, seq_len, head_dim//2]
382
+ q = apply_rotary_pos_emb_vision_batch(q, rotary_pos_emb)
383
+ k = apply_rotary_pos_emb_vision_batch(k, rotary_pos_emb)
384
+
385
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
386
+ if attention_mask is not None:
387
+ attn_weights = attn_weights + attention_mask
388
+
389
+ # Softmax
390
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
391
+
392
+ attn_output = torch.matmul(attn_weights, v) # [batch_size, num_heads, seq_len, head_dim]
393
+ attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, -1)
394
+ return self.proj(attn_output)
395
+
396
+
397
+ class VisionXformerAttention(nn.Module):
398
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
399
+ super().__init__()
400
+ self.num_heads = num_heads
401
+ self.head_dim = dim // num_heads
402
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
403
+ self.proj = nn.Linear(dim, dim)
404
+
405
+ def forward(
406
+ self,
407
+ hidden_states: torch.Tensor,
408
+ cu_seqlens: torch.Tensor,
409
+ rotary_pos_emb: torch.Tensor = None
410
+ ) -> torch.Tensor:
411
+ seq_length = hidden_states.shape[0]
412
+
413
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
414
+
415
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb)
416
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb)
417
+
418
+ seqlens = [cu_seqlens[0]] + [cu_seqlens[i] - cu_seqlens[i - 1] for i in range(1, len(cu_seqlens))]
419
+ attn_bias = xops.fmha.BlockDiagonalMask.from_seqlens(seqlens)
420
+
421
+ attn_output = xops.memory_efficient_attention(
422
+ q, k, v.unsqueeze(0),
423
+ attn_bias=attn_bias,
424
+ scale=1.0 / math.sqrt(self.head_dim)
425
+ )
426
+ attn_output = attn_output.reshape(seq_length, -1)
427
+ attn_output = self.proj(attn_output)
428
+ return attn_output
429
+
430
+
431
+ class BatchVisionXformerAttention(nn.Module):
432
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
433
+ super().__init__()
434
+ self.num_heads = num_heads
435
+ self.head_dim = dim // num_heads
436
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
437
+ self.proj = nn.Linear(dim, dim)
438
+
439
+ def forward(
440
+ self,
441
+ hidden_states: torch.Tensor,
442
+ attention_mask: torch.Tensor, # [batch_size, 1, 1, seq_len]
443
+ rotary_pos_emb: torch.Tensor = None
444
+ ) -> torch.Tensor:
445
+ seq_length = hidden_states.shape[0]
446
+ batch_size, seq_len = hidden_states.shape
447
+
448
+ q, k, v = self.qkv(hidden_states).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim).permute(2, 0,
449
+ 3, 1,
450
+ 4).unbind(
451
+ 0)
452
+ # [batch_size, num_heads, seq_len, head_dim]
453
+
454
+ if rotary_pos_emb is not None:
455
+ rotary_pos_emb = rotary_pos_emb.unsqueeze(1) # [batch_size, 1, seq_len, head_dim//2]
456
+ q = apply_rotary_pos_emb_vision_batch(q, rotary_pos_emb)
457
+ k = apply_rotary_pos_emb_vision_batch(k, rotary_pos_emb)
458
+
459
+ attn_output = xops.memory_efficient_attention(
460
+ q, k, v,
461
+ attn_bias=attention_mask,
462
+ scale=1.0 / math.sqrt(self.head_dim)
463
+ )
464
+ attn_output = attn_output.reshape(batch_size, seq_len, -1)
465
+ return self.proj(attn_output)
466
+
467
+
468
+ class VisionFlashAttention2(nn.Module):
469
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
470
+ super().__init__()
471
+ self.num_heads = num_heads
472
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
473
+ self.proj = nn.Linear(dim, dim)
474
+
475
+ def forward(
476
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
477
+ ) -> torch.Tensor:
478
+ seq_length = hidden_states.shape[0]
479
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
480
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
481
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
482
+
483
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
484
+ attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
485
+ seq_length, -1
486
+ )
487
+ attn_output = self.proj(attn_output)
488
+ return attn_output
489
+
490
+
491
+ class BatchVisionFlashAttention2(nn.Module):
492
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
493
+ super().__init__()
494
+ self.num_heads = num_heads
495
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
496
+ self.proj = nn.Linear(dim, dim)
497
+
498
+ def forward(
499
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
500
+ ) -> torch.Tensor:
501
+ batch_size, seq_len, _ = hidden_states.shape
502
+
503
+ q, k, v = self.qkv(hidden_states).reshape(batch_size, seq_len, 3, self.num_heads, -1).permute(2, 0, 3, 1,
504
+ 4).unbind(0)
505
+
506
+ if rotary_pos_emb is not None:
507
+ rotary_pos_emb = rotary_pos_emb.unsqueeze(1) # [batch_size, 1, seq_len, head_dim//2]
508
+ q = apply_rotary_pos_emb_vision_batch(q, rotary_pos_emb)
509
+ k = apply_rotary_pos_emb_vision_batch(k, rotary_pos_emb)
510
+
511
+ q = rearrange(q, 'b h l d -> b l h d')
512
+ k = rearrange(k, 'b h l d -> b l h d')
513
+ v = rearrange(v, 'b h l d -> b l h d')
514
+
515
+ attn_output = _flash_attention_forward(q, k, v).reshape(batch_size, seq_len, -1)
516
+ attn_output = self.proj(attn_output)
517
+ return attn_output
518
+
519
+
520
+ class VisionSdpaAttention(nn.Module):
521
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
522
+ super().__init__()
523
+ self.num_heads = num_heads
524
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
525
+ self.proj = nn.Linear(dim, dim)
526
+
527
+ def forward(
528
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
529
+ ) -> torch.Tensor:
530
+ seq_length = hidden_states.shape[0]
531
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
532
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
533
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
534
+
535
+ attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
536
+ for i in range(1, len(cu_seqlens)):
537
+ attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = True
538
+ q = q.transpose(0, 1)
539
+ k = k.transpose(0, 1)
540
+ v = v.transpose(0, 1)
541
+ attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
542
+ attn_output = attn_output.transpose(0, 1)
543
+ attn_output = attn_output.reshape(seq_length, -1)
544
+ attn_output = self.proj(attn_output)
545
+ return attn_output
546
+
547
+
548
+ class BatchVisionSdpaAttention(nn.Module):
549
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
550
+ super().__init__()
551
+ self.num_heads = num_heads
552
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
553
+ self.proj = nn.Linear(dim, dim)
554
+
555
+ def forward(
556
+ self,
557
+ hidden_states: torch.Tensor, # [batch_size, seq_len, dim]
558
+ attention_mask: torch.Tensor = None, # [batch_size, 1, 1, seq_len]
559
+ rotary_pos_emb: torch.Tensor = None # [batch_size, seq_len, head_dim//2]
560
+ ) -> torch.Tensor:
561
+ batch_size, seq_len, _ = hidden_states.shape
562
+ q, k, v = self.qkv(hidden_states).reshape(batch_size, seq_len, 3, self.num_heads, -1).permute(2, 0, 3, 1,
563
+ 4).unbind(0)
564
+ # [batch_size, num_heads, seq_len, head_dim]
565
+
566
+ if rotary_pos_emb is not None:
567
+ rotary_pos_emb = rotary_pos_emb.unsqueeze(1) # [batch_size, 1, seq_len, head_dim//2]
568
+ q = apply_rotary_pos_emb_vision_batch(q, rotary_pos_emb)
569
+ k = apply_rotary_pos_emb_vision_batch(k, rotary_pos_emb)
570
+
571
+ attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
572
+ attn_output = attn_output.transpose(1, 2)
573
+ attn_output = attn_output.reshape(batch_size, seq_len, -1)
574
+ attn_output = self.proj(attn_output)
575
+ return attn_output
576
+
577
+
578
+ QWEN2_VL_VISION_ATTENTION_CLASSES = {
579
+ "eager": VisionAttention,
580
+ "flash_attention_2": VisionFlashAttention2,
581
+ "sdpa": VisionSdpaAttention,
582
+ "xformers": VisionXformerAttention,
583
+ }
584
+
585
+ QWEN2_VL_VISION_BATCH_ATTENTION_CLASSES = {
586
+ "eager": BatchVisionAttention,
587
+ "flash_attention_2": VisionFlashAttention2,
588
+ "sdpa": BatchVisionSdpaAttention,
589
+ }
590
+
591
+
592
+ class Qwen2VLVisionBlock(nn.Module):
593
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
594
+ super().__init__()
595
+
596
+ self.norm1 = nn.LayerNorm(config.embed_dim, eps=1e-6)
597
+ self.norm2 = nn.LayerNorm(config.embed_dim, eps=1e-6)
598
+ mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
599
+
600
+ self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation](
601
+ config.embed_dim, num_heads=config.num_heads,
602
+ )
603
+ self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
604
+
605
+ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb, grid_thw) -> torch.Tensor:
606
+ hidden_states = hidden_states + self.attn(
607
+ self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
608
+ )
609
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
610
+ return hidden_states
611
+
612
+
613
+ class Qwen2VLBatchVisionBlock(nn.Module):
614
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
615
+ super().__init__()
616
+ self.norm1 = nn.LayerNorm(config.embed_dim, eps=1e-6)
617
+ self.norm2 = nn.LayerNorm(config.embed_dim, eps=1e-6)
618
+ mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
619
+
620
+ self.attn = QWEN2_VL_VISION_BATCH_ATTENTION_CLASSES[attn_implementation](
621
+ config.embed_dim, num_heads=config.num_heads,
622
+ )
623
+ self.mlp = VisionMlp(config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
624
+
625
+ def forward(
626
+ self,
627
+ hidden_states: torch.Tensor, # [batch_size, seq_len, dim]
628
+ attention_mask: torch.Tensor = None, # [batch_size, 1, 1, seq_len]
629
+ rotary_pos_emb: torch.Tensor = None # [batch_size, seq_len, head_dim//2]
630
+ ) -> torch.Tensor:
631
+ # Attention
632
+ hidden_states = hidden_states + self.attn(
633
+ self.norm1(hidden_states), attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb
634
+ )
635
+ # MLP
636
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
637
+ return hidden_states
638
+
639
+
640
+ # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
641
+ def _prepare_4d_causal_attention_mask_with_cache_position(
642
+ attention_mask: torch.Tensor,
643
+ sequence_length: int,
644
+ target_length: int,
645
+ dtype: torch.dtype,
646
+ device: torch.device,
647
+ min_dtype: float,
648
+ cache_position: torch.Tensor,
649
+ batch_size: int,
650
+ ):
651
+ """
652
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
653
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
654
+
655
+ Args:
656
+ attention_mask (`torch.Tensor`):
657
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
658
+ sequence_length (`int`):
659
+ The sequence length being processed.
660
+ target_length (`int`):
661
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
662
+ dtype (`torch.dtype`):
663
+ The dtype to use for the 4D attention mask.
664
+ device (`torch.device`):
665
+ The device to plcae the 4D attention mask on.
666
+ min_dtype (`float`):
667
+ The minimum value representable with the dtype `dtype`.
668
+ cache_position (`torch.Tensor`):
669
+ Indices depicting the position of the input sequence tokens in the sequence.
670
+ batch_size (`torch.Tensor`):
671
+ Batch size.
672
+ """
673
+ if attention_mask is not None and attention_mask.dim() == 4:
674
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
675
+ causal_mask = attention_mask
676
+ else:
677
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
678
+ if sequence_length != 1:
679
+ causal_mask = torch.triu(causal_mask, diagonal=1)
680
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
681
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
682
+ if attention_mask is not None:
683
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
684
+ mask_length = attention_mask.shape[-1]
685
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
686
+ padding_mask = padding_mask == 0
687
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
688
+ padding_mask, min_dtype
689
+ )
690
+
691
+ return causal_mask
692
+
693
+
694
+ # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm
695
+ class Qwen2RMSNorm(nn.Module):
696
+ def __init__(self, hidden_size, eps=1e-6):
697
+ """
698
+ Qwen2RMSNorm is equivalent to T5nn.LayerNorm
699
+ """
700
+ super().__init__()
701
+ self.weight = nn.Parameter(torch.ones(hidden_size))
702
+ self.variance_epsilon = eps
703
+
704
+ def forward(self, hidden_states):
705
+ input_dtype = hidden_states.dtype
706
+ hidden_states = hidden_states.to(torch.float32)
707
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
708
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
709
+ return self.weight * hidden_states.to(input_dtype)
710
+
711
+ def extra_repr(self):
712
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
713
+
714
+
715
+ # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2MLP
716
+ class Qwen2MLP(nn.Module):
717
+ def __init__(self, config):
718
+ super().__init__()
719
+ self.hidden_size = config.hidden_size
720
+ self.intermediate_size = config.intermediate_size
721
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
722
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
723
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
724
+ self.act_fn = ACT2FN[config.hidden_act]
725
+
726
+ def forward(self, hidden_state):
727
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
728
+
729
+
730
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
731
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
732
+ """
733
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
734
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
735
+ """
736
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
737
+ if n_rep == 1:
738
+ return hidden_states
739
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
740
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
741
+
742
+
743
+ class Qwen2VLPreTrainedModel(PreTrainedModel):
744
+ config_class = Qwen2VLConfig
745
+ base_model_prefix = "model"
746
+ supports_gradient_checkpointing = True
747
+ _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"]
748
+ _skip_keys_device_placement = "past_key_values"
749
+ _supports_flash_attn_2 = True
750
+ _supports_sdpa = True
751
+ _supports_cache_class = True
752
+ _supports_static_cache = True
753
+
754
+ def _init_weights(self, module):
755
+ std = self.config.initializer_range
756
+ if isinstance(module, (nn.Linear, nn.Conv3d)):
757
+ module.weight.data.normal_(mean=0.0, std=std)
758
+ if module.bias is not None:
759
+ module.bias.data.zero_()
760
+ elif isinstance(module, nn.Embedding):
761
+ module.weight.data.normal_(mean=0.0, std=std)
762
+ if module.padding_idx is not None:
763
+ module.weight.data[module.padding_idx].zero_()
764
+
765
+
766
+ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
767
+ config_class = Qwen2VLVisionConfig
768
+ _no_split_modules = ["Qwen2VLVisionBlock"]
769
+
770
+ def __init__(self, config) -> None:
771
+ super().__init__(config)
772
+ self.spatial_merge_size = config.spatial_merge_size
773
+
774
+ self.patch_embed = PatchEmbed(
775
+ patch_size=config.patch_size,
776
+ temporal_patch_size=config.temporal_patch_size,
777
+ in_channels=config.in_channels,
778
+ embed_dim=config.embed_dim,
779
+ )
780
+
781
+ head_dim = config.embed_dim // config.num_heads
782
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
783
+
784
+ self.blocks = nn.ModuleList(
785
+ [Qwen2VLVisionBlock(config, config.attn_implementation) for _ in range(config.depth)]
786
+ )
787
+ self.merger = PatchMerger(
788
+ dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size
789
+ )
790
+
791
+ def get_dtype(self) -> torch.dtype:
792
+ return self.blocks[0].mlp.fc2.weight.dtype
793
+
794
+ def get_device(self) -> torch.device:
795
+ return self.blocks[0].mlp.fc2.weight.device
796
+
797
+ def rot_pos_emb(self, grid_thw):
798
+ pos_ids = []
799
+ for t, h, w in grid_thw:
800
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
801
+ hpos_ids = hpos_ids.reshape(
802
+ h // self.spatial_merge_size,
803
+ self.spatial_merge_size,
804
+ w // self.spatial_merge_size,
805
+ self.spatial_merge_size,
806
+ )
807
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
808
+ hpos_ids = hpos_ids.flatten()
809
+
810
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
811
+ wpos_ids = wpos_ids.reshape(
812
+ h // self.spatial_merge_size,
813
+ self.spatial_merge_size,
814
+ w // self.spatial_merge_size,
815
+ self.spatial_merge_size,
816
+ )
817
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
818
+ wpos_ids = wpos_ids.flatten()
819
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
820
+ pos_ids = torch.cat(pos_ids, dim=0)
821
+ max_grid_size = grid_thw[:, 1:].max()
822
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
823
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
824
+ return rotary_pos_emb
825
+
826
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
827
+ hidden_states = self.patch_embed(hidden_states)
828
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
829
+
830
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
831
+ dim=0, dtype=torch.int32
832
+ )
833
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
834
+
835
+ for blk in self.blocks:
836
+ hidden_states = blk(hidden_states,
837
+ cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb,
838
+ grid_thw=grid_thw)
839
+
840
+ hidden_states = self.merger(hidden_states, grid_thw)
841
+ return hidden_states
modeling_rope_utils.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Optional, Tuple
17
+
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import is_torch_available, logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ if is_torch_available():
27
+ import torch
28
+
29
+
30
+ def _compute_default_rope_parameters(
31
+ config: Optional[PretrainedConfig] = None,
32
+ device: Optional["torch.device"] = None,
33
+ seq_len: Optional[int] = None,
34
+ **rope_kwargs,
35
+ ) -> Tuple["torch.Tensor", float]:
36
+ """
37
+ Computes the inverse frequencies according to the original RoPE implementation
38
+ Args:
39
+ config ([`~transformers.PretrainedConfig`]):
40
+ The model configuration.
41
+ device (`torch.device`):
42
+ The device to use for initialization of the inverse frequencies.
43
+ seq_len (`int`, *optional*):
44
+ The current sequence length. Unused for this type of RoPE.
45
+ rope_kwargs (`Dict`, *optional*):
46
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
47
+ Returns:
48
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
49
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
50
+ """
51
+ if config is not None and len(rope_kwargs) > 0:
52
+ raise ValueError(
53
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
54
+ f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
55
+ )
56
+ if len(rope_kwargs) > 0:
57
+ base = rope_kwargs["base"]
58
+ dim = rope_kwargs["dim"]
59
+ elif config is not None:
60
+ base = config.rope_theta
61
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
62
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
63
+ dim = int(head_dim * partial_rotary_factor)
64
+
65
+ attention_factor = 1.0 # Unused in this type of RoPE
66
+
67
+ # Compute the inverse frequencies
68
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
69
+ return inv_freq, attention_factor
70
+
71
+
72
+ def _compute_linear_scaling_rope_parameters(
73
+ config: Optional[PretrainedConfig] = None,
74
+ device: Optional["torch.device"] = None,
75
+ seq_len: Optional[int] = None,
76
+ **rope_kwargs,
77
+ ) -> Tuple["torch.Tensor", float]:
78
+ """
79
+ Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
80
+ Args:
81
+ config ([`~transformers.PretrainedConfig`]):
82
+ The model configuration.
83
+ device (`torch.device`):
84
+ The device to use for initialization of the inverse frequencies.
85
+ seq_len (`int`, *optional*):
86
+ The current sequence length. Unused for this type of RoPE.
87
+ rope_kwargs (`Dict`, *optional*):
88
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
89
+ Returns:
90
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
91
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
92
+ """
93
+ if config is not None and len(rope_kwargs) > 0:
94
+ raise ValueError(
95
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
96
+ f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
97
+ )
98
+ if len(rope_kwargs) > 0:
99
+ factor = rope_kwargs["factor"]
100
+ elif config is not None:
101
+ factor = config.rope_scaling["factor"]
102
+
103
+ # Gets the default RoPE parameters
104
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
105
+
106
+ # Then applies linear scaling to the frequencies.
107
+ # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
108
+ # applying scaling to the inverse frequencies is equivalent.
109
+ inv_freq /= factor
110
+ return inv_freq, attention_factor
111
+
112
+
113
+ def _compute_dynamic_ntk_parameters(
114
+ config: Optional[PretrainedConfig] = None,
115
+ device: Optional["torch.device"] = None,
116
+ seq_len: Optional[int] = None,
117
+ **rope_kwargs,
118
+ ) -> Tuple["torch.Tensor", float]:
119
+ """
120
+ Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
121
+ Args:
122
+ config ([`~transformers.PretrainedConfig`]):
123
+ The model configuration.
124
+ device (`torch.device`):
125
+ The device to use for initialization of the inverse frequencies.
126
+ seq_len (`int`, *optional*):
127
+ The current sequence length, used to update the dynamic RoPE at inference time.
128
+ rope_kwargs (`Dict`, *optional*):
129
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
130
+ Returns:
131
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
132
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
133
+ """
134
+ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
135
+ if config is not None and len(rope_kwargs) > 0:
136
+ raise ValueError(
137
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
138
+ f"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
139
+ )
140
+ if len(rope_kwargs) > 0:
141
+ base = rope_kwargs["base"]
142
+ dim = rope_kwargs["dim"]
143
+ max_position_embeddings = rope_kwargs["max_position_embeddings"]
144
+ factor = rope_kwargs["factor"]
145
+ elif config is not None:
146
+ base = config.rope_theta
147
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
148
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
149
+ dim = int(head_dim * partial_rotary_factor)
150
+ max_position_embeddings = config.max_position_embeddings
151
+ factor = config.rope_scaling["factor"]
152
+
153
+ attention_factor = 1.0 # Unused in this type of RoPE
154
+
155
+ # seq_len: default to max_position_embeddings, e.g. at init time
156
+ seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings
157
+
158
+ # Compute the inverse frequencies
159
+ base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
160
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
161
+ return inv_freq, attention_factor
162
+
163
+
164
+ def _compute_yarn_parameters(
165
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
166
+ ) -> Tuple["torch.Tensor", float]:
167
+ """
168
+ Computes the inverse frequencies with NTK scaling. Please refer to the
169
+ [original paper](https://arxiv.org/abs/2309.00071)
170
+ Args:
171
+ config ([`~transformers.PretrainedConfig`]):
172
+ The model configuration.
173
+ device (`torch.device`):
174
+ The device to use for initialization of the inverse frequencies.
175
+ seq_len (`int`, *optional*):
176
+ The current sequence length. Unused for this type of RoPE.
177
+ rope_kwargs (`Dict`, *optional*):
178
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
179
+ Returns:
180
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
181
+ post-processing scaling factor applied to the computed cos/sin.
182
+ """
183
+ # No need to keep BC with yarn, unreleased when this new pattern was created.
184
+ if len(rope_kwargs) > 0:
185
+ raise ValueError(
186
+ f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}"
187
+ )
188
+
189
+ base = config.rope_theta
190
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
191
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
192
+ dim = int(head_dim * partial_rotary_factor)
193
+ max_position_embeddings = config.max_position_embeddings
194
+ factor = config.rope_scaling["factor"]
195
+
196
+ # Sets the attention factor as suggested in the paper
197
+ attention_factor = config.rope_scaling.get("attention_factor")
198
+ if attention_factor is None:
199
+ attention_factor = 0.1 * math.log(factor) + 1.0
200
+
201
+ # Optional config options
202
+ # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
203
+ beta_fast = config.rope_scaling.get("beta_fast") or 32
204
+ beta_slow = config.rope_scaling.get("beta_slow") or 1
205
+
206
+ # Compute the inverse frequencies
207
+ def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
208
+ """Inverse dimension formula to find the dimension based on the number of rotations"""
209
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
210
+
211
+ def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
212
+ """Find dimension range bounds based on rotations"""
213
+ low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
214
+ high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
215
+ return max(low, 0), min(high, dim - 1)
216
+
217
+ def linear_ramp_factor(min, max, dim):
218
+ if min == max:
219
+ max += 0.001 # Prevent singularity
220
+
221
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
222
+ ramp_func = torch.clamp(linear_func, 0, 1)
223
+ return ramp_func
224
+
225
+ # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
226
+ # to expand the possible context length. In other words, interpolation = apply scaling factor.
227
+ pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
228
+ inv_freq_extrapolation = 1.0 / pos_freqs
229
+ inv_freq_interpolation = 1.0 / (factor * pos_freqs)
230
+
231
+ low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
232
+
233
+ # Get n-dimensional rotational scaling corrected for extrapolation
234
+ inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
235
+ inv_freq = (
236
+ inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
237
+ + inv_freq_extrapolation * inv_freq_extrapolation_factor
238
+ )
239
+
240
+ return inv_freq, attention_factor
241
+
242
+
243
+ def _compute_longrope_parameters(
244
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
245
+ ) -> Tuple["torch.Tensor", float]:
246
+ """
247
+ Computes the inverse frequencies with LongRoPE scaling. Please refer to the
248
+ [original implementation](https://github.com/microsoft/LongRoPE)
249
+ Args:
250
+ config ([`~transformers.PretrainedConfig`]):
251
+ The model configuration.
252
+ device (`torch.device`):
253
+ The device to use for initialization of the inverse frequencies.
254
+ seq_len (`int`, *optional*):
255
+ The current sequence length. Unused for this type of RoPE.
256
+ rope_kwargs (`Dict`, *optional*):
257
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
258
+ Returns:
259
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
260
+ post-processing scaling factor applied to the computed cos/sin.
261
+ """
262
+ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
263
+ # No need to keep BC with longrope, unreleased when this new pattern was created.
264
+ if len(rope_kwargs) > 0:
265
+ raise ValueError(
266
+ "Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got "
267
+ f"{rope_kwargs}"
268
+ )
269
+
270
+ base = config.rope_theta
271
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
272
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
273
+ dim = int(head_dim * partial_rotary_factor)
274
+ long_factor = config.rope_scaling["long_factor"]
275
+ short_factor = config.rope_scaling["short_factor"]
276
+ factor = config.rope_scaling.get("factor")
277
+ attention_factor = config.rope_scaling.get("attention_factor")
278
+
279
+ # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
280
+ # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
281
+ # values to compute the default attention scaling factor, instead of using `factor`.
282
+ if hasattr(config, "original_max_position_embeddings"):
283
+ max_position_embeddings = config.original_max_position_embeddings
284
+ expanded_max_position_embeddings = config.max_position_embeddings
285
+ factor = expanded_max_position_embeddings / max_position_embeddings
286
+ else:
287
+ max_position_embeddings = config.max_position_embeddings
288
+ expanded_max_position_embeddings = max_position_embeddings * factor
289
+
290
+ # Sets the attention factor as suggested in the paper
291
+ if attention_factor is None:
292
+ if factor <= 1.0:
293
+ attention_factor = 1.0
294
+ else:
295
+ attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
296
+
297
+ # Compute the inverse frequencies -- scaled based on the target sequence length
298
+ if expanded_max_position_embeddings > max_position_embeddings:
299
+ ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
300
+ else:
301
+ ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
302
+ inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim
303
+ inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
304
+
305
+ return inv_freq, attention_factor
306
+
307
+
308
+ def _compute_llama3_parameters(
309
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
310
+ ) -> Tuple["torch.Tensor", float]:
311
+ """
312
+ Computes the inverse frequencies for llama 3.1.
313
+
314
+ Args:
315
+ config ([`~transformers.PretrainedConfig`]):
316
+ The model configuration.
317
+ device (`torch.device`):
318
+ The device to use for initialization of the inverse frequencies.
319
+ seq_len (`int`, *optional*):
320
+ The current sequence length. Unused for this type of RoPE.
321
+ rope_kwargs (`Dict`, *optional*):
322
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
323
+ Returns:
324
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
325
+ post-processing scaling factor applied to the computed cos/sin.
326
+ """
327
+ # Gets the default RoPE parameters
328
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
329
+
330
+ factor = config.rope_scaling["factor"] # `8` in the original implementation
331
+ low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
332
+ high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
333
+ old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
334
+
335
+ low_freq_wavelen = old_context_len / low_freq_factor
336
+ high_freq_wavelen = old_context_len / high_freq_factor
337
+
338
+ wavelen = 2 * math.pi / inv_freq
339
+ # wavelen < high_freq_wavelen: do nothing
340
+ # wavelen > low_freq_wavelen: divide by factor
341
+ inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
342
+ # otherwise: interpolate between the two, using a smooth factor
343
+ smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
344
+ smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
345
+ is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
346
+ inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
347
+
348
+ return inv_freq_llama, attention_factor
349
+
350
+
351
+ # This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
352
+ # from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
353
+ # parameterizations, as long as the callable has the same signature.
354
+ ROPE_INIT_FUNCTIONS = {
355
+ "default": _compute_default_rope_parameters,
356
+ "linear": _compute_linear_scaling_rope_parameters,
357
+ "dynamic": _compute_dynamic_ntk_parameters,
358
+ "yarn": _compute_yarn_parameters,
359
+ "longrope": _compute_longrope_parameters,
360
+ "llama3": _compute_llama3_parameters,
361
+ }
362
+
363
+
364
+ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None):
365
+ """Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
366
+ # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
367
+ if "type" in received_keys:
368
+ received_keys -= {"type"}
369
+ required_keys.add("rope_type")
370
+
371
+ missing_keys = required_keys - received_keys
372
+ if missing_keys:
373
+ raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}")
374
+
375
+ if optional_keys is not None:
376
+ unused_keys = received_keys - required_keys - optional_keys
377
+ else:
378
+ unused_keys = received_keys - required_keys
379
+ if unused_keys:
380
+ logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
381
+
382
+
383
+ def _validate_default_rope_parameters(config: PretrainedConfig):
384
+ rope_scaling = config.rope_scaling
385
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
386
+ required_keys = {"rope_type"}
387
+ received_keys = set(rope_scaling.keys())
388
+ _check_received_keys(rope_type, received_keys, required_keys)
389
+
390
+
391
+ def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
392
+ rope_scaling = config.rope_scaling
393
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
394
+ required_keys = {"rope_type", "factor"}
395
+ received_keys = set(rope_scaling.keys())
396
+ _check_received_keys(rope_type, received_keys, required_keys)
397
+
398
+ factor = rope_scaling["factor"]
399
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
400
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
401
+
402
+
403
+ def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
404
+ rope_scaling = config.rope_scaling
405
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
406
+ required_keys = {"rope_type", "factor"}
407
+ # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
408
+ optional_keys = {"original_max_position_embeddings"}
409
+ received_keys = set(rope_scaling.keys())
410
+ _check_received_keys(rope_type, received_keys, required_keys, optional_keys)
411
+
412
+ factor = rope_scaling["factor"]
413
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
414
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
415
+
416
+
417
+ def _validate_yarn_parameters(config: PretrainedConfig):
418
+ rope_scaling = config.rope_scaling
419
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
420
+ required_keys = {"rope_type", "factor"}
421
+ optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
422
+ received_keys = set(rope_scaling.keys())
423
+ _check_received_keys(rope_type, received_keys, required_keys, optional_keys)
424
+
425
+ factor = rope_scaling["factor"]
426
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
427
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
428
+
429
+ attention_factor = rope_scaling.get("attention_factor")
430
+ if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0):
431
+ logger.warning(
432
+ f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
433
+ )
434
+ beta_fast = rope_scaling.get("beta_fast")
435
+ if beta_fast is not None and not isinstance(beta_fast, float):
436
+ logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
437
+ beta_slow = rope_scaling.get("beta_slow")
438
+ if beta_slow is not None and not isinstance(beta_slow, float):
439
+ logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
440
+
441
+ if (beta_fast or 32) < (beta_slow or 1):
442
+ logger.warning(
443
+ f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
444
+ f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
445
+ )
446
+
447
+
448
+ def _validate_longrope_parameters(config: PretrainedConfig):
449
+ rope_scaling = config.rope_scaling
450
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
451
+ required_keys = {"rope_type", "short_factor", "long_factor"}
452
+ # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
453
+ optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
454
+ received_keys = set(rope_scaling.keys())
455
+ _check_received_keys(rope_type, received_keys, required_keys, optional_keys)
456
+
457
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
458
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
459
+ dim = int(head_dim * partial_rotary_factor)
460
+
461
+ short_factor = rope_scaling.get("short_factor")
462
+ if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
463
+ logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
464
+ if not len(short_factor) == dim // 2:
465
+ logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")
466
+
467
+ long_factor = rope_scaling.get("long_factor")
468
+ if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor):
469
+ logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
470
+ if not len(long_factor) == dim // 2:
471
+ logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")
472
+
473
+ # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
474
+ # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is
475
+ # unique to longrope (= undesirable)
476
+ if hasattr(config, "original_max_position_embeddings"):
477
+ logger.warning_once(
478
+ "This model has set a `original_max_position_embeddings` field, to be used together with "
479
+ "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`"
480
+ "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, "
481
+ "as it is compatible with most model architectures."
482
+ )
483
+ else:
484
+ factor = rope_scaling.get("factor")
485
+ if factor is None:
486
+ logger.warning("Missing required keys in `rope_scaling`: 'factor'")
487
+ elif not isinstance(factor, float) or factor < 1.0:
488
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
489
+
490
+ attention_factor = rope_scaling.get("attention_factor")
491
+ if attention_factor is not None:
492
+ if not isinstance(attention_factor, float) or attention_factor < 0.0:
493
+ logger.warning(
494
+ f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
495
+ )
496
+
497
+
498
+ def _validate_llama3_parameters(config: PretrainedConfig):
499
+ rope_scaling = config.rope_scaling
500
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
501
+ required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
502
+ received_keys = set(rope_scaling.keys())
503
+ _check_received_keys(rope_type, received_keys, required_keys)
504
+
505
+ factor = rope_scaling["factor"]
506
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
507
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
508
+
509
+ low_freq_factor = rope_scaling["low_freq_factor"]
510
+ high_freq_factor = rope_scaling["high_freq_factor"]
511
+ if low_freq_factor is None or not isinstance(low_freq_factor, float):
512
+ logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
513
+ if high_freq_factor is None or not isinstance(high_freq_factor, float):
514
+ logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
515
+ if high_freq_factor <= low_freq_factor:
516
+ logger.warning(
517
+ "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
518
+ f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
519
+ )
520
+
521
+ original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
522
+ if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
523
+ logger.warning(
524
+ "`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
525
+ f"{original_max_position_embeddings}"
526
+ )
527
+ if original_max_position_embeddings >= config.max_position_embeddings:
528
+ logger.warning(
529
+ "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
530
+ f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
531
+ )
532
+
533
+
534
+ # Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
535
+ ROPE_VALIDATION_FUNCTIONS = {
536
+ "default": _validate_default_rope_parameters,
537
+ "linear": _validate_linear_scaling_rope_parameters,
538
+ "dynamic": _validate_dynamic_scaling_rope_parameters,
539
+ "yarn": _validate_yarn_parameters,
540
+ "longrope": _validate_longrope_parameters,
541
+ "llama3": _validate_llama3_parameters,
542
+ }
543
+
544
+
545
+ def rope_config_validation(config: PretrainedConfig):
546
+ """
547
+ Validate the RoPE config arguments, given a `PretrainedConfig` object
548
+ """
549
+ rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig`
550
+ if rope_scaling is None:
551
+ return
552
+
553
+ # BC: "rope_type" was originally "type"
554
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
555
+ validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
556
+ if validation_fn is not None:
557
+ validation_fn(config)
558
+ else:
559
+ logger.warning(
560
+ f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
561
+ )
processing_illume.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processor class for ILLUME_plus with dualvitok and dualvitok-sdxl-decoder.
3
+ """
4
+
5
+ import json
6
+ from typing import List, Union
7
+
8
+ from transformers import AutoProcessor, AutoImageProcessor
9
+
10
+ try:
11
+ from typing import Unpack
12
+ except ImportError:
13
+ from typing_extensions import Unpack
14
+
15
+ from transformers.feature_extraction_utils import BatchFeature
16
+ from .image_utils import ImageInput
17
+ from transformers.processing_utils import (
18
+ ProcessingKwargs,
19
+ ProcessorMixin,
20
+ )
21
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
22
+ from transformers.utils import logging
23
+
24
+ from PIL import Image
25
+ import re # Added for parsing image tokens
26
+ from typing import List, Tuple
27
+ import torch
28
+
29
+ from .configuration_illume import ILLUMEConfig
30
+ from .image_processing_movqgan import MoVQImageProcessor
31
+ from .image_processing_dualvitok import DualViTokImageProcessor
32
+ from .aspect_ratio_utils import AspectRatioCrop, RATIOS, unpad_and_resize_back
33
+ from .inference_utils import parse_interleaved_text_image, calculate_image_token_num
34
+ from .sdxl_decoder_pipe import StableDiffusionXLDecoderPipeline
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+
39
+ class ILLUMEProcessorKwargs(ProcessingKwargs, total=False):
40
+ _defaults = {
41
+ "text_kwargs": {
42
+ "padding": False,
43
+ },
44
+ }
45
+
46
+
47
+ class ILLUMEProcessor(ProcessorMixin):
48
+ r"""
49
+ Constructs a Qwen2-VL processor which wraps a Qwen2-VL image processor and a Qwen2 tokenizer into a single processor.
50
+ [`ILLUMEProcessor`] offers all the functionalities of [`ILLUMEImageProcessor`] and [`Qwen2TokenizerFast`]. See the
51
+ [`~ILLUMEProcessor.__call__`] and [`~ILLUMEProcessor.decode`] for more information.
52
+ Args:
53
+ image_processor ([`IllumeImageProcessor`], *optional*):
54
+ The image processor is a required input.
55
+ tokenizer ([`Qwen2TokenizerFast`], *optional*):
56
+ The tokenizer is a required input.
57
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
58
+ in a chat into a tokenizable string.
59
+ """
60
+
61
+ attributes = ["image_processor", "tokenizer"]
62
+ valid_kwargs = ["chat_template"]
63
+ image_processor_class = "AutoImageProcessor"
64
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
65
+ _re_placeholder = re.compile(r"<image_out>|<image>")
66
+
67
+ def __init__(self, image_processor=None, tokenizer=None, chat_template=None,
68
+ crop_percent_thresh=0.2, anyres_indicator_base=64, **kwargs):
69
+ super().__init__(image_processor=image_processor, tokenizer=tokenizer, chat_template=chat_template)
70
+ self.vision_tokenizer = None
71
+ self.diffusion_vision_detokenizer = None
72
+ self.crop_percent_thresh = crop_percent_thresh
73
+ self.anyres_indicator_base = anyres_indicator_base
74
+
75
+ def set_vision_tokenizer(self, tokenizer):
76
+ if self.vision_tokenizer and tokenizer:
77
+ logger.info('You are resetting vision tokenizer!')
78
+ return
79
+ self.vision_tokenizer = tokenizer
80
+ logger.info('Setting vision tokenizer!')
81
+
82
+ def load_diffusion_vision_detokenizer(self, diffusion_decoder,
83
+ torch_dtype=torch.float16,
84
+ add_watermarker=False,
85
+ device='cuda',
86
+ ):
87
+ if self.diffusion_vision_detokenizer:
88
+ logger.info('You are resetting diffusion vision detokenizer!')
89
+ return
90
+
91
+ if self.vision_tokenizer is None:
92
+ raise ValueError("Vision tokenizer is not set. Please set the vision tokenizer by using `processor.set_vision_tokenizer`")
93
+
94
+ self.diffusion_vision_detokenizer = StableDiffusionXLDecoderPipeline.from_pretrained(diffusion_decoder,
95
+ torch_dtype=torch_dtype,
96
+ add_watermarker=add_watermarker,
97
+ vq_model=self.vision_tokenizer,
98
+ vq_image_processor=self.image_processor).to(device)
99
+ logger.info('Setting diffusion vision detokenizer!')
100
+
101
+ def get_ratio_tag_from_ratio(self, ratio):
102
+ h, w = ratio
103
+ h_indicator, w_indicator = h // self.anyres_indicator_base, w // self.anyres_indicator_base
104
+ ratio_tag = f"<height_{h_indicator}><width_{w_indicator}>"
105
+ return ratio_tag
106
+
107
+ @torch.no_grad()
108
+ def _encode_with_dualvitok(self, img):
109
+ # img is a PIL image or np.ndarray
110
+ px = self.image_processor(img, return_tensors="pt")["pixel_values"].to(self.vision_tokenizer.device)
111
+ (_, _, idx_sem, _), (_, _, idx_pix) = self.vision_tokenizer.encode(px)
112
+ return idx_sem[0].cpu().tolist(), idx_pix[0].cpu().tolist()
113
+
114
+ def transform_image_nearest_resolution_ratio(self, image, ratios=RATIOS):
115
+ arc = AspectRatioCrop(ratios, crop_percent_thresh=self.crop_percent_thresh)
116
+ image, original_size, target_size, flag_matched = arc(image, is_inference=True)
117
+ return image
118
+
119
+ def convert_image_to_token_string(self, image, ratios=RATIOS):
120
+ arc = AspectRatioCrop(ratios, crop_percent_thresh=self.crop_percent_thresh)
121
+ image, original_size, target_size, flag_matched = arc(image, is_inference=True)
122
+ ratio_tag = self.get_ratio_tag_from_ratio(target_size)
123
+
124
+ image_embed_inds = self._encode_with_dualvitok(image)
125
+ return ratio_tag + self.encode_image_token_into_code(image_embed_inds)
126
+
127
+ def unpad_and_resize_back(self, padded_image, original_width, original_height):
128
+ return unpad_and_resize_back(padded_image, original_width, original_height)
129
+
130
+ def encode_image_token_into_code(self, image_embed_inds,
131
+ add_token_name="<|image_level{}_{}|>",
132
+ selected_vision_tokenizer_levels=None):
133
+ '''
134
+ Args:
135
+ image_embed_inds: 3D list, vision token ids for each tokenizer level
136
+ add_token_name: tag name for vision tokens
137
+ Returns:
138
+ image_token_return: str
139
+ '''
140
+
141
+ if selected_vision_tokenizer_levels is not None:
142
+ image_embed_inds_new = []
143
+ for level in selected_vision_tokenizer_levels:
144
+ image_embed_inds_new.append(image_embed_inds[level])
145
+ image_embed_inds = image_embed_inds_new
146
+
147
+ image_token_name_list = []
148
+ for level, image_embed_ind in enumerate(image_embed_inds):
149
+ image_token_name = []
150
+ for row in image_embed_ind:
151
+ image_token_name.append([add_token_name.format(level, ind) for ind in row])
152
+
153
+ image_token_name_list.append("<start_of_level{}>".format(level))
154
+ for row in image_token_name:
155
+ row.append("<end_of_line>")
156
+
157
+ for row in image_token_name:
158
+ image_token_name_list.extend(row)
159
+
160
+ image_token_name_list.append("<end_of_level{}>".format(level))
161
+
162
+ image_token_return = "".join(image_token_name_list)
163
+ image_token_return = "<start_of_image>" + image_token_return + "<end_of_image>"
164
+ return image_token_return
165
+
166
+ @torch.no_grad()
167
+ def decode_images(self, image_inds_list, target_resolution=(512, 512), return_type='pil',
168
+ use_diffusion=False, diffusion_cfg_scale=2.0, diffusion_num_inference_steps=20, **kwargs):
169
+
170
+ token_nums, _, h1, w1, h2, w2 = calculate_image_token_num(*target_resolution)
171
+
172
+ decoded_images = []
173
+ for image_inds in image_inds_list:
174
+ semantic_code = torch.as_tensor([image_inds[0]])
175
+ texture_code = torch.as_tensor([image_inds[1]])
176
+ if use_diffusion:
177
+ if self.diffusion_vision_detokenizer is None:
178
+ raise RuntimeError(
179
+ "diffusion_vision_detokenizer is not set. Please set the diffusion decoder by using `pipe.load_diffusion_vision_detokenizer`")
180
+
181
+ semantic_code = semantic_code.view(semantic_code.shape[0], h1, w1)
182
+ texture_code = texture_code.view(texture_code.shape[0], h2, w2)
183
+
184
+ diffusion_outputs = self.diffusion_vision_detokenizer(
185
+ vq_indices=(semantic_code, texture_code),
186
+ height=target_resolution[0] * 2,
187
+ width=target_resolution[1] * 2,
188
+ guidance_scale=diffusion_cfg_scale,
189
+ num_inference_steps=diffusion_num_inference_steps,
190
+ output_type=return_type,
191
+ **kwargs
192
+ )
193
+ samples = diffusion_outputs.images
194
+ image = samples[0]
195
+ else:
196
+ if self.vision_tokenizer is None:
197
+ raise RuntimeError(
198
+ "vision_detokenizer is not set. Please set the vision decoder by using `pipe.set_vision_detokenizer`")
199
+
200
+ semantic_code = semantic_code.view(semantic_code.shape[0], h1, w1)
201
+ texture_code = texture_code.view(texture_code.shape[0], h2, w2)
202
+
203
+ samples = self.vision_tokenizer.decode_code(semantic_code, texture_code)
204
+
205
+ if return_type == 'pil':
206
+ sample = \
207
+ torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).cpu().to(torch.uint8).numpy()[0]
208
+ image = Image.fromarray(sample)
209
+ else: # return numpy range -1 to 1.
210
+ image = samples.permute(0, 2, 3, 1).cpu().numpy()[0]
211
+ decoded_images.append(image)
212
+
213
+ return decoded_images
214
+
215
+ def parse_text_image(self, text, image_placeholder='<image_out>'):
216
+ generated_text, image_embed_inds_list, list_image_token_parts = parse_interleaved_text_image(text, num_levels=2,
217
+ image_placeholder=image_placeholder)
218
+ return generated_text, image_embed_inds_list, list_image_token_parts
219
+
220
+ def _encode_out_placeholder(self, img):
221
+ """
222
+ Encode one image with DualViTok and return a string
223
+ that can replace the <image_out> marker in the text.
224
+ """
225
+ return self.convert_image_to_token_string(img)
226
+
227
+ def __call__(
228
+ self,
229
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
230
+ images: ImageInput = None,
231
+ **kwargs: Unpack[ILLUMEProcessorKwargs],
232
+ ) -> BatchFeature:
233
+ """
234
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
235
+ and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
236
+ the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
237
+ DualViTokImageProcessor's [`~DualViTokImageProcessor.__call__`] if `vision_infos` is not `None`.
238
+
239
+ Args:
240
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
241
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
242
+ tensor. Both channels-first and channels-last formats are supported.
243
+ text (`str`, `List[str]`, `List[List[str]]`):
244
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
245
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
246
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
247
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
248
+ If set, will return tensors of a particular framework. Acceptable values are:
249
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
250
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
251
+ - `'np'`: Return NumPy `np.ndarray` objects.
252
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
253
+
254
+ Returns:
255
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
256
+
257
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
258
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
259
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
260
+ `None`).
261
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
262
+ """
263
+ output_kwargs = self._merge_kwargs(
264
+ ILLUMEProcessorKwargs,
265
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
266
+ **kwargs,
267
+ )
268
+
269
+ if not isinstance(text, list):
270
+ text = [text]
271
+
272
+ if isinstance(images, str):
273
+ images = [images]
274
+ elif images and isinstance(images[0], list):
275
+ # flatten List[List[PIL.Image.Image]]
276
+ images = [item for sublist in images for item in sublist]
277
+
278
+ _ = output_kwargs["text_kwargs"].pop("padding_side", None)
279
+ try:
280
+ text = self.apply_chat_template(text, add_generation_prompt=True, padding=True)
281
+ except Exception as e:
282
+ logger.info('Warning: input texts have been applied chat templates!')
283
+
284
+ if images is None:
285
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
286
+ return BatchFeature(data={**text_inputs})
287
+ else:
288
+ imgs_in, new_text, used = [], [], 0
289
+
290
+ if not isinstance(text, list):
291
+ text = [text]
292
+
293
+ for s in text: # walk each prompt
294
+ out, i = [], 0
295
+ for m in self._re_placeholder.finditer(s): # every placeholder
296
+ out.append(s[i:m.start()])
297
+ if used >= len(images):
298
+ raise ValueError("not enough images for placeholders")
299
+ img = images[used]
300
+ used += 1
301
+ if m.group() == "<image_out>":
302
+ out.append(self.convert_image_to_token_string(img)) # replace
303
+ else: # <image>
304
+ out.append("<image>")
305
+ imgs_in.append(img) # keep for pixel feats
306
+ i = m.end()
307
+ out.append(s[i:])
308
+ new_text.append("".join(out))
309
+
310
+ if used != len(images):
311
+ raise ValueError(f"too many images for placeholders. used {used} vs len(images) {len(images)}. {text}")
312
+
313
+ text_inputs = self.tokenizer(new_text, **output_kwargs["text_kwargs"])
314
+ image_inputs = self.image_processor.preprocess(imgs_in, **output_kwargs["images_kwargs"]) if imgs_in else {}
315
+
316
+ return BatchFeature(data={**text_inputs, **image_inputs})
317
+
318
+ def batch_decode(self, sequences, *args, **kwargs):
319
+ return [self.decode(seq, *args, **kwargs)
320
+ for i, seq in enumerate(sequences)]
321
+
322
+ def decode(self, *args, **kwargs):
323
+ return self.tokenizer.decode(*args, **kwargs)
324
+
325
+ @property
326
+ def model_input_names(self):
327
+ tokenizer_input_names = self.tokenizer.model_input_names
328
+ image_processor_input_names = self.image_processor.model_input_names
329
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
pytorch_model-00001-of-00004.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7c85c4110b8939d12d79ac65e4d0985cec132c92ef481e36b3923e5e090ba21
3
+ size 4970246339
pytorch_model-00002-of-00004.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:813720e5102d637c585854db25e80984d03fb73490e72795247d917933d8f821
3
+ size 4932780328
pytorch_model-00003-of-00004.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab7b0e070fc80137d7d9e43951d89a4ea82c961fb0a3dffeba4e73be033d07e1
3
+ size 4991527403
pytorch_model-00004-of-00004.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89e7232078f06cb9df7470b148d01e6de8826c936d49ed56401c7e91cddea297
3
+ size 3699762846
pytorch_model.bin.index.json ADDED
@@ -0,0 +1,889 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 18594005440
4
+ },
5
+ "weight_map": {
6
+ "language_model.lm_head.weight": "pytorch_model-00004-of-00004.bin",
7
+ "language_model.model.embed_tokens.weight": "pytorch_model-00001-of-00004.bin",
8
+ "language_model.model.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00004.bin",
9
+ "language_model.model.layers.0.mlp.down_proj.weight": "pytorch_model-00001-of-00004.bin",
10
+ "language_model.model.layers.0.mlp.gate_proj.weight": "pytorch_model-00001-of-00004.bin",
11
+ "language_model.model.layers.0.mlp.up_proj.weight": "pytorch_model-00001-of-00004.bin",
12
+ "language_model.model.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00004.bin",
13
+ "language_model.model.layers.0.self_attn.k_proj.bias": "pytorch_model-00001-of-00004.bin",
14
+ "language_model.model.layers.0.self_attn.k_proj.weight": "pytorch_model-00001-of-00004.bin",
15
+ "language_model.model.layers.0.self_attn.o_proj.weight": "pytorch_model-00001-of-00004.bin",
16
+ "language_model.model.layers.0.self_attn.q_proj.bias": "pytorch_model-00001-of-00004.bin",
17
+ "language_model.model.layers.0.self_attn.q_proj.weight": "pytorch_model-00001-of-00004.bin",
18
+ "language_model.model.layers.0.self_attn.v_proj.bias": "pytorch_model-00001-of-00004.bin",
19
+ "language_model.model.layers.0.self_attn.v_proj.weight": "pytorch_model-00001-of-00004.bin",
20
+ "language_model.model.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00004.bin",
21
+ "language_model.model.layers.1.mlp.down_proj.weight": "pytorch_model-00001-of-00004.bin",
22
+ "language_model.model.layers.1.mlp.gate_proj.weight": "pytorch_model-00001-of-00004.bin",
23
+ "language_model.model.layers.1.mlp.up_proj.weight": "pytorch_model-00001-of-00004.bin",
24
+ "language_model.model.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00004.bin",
25
+ "language_model.model.layers.1.self_attn.k_proj.bias": "pytorch_model-00001-of-00004.bin",
26
+ "language_model.model.layers.1.self_attn.k_proj.weight": "pytorch_model-00001-of-00004.bin",
27
+ "language_model.model.layers.1.self_attn.o_proj.weight": "pytorch_model-00001-of-00004.bin",
28
+ "language_model.model.layers.1.self_attn.q_proj.bias": "pytorch_model-00001-of-00004.bin",
29
+ "language_model.model.layers.1.self_attn.q_proj.weight": "pytorch_model-00001-of-00004.bin",
30
+ "language_model.model.layers.1.self_attn.v_proj.bias": "pytorch_model-00001-of-00004.bin",
31
+ "language_model.model.layers.1.self_attn.v_proj.weight": "pytorch_model-00001-of-00004.bin",
32
+ "language_model.model.layers.10.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
33
+ "language_model.model.layers.10.mlp.down_proj.weight": "pytorch_model-00002-of-00004.bin",
34
+ "language_model.model.layers.10.mlp.gate_proj.weight": "pytorch_model-00002-of-00004.bin",
35
+ "language_model.model.layers.10.mlp.up_proj.weight": "pytorch_model-00002-of-00004.bin",
36
+ "language_model.model.layers.10.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
37
+ "language_model.model.layers.10.self_attn.k_proj.bias": "pytorch_model-00002-of-00004.bin",
38
+ "language_model.model.layers.10.self_attn.k_proj.weight": "pytorch_model-00002-of-00004.bin",
39
+ "language_model.model.layers.10.self_attn.o_proj.weight": "pytorch_model-00002-of-00004.bin",
40
+ "language_model.model.layers.10.self_attn.q_proj.bias": "pytorch_model-00002-of-00004.bin",
41
+ "language_model.model.layers.10.self_attn.q_proj.weight": "pytorch_model-00002-of-00004.bin",
42
+ "language_model.model.layers.10.self_attn.v_proj.bias": "pytorch_model-00002-of-00004.bin",
43
+ "language_model.model.layers.10.self_attn.v_proj.weight": "pytorch_model-00002-of-00004.bin",
44
+ "language_model.model.layers.11.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
45
+ "language_model.model.layers.11.mlp.down_proj.weight": "pytorch_model-00002-of-00004.bin",
46
+ "language_model.model.layers.11.mlp.gate_proj.weight": "pytorch_model-00002-of-00004.bin",
47
+ "language_model.model.layers.11.mlp.up_proj.weight": "pytorch_model-00002-of-00004.bin",
48
+ "language_model.model.layers.11.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
49
+ "language_model.model.layers.11.self_attn.k_proj.bias": "pytorch_model-00002-of-00004.bin",
50
+ "language_model.model.layers.11.self_attn.k_proj.weight": "pytorch_model-00002-of-00004.bin",
51
+ "language_model.model.layers.11.self_attn.o_proj.weight": "pytorch_model-00002-of-00004.bin",
52
+ "language_model.model.layers.11.self_attn.q_proj.bias": "pytorch_model-00002-of-00004.bin",
53
+ "language_model.model.layers.11.self_attn.q_proj.weight": "pytorch_model-00002-of-00004.bin",
54
+ "language_model.model.layers.11.self_attn.v_proj.bias": "pytorch_model-00002-of-00004.bin",
55
+ "language_model.model.layers.11.self_attn.v_proj.weight": "pytorch_model-00002-of-00004.bin",
56
+ "language_model.model.layers.12.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
57
+ "language_model.model.layers.12.mlp.down_proj.weight": "pytorch_model-00002-of-00004.bin",
58
+ "language_model.model.layers.12.mlp.gate_proj.weight": "pytorch_model-00002-of-00004.bin",
59
+ "language_model.model.layers.12.mlp.up_proj.weight": "pytorch_model-00002-of-00004.bin",
60
+ "language_model.model.layers.12.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
61
+ "language_model.model.layers.12.self_attn.k_proj.bias": "pytorch_model-00002-of-00004.bin",
62
+ "language_model.model.layers.12.self_attn.k_proj.weight": "pytorch_model-00002-of-00004.bin",
63
+ "language_model.model.layers.12.self_attn.o_proj.weight": "pytorch_model-00002-of-00004.bin",
64
+ "language_model.model.layers.12.self_attn.q_proj.bias": "pytorch_model-00002-of-00004.bin",
65
+ "language_model.model.layers.12.self_attn.q_proj.weight": "pytorch_model-00002-of-00004.bin",
66
+ "language_model.model.layers.12.self_attn.v_proj.bias": "pytorch_model-00002-of-00004.bin",
67
+ "language_model.model.layers.12.self_attn.v_proj.weight": "pytorch_model-00002-of-00004.bin",
68
+ "language_model.model.layers.13.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
69
+ "language_model.model.layers.13.mlp.down_proj.weight": "pytorch_model-00003-of-00004.bin",
70
+ "language_model.model.layers.13.mlp.gate_proj.weight": "pytorch_model-00002-of-00004.bin",
71
+ "language_model.model.layers.13.mlp.up_proj.weight": "pytorch_model-00002-of-00004.bin",
72
+ "language_model.model.layers.13.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
73
+ "language_model.model.layers.13.self_attn.k_proj.bias": "pytorch_model-00002-of-00004.bin",
74
+ "language_model.model.layers.13.self_attn.k_proj.weight": "pytorch_model-00002-of-00004.bin",
75
+ "language_model.model.layers.13.self_attn.o_proj.weight": "pytorch_model-00002-of-00004.bin",
76
+ "language_model.model.layers.13.self_attn.q_proj.bias": "pytorch_model-00002-of-00004.bin",
77
+ "language_model.model.layers.13.self_attn.q_proj.weight": "pytorch_model-00002-of-00004.bin",
78
+ "language_model.model.layers.13.self_attn.v_proj.bias": "pytorch_model-00002-of-00004.bin",
79
+ "language_model.model.layers.13.self_attn.v_proj.weight": "pytorch_model-00002-of-00004.bin",
80
+ "language_model.model.layers.14.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
81
+ "language_model.model.layers.14.mlp.down_proj.weight": "pytorch_model-00003-of-00004.bin",
82
+ "language_model.model.layers.14.mlp.gate_proj.weight": "pytorch_model-00003-of-00004.bin",
83
+ "language_model.model.layers.14.mlp.up_proj.weight": "pytorch_model-00003-of-00004.bin",
84
+ "language_model.model.layers.14.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
85
+ "language_model.model.layers.14.self_attn.k_proj.bias": "pytorch_model-00003-of-00004.bin",
86
+ "language_model.model.layers.14.self_attn.k_proj.weight": "pytorch_model-00003-of-00004.bin",
87
+ "language_model.model.layers.14.self_attn.o_proj.weight": "pytorch_model-00003-of-00004.bin",
88
+ "language_model.model.layers.14.self_attn.q_proj.bias": "pytorch_model-00003-of-00004.bin",
89
+ "language_model.model.layers.14.self_attn.q_proj.weight": "pytorch_model-00003-of-00004.bin",
90
+ "language_model.model.layers.14.self_attn.v_proj.bias": "pytorch_model-00003-of-00004.bin",
91
+ "language_model.model.layers.14.self_attn.v_proj.weight": "pytorch_model-00003-of-00004.bin",
92
+ "language_model.model.layers.15.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
93
+ "language_model.model.layers.15.mlp.down_proj.weight": "pytorch_model-00003-of-00004.bin",
94
+ "language_model.model.layers.15.mlp.gate_proj.weight": "pytorch_model-00003-of-00004.bin",
95
+ "language_model.model.layers.15.mlp.up_proj.weight": "pytorch_model-00003-of-00004.bin",
96
+ "language_model.model.layers.15.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
97
+ "language_model.model.layers.15.self_attn.k_proj.bias": "pytorch_model-00003-of-00004.bin",
98
+ "language_model.model.layers.15.self_attn.k_proj.weight": "pytorch_model-00003-of-00004.bin",
99
+ "language_model.model.layers.15.self_attn.o_proj.weight": "pytorch_model-00003-of-00004.bin",
100
+ "language_model.model.layers.15.self_attn.q_proj.bias": "pytorch_model-00003-of-00004.bin",
101
+ "language_model.model.layers.15.self_attn.q_proj.weight": "pytorch_model-00003-of-00004.bin",
102
+ "language_model.model.layers.15.self_attn.v_proj.bias": "pytorch_model-00003-of-00004.bin",
103
+ "language_model.model.layers.15.self_attn.v_proj.weight": "pytorch_model-00003-of-00004.bin",
104
+ "language_model.model.layers.16.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
105
+ "language_model.model.layers.16.mlp.down_proj.weight": "pytorch_model-00003-of-00004.bin",
106
+ "language_model.model.layers.16.mlp.gate_proj.weight": "pytorch_model-00003-of-00004.bin",
107
+ "language_model.model.layers.16.mlp.up_proj.weight": "pytorch_model-00003-of-00004.bin",
108
+ "language_model.model.layers.16.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
109
+ "language_model.model.layers.16.self_attn.k_proj.bias": "pytorch_model-00003-of-00004.bin",
110
+ "language_model.model.layers.16.self_attn.k_proj.weight": "pytorch_model-00003-of-00004.bin",
111
+ "language_model.model.layers.16.self_attn.o_proj.weight": "pytorch_model-00003-of-00004.bin",
112
+ "language_model.model.layers.16.self_attn.q_proj.bias": "pytorch_model-00003-of-00004.bin",
113
+ "language_model.model.layers.16.self_attn.q_proj.weight": "pytorch_model-00003-of-00004.bin",
114
+ "language_model.model.layers.16.self_attn.v_proj.bias": "pytorch_model-00003-of-00004.bin",
115
+ "language_model.model.layers.16.self_attn.v_proj.weight": "pytorch_model-00003-of-00004.bin",
116
+ "language_model.model.layers.17.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
117
+ "language_model.model.layers.17.mlp.down_proj.weight": "pytorch_model-00003-of-00004.bin",
118
+ "language_model.model.layers.17.mlp.gate_proj.weight": "pytorch_model-00003-of-00004.bin",
119
+ "language_model.model.layers.17.mlp.up_proj.weight": "pytorch_model-00003-of-00004.bin",
120
+ "language_model.model.layers.17.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
121
+ "language_model.model.layers.17.self_attn.k_proj.bias": "pytorch_model-00003-of-00004.bin",
122
+ "language_model.model.layers.17.self_attn.k_proj.weight": "pytorch_model-00003-of-00004.bin",
123
+ "language_model.model.layers.17.self_attn.o_proj.weight": "pytorch_model-00003-of-00004.bin",
124
+ "language_model.model.layers.17.self_attn.q_proj.bias": "pytorch_model-00003-of-00004.bin",
125
+ "language_model.model.layers.17.self_attn.q_proj.weight": "pytorch_model-00003-of-00004.bin",
126
+ "language_model.model.layers.17.self_attn.v_proj.bias": "pytorch_model-00003-of-00004.bin",
127
+ "language_model.model.layers.17.self_attn.v_proj.weight": "pytorch_model-00003-of-00004.bin",
128
+ "language_model.model.layers.18.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
129
+ "language_model.model.layers.18.mlp.down_proj.weight": "pytorch_model-00003-of-00004.bin",
130
+ "language_model.model.layers.18.mlp.gate_proj.weight": "pytorch_model-00003-of-00004.bin",
131
+ "language_model.model.layers.18.mlp.up_proj.weight": "pytorch_model-00003-of-00004.bin",
132
+ "language_model.model.layers.18.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
133
+ "language_model.model.layers.18.self_attn.k_proj.bias": "pytorch_model-00003-of-00004.bin",
134
+ "language_model.model.layers.18.self_attn.k_proj.weight": "pytorch_model-00003-of-00004.bin",
135
+ "language_model.model.layers.18.self_attn.o_proj.weight": "pytorch_model-00003-of-00004.bin",
136
+ "language_model.model.layers.18.self_attn.q_proj.bias": "pytorch_model-00003-of-00004.bin",
137
+ "language_model.model.layers.18.self_attn.q_proj.weight": "pytorch_model-00003-of-00004.bin",
138
+ "language_model.model.layers.18.self_attn.v_proj.bias": "pytorch_model-00003-of-00004.bin",
139
+ "language_model.model.layers.18.self_attn.v_proj.weight": "pytorch_model-00003-of-00004.bin",
140
+ "language_model.model.layers.19.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
141
+ "language_model.model.layers.19.mlp.down_proj.weight": "pytorch_model-00003-of-00004.bin",
142
+ "language_model.model.layers.19.mlp.gate_proj.weight": "pytorch_model-00003-of-00004.bin",
143
+ "language_model.model.layers.19.mlp.up_proj.weight": "pytorch_model-00003-of-00004.bin",
144
+ "language_model.model.layers.19.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
145
+ "language_model.model.layers.19.self_attn.k_proj.bias": "pytorch_model-00003-of-00004.bin",
146
+ "language_model.model.layers.19.self_attn.k_proj.weight": "pytorch_model-00003-of-00004.bin",
147
+ "language_model.model.layers.19.self_attn.o_proj.weight": "pytorch_model-00003-of-00004.bin",
148
+ "language_model.model.layers.19.self_attn.q_proj.bias": "pytorch_model-00003-of-00004.bin",
149
+ "language_model.model.layers.19.self_attn.q_proj.weight": "pytorch_model-00003-of-00004.bin",
150
+ "language_model.model.layers.19.self_attn.v_proj.bias": "pytorch_model-00003-of-00004.bin",
151
+ "language_model.model.layers.19.self_attn.v_proj.weight": "pytorch_model-00003-of-00004.bin",
152
+ "language_model.model.layers.2.input_layernorm.weight": "pytorch_model-00001-of-00004.bin",
153
+ "language_model.model.layers.2.mlp.down_proj.weight": "pytorch_model-00001-of-00004.bin",
154
+ "language_model.model.layers.2.mlp.gate_proj.weight": "pytorch_model-00001-of-00004.bin",
155
+ "language_model.model.layers.2.mlp.up_proj.weight": "pytorch_model-00001-of-00004.bin",
156
+ "language_model.model.layers.2.post_attention_layernorm.weight": "pytorch_model-00001-of-00004.bin",
157
+ "language_model.model.layers.2.self_attn.k_proj.bias": "pytorch_model-00001-of-00004.bin",
158
+ "language_model.model.layers.2.self_attn.k_proj.weight": "pytorch_model-00001-of-00004.bin",
159
+ "language_model.model.layers.2.self_attn.o_proj.weight": "pytorch_model-00001-of-00004.bin",
160
+ "language_model.model.layers.2.self_attn.q_proj.bias": "pytorch_model-00001-of-00004.bin",
161
+ "language_model.model.layers.2.self_attn.q_proj.weight": "pytorch_model-00001-of-00004.bin",
162
+ "language_model.model.layers.2.self_attn.v_proj.bias": "pytorch_model-00001-of-00004.bin",
163
+ "language_model.model.layers.2.self_attn.v_proj.weight": "pytorch_model-00001-of-00004.bin",
164
+ "language_model.model.layers.20.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
165
+ "language_model.model.layers.20.mlp.down_proj.weight": "pytorch_model-00003-of-00004.bin",
166
+ "language_model.model.layers.20.mlp.gate_proj.weight": "pytorch_model-00003-of-00004.bin",
167
+ "language_model.model.layers.20.mlp.up_proj.weight": "pytorch_model-00003-of-00004.bin",
168
+ "language_model.model.layers.20.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
169
+ "language_model.model.layers.20.self_attn.k_proj.bias": "pytorch_model-00003-of-00004.bin",
170
+ "language_model.model.layers.20.self_attn.k_proj.weight": "pytorch_model-00003-of-00004.bin",
171
+ "language_model.model.layers.20.self_attn.o_proj.weight": "pytorch_model-00003-of-00004.bin",
172
+ "language_model.model.layers.20.self_attn.q_proj.bias": "pytorch_model-00003-of-00004.bin",
173
+ "language_model.model.layers.20.self_attn.q_proj.weight": "pytorch_model-00003-of-00004.bin",
174
+ "language_model.model.layers.20.self_attn.v_proj.bias": "pytorch_model-00003-of-00004.bin",
175
+ "language_model.model.layers.20.self_attn.v_proj.weight": "pytorch_model-00003-of-00004.bin",
176
+ "language_model.model.layers.21.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
177
+ "language_model.model.layers.21.mlp.down_proj.weight": "pytorch_model-00003-of-00004.bin",
178
+ "language_model.model.layers.21.mlp.gate_proj.weight": "pytorch_model-00003-of-00004.bin",
179
+ "language_model.model.layers.21.mlp.up_proj.weight": "pytorch_model-00003-of-00004.bin",
180
+ "language_model.model.layers.21.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
181
+ "language_model.model.layers.21.self_attn.k_proj.bias": "pytorch_model-00003-of-00004.bin",
182
+ "language_model.model.layers.21.self_attn.k_proj.weight": "pytorch_model-00003-of-00004.bin",
183
+ "language_model.model.layers.21.self_attn.o_proj.weight": "pytorch_model-00003-of-00004.bin",
184
+ "language_model.model.layers.21.self_attn.q_proj.bias": "pytorch_model-00003-of-00004.bin",
185
+ "language_model.model.layers.21.self_attn.q_proj.weight": "pytorch_model-00003-of-00004.bin",
186
+ "language_model.model.layers.21.self_attn.v_proj.bias": "pytorch_model-00003-of-00004.bin",
187
+ "language_model.model.layers.21.self_attn.v_proj.weight": "pytorch_model-00003-of-00004.bin",
188
+ "language_model.model.layers.22.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
189
+ "language_model.model.layers.22.mlp.down_proj.weight": "pytorch_model-00003-of-00004.bin",
190
+ "language_model.model.layers.22.mlp.gate_proj.weight": "pytorch_model-00003-of-00004.bin",
191
+ "language_model.model.layers.22.mlp.up_proj.weight": "pytorch_model-00003-of-00004.bin",
192
+ "language_model.model.layers.22.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
193
+ "language_model.model.layers.22.self_attn.k_proj.bias": "pytorch_model-00003-of-00004.bin",
194
+ "language_model.model.layers.22.self_attn.k_proj.weight": "pytorch_model-00003-of-00004.bin",
195
+ "language_model.model.layers.22.self_attn.o_proj.weight": "pytorch_model-00003-of-00004.bin",
196
+ "language_model.model.layers.22.self_attn.q_proj.bias": "pytorch_model-00003-of-00004.bin",
197
+ "language_model.model.layers.22.self_attn.q_proj.weight": "pytorch_model-00003-of-00004.bin",
198
+ "language_model.model.layers.22.self_attn.v_proj.bias": "pytorch_model-00003-of-00004.bin",
199
+ "language_model.model.layers.22.self_attn.v_proj.weight": "pytorch_model-00003-of-00004.bin",
200
+ "language_model.model.layers.23.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
201
+ "language_model.model.layers.23.mlp.down_proj.weight": "pytorch_model-00003-of-00004.bin",
202
+ "language_model.model.layers.23.mlp.gate_proj.weight": "pytorch_model-00003-of-00004.bin",
203
+ "language_model.model.layers.23.mlp.up_proj.weight": "pytorch_model-00003-of-00004.bin",
204
+ "language_model.model.layers.23.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
205
+ "language_model.model.layers.23.self_attn.k_proj.bias": "pytorch_model-00003-of-00004.bin",
206
+ "language_model.model.layers.23.self_attn.k_proj.weight": "pytorch_model-00003-of-00004.bin",
207
+ "language_model.model.layers.23.self_attn.o_proj.weight": "pytorch_model-00003-of-00004.bin",
208
+ "language_model.model.layers.23.self_attn.q_proj.bias": "pytorch_model-00003-of-00004.bin",
209
+ "language_model.model.layers.23.self_attn.q_proj.weight": "pytorch_model-00003-of-00004.bin",
210
+ "language_model.model.layers.23.self_attn.v_proj.bias": "pytorch_model-00003-of-00004.bin",
211
+ "language_model.model.layers.23.self_attn.v_proj.weight": "pytorch_model-00003-of-00004.bin",
212
+ "language_model.model.layers.24.input_layernorm.weight": "pytorch_model-00004-of-00004.bin",
213
+ "language_model.model.layers.24.mlp.down_proj.weight": "pytorch_model-00004-of-00004.bin",
214
+ "language_model.model.layers.24.mlp.gate_proj.weight": "pytorch_model-00003-of-00004.bin",
215
+ "language_model.model.layers.24.mlp.up_proj.weight": "pytorch_model-00004-of-00004.bin",
216
+ "language_model.model.layers.24.post_attention_layernorm.weight": "pytorch_model-00004-of-00004.bin",
217
+ "language_model.model.layers.24.self_attn.k_proj.bias": "pytorch_model-00003-of-00004.bin",
218
+ "language_model.model.layers.24.self_attn.k_proj.weight": "pytorch_model-00003-of-00004.bin",
219
+ "language_model.model.layers.24.self_attn.o_proj.weight": "pytorch_model-00003-of-00004.bin",
220
+ "language_model.model.layers.24.self_attn.q_proj.bias": "pytorch_model-00003-of-00004.bin",
221
+ "language_model.model.layers.24.self_attn.q_proj.weight": "pytorch_model-00003-of-00004.bin",
222
+ "language_model.model.layers.24.self_attn.v_proj.bias": "pytorch_model-00003-of-00004.bin",
223
+ "language_model.model.layers.24.self_attn.v_proj.weight": "pytorch_model-00003-of-00004.bin",
224
+ "language_model.model.layers.25.input_layernorm.weight": "pytorch_model-00004-of-00004.bin",
225
+ "language_model.model.layers.25.mlp.down_proj.weight": "pytorch_model-00004-of-00004.bin",
226
+ "language_model.model.layers.25.mlp.gate_proj.weight": "pytorch_model-00004-of-00004.bin",
227
+ "language_model.model.layers.25.mlp.up_proj.weight": "pytorch_model-00004-of-00004.bin",
228
+ "language_model.model.layers.25.post_attention_layernorm.weight": "pytorch_model-00004-of-00004.bin",
229
+ "language_model.model.layers.25.self_attn.k_proj.bias": "pytorch_model-00004-of-00004.bin",
230
+ "language_model.model.layers.25.self_attn.k_proj.weight": "pytorch_model-00004-of-00004.bin",
231
+ "language_model.model.layers.25.self_attn.o_proj.weight": "pytorch_model-00004-of-00004.bin",
232
+ "language_model.model.layers.25.self_attn.q_proj.bias": "pytorch_model-00004-of-00004.bin",
233
+ "language_model.model.layers.25.self_attn.q_proj.weight": "pytorch_model-00004-of-00004.bin",
234
+ "language_model.model.layers.25.self_attn.v_proj.bias": "pytorch_model-00004-of-00004.bin",
235
+ "language_model.model.layers.25.self_attn.v_proj.weight": "pytorch_model-00004-of-00004.bin",
236
+ "language_model.model.layers.26.input_layernorm.weight": "pytorch_model-00004-of-00004.bin",
237
+ "language_model.model.layers.26.mlp.down_proj.weight": "pytorch_model-00004-of-00004.bin",
238
+ "language_model.model.layers.26.mlp.gate_proj.weight": "pytorch_model-00004-of-00004.bin",
239
+ "language_model.model.layers.26.mlp.up_proj.weight": "pytorch_model-00004-of-00004.bin",
240
+ "language_model.model.layers.26.post_attention_layernorm.weight": "pytorch_model-00004-of-00004.bin",
241
+ "language_model.model.layers.26.self_attn.k_proj.bias": "pytorch_model-00004-of-00004.bin",
242
+ "language_model.model.layers.26.self_attn.k_proj.weight": "pytorch_model-00004-of-00004.bin",
243
+ "language_model.model.layers.26.self_attn.o_proj.weight": "pytorch_model-00004-of-00004.bin",
244
+ "language_model.model.layers.26.self_attn.q_proj.bias": "pytorch_model-00004-of-00004.bin",
245
+ "language_model.model.layers.26.self_attn.q_proj.weight": "pytorch_model-00004-of-00004.bin",
246
+ "language_model.model.layers.26.self_attn.v_proj.bias": "pytorch_model-00004-of-00004.bin",
247
+ "language_model.model.layers.26.self_attn.v_proj.weight": "pytorch_model-00004-of-00004.bin",
248
+ "language_model.model.layers.27.input_layernorm.weight": "pytorch_model-00004-of-00004.bin",
249
+ "language_model.model.layers.27.mlp.down_proj.weight": "pytorch_model-00004-of-00004.bin",
250
+ "language_model.model.layers.27.mlp.gate_proj.weight": "pytorch_model-00004-of-00004.bin",
251
+ "language_model.model.layers.27.mlp.up_proj.weight": "pytorch_model-00004-of-00004.bin",
252
+ "language_model.model.layers.27.post_attention_layernorm.weight": "pytorch_model-00004-of-00004.bin",
253
+ "language_model.model.layers.27.self_attn.k_proj.bias": "pytorch_model-00004-of-00004.bin",
254
+ "language_model.model.layers.27.self_attn.k_proj.weight": "pytorch_model-00004-of-00004.bin",
255
+ "language_model.model.layers.27.self_attn.o_proj.weight": "pytorch_model-00004-of-00004.bin",
256
+ "language_model.model.layers.27.self_attn.q_proj.bias": "pytorch_model-00004-of-00004.bin",
257
+ "language_model.model.layers.27.self_attn.q_proj.weight": "pytorch_model-00004-of-00004.bin",
258
+ "language_model.model.layers.27.self_attn.v_proj.bias": "pytorch_model-00004-of-00004.bin",
259
+ "language_model.model.layers.27.self_attn.v_proj.weight": "pytorch_model-00004-of-00004.bin",
260
+ "language_model.model.layers.3.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
261
+ "language_model.model.layers.3.mlp.down_proj.weight": "pytorch_model-00002-of-00004.bin",
262
+ "language_model.model.layers.3.mlp.gate_proj.weight": "pytorch_model-00002-of-00004.bin",
263
+ "language_model.model.layers.3.mlp.up_proj.weight": "pytorch_model-00002-of-00004.bin",
264
+ "language_model.model.layers.3.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
265
+ "language_model.model.layers.3.self_attn.k_proj.bias": "pytorch_model-00001-of-00004.bin",
266
+ "language_model.model.layers.3.self_attn.k_proj.weight": "pytorch_model-00001-of-00004.bin",
267
+ "language_model.model.layers.3.self_attn.o_proj.weight": "pytorch_model-00001-of-00004.bin",
268
+ "language_model.model.layers.3.self_attn.q_proj.bias": "pytorch_model-00001-of-00004.bin",
269
+ "language_model.model.layers.3.self_attn.q_proj.weight": "pytorch_model-00001-of-00004.bin",
270
+ "language_model.model.layers.3.self_attn.v_proj.bias": "pytorch_model-00001-of-00004.bin",
271
+ "language_model.model.layers.3.self_attn.v_proj.weight": "pytorch_model-00001-of-00004.bin",
272
+ "language_model.model.layers.4.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
273
+ "language_model.model.layers.4.mlp.down_proj.weight": "pytorch_model-00002-of-00004.bin",
274
+ "language_model.model.layers.4.mlp.gate_proj.weight": "pytorch_model-00002-of-00004.bin",
275
+ "language_model.model.layers.4.mlp.up_proj.weight": "pytorch_model-00002-of-00004.bin",
276
+ "language_model.model.layers.4.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
277
+ "language_model.model.layers.4.self_attn.k_proj.bias": "pytorch_model-00002-of-00004.bin",
278
+ "language_model.model.layers.4.self_attn.k_proj.weight": "pytorch_model-00002-of-00004.bin",
279
+ "language_model.model.layers.4.self_attn.o_proj.weight": "pytorch_model-00002-of-00004.bin",
280
+ "language_model.model.layers.4.self_attn.q_proj.bias": "pytorch_model-00002-of-00004.bin",
281
+ "language_model.model.layers.4.self_attn.q_proj.weight": "pytorch_model-00002-of-00004.bin",
282
+ "language_model.model.layers.4.self_attn.v_proj.bias": "pytorch_model-00002-of-00004.bin",
283
+ "language_model.model.layers.4.self_attn.v_proj.weight": "pytorch_model-00002-of-00004.bin",
284
+ "language_model.model.layers.5.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
285
+ "language_model.model.layers.5.mlp.down_proj.weight": "pytorch_model-00002-of-00004.bin",
286
+ "language_model.model.layers.5.mlp.gate_proj.weight": "pytorch_model-00002-of-00004.bin",
287
+ "language_model.model.layers.5.mlp.up_proj.weight": "pytorch_model-00002-of-00004.bin",
288
+ "language_model.model.layers.5.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
289
+ "language_model.model.layers.5.self_attn.k_proj.bias": "pytorch_model-00002-of-00004.bin",
290
+ "language_model.model.layers.5.self_attn.k_proj.weight": "pytorch_model-00002-of-00004.bin",
291
+ "language_model.model.layers.5.self_attn.o_proj.weight": "pytorch_model-00002-of-00004.bin",
292
+ "language_model.model.layers.5.self_attn.q_proj.bias": "pytorch_model-00002-of-00004.bin",
293
+ "language_model.model.layers.5.self_attn.q_proj.weight": "pytorch_model-00002-of-00004.bin",
294
+ "language_model.model.layers.5.self_attn.v_proj.bias": "pytorch_model-00002-of-00004.bin",
295
+ "language_model.model.layers.5.self_attn.v_proj.weight": "pytorch_model-00002-of-00004.bin",
296
+ "language_model.model.layers.6.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
297
+ "language_model.model.layers.6.mlp.down_proj.weight": "pytorch_model-00002-of-00004.bin",
298
+ "language_model.model.layers.6.mlp.gate_proj.weight": "pytorch_model-00002-of-00004.bin",
299
+ "language_model.model.layers.6.mlp.up_proj.weight": "pytorch_model-00002-of-00004.bin",
300
+ "language_model.model.layers.6.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
301
+ "language_model.model.layers.6.self_attn.k_proj.bias": "pytorch_model-00002-of-00004.bin",
302
+ "language_model.model.layers.6.self_attn.k_proj.weight": "pytorch_model-00002-of-00004.bin",
303
+ "language_model.model.layers.6.self_attn.o_proj.weight": "pytorch_model-00002-of-00004.bin",
304
+ "language_model.model.layers.6.self_attn.q_proj.bias": "pytorch_model-00002-of-00004.bin",
305
+ "language_model.model.layers.6.self_attn.q_proj.weight": "pytorch_model-00002-of-00004.bin",
306
+ "language_model.model.layers.6.self_attn.v_proj.bias": "pytorch_model-00002-of-00004.bin",
307
+ "language_model.model.layers.6.self_attn.v_proj.weight": "pytorch_model-00002-of-00004.bin",
308
+ "language_model.model.layers.7.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
309
+ "language_model.model.layers.7.mlp.down_proj.weight": "pytorch_model-00002-of-00004.bin",
310
+ "language_model.model.layers.7.mlp.gate_proj.weight": "pytorch_model-00002-of-00004.bin",
311
+ "language_model.model.layers.7.mlp.up_proj.weight": "pytorch_model-00002-of-00004.bin",
312
+ "language_model.model.layers.7.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
313
+ "language_model.model.layers.7.self_attn.k_proj.bias": "pytorch_model-00002-of-00004.bin",
314
+ "language_model.model.layers.7.self_attn.k_proj.weight": "pytorch_model-00002-of-00004.bin",
315
+ "language_model.model.layers.7.self_attn.o_proj.weight": "pytorch_model-00002-of-00004.bin",
316
+ "language_model.model.layers.7.self_attn.q_proj.bias": "pytorch_model-00002-of-00004.bin",
317
+ "language_model.model.layers.7.self_attn.q_proj.weight": "pytorch_model-00002-of-00004.bin",
318
+ "language_model.model.layers.7.self_attn.v_proj.bias": "pytorch_model-00002-of-00004.bin",
319
+ "language_model.model.layers.7.self_attn.v_proj.weight": "pytorch_model-00002-of-00004.bin",
320
+ "language_model.model.layers.8.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
321
+ "language_model.model.layers.8.mlp.down_proj.weight": "pytorch_model-00002-of-00004.bin",
322
+ "language_model.model.layers.8.mlp.gate_proj.weight": "pytorch_model-00002-of-00004.bin",
323
+ "language_model.model.layers.8.mlp.up_proj.weight": "pytorch_model-00002-of-00004.bin",
324
+ "language_model.model.layers.8.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
325
+ "language_model.model.layers.8.self_attn.k_proj.bias": "pytorch_model-00002-of-00004.bin",
326
+ "language_model.model.layers.8.self_attn.k_proj.weight": "pytorch_model-00002-of-00004.bin",
327
+ "language_model.model.layers.8.self_attn.o_proj.weight": "pytorch_model-00002-of-00004.bin",
328
+ "language_model.model.layers.8.self_attn.q_proj.bias": "pytorch_model-00002-of-00004.bin",
329
+ "language_model.model.layers.8.self_attn.q_proj.weight": "pytorch_model-00002-of-00004.bin",
330
+ "language_model.model.layers.8.self_attn.v_proj.bias": "pytorch_model-00002-of-00004.bin",
331
+ "language_model.model.layers.8.self_attn.v_proj.weight": "pytorch_model-00002-of-00004.bin",
332
+ "language_model.model.layers.9.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
333
+ "language_model.model.layers.9.mlp.down_proj.weight": "pytorch_model-00002-of-00004.bin",
334
+ "language_model.model.layers.9.mlp.gate_proj.weight": "pytorch_model-00002-of-00004.bin",
335
+ "language_model.model.layers.9.mlp.up_proj.weight": "pytorch_model-00002-of-00004.bin",
336
+ "language_model.model.layers.9.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
337
+ "language_model.model.layers.9.self_attn.k_proj.bias": "pytorch_model-00002-of-00004.bin",
338
+ "language_model.model.layers.9.self_attn.k_proj.weight": "pytorch_model-00002-of-00004.bin",
339
+ "language_model.model.layers.9.self_attn.o_proj.weight": "pytorch_model-00002-of-00004.bin",
340
+ "language_model.model.layers.9.self_attn.q_proj.bias": "pytorch_model-00002-of-00004.bin",
341
+ "language_model.model.layers.9.self_attn.q_proj.weight": "pytorch_model-00002-of-00004.bin",
342
+ "language_model.model.layers.9.self_attn.v_proj.bias": "pytorch_model-00002-of-00004.bin",
343
+ "language_model.model.layers.9.self_attn.v_proj.weight": "pytorch_model-00002-of-00004.bin",
344
+ "language_model.model.norm.weight": "pytorch_model-00004-of-00004.bin",
345
+ "mm_projector.projector_1.0.bias": "pytorch_model-00001-of-00004.bin",
346
+ "mm_projector.projector_1.0.weight": "pytorch_model-00001-of-00004.bin",
347
+ "mm_projector.projector_1.2.bias": "pytorch_model-00001-of-00004.bin",
348
+ "mm_projector.projector_1.2.weight": "pytorch_model-00001-of-00004.bin",
349
+ "mm_projector.projector_2.0.bias": "pytorch_model-00001-of-00004.bin",
350
+ "mm_projector.projector_2.0.weight": "pytorch_model-00001-of-00004.bin",
351
+ "mm_projector.projector_2.2.bias": "pytorch_model-00001-of-00004.bin",
352
+ "mm_projector.projector_2.2.weight": "pytorch_model-00001-of-00004.bin",
353
+ "vision_tower.pixel_encoder.conv_in.bias": "pytorch_model-00001-of-00004.bin",
354
+ "vision_tower.pixel_encoder.conv_in.weight": "pytorch_model-00001-of-00004.bin",
355
+ "vision_tower.pixel_encoder.conv_out.bias": "pytorch_model-00001-of-00004.bin",
356
+ "vision_tower.pixel_encoder.conv_out.weight": "pytorch_model-00001-of-00004.bin",
357
+ "vision_tower.pixel_encoder.down.0.block.0.conv1.bias": "pytorch_model-00001-of-00004.bin",
358
+ "vision_tower.pixel_encoder.down.0.block.0.conv1.weight": "pytorch_model-00001-of-00004.bin",
359
+ "vision_tower.pixel_encoder.down.0.block.0.conv2.bias": "pytorch_model-00001-of-00004.bin",
360
+ "vision_tower.pixel_encoder.down.0.block.0.conv2.weight": "pytorch_model-00001-of-00004.bin",
361
+ "vision_tower.pixel_encoder.down.0.block.0.norm1.bias": "pytorch_model-00001-of-00004.bin",
362
+ "vision_tower.pixel_encoder.down.0.block.0.norm1.weight": "pytorch_model-00001-of-00004.bin",
363
+ "vision_tower.pixel_encoder.down.0.block.0.norm2.bias": "pytorch_model-00001-of-00004.bin",
364
+ "vision_tower.pixel_encoder.down.0.block.0.norm2.weight": "pytorch_model-00001-of-00004.bin",
365
+ "vision_tower.pixel_encoder.down.0.block.1.conv1.bias": "pytorch_model-00001-of-00004.bin",
366
+ "vision_tower.pixel_encoder.down.0.block.1.conv1.weight": "pytorch_model-00001-of-00004.bin",
367
+ "vision_tower.pixel_encoder.down.0.block.1.conv2.bias": "pytorch_model-00001-of-00004.bin",
368
+ "vision_tower.pixel_encoder.down.0.block.1.conv2.weight": "pytorch_model-00001-of-00004.bin",
369
+ "vision_tower.pixel_encoder.down.0.block.1.norm1.bias": "pytorch_model-00001-of-00004.bin",
370
+ "vision_tower.pixel_encoder.down.0.block.1.norm1.weight": "pytorch_model-00001-of-00004.bin",
371
+ "vision_tower.pixel_encoder.down.0.block.1.norm2.bias": "pytorch_model-00001-of-00004.bin",
372
+ "vision_tower.pixel_encoder.down.0.block.1.norm2.weight": "pytorch_model-00001-of-00004.bin",
373
+ "vision_tower.pixel_encoder.down.0.downsample.conv.bias": "pytorch_model-00001-of-00004.bin",
374
+ "vision_tower.pixel_encoder.down.0.downsample.conv.weight": "pytorch_model-00001-of-00004.bin",
375
+ "vision_tower.pixel_encoder.down.1.block.0.conv1.bias": "pytorch_model-00001-of-00004.bin",
376
+ "vision_tower.pixel_encoder.down.1.block.0.conv1.weight": "pytorch_model-00001-of-00004.bin",
377
+ "vision_tower.pixel_encoder.down.1.block.0.conv2.bias": "pytorch_model-00001-of-00004.bin",
378
+ "vision_tower.pixel_encoder.down.1.block.0.conv2.weight": "pytorch_model-00001-of-00004.bin",
379
+ "vision_tower.pixel_encoder.down.1.block.0.norm1.bias": "pytorch_model-00001-of-00004.bin",
380
+ "vision_tower.pixel_encoder.down.1.block.0.norm1.weight": "pytorch_model-00001-of-00004.bin",
381
+ "vision_tower.pixel_encoder.down.1.block.0.norm2.bias": "pytorch_model-00001-of-00004.bin",
382
+ "vision_tower.pixel_encoder.down.1.block.0.norm2.weight": "pytorch_model-00001-of-00004.bin",
383
+ "vision_tower.pixel_encoder.down.1.block.1.conv1.bias": "pytorch_model-00001-of-00004.bin",
384
+ "vision_tower.pixel_encoder.down.1.block.1.conv1.weight": "pytorch_model-00001-of-00004.bin",
385
+ "vision_tower.pixel_encoder.down.1.block.1.conv2.bias": "pytorch_model-00001-of-00004.bin",
386
+ "vision_tower.pixel_encoder.down.1.block.1.conv2.weight": "pytorch_model-00001-of-00004.bin",
387
+ "vision_tower.pixel_encoder.down.1.block.1.norm1.bias": "pytorch_model-00001-of-00004.bin",
388
+ "vision_tower.pixel_encoder.down.1.block.1.norm1.weight": "pytorch_model-00001-of-00004.bin",
389
+ "vision_tower.pixel_encoder.down.1.block.1.norm2.bias": "pytorch_model-00001-of-00004.bin",
390
+ "vision_tower.pixel_encoder.down.1.block.1.norm2.weight": "pytorch_model-00001-of-00004.bin",
391
+ "vision_tower.pixel_encoder.down.1.downsample.conv.bias": "pytorch_model-00001-of-00004.bin",
392
+ "vision_tower.pixel_encoder.down.1.downsample.conv.weight": "pytorch_model-00001-of-00004.bin",
393
+ "vision_tower.pixel_encoder.down.2.block.0.conv1.bias": "pytorch_model-00001-of-00004.bin",
394
+ "vision_tower.pixel_encoder.down.2.block.0.conv1.weight": "pytorch_model-00001-of-00004.bin",
395
+ "vision_tower.pixel_encoder.down.2.block.0.conv2.bias": "pytorch_model-00001-of-00004.bin",
396
+ "vision_tower.pixel_encoder.down.2.block.0.conv2.weight": "pytorch_model-00001-of-00004.bin",
397
+ "vision_tower.pixel_encoder.down.2.block.0.nin_shortcut.bias": "pytorch_model-00001-of-00004.bin",
398
+ "vision_tower.pixel_encoder.down.2.block.0.nin_shortcut.weight": "pytorch_model-00001-of-00004.bin",
399
+ "vision_tower.pixel_encoder.down.2.block.0.norm1.bias": "pytorch_model-00001-of-00004.bin",
400
+ "vision_tower.pixel_encoder.down.2.block.0.norm1.weight": "pytorch_model-00001-of-00004.bin",
401
+ "vision_tower.pixel_encoder.down.2.block.0.norm2.bias": "pytorch_model-00001-of-00004.bin",
402
+ "vision_tower.pixel_encoder.down.2.block.0.norm2.weight": "pytorch_model-00001-of-00004.bin",
403
+ "vision_tower.pixel_encoder.down.2.block.1.conv1.bias": "pytorch_model-00001-of-00004.bin",
404
+ "vision_tower.pixel_encoder.down.2.block.1.conv1.weight": "pytorch_model-00001-of-00004.bin",
405
+ "vision_tower.pixel_encoder.down.2.block.1.conv2.bias": "pytorch_model-00001-of-00004.bin",
406
+ "vision_tower.pixel_encoder.down.2.block.1.conv2.weight": "pytorch_model-00001-of-00004.bin",
407
+ "vision_tower.pixel_encoder.down.2.block.1.norm1.bias": "pytorch_model-00001-of-00004.bin",
408
+ "vision_tower.pixel_encoder.down.2.block.1.norm1.weight": "pytorch_model-00001-of-00004.bin",
409
+ "vision_tower.pixel_encoder.down.2.block.1.norm2.bias": "pytorch_model-00001-of-00004.bin",
410
+ "vision_tower.pixel_encoder.down.2.block.1.norm2.weight": "pytorch_model-00001-of-00004.bin",
411
+ "vision_tower.pixel_encoder.down.2.downsample.conv.bias": "pytorch_model-00001-of-00004.bin",
412
+ "vision_tower.pixel_encoder.down.2.downsample.conv.weight": "pytorch_model-00001-of-00004.bin",
413
+ "vision_tower.pixel_encoder.down.3.block.0.conv1.bias": "pytorch_model-00001-of-00004.bin",
414
+ "vision_tower.pixel_encoder.down.3.block.0.conv1.weight": "pytorch_model-00001-of-00004.bin",
415
+ "vision_tower.pixel_encoder.down.3.block.0.conv2.bias": "pytorch_model-00001-of-00004.bin",
416
+ "vision_tower.pixel_encoder.down.3.block.0.conv2.weight": "pytorch_model-00001-of-00004.bin",
417
+ "vision_tower.pixel_encoder.down.3.block.0.norm1.bias": "pytorch_model-00001-of-00004.bin",
418
+ "vision_tower.pixel_encoder.down.3.block.0.norm1.weight": "pytorch_model-00001-of-00004.bin",
419
+ "vision_tower.pixel_encoder.down.3.block.0.norm2.bias": "pytorch_model-00001-of-00004.bin",
420
+ "vision_tower.pixel_encoder.down.3.block.0.norm2.weight": "pytorch_model-00001-of-00004.bin",
421
+ "vision_tower.pixel_encoder.down.3.block.1.conv1.bias": "pytorch_model-00001-of-00004.bin",
422
+ "vision_tower.pixel_encoder.down.3.block.1.conv1.weight": "pytorch_model-00001-of-00004.bin",
423
+ "vision_tower.pixel_encoder.down.3.block.1.conv2.bias": "pytorch_model-00001-of-00004.bin",
424
+ "vision_tower.pixel_encoder.down.3.block.1.conv2.weight": "pytorch_model-00001-of-00004.bin",
425
+ "vision_tower.pixel_encoder.down.3.block.1.norm1.bias": "pytorch_model-00001-of-00004.bin",
426
+ "vision_tower.pixel_encoder.down.3.block.1.norm1.weight": "pytorch_model-00001-of-00004.bin",
427
+ "vision_tower.pixel_encoder.down.3.block.1.norm2.bias": "pytorch_model-00001-of-00004.bin",
428
+ "vision_tower.pixel_encoder.down.3.block.1.norm2.weight": "pytorch_model-00001-of-00004.bin",
429
+ "vision_tower.pixel_encoder.down.3.downsample.conv.bias": "pytorch_model-00001-of-00004.bin",
430
+ "vision_tower.pixel_encoder.down.3.downsample.conv.weight": "pytorch_model-00001-of-00004.bin",
431
+ "vision_tower.pixel_encoder.down.4.attn.0.k.bias": "pytorch_model-00001-of-00004.bin",
432
+ "vision_tower.pixel_encoder.down.4.attn.0.k.weight": "pytorch_model-00001-of-00004.bin",
433
+ "vision_tower.pixel_encoder.down.4.attn.0.norm.bias": "pytorch_model-00001-of-00004.bin",
434
+ "vision_tower.pixel_encoder.down.4.attn.0.norm.weight": "pytorch_model-00001-of-00004.bin",
435
+ "vision_tower.pixel_encoder.down.4.attn.0.proj_out.bias": "pytorch_model-00001-of-00004.bin",
436
+ "vision_tower.pixel_encoder.down.4.attn.0.proj_out.weight": "pytorch_model-00001-of-00004.bin",
437
+ "vision_tower.pixel_encoder.down.4.attn.0.q.bias": "pytorch_model-00001-of-00004.bin",
438
+ "vision_tower.pixel_encoder.down.4.attn.0.q.weight": "pytorch_model-00001-of-00004.bin",
439
+ "vision_tower.pixel_encoder.down.4.attn.0.v.bias": "pytorch_model-00001-of-00004.bin",
440
+ "vision_tower.pixel_encoder.down.4.attn.0.v.weight": "pytorch_model-00001-of-00004.bin",
441
+ "vision_tower.pixel_encoder.down.4.attn.1.k.bias": "pytorch_model-00001-of-00004.bin",
442
+ "vision_tower.pixel_encoder.down.4.attn.1.k.weight": "pytorch_model-00001-of-00004.bin",
443
+ "vision_tower.pixel_encoder.down.4.attn.1.norm.bias": "pytorch_model-00001-of-00004.bin",
444
+ "vision_tower.pixel_encoder.down.4.attn.1.norm.weight": "pytorch_model-00001-of-00004.bin",
445
+ "vision_tower.pixel_encoder.down.4.attn.1.proj_out.bias": "pytorch_model-00001-of-00004.bin",
446
+ "vision_tower.pixel_encoder.down.4.attn.1.proj_out.weight": "pytorch_model-00001-of-00004.bin",
447
+ "vision_tower.pixel_encoder.down.4.attn.1.q.bias": "pytorch_model-00001-of-00004.bin",
448
+ "vision_tower.pixel_encoder.down.4.attn.1.q.weight": "pytorch_model-00001-of-00004.bin",
449
+ "vision_tower.pixel_encoder.down.4.attn.1.v.bias": "pytorch_model-00001-of-00004.bin",
450
+ "vision_tower.pixel_encoder.down.4.attn.1.v.weight": "pytorch_model-00001-of-00004.bin",
451
+ "vision_tower.pixel_encoder.down.4.block.0.conv1.bias": "pytorch_model-00001-of-00004.bin",
452
+ "vision_tower.pixel_encoder.down.4.block.0.conv1.weight": "pytorch_model-00001-of-00004.bin",
453
+ "vision_tower.pixel_encoder.down.4.block.0.conv2.bias": "pytorch_model-00001-of-00004.bin",
454
+ "vision_tower.pixel_encoder.down.4.block.0.conv2.weight": "pytorch_model-00001-of-00004.bin",
455
+ "vision_tower.pixel_encoder.down.4.block.0.nin_shortcut.bias": "pytorch_model-00001-of-00004.bin",
456
+ "vision_tower.pixel_encoder.down.4.block.0.nin_shortcut.weight": "pytorch_model-00001-of-00004.bin",
457
+ "vision_tower.pixel_encoder.down.4.block.0.norm1.bias": "pytorch_model-00001-of-00004.bin",
458
+ "vision_tower.pixel_encoder.down.4.block.0.norm1.weight": "pytorch_model-00001-of-00004.bin",
459
+ "vision_tower.pixel_encoder.down.4.block.0.norm2.bias": "pytorch_model-00001-of-00004.bin",
460
+ "vision_tower.pixel_encoder.down.4.block.0.norm2.weight": "pytorch_model-00001-of-00004.bin",
461
+ "vision_tower.pixel_encoder.down.4.block.1.conv1.bias": "pytorch_model-00001-of-00004.bin",
462
+ "vision_tower.pixel_encoder.down.4.block.1.conv1.weight": "pytorch_model-00001-of-00004.bin",
463
+ "vision_tower.pixel_encoder.down.4.block.1.conv2.bias": "pytorch_model-00001-of-00004.bin",
464
+ "vision_tower.pixel_encoder.down.4.block.1.conv2.weight": "pytorch_model-00001-of-00004.bin",
465
+ "vision_tower.pixel_encoder.down.4.block.1.norm1.bias": "pytorch_model-00001-of-00004.bin",
466
+ "vision_tower.pixel_encoder.down.4.block.1.norm1.weight": "pytorch_model-00001-of-00004.bin",
467
+ "vision_tower.pixel_encoder.down.4.block.1.norm2.bias": "pytorch_model-00001-of-00004.bin",
468
+ "vision_tower.pixel_encoder.down.4.block.1.norm2.weight": "pytorch_model-00001-of-00004.bin",
469
+ "vision_tower.pixel_encoder.mid.attn_1.k.bias": "pytorch_model-00001-of-00004.bin",
470
+ "vision_tower.pixel_encoder.mid.attn_1.k.weight": "pytorch_model-00001-of-00004.bin",
471
+ "vision_tower.pixel_encoder.mid.attn_1.norm.bias": "pytorch_model-00001-of-00004.bin",
472
+ "vision_tower.pixel_encoder.mid.attn_1.norm.weight": "pytorch_model-00001-of-00004.bin",
473
+ "vision_tower.pixel_encoder.mid.attn_1.proj_out.bias": "pytorch_model-00001-of-00004.bin",
474
+ "vision_tower.pixel_encoder.mid.attn_1.proj_out.weight": "pytorch_model-00001-of-00004.bin",
475
+ "vision_tower.pixel_encoder.mid.attn_1.q.bias": "pytorch_model-00001-of-00004.bin",
476
+ "vision_tower.pixel_encoder.mid.attn_1.q.weight": "pytorch_model-00001-of-00004.bin",
477
+ "vision_tower.pixel_encoder.mid.attn_1.v.bias": "pytorch_model-00001-of-00004.bin",
478
+ "vision_tower.pixel_encoder.mid.attn_1.v.weight": "pytorch_model-00001-of-00004.bin",
479
+ "vision_tower.pixel_encoder.mid.block_1.conv1.bias": "pytorch_model-00001-of-00004.bin",
480
+ "vision_tower.pixel_encoder.mid.block_1.conv1.weight": "pytorch_model-00001-of-00004.bin",
481
+ "vision_tower.pixel_encoder.mid.block_1.conv2.bias": "pytorch_model-00001-of-00004.bin",
482
+ "vision_tower.pixel_encoder.mid.block_1.conv2.weight": "pytorch_model-00001-of-00004.bin",
483
+ "vision_tower.pixel_encoder.mid.block_1.norm1.bias": "pytorch_model-00001-of-00004.bin",
484
+ "vision_tower.pixel_encoder.mid.block_1.norm1.weight": "pytorch_model-00001-of-00004.bin",
485
+ "vision_tower.pixel_encoder.mid.block_1.norm2.bias": "pytorch_model-00001-of-00004.bin",
486
+ "vision_tower.pixel_encoder.mid.block_1.norm2.weight": "pytorch_model-00001-of-00004.bin",
487
+ "vision_tower.pixel_encoder.mid.block_2.conv1.bias": "pytorch_model-00001-of-00004.bin",
488
+ "vision_tower.pixel_encoder.mid.block_2.conv1.weight": "pytorch_model-00001-of-00004.bin",
489
+ "vision_tower.pixel_encoder.mid.block_2.conv2.bias": "pytorch_model-00001-of-00004.bin",
490
+ "vision_tower.pixel_encoder.mid.block_2.conv2.weight": "pytorch_model-00001-of-00004.bin",
491
+ "vision_tower.pixel_encoder.mid.block_2.norm1.bias": "pytorch_model-00001-of-00004.bin",
492
+ "vision_tower.pixel_encoder.mid.block_2.norm1.weight": "pytorch_model-00001-of-00004.bin",
493
+ "vision_tower.pixel_encoder.mid.block_2.norm2.bias": "pytorch_model-00001-of-00004.bin",
494
+ "vision_tower.pixel_encoder.mid.block_2.norm2.weight": "pytorch_model-00001-of-00004.bin",
495
+ "vision_tower.pixel_encoder.norm_out.bias": "pytorch_model-00001-of-00004.bin",
496
+ "vision_tower.pixel_encoder.norm_out.weight": "pytorch_model-00001-of-00004.bin",
497
+ "vision_tower.semantic_encoder.blocks.0.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
498
+ "vision_tower.semantic_encoder.blocks.0.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
499
+ "vision_tower.semantic_encoder.blocks.0.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
500
+ "vision_tower.semantic_encoder.blocks.0.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
501
+ "vision_tower.semantic_encoder.blocks.0.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
502
+ "vision_tower.semantic_encoder.blocks.0.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
503
+ "vision_tower.semantic_encoder.blocks.0.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
504
+ "vision_tower.semantic_encoder.blocks.0.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
505
+ "vision_tower.semantic_encoder.blocks.0.norm1.bias": "pytorch_model-00001-of-00004.bin",
506
+ "vision_tower.semantic_encoder.blocks.0.norm1.weight": "pytorch_model-00001-of-00004.bin",
507
+ "vision_tower.semantic_encoder.blocks.0.norm2.bias": "pytorch_model-00001-of-00004.bin",
508
+ "vision_tower.semantic_encoder.blocks.0.norm2.weight": "pytorch_model-00001-of-00004.bin",
509
+ "vision_tower.semantic_encoder.blocks.1.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
510
+ "vision_tower.semantic_encoder.blocks.1.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
511
+ "vision_tower.semantic_encoder.blocks.1.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
512
+ "vision_tower.semantic_encoder.blocks.1.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
513
+ "vision_tower.semantic_encoder.blocks.1.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
514
+ "vision_tower.semantic_encoder.blocks.1.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
515
+ "vision_tower.semantic_encoder.blocks.1.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
516
+ "vision_tower.semantic_encoder.blocks.1.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
517
+ "vision_tower.semantic_encoder.blocks.1.norm1.bias": "pytorch_model-00001-of-00004.bin",
518
+ "vision_tower.semantic_encoder.blocks.1.norm1.weight": "pytorch_model-00001-of-00004.bin",
519
+ "vision_tower.semantic_encoder.blocks.1.norm2.bias": "pytorch_model-00001-of-00004.bin",
520
+ "vision_tower.semantic_encoder.blocks.1.norm2.weight": "pytorch_model-00001-of-00004.bin",
521
+ "vision_tower.semantic_encoder.blocks.10.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
522
+ "vision_tower.semantic_encoder.blocks.10.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
523
+ "vision_tower.semantic_encoder.blocks.10.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
524
+ "vision_tower.semantic_encoder.blocks.10.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
525
+ "vision_tower.semantic_encoder.blocks.10.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
526
+ "vision_tower.semantic_encoder.blocks.10.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
527
+ "vision_tower.semantic_encoder.blocks.10.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
528
+ "vision_tower.semantic_encoder.blocks.10.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
529
+ "vision_tower.semantic_encoder.blocks.10.norm1.bias": "pytorch_model-00001-of-00004.bin",
530
+ "vision_tower.semantic_encoder.blocks.10.norm1.weight": "pytorch_model-00001-of-00004.bin",
531
+ "vision_tower.semantic_encoder.blocks.10.norm2.bias": "pytorch_model-00001-of-00004.bin",
532
+ "vision_tower.semantic_encoder.blocks.10.norm2.weight": "pytorch_model-00001-of-00004.bin",
533
+ "vision_tower.semantic_encoder.blocks.11.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
534
+ "vision_tower.semantic_encoder.blocks.11.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
535
+ "vision_tower.semantic_encoder.blocks.11.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
536
+ "vision_tower.semantic_encoder.blocks.11.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
537
+ "vision_tower.semantic_encoder.blocks.11.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
538
+ "vision_tower.semantic_encoder.blocks.11.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
539
+ "vision_tower.semantic_encoder.blocks.11.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
540
+ "vision_tower.semantic_encoder.blocks.11.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
541
+ "vision_tower.semantic_encoder.blocks.11.norm1.bias": "pytorch_model-00001-of-00004.bin",
542
+ "vision_tower.semantic_encoder.blocks.11.norm1.weight": "pytorch_model-00001-of-00004.bin",
543
+ "vision_tower.semantic_encoder.blocks.11.norm2.bias": "pytorch_model-00001-of-00004.bin",
544
+ "vision_tower.semantic_encoder.blocks.11.norm2.weight": "pytorch_model-00001-of-00004.bin",
545
+ "vision_tower.semantic_encoder.blocks.12.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
546
+ "vision_tower.semantic_encoder.blocks.12.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
547
+ "vision_tower.semantic_encoder.blocks.12.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
548
+ "vision_tower.semantic_encoder.blocks.12.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
549
+ "vision_tower.semantic_encoder.blocks.12.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
550
+ "vision_tower.semantic_encoder.blocks.12.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
551
+ "vision_tower.semantic_encoder.blocks.12.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
552
+ "vision_tower.semantic_encoder.blocks.12.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
553
+ "vision_tower.semantic_encoder.blocks.12.norm1.bias": "pytorch_model-00001-of-00004.bin",
554
+ "vision_tower.semantic_encoder.blocks.12.norm1.weight": "pytorch_model-00001-of-00004.bin",
555
+ "vision_tower.semantic_encoder.blocks.12.norm2.bias": "pytorch_model-00001-of-00004.bin",
556
+ "vision_tower.semantic_encoder.blocks.12.norm2.weight": "pytorch_model-00001-of-00004.bin",
557
+ "vision_tower.semantic_encoder.blocks.13.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
558
+ "vision_tower.semantic_encoder.blocks.13.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
559
+ "vision_tower.semantic_encoder.blocks.13.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
560
+ "vision_tower.semantic_encoder.blocks.13.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
561
+ "vision_tower.semantic_encoder.blocks.13.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
562
+ "vision_tower.semantic_encoder.blocks.13.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
563
+ "vision_tower.semantic_encoder.blocks.13.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
564
+ "vision_tower.semantic_encoder.blocks.13.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
565
+ "vision_tower.semantic_encoder.blocks.13.norm1.bias": "pytorch_model-00001-of-00004.bin",
566
+ "vision_tower.semantic_encoder.blocks.13.norm1.weight": "pytorch_model-00001-of-00004.bin",
567
+ "vision_tower.semantic_encoder.blocks.13.norm2.bias": "pytorch_model-00001-of-00004.bin",
568
+ "vision_tower.semantic_encoder.blocks.13.norm2.weight": "pytorch_model-00001-of-00004.bin",
569
+ "vision_tower.semantic_encoder.blocks.14.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
570
+ "vision_tower.semantic_encoder.blocks.14.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
571
+ "vision_tower.semantic_encoder.blocks.14.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
572
+ "vision_tower.semantic_encoder.blocks.14.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
573
+ "vision_tower.semantic_encoder.blocks.14.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
574
+ "vision_tower.semantic_encoder.blocks.14.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
575
+ "vision_tower.semantic_encoder.blocks.14.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
576
+ "vision_tower.semantic_encoder.blocks.14.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
577
+ "vision_tower.semantic_encoder.blocks.14.norm1.bias": "pytorch_model-00001-of-00004.bin",
578
+ "vision_tower.semantic_encoder.blocks.14.norm1.weight": "pytorch_model-00001-of-00004.bin",
579
+ "vision_tower.semantic_encoder.blocks.14.norm2.bias": "pytorch_model-00001-of-00004.bin",
580
+ "vision_tower.semantic_encoder.blocks.14.norm2.weight": "pytorch_model-00001-of-00004.bin",
581
+ "vision_tower.semantic_encoder.blocks.15.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
582
+ "vision_tower.semantic_encoder.blocks.15.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
583
+ "vision_tower.semantic_encoder.blocks.15.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
584
+ "vision_tower.semantic_encoder.blocks.15.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
585
+ "vision_tower.semantic_encoder.blocks.15.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
586
+ "vision_tower.semantic_encoder.blocks.15.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
587
+ "vision_tower.semantic_encoder.blocks.15.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
588
+ "vision_tower.semantic_encoder.blocks.15.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
589
+ "vision_tower.semantic_encoder.blocks.15.norm1.bias": "pytorch_model-00001-of-00004.bin",
590
+ "vision_tower.semantic_encoder.blocks.15.norm1.weight": "pytorch_model-00001-of-00004.bin",
591
+ "vision_tower.semantic_encoder.blocks.15.norm2.bias": "pytorch_model-00001-of-00004.bin",
592
+ "vision_tower.semantic_encoder.blocks.15.norm2.weight": "pytorch_model-00001-of-00004.bin",
593
+ "vision_tower.semantic_encoder.blocks.16.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
594
+ "vision_tower.semantic_encoder.blocks.16.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
595
+ "vision_tower.semantic_encoder.blocks.16.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
596
+ "vision_tower.semantic_encoder.blocks.16.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
597
+ "vision_tower.semantic_encoder.blocks.16.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
598
+ "vision_tower.semantic_encoder.blocks.16.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
599
+ "vision_tower.semantic_encoder.blocks.16.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
600
+ "vision_tower.semantic_encoder.blocks.16.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
601
+ "vision_tower.semantic_encoder.blocks.16.norm1.bias": "pytorch_model-00001-of-00004.bin",
602
+ "vision_tower.semantic_encoder.blocks.16.norm1.weight": "pytorch_model-00001-of-00004.bin",
603
+ "vision_tower.semantic_encoder.blocks.16.norm2.bias": "pytorch_model-00001-of-00004.bin",
604
+ "vision_tower.semantic_encoder.blocks.16.norm2.weight": "pytorch_model-00001-of-00004.bin",
605
+ "vision_tower.semantic_encoder.blocks.17.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
606
+ "vision_tower.semantic_encoder.blocks.17.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
607
+ "vision_tower.semantic_encoder.blocks.17.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
608
+ "vision_tower.semantic_encoder.blocks.17.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
609
+ "vision_tower.semantic_encoder.blocks.17.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
610
+ "vision_tower.semantic_encoder.blocks.17.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
611
+ "vision_tower.semantic_encoder.blocks.17.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
612
+ "vision_tower.semantic_encoder.blocks.17.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
613
+ "vision_tower.semantic_encoder.blocks.17.norm1.bias": "pytorch_model-00001-of-00004.bin",
614
+ "vision_tower.semantic_encoder.blocks.17.norm1.weight": "pytorch_model-00001-of-00004.bin",
615
+ "vision_tower.semantic_encoder.blocks.17.norm2.bias": "pytorch_model-00001-of-00004.bin",
616
+ "vision_tower.semantic_encoder.blocks.17.norm2.weight": "pytorch_model-00001-of-00004.bin",
617
+ "vision_tower.semantic_encoder.blocks.18.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
618
+ "vision_tower.semantic_encoder.blocks.18.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
619
+ "vision_tower.semantic_encoder.blocks.18.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
620
+ "vision_tower.semantic_encoder.blocks.18.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
621
+ "vision_tower.semantic_encoder.blocks.18.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
622
+ "vision_tower.semantic_encoder.blocks.18.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
623
+ "vision_tower.semantic_encoder.blocks.18.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
624
+ "vision_tower.semantic_encoder.blocks.18.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
625
+ "vision_tower.semantic_encoder.blocks.18.norm1.bias": "pytorch_model-00001-of-00004.bin",
626
+ "vision_tower.semantic_encoder.blocks.18.norm1.weight": "pytorch_model-00001-of-00004.bin",
627
+ "vision_tower.semantic_encoder.blocks.18.norm2.bias": "pytorch_model-00001-of-00004.bin",
628
+ "vision_tower.semantic_encoder.blocks.18.norm2.weight": "pytorch_model-00001-of-00004.bin",
629
+ "vision_tower.semantic_encoder.blocks.19.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
630
+ "vision_tower.semantic_encoder.blocks.19.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
631
+ "vision_tower.semantic_encoder.blocks.19.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
632
+ "vision_tower.semantic_encoder.blocks.19.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
633
+ "vision_tower.semantic_encoder.blocks.19.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
634
+ "vision_tower.semantic_encoder.blocks.19.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
635
+ "vision_tower.semantic_encoder.blocks.19.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
636
+ "vision_tower.semantic_encoder.blocks.19.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
637
+ "vision_tower.semantic_encoder.blocks.19.norm1.bias": "pytorch_model-00001-of-00004.bin",
638
+ "vision_tower.semantic_encoder.blocks.19.norm1.weight": "pytorch_model-00001-of-00004.bin",
639
+ "vision_tower.semantic_encoder.blocks.19.norm2.bias": "pytorch_model-00001-of-00004.bin",
640
+ "vision_tower.semantic_encoder.blocks.19.norm2.weight": "pytorch_model-00001-of-00004.bin",
641
+ "vision_tower.semantic_encoder.blocks.2.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
642
+ "vision_tower.semantic_encoder.blocks.2.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
643
+ "vision_tower.semantic_encoder.blocks.2.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
644
+ "vision_tower.semantic_encoder.blocks.2.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
645
+ "vision_tower.semantic_encoder.blocks.2.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
646
+ "vision_tower.semantic_encoder.blocks.2.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
647
+ "vision_tower.semantic_encoder.blocks.2.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
648
+ "vision_tower.semantic_encoder.blocks.2.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
649
+ "vision_tower.semantic_encoder.blocks.2.norm1.bias": "pytorch_model-00001-of-00004.bin",
650
+ "vision_tower.semantic_encoder.blocks.2.norm1.weight": "pytorch_model-00001-of-00004.bin",
651
+ "vision_tower.semantic_encoder.blocks.2.norm2.bias": "pytorch_model-00001-of-00004.bin",
652
+ "vision_tower.semantic_encoder.blocks.2.norm2.weight": "pytorch_model-00001-of-00004.bin",
653
+ "vision_tower.semantic_encoder.blocks.20.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
654
+ "vision_tower.semantic_encoder.blocks.20.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
655
+ "vision_tower.semantic_encoder.blocks.20.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
656
+ "vision_tower.semantic_encoder.blocks.20.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
657
+ "vision_tower.semantic_encoder.blocks.20.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
658
+ "vision_tower.semantic_encoder.blocks.20.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
659
+ "vision_tower.semantic_encoder.blocks.20.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
660
+ "vision_tower.semantic_encoder.blocks.20.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
661
+ "vision_tower.semantic_encoder.blocks.20.norm1.bias": "pytorch_model-00001-of-00004.bin",
662
+ "vision_tower.semantic_encoder.blocks.20.norm1.weight": "pytorch_model-00001-of-00004.bin",
663
+ "vision_tower.semantic_encoder.blocks.20.norm2.bias": "pytorch_model-00001-of-00004.bin",
664
+ "vision_tower.semantic_encoder.blocks.20.norm2.weight": "pytorch_model-00001-of-00004.bin",
665
+ "vision_tower.semantic_encoder.blocks.21.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
666
+ "vision_tower.semantic_encoder.blocks.21.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
667
+ "vision_tower.semantic_encoder.blocks.21.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
668
+ "vision_tower.semantic_encoder.blocks.21.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
669
+ "vision_tower.semantic_encoder.blocks.21.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
670
+ "vision_tower.semantic_encoder.blocks.21.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
671
+ "vision_tower.semantic_encoder.blocks.21.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
672
+ "vision_tower.semantic_encoder.blocks.21.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
673
+ "vision_tower.semantic_encoder.blocks.21.norm1.bias": "pytorch_model-00001-of-00004.bin",
674
+ "vision_tower.semantic_encoder.blocks.21.norm1.weight": "pytorch_model-00001-of-00004.bin",
675
+ "vision_tower.semantic_encoder.blocks.21.norm2.bias": "pytorch_model-00001-of-00004.bin",
676
+ "vision_tower.semantic_encoder.blocks.21.norm2.weight": "pytorch_model-00001-of-00004.bin",
677
+ "vision_tower.semantic_encoder.blocks.22.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
678
+ "vision_tower.semantic_encoder.blocks.22.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
679
+ "vision_tower.semantic_encoder.blocks.22.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
680
+ "vision_tower.semantic_encoder.blocks.22.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
681
+ "vision_tower.semantic_encoder.blocks.22.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
682
+ "vision_tower.semantic_encoder.blocks.22.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
683
+ "vision_tower.semantic_encoder.blocks.22.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
684
+ "vision_tower.semantic_encoder.blocks.22.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
685
+ "vision_tower.semantic_encoder.blocks.22.norm1.bias": "pytorch_model-00001-of-00004.bin",
686
+ "vision_tower.semantic_encoder.blocks.22.norm1.weight": "pytorch_model-00001-of-00004.bin",
687
+ "vision_tower.semantic_encoder.blocks.22.norm2.bias": "pytorch_model-00001-of-00004.bin",
688
+ "vision_tower.semantic_encoder.blocks.22.norm2.weight": "pytorch_model-00001-of-00004.bin",
689
+ "vision_tower.semantic_encoder.blocks.23.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
690
+ "vision_tower.semantic_encoder.blocks.23.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
691
+ "vision_tower.semantic_encoder.blocks.23.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
692
+ "vision_tower.semantic_encoder.blocks.23.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
693
+ "vision_tower.semantic_encoder.blocks.23.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
694
+ "vision_tower.semantic_encoder.blocks.23.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
695
+ "vision_tower.semantic_encoder.blocks.23.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
696
+ "vision_tower.semantic_encoder.blocks.23.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
697
+ "vision_tower.semantic_encoder.blocks.23.norm1.bias": "pytorch_model-00001-of-00004.bin",
698
+ "vision_tower.semantic_encoder.blocks.23.norm1.weight": "pytorch_model-00001-of-00004.bin",
699
+ "vision_tower.semantic_encoder.blocks.23.norm2.bias": "pytorch_model-00001-of-00004.bin",
700
+ "vision_tower.semantic_encoder.blocks.23.norm2.weight": "pytorch_model-00001-of-00004.bin",
701
+ "vision_tower.semantic_encoder.blocks.24.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
702
+ "vision_tower.semantic_encoder.blocks.24.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
703
+ "vision_tower.semantic_encoder.blocks.24.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
704
+ "vision_tower.semantic_encoder.blocks.24.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
705
+ "vision_tower.semantic_encoder.blocks.24.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
706
+ "vision_tower.semantic_encoder.blocks.24.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
707
+ "vision_tower.semantic_encoder.blocks.24.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
708
+ "vision_tower.semantic_encoder.blocks.24.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
709
+ "vision_tower.semantic_encoder.blocks.24.norm1.bias": "pytorch_model-00001-of-00004.bin",
710
+ "vision_tower.semantic_encoder.blocks.24.norm1.weight": "pytorch_model-00001-of-00004.bin",
711
+ "vision_tower.semantic_encoder.blocks.24.norm2.bias": "pytorch_model-00001-of-00004.bin",
712
+ "vision_tower.semantic_encoder.blocks.24.norm2.weight": "pytorch_model-00001-of-00004.bin",
713
+ "vision_tower.semantic_encoder.blocks.25.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
714
+ "vision_tower.semantic_encoder.blocks.25.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
715
+ "vision_tower.semantic_encoder.blocks.25.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
716
+ "vision_tower.semantic_encoder.blocks.25.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
717
+ "vision_tower.semantic_encoder.blocks.25.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
718
+ "vision_tower.semantic_encoder.blocks.25.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
719
+ "vision_tower.semantic_encoder.blocks.25.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
720
+ "vision_tower.semantic_encoder.blocks.25.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
721
+ "vision_tower.semantic_encoder.blocks.25.norm1.bias": "pytorch_model-00001-of-00004.bin",
722
+ "vision_tower.semantic_encoder.blocks.25.norm1.weight": "pytorch_model-00001-of-00004.bin",
723
+ "vision_tower.semantic_encoder.blocks.25.norm2.bias": "pytorch_model-00001-of-00004.bin",
724
+ "vision_tower.semantic_encoder.blocks.25.norm2.weight": "pytorch_model-00001-of-00004.bin",
725
+ "vision_tower.semantic_encoder.blocks.26.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
726
+ "vision_tower.semantic_encoder.blocks.26.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
727
+ "vision_tower.semantic_encoder.blocks.26.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
728
+ "vision_tower.semantic_encoder.blocks.26.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
729
+ "vision_tower.semantic_encoder.blocks.26.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
730
+ "vision_tower.semantic_encoder.blocks.26.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
731
+ "vision_tower.semantic_encoder.blocks.26.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
732
+ "vision_tower.semantic_encoder.blocks.26.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
733
+ "vision_tower.semantic_encoder.blocks.26.norm1.bias": "pytorch_model-00001-of-00004.bin",
734
+ "vision_tower.semantic_encoder.blocks.26.norm1.weight": "pytorch_model-00001-of-00004.bin",
735
+ "vision_tower.semantic_encoder.blocks.26.norm2.bias": "pytorch_model-00001-of-00004.bin",
736
+ "vision_tower.semantic_encoder.blocks.26.norm2.weight": "pytorch_model-00001-of-00004.bin",
737
+ "vision_tower.semantic_encoder.blocks.27.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
738
+ "vision_tower.semantic_encoder.blocks.27.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
739
+ "vision_tower.semantic_encoder.blocks.27.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
740
+ "vision_tower.semantic_encoder.blocks.27.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
741
+ "vision_tower.semantic_encoder.blocks.27.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
742
+ "vision_tower.semantic_encoder.blocks.27.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
743
+ "vision_tower.semantic_encoder.blocks.27.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
744
+ "vision_tower.semantic_encoder.blocks.27.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
745
+ "vision_tower.semantic_encoder.blocks.27.norm1.bias": "pytorch_model-00001-of-00004.bin",
746
+ "vision_tower.semantic_encoder.blocks.27.norm1.weight": "pytorch_model-00001-of-00004.bin",
747
+ "vision_tower.semantic_encoder.blocks.27.norm2.bias": "pytorch_model-00001-of-00004.bin",
748
+ "vision_tower.semantic_encoder.blocks.27.norm2.weight": "pytorch_model-00001-of-00004.bin",
749
+ "vision_tower.semantic_encoder.blocks.28.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
750
+ "vision_tower.semantic_encoder.blocks.28.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
751
+ "vision_tower.semantic_encoder.blocks.28.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
752
+ "vision_tower.semantic_encoder.blocks.28.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
753
+ "vision_tower.semantic_encoder.blocks.28.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
754
+ "vision_tower.semantic_encoder.blocks.28.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
755
+ "vision_tower.semantic_encoder.blocks.28.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
756
+ "vision_tower.semantic_encoder.blocks.28.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
757
+ "vision_tower.semantic_encoder.blocks.28.norm1.bias": "pytorch_model-00001-of-00004.bin",
758
+ "vision_tower.semantic_encoder.blocks.28.norm1.weight": "pytorch_model-00001-of-00004.bin",
759
+ "vision_tower.semantic_encoder.blocks.28.norm2.bias": "pytorch_model-00001-of-00004.bin",
760
+ "vision_tower.semantic_encoder.blocks.28.norm2.weight": "pytorch_model-00001-of-00004.bin",
761
+ "vision_tower.semantic_encoder.blocks.29.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
762
+ "vision_tower.semantic_encoder.blocks.29.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
763
+ "vision_tower.semantic_encoder.blocks.29.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
764
+ "vision_tower.semantic_encoder.blocks.29.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
765
+ "vision_tower.semantic_encoder.blocks.29.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
766
+ "vision_tower.semantic_encoder.blocks.29.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
767
+ "vision_tower.semantic_encoder.blocks.29.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
768
+ "vision_tower.semantic_encoder.blocks.29.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
769
+ "vision_tower.semantic_encoder.blocks.29.norm1.bias": "pytorch_model-00001-of-00004.bin",
770
+ "vision_tower.semantic_encoder.blocks.29.norm1.weight": "pytorch_model-00001-of-00004.bin",
771
+ "vision_tower.semantic_encoder.blocks.29.norm2.bias": "pytorch_model-00001-of-00004.bin",
772
+ "vision_tower.semantic_encoder.blocks.29.norm2.weight": "pytorch_model-00001-of-00004.bin",
773
+ "vision_tower.semantic_encoder.blocks.3.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
774
+ "vision_tower.semantic_encoder.blocks.3.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
775
+ "vision_tower.semantic_encoder.blocks.3.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
776
+ "vision_tower.semantic_encoder.blocks.3.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
777
+ "vision_tower.semantic_encoder.blocks.3.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
778
+ "vision_tower.semantic_encoder.blocks.3.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
779
+ "vision_tower.semantic_encoder.blocks.3.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
780
+ "vision_tower.semantic_encoder.blocks.3.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
781
+ "vision_tower.semantic_encoder.blocks.3.norm1.bias": "pytorch_model-00001-of-00004.bin",
782
+ "vision_tower.semantic_encoder.blocks.3.norm1.weight": "pytorch_model-00001-of-00004.bin",
783
+ "vision_tower.semantic_encoder.blocks.3.norm2.bias": "pytorch_model-00001-of-00004.bin",
784
+ "vision_tower.semantic_encoder.blocks.3.norm2.weight": "pytorch_model-00001-of-00004.bin",
785
+ "vision_tower.semantic_encoder.blocks.30.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
786
+ "vision_tower.semantic_encoder.blocks.30.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
787
+ "vision_tower.semantic_encoder.blocks.30.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
788
+ "vision_tower.semantic_encoder.blocks.30.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
789
+ "vision_tower.semantic_encoder.blocks.30.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
790
+ "vision_tower.semantic_encoder.blocks.30.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
791
+ "vision_tower.semantic_encoder.blocks.30.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
792
+ "vision_tower.semantic_encoder.blocks.30.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
793
+ "vision_tower.semantic_encoder.blocks.30.norm1.bias": "pytorch_model-00001-of-00004.bin",
794
+ "vision_tower.semantic_encoder.blocks.30.norm1.weight": "pytorch_model-00001-of-00004.bin",
795
+ "vision_tower.semantic_encoder.blocks.30.norm2.bias": "pytorch_model-00001-of-00004.bin",
796
+ "vision_tower.semantic_encoder.blocks.30.norm2.weight": "pytorch_model-00001-of-00004.bin",
797
+ "vision_tower.semantic_encoder.blocks.31.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
798
+ "vision_tower.semantic_encoder.blocks.31.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
799
+ "vision_tower.semantic_encoder.blocks.31.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
800
+ "vision_tower.semantic_encoder.blocks.31.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
801
+ "vision_tower.semantic_encoder.blocks.31.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
802
+ "vision_tower.semantic_encoder.blocks.31.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
803
+ "vision_tower.semantic_encoder.blocks.31.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
804
+ "vision_tower.semantic_encoder.blocks.31.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
805
+ "vision_tower.semantic_encoder.blocks.31.norm1.bias": "pytorch_model-00001-of-00004.bin",
806
+ "vision_tower.semantic_encoder.blocks.31.norm1.weight": "pytorch_model-00001-of-00004.bin",
807
+ "vision_tower.semantic_encoder.blocks.31.norm2.bias": "pytorch_model-00001-of-00004.bin",
808
+ "vision_tower.semantic_encoder.blocks.31.norm2.weight": "pytorch_model-00001-of-00004.bin",
809
+ "vision_tower.semantic_encoder.blocks.4.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
810
+ "vision_tower.semantic_encoder.blocks.4.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
811
+ "vision_tower.semantic_encoder.blocks.4.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
812
+ "vision_tower.semantic_encoder.blocks.4.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
813
+ "vision_tower.semantic_encoder.blocks.4.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
814
+ "vision_tower.semantic_encoder.blocks.4.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
815
+ "vision_tower.semantic_encoder.blocks.4.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
816
+ "vision_tower.semantic_encoder.blocks.4.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
817
+ "vision_tower.semantic_encoder.blocks.4.norm1.bias": "pytorch_model-00001-of-00004.bin",
818
+ "vision_tower.semantic_encoder.blocks.4.norm1.weight": "pytorch_model-00001-of-00004.bin",
819
+ "vision_tower.semantic_encoder.blocks.4.norm2.bias": "pytorch_model-00001-of-00004.bin",
820
+ "vision_tower.semantic_encoder.blocks.4.norm2.weight": "pytorch_model-00001-of-00004.bin",
821
+ "vision_tower.semantic_encoder.blocks.5.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
822
+ "vision_tower.semantic_encoder.blocks.5.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
823
+ "vision_tower.semantic_encoder.blocks.5.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
824
+ "vision_tower.semantic_encoder.blocks.5.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
825
+ "vision_tower.semantic_encoder.blocks.5.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
826
+ "vision_tower.semantic_encoder.blocks.5.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
827
+ "vision_tower.semantic_encoder.blocks.5.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
828
+ "vision_tower.semantic_encoder.blocks.5.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
829
+ "vision_tower.semantic_encoder.blocks.5.norm1.bias": "pytorch_model-00001-of-00004.bin",
830
+ "vision_tower.semantic_encoder.blocks.5.norm1.weight": "pytorch_model-00001-of-00004.bin",
831
+ "vision_tower.semantic_encoder.blocks.5.norm2.bias": "pytorch_model-00001-of-00004.bin",
832
+ "vision_tower.semantic_encoder.blocks.5.norm2.weight": "pytorch_model-00001-of-00004.bin",
833
+ "vision_tower.semantic_encoder.blocks.6.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
834
+ "vision_tower.semantic_encoder.blocks.6.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
835
+ "vision_tower.semantic_encoder.blocks.6.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
836
+ "vision_tower.semantic_encoder.blocks.6.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
837
+ "vision_tower.semantic_encoder.blocks.6.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
838
+ "vision_tower.semantic_encoder.blocks.6.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
839
+ "vision_tower.semantic_encoder.blocks.6.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
840
+ "vision_tower.semantic_encoder.blocks.6.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
841
+ "vision_tower.semantic_encoder.blocks.6.norm1.bias": "pytorch_model-00001-of-00004.bin",
842
+ "vision_tower.semantic_encoder.blocks.6.norm1.weight": "pytorch_model-00001-of-00004.bin",
843
+ "vision_tower.semantic_encoder.blocks.6.norm2.bias": "pytorch_model-00001-of-00004.bin",
844
+ "vision_tower.semantic_encoder.blocks.6.norm2.weight": "pytorch_model-00001-of-00004.bin",
845
+ "vision_tower.semantic_encoder.blocks.7.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
846
+ "vision_tower.semantic_encoder.blocks.7.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
847
+ "vision_tower.semantic_encoder.blocks.7.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
848
+ "vision_tower.semantic_encoder.blocks.7.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
849
+ "vision_tower.semantic_encoder.blocks.7.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
850
+ "vision_tower.semantic_encoder.blocks.7.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
851
+ "vision_tower.semantic_encoder.blocks.7.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
852
+ "vision_tower.semantic_encoder.blocks.7.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
853
+ "vision_tower.semantic_encoder.blocks.7.norm1.bias": "pytorch_model-00001-of-00004.bin",
854
+ "vision_tower.semantic_encoder.blocks.7.norm1.weight": "pytorch_model-00001-of-00004.bin",
855
+ "vision_tower.semantic_encoder.blocks.7.norm2.bias": "pytorch_model-00001-of-00004.bin",
856
+ "vision_tower.semantic_encoder.blocks.7.norm2.weight": "pytorch_model-00001-of-00004.bin",
857
+ "vision_tower.semantic_encoder.blocks.8.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
858
+ "vision_tower.semantic_encoder.blocks.8.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
859
+ "vision_tower.semantic_encoder.blocks.8.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
860
+ "vision_tower.semantic_encoder.blocks.8.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
861
+ "vision_tower.semantic_encoder.blocks.8.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
862
+ "vision_tower.semantic_encoder.blocks.8.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
863
+ "vision_tower.semantic_encoder.blocks.8.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
864
+ "vision_tower.semantic_encoder.blocks.8.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
865
+ "vision_tower.semantic_encoder.blocks.8.norm1.bias": "pytorch_model-00001-of-00004.bin",
866
+ "vision_tower.semantic_encoder.blocks.8.norm1.weight": "pytorch_model-00001-of-00004.bin",
867
+ "vision_tower.semantic_encoder.blocks.8.norm2.bias": "pytorch_model-00001-of-00004.bin",
868
+ "vision_tower.semantic_encoder.blocks.8.norm2.weight": "pytorch_model-00001-of-00004.bin",
869
+ "vision_tower.semantic_encoder.blocks.9.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
870
+ "vision_tower.semantic_encoder.blocks.9.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
871
+ "vision_tower.semantic_encoder.blocks.9.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
872
+ "vision_tower.semantic_encoder.blocks.9.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
873
+ "vision_tower.semantic_encoder.blocks.9.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
874
+ "vision_tower.semantic_encoder.blocks.9.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
875
+ "vision_tower.semantic_encoder.blocks.9.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
876
+ "vision_tower.semantic_encoder.blocks.9.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
877
+ "vision_tower.semantic_encoder.blocks.9.norm1.bias": "pytorch_model-00001-of-00004.bin",
878
+ "vision_tower.semantic_encoder.blocks.9.norm1.weight": "pytorch_model-00001-of-00004.bin",
879
+ "vision_tower.semantic_encoder.blocks.9.norm2.bias": "pytorch_model-00001-of-00004.bin",
880
+ "vision_tower.semantic_encoder.blocks.9.norm2.weight": "pytorch_model-00001-of-00004.bin",
881
+ "vision_tower.semantic_encoder.merger.ln_q.bias": "pytorch_model-00001-of-00004.bin",
882
+ "vision_tower.semantic_encoder.merger.ln_q.weight": "pytorch_model-00001-of-00004.bin",
883
+ "vision_tower.semantic_encoder.merger.mlp.0.bias": "pytorch_model-00001-of-00004.bin",
884
+ "vision_tower.semantic_encoder.merger.mlp.0.weight": "pytorch_model-00001-of-00004.bin",
885
+ "vision_tower.semantic_encoder.merger.mlp.2.bias": "pytorch_model-00001-of-00004.bin",
886
+ "vision_tower.semantic_encoder.merger.mlp.2.weight": "pytorch_model-00001-of-00004.bin",
887
+ "vision_tower.semantic_encoder.patch_embed.proj.weight": "pytorch_model-00001-of-00004.bin"
888
+ }
889
+ }
sdxl_decoder_pipe.py ADDED
@@ -0,0 +1,901 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modify from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
2
+ import inspect
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from einops import repeat, rearrange
12
+
13
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
14
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
15
+
16
+ from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
17
+ from diffusers.schedulers import KarrasDiffusionSchedulers
18
+
19
+ from diffusers.utils.torch_utils import randn_tensor
20
+ import PIL.Image
21
+
22
+ from diffusers.models.attention_processor import (
23
+ AttnProcessor2_0,
24
+ FusedAttnProcessor2_0,
25
+ XFormersAttnProcessor,
26
+ )
27
+
28
+ from diffusers.utils import (
29
+ USE_PEFT_BACKEND,
30
+ deprecate,
31
+ is_invisible_watermark_available,
32
+ is_torch_xla_available,
33
+ logging,
34
+ replace_example_docstring,
35
+ scale_lora_layers,
36
+ unscale_lora_layers,
37
+ )
38
+
39
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
40
+ from diffusers.loaders import (
41
+ FromSingleFileMixin,
42
+ IPAdapterMixin,
43
+ StableDiffusionXLLoraLoaderMixin,
44
+ TextualInversionLoaderMixin,
45
+ )
46
+
47
+ if is_invisible_watermark_available():
48
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
49
+
50
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
51
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline, \
52
+ retrieve_timesteps, rescale_noise_cfg
53
+
54
+ from torchvision.transforms import Compose, Resize, CenterCrop, Normalize, InterpolationMode
55
+
56
+ if is_torch_xla_available():
57
+ import torch_xla.core.xla_model as xm
58
+
59
+ XLA_AVAILABLE = True
60
+ else:
61
+ XLA_AVAILABLE = False
62
+
63
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
64
+
65
+
66
+ @dataclass
67
+ class StableDiffusionXLDecoderPipelineOutput(StableDiffusionXLPipelineOutput):
68
+ images: Union[List[PIL.Image.Image], np.ndarray]
69
+ indices_semantic: Optional[torch.Tensor] = None
70
+ indices_pixel: Optional[torch.Tensor] = None
71
+
72
+
73
+ def expand_dims_like(x, y):
74
+ while x.dim() != y.dim():
75
+ x = x.unsqueeze(-1)
76
+ return x
77
+
78
+
79
+ class AbstractEmbModel(nn.Module):
80
+ def __init__(self):
81
+ super().__init__()
82
+ self._is_trainable = None
83
+ self._ucg_rate = None
84
+ self._input_key = None
85
+
86
+ @property
87
+ def is_trainable(self) -> bool:
88
+ return self._is_trainable
89
+
90
+ @property
91
+ def ucg_rate(self) -> Union[float, torch.Tensor]:
92
+ return self._ucg_rate
93
+
94
+ @property
95
+ def input_key(self) -> str:
96
+ return self._input_key
97
+
98
+ @is_trainable.setter
99
+ def is_trainable(self, value: bool):
100
+ self._is_trainable = value
101
+
102
+ @ucg_rate.setter
103
+ def ucg_rate(self, value: Union[float, torch.Tensor]):
104
+ self._ucg_rate = value
105
+
106
+ @input_key.setter
107
+ def input_key(self, value: str):
108
+ self._input_key = value
109
+
110
+ @is_trainable.deleter
111
+ def is_trainable(self):
112
+ del self._is_trainable
113
+
114
+ @ucg_rate.deleter
115
+ def ucg_rate(self):
116
+ del self._ucg_rate
117
+
118
+ @input_key.deleter
119
+ def input_key(self):
120
+ del self._input_key
121
+
122
+
123
+ class DualViTok2ImageEmbedder(AbstractEmbModel):
124
+ def __init__(
125
+ self,
126
+ image_processor=None,
127
+ vq_model=None,
128
+ device="cuda",
129
+ dtype=torch.float32,
130
+ freeze=True,
131
+ image_size=0,
132
+ resize_factor=1,
133
+ not_bicubic=True,
134
+ return_sequence=False,
135
+ grid_feature_scale=1,
136
+ texture_drop_prob=0,
137
+ semantic_drop_prob=0,
138
+ pixel_channel=32,
139
+ semantic_channel=32,
140
+ ):
141
+ super().__init__()
142
+ vq_model.to(device=device, dtype=dtype)
143
+ vq_model.eval()
144
+
145
+ self.processor = image_processor
146
+
147
+ self.model = vq_model
148
+ self.device = device
149
+ if freeze:
150
+ self.freeze()
151
+
152
+ if image_size > 0:
153
+ preprocessor = [
154
+ Resize(image_size) if not_bicubic else Resize(image_size, interpolation=InterpolationMode.BICUBIC)]
155
+ preprocessor += [
156
+ CenterCrop(image_size),
157
+ ]
158
+ self.preprocessor = Compose(preprocessor)
159
+ self.image_size = image_size
160
+ self.resize_factor = resize_factor
161
+ self.not_bicubic = not_bicubic
162
+ self.return_sequence = return_sequence
163
+ self.grid_feature_scale = grid_feature_scale
164
+ self.texture_drop_prob = texture_drop_prob
165
+ self.semantic_drop_prob = semantic_drop_prob
166
+ self.pixel_channel = pixel_channel
167
+ self.semantic_channel = semantic_channel
168
+
169
+ def freeze(self):
170
+ self.model = self.model.eval()
171
+ for param in self.parameters():
172
+ param.requires_grad = False
173
+
174
+ def vq_encode(self, image):
175
+ if image.ndim == 5:
176
+ assert image.size(1) == 1
177
+ image = image.squeeze(1)
178
+ bs, _, h, w = image.shape
179
+
180
+ if self.image_size > 0:
181
+ image = self.preprocessor(image)
182
+ else:
183
+ assert self.resize_factor > 0
184
+ preprocessor = Resize((int(h * self.resize_factor), int(w * self.resize_factor))) if self.not_bicubic else \
185
+ Resize((int(h * self.resize_factor), int(w * self.resize_factor)),
186
+ interpolation=InterpolationMode.BICUBIC)
187
+ image = preprocessor(image)
188
+
189
+ inputs = dict(image=image)
190
+ inputs = self.model.get_input(inputs)
191
+
192
+ (quant_semantic, diff_semantic, indices_semantic, target_semantic), \
193
+ (quant_pixel, diff_pixel, indices_pixel) = self.model.encode(**inputs)
194
+ return indices_semantic, indices_pixel
195
+
196
+ def vq_encode_code(self, image):
197
+ (quant_semantic, diff_semantic, indices_semantic, target_semantic), \
198
+ (quant_pixel, diff_pixel, indices_pixel) = self.vq_encode(image)
199
+ return indices_semantic, indices_pixel
200
+
201
+ def vq_decode_code(self, indices_semantic, indices_pixel):
202
+ return self.model.decode_code(indices_semantic, indices_pixel)
203
+
204
+ def forward(self, image, return_indices=False):
205
+ if image.ndim == 5:
206
+ assert image.size(1) == 1
207
+ image = image.squeeze(1)
208
+ bs, _, h, w = image.shape
209
+
210
+ if self.image_size > 0:
211
+ image = self.preprocessor(image)
212
+ else:
213
+ assert self.resize_factor > 0
214
+ preprocessor = Resize((int(h * self.resize_factor), int(w * self.resize_factor))) if self.not_bicubic else \
215
+ Resize((int(h * self.resize_factor), int(w * self.resize_factor)),
216
+ interpolation=InterpolationMode.BICUBIC)
217
+ image = preprocessor(image)
218
+
219
+ inputs = dict(image=image)
220
+ inputs = self.model.get_input(inputs)
221
+
222
+ (quant_semantic, diff_semantic, indices_semantic, target_semantic), \
223
+ (quant_pixel, diff_pixel, indices_pixel) = self.model.encode(**inputs)
224
+
225
+ feature = self.model.merge_quants(quant_semantic, quant_pixel)
226
+
227
+ if self.return_sequence:
228
+ feature = rearrange(feature, 'b c h w -> b h w c')
229
+ _, this_h, this_w, _ = feature.shape
230
+ feature = feature.view(bs, this_w * this_w, -1)
231
+ else:
232
+ feature = feature * self.grid_feature_scale
233
+
234
+ if return_indices:
235
+ return feature, indices_semantic, indices_pixel
236
+
237
+ return feature
238
+
239
+ def encode(self, img):
240
+ return self(img)
241
+
242
+ def indices_to_codes(self, semantic_indices, texture_indices):
243
+ quant_semantic, quant_texture = self.model.indices_to_codes(semantic_indices, texture_indices)
244
+ return self.model.merge_quants(quant_semantic, quant_texture)
245
+
246
+
247
+ class StableDiffusionXLDecoderPipeline(
248
+ DiffusionPipeline,
249
+ StableDiffusionMixin,
250
+ FromSingleFileMixin,
251
+ StableDiffusionXLLoraLoaderMixin,
252
+ TextualInversionLoaderMixin,
253
+ ):
254
+ model_cpu_offload_seq = "vq_model_embedder->unet->vae"
255
+ _optional_components = [
256
+ "vq_model_embedder",
257
+ ]
258
+ _callback_tensor_inputs = [
259
+ "latents",
260
+ "prompt_embeds",
261
+ "negative_prompt_embeds",
262
+ "add_text_embeds",
263
+ "add_time_ids",
264
+ "negative_pooled_prompt_embeds",
265
+ "negative_add_time_ids",
266
+ ]
267
+
268
+ def __init__(
269
+ self,
270
+ vae: AutoencoderKL,
271
+ unet: UNet2DConditionModel,
272
+ scheduler: KarrasDiffusionSchedulers,
273
+ force_zeros_for_empty_prompt: bool = True,
274
+ add_watermarker: Optional[bool] = None,
275
+ vq_image_processor=None,
276
+ vq_model=None,
277
+ ):
278
+ super().__init__()
279
+
280
+ self.register_modules(
281
+ vae=vae,
282
+ unet=unet,
283
+ scheduler=scheduler,
284
+ )
285
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
286
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
287
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
288
+
289
+ self.default_sample_size = self.unet.config.sample_size
290
+
291
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
292
+
293
+ if add_watermarker:
294
+ self.watermark = StableDiffusionXLWatermarker()
295
+ else:
296
+ self.watermark = None
297
+
298
+ self.empty_prompt_embeds = torch.zeros([1, 77, 2048]).to(device=unet.device, dtype=unet.dtype)
299
+ self.empty_pooled_prompt_embeds = torch.zeros([1, 1280]).to(device=unet.device, dtype=unet.dtype)
300
+ self.dualvitok_channels = vq_model.pixel_channel + vq_model.semantic_channel
301
+
302
+ self.resolution_group = ['(1024, 1024)', '(768, 1024)', '(1024, 768)', '(512, 2048)', '(2048, 512)',
303
+ '(640, 1920)', '(1920, 640)', '(768, 1536)', '(1536, 768)', '(768, 1152)',
304
+ '(1152, 768)', '(512, 512)']
305
+
306
+ embedder_kwargs = dict(image_size=0,
307
+ resize_factor=1,
308
+ return_sequence=False,
309
+ grid_feature_scale=1)
310
+ if isinstance(vq_model, DualViTok2ImageEmbedder):
311
+ self.vq_model_embedder = vq_model
312
+ else:
313
+ self.vq_model_embedder = DualViTok2ImageEmbedder(vq_image_processor, vq_model, **embedder_kwargs)
314
+
315
+ def vq_encode(self, image):
316
+ return self.vq_model_embedder.encode(image)
317
+
318
+ def vq_encode_code(self, image):
319
+ return self.vq_model_embedder.vq_encode_code(image)
320
+
321
+ def vq_decode_code(self, *args, **kwargs):
322
+ return self.vq_model_embedder.vq_decode_code(*args, **kwargs)
323
+
324
+ def indices_to_codes(self, *args, **kwargs):
325
+ return self.vq_model_embedder.indices_to_codes(*args, **kwargs)
326
+
327
+ def _get_add_time_ids(
328
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None,
329
+ resolution_index=None,
330
+ ):
331
+ add_time_ids = [resolution_index] * 6
332
+
333
+ passed_add_embed_dim = (
334
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
335
+ )
336
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
337
+
338
+ if expected_add_embed_dim != passed_add_embed_dim:
339
+ raise ValueError(
340
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
341
+ )
342
+
343
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
344
+ return add_time_ids
345
+
346
+ def check_inputs(
347
+ self,
348
+ height,
349
+ width,
350
+ callback_steps,
351
+ callback_on_step_end_tensor_inputs=None,
352
+ ):
353
+ if height % 8 != 0 or width % 8 != 0:
354
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
355
+
356
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
357
+ raise ValueError(
358
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
359
+ f" {type(callback_steps)}."
360
+ )
361
+
362
+ if callback_on_step_end_tensor_inputs is not None and not all(
363
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
364
+ ):
365
+ raise ValueError(
366
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
367
+ )
368
+
369
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
370
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
371
+ shape = (
372
+ batch_size,
373
+ num_channels_latents,
374
+ int(height) // self.vae_scale_factor,
375
+ int(width) // self.vae_scale_factor,
376
+ )
377
+ if isinstance(generator, list) and len(generator) != batch_size:
378
+ raise ValueError(
379
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
380
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
381
+ )
382
+
383
+ if latents is None:
384
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
385
+ else:
386
+ latents = latents.to(device)
387
+
388
+ # scale the initial noise by the standard deviation required by the scheduler
389
+ latents = latents * self.scheduler.init_noise_sigma
390
+ return latents
391
+
392
+ def upcast_vae(self):
393
+ dtype = self.vae.dtype
394
+ self.vae.to(dtype=torch.float32)
395
+ use_torch_2_0_or_xformers = isinstance(
396
+ self.vae.decoder.mid_block.attentions[0].processor,
397
+ (
398
+ AttnProcessor2_0,
399
+ XFormersAttnProcessor,
400
+ FusedAttnProcessor2_0,
401
+ ),
402
+ )
403
+ # if xformers or torch_2_0 is used attention block does not need
404
+ # to be in float32 which can save lots of memory
405
+ if use_torch_2_0_or_xformers:
406
+ self.vae.post_quant_conv.to(dtype)
407
+ self.vae.decoder.conv_in.to(dtype)
408
+ self.vae.decoder.mid_block.to(dtype)
409
+
410
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
411
+ def get_guidance_scale_embedding(
412
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
413
+ ) -> torch.Tensor:
414
+ """
415
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
416
+
417
+ Args:
418
+ w (`torch.Tensor`):
419
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
420
+ embedding_dim (`int`, *optional*, defaults to 512):
421
+ Dimension of the embeddings to generate.
422
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
423
+ Data type of the generated embeddings.
424
+
425
+ Returns:
426
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
427
+ """
428
+ assert len(w.shape) == 1
429
+ w = w * 1000.0
430
+
431
+ half_dim = embedding_dim // 2
432
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
433
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
434
+ emb = w.to(dtype)[:, None] * emb[None, :]
435
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
436
+ if embedding_dim % 2 == 1: # zero pad
437
+ emb = torch.nn.functional.pad(emb, (0, 1))
438
+ assert emb.shape == (w.shape[0], embedding_dim)
439
+ return emb
440
+
441
+ @property
442
+ def guidance_scale(self):
443
+ return self._guidance_scale
444
+
445
+ @property
446
+ def guidance_rescale(self):
447
+ return self._guidance_rescale
448
+
449
+ @property
450
+ def clip_skip(self):
451
+ return self._clip_skip
452
+
453
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
454
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
455
+ # corresponds to doing no classifier free guidance.
456
+ @property
457
+ def do_classifier_free_guidance(self):
458
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
459
+
460
+ @property
461
+ def cross_attention_kwargs(self):
462
+ return self._cross_attention_kwargs
463
+
464
+ @property
465
+ def denoising_end(self):
466
+ return self._denoising_end
467
+
468
+ @property
469
+ def num_timesteps(self):
470
+ return self._num_timesteps
471
+
472
+ @property
473
+ def interrupt(self):
474
+ return self._interrupt
475
+
476
+ def prepare_extra_step_kwargs(self, generator, eta):
477
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
478
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
479
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
480
+ # and should be between [0, 1]
481
+
482
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
483
+ extra_step_kwargs = {}
484
+ if accepts_eta:
485
+ extra_step_kwargs["eta"] = eta
486
+
487
+ # check if the scheduler accepts generator
488
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
489
+ if accepts_generator:
490
+ extra_step_kwargs["generator"] = generator
491
+ return extra_step_kwargs
492
+
493
+ @torch.no_grad()
494
+ def __call__(
495
+ self,
496
+ vq_indices: Optional[List] = None,
497
+ vq_embeds: Optional[torch.Tensor] = None,
498
+ images: Optional[PipelineImageInput] = None,
499
+ height: Optional[int] = None,
500
+ width: Optional[int] = None,
501
+ num_inference_steps: int = 50,
502
+ timesteps: List[int] = None,
503
+ sigmas: List[float] = None,
504
+ denoising_end: Optional[float] = None,
505
+ guidance_scale: float = 2.0,
506
+ eta: float = 0.0,
507
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
508
+ latents: Optional[torch.Tensor] = None,
509
+ output_type: Optional[str] = "pil",
510
+ return_dict: bool = True,
511
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
512
+ guidance_rescale: float = 0.0,
513
+ original_size: Optional[Tuple[int, int]] = None,
514
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
515
+ target_size: Optional[Tuple[int, int]] = None,
516
+ negative_original_size: Optional[Tuple[int, int]] = None,
517
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
518
+ negative_target_size: Optional[Tuple[int, int]] = None,
519
+ clip_skip: Optional[int] = None,
520
+ callback_on_step_end: Optional[
521
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
522
+ ] = None,
523
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
524
+ **kwargs,
525
+ ):
526
+ r"""
527
+ Function invoked when calling the pipeline for generation.
528
+
529
+ Args:
530
+ vq_indices (`Optional[PipelineImageInput]`, *optional*):
531
+ The VQ indices for semantic and pixel tokens. Should be a tuple of (semantic_indices, pixel_indices).
532
+ images (`Optional[PipelineImageInput]`, *optional*):
533
+ Input images in range [-1, 1] as torch.Tensor with shape (batch_size, channels, height, width).
534
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
535
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
536
+ Anything below 512 pixels won't work well for
537
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
538
+ and checkpoints that are not specifically fine-tuned on low resolutions.
539
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
540
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
541
+ Anything below 512 pixels won't work well for
542
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
543
+ and checkpoints that are not specifically fine-tuned on low resolutions.
544
+ num_inference_steps (`int`, *optional*, defaults to 50):
545
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
546
+ expense of slower inference.
547
+ timesteps (`List[int]`, *optional*):
548
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
549
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
550
+ passed will be used. Must be in descending order.
551
+ sigmas (`List[float]`, *optional*):
552
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
553
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
554
+ will be used.
555
+ denoising_end (`float`, *optional*):
556
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
557
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
558
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
559
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
560
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
561
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
562
+ guidance_scale (`float`, *optional*, defaults to 5.0):
563
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
564
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
565
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
566
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
567
+ usually at the expense of lower image quality.
568
+ eta (`float`, *optional*, defaults to 0.0):
569
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
570
+ [`schedulers.DDIMScheduler`], will be ignored for others.
571
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
572
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
573
+ to make generation deterministic.
574
+ latents (`torch.Tensor`, *optional*):
575
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
576
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
577
+ tensor will ge generated by sampling using the supplied random `generator`.
578
+ output_type (`str`, *optional*, defaults to `"pil"`):
579
+ The output format of the generate image. Choose between
580
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
581
+ return_dict (`bool`, *optional*, defaults to `True`):
582
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
583
+ of a plain tuple.
584
+ cross_attention_kwargs (`dict`, *optional*):
585
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
586
+ `self.processor` in
587
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
588
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
589
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
590
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
591
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
592
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
593
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
594
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
595
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
596
+ explained in section 2.2 of
597
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
598
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
599
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
600
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
601
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
602
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
603
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
604
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
605
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
606
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
607
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
608
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
609
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
610
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
611
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
612
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
613
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
614
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
615
+ `._callback_tensor_inputs` attribute of your pipeline class.
616
+
617
+ Examples:
618
+
619
+ Returns:
620
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
621
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
622
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
623
+ """
624
+
625
+ callback = kwargs.pop("callback", None)
626
+ callback_steps = kwargs.pop("callback_steps", None)
627
+
628
+ if callback is not None:
629
+ deprecate(
630
+ "callback",
631
+ "1.0.0",
632
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
633
+ )
634
+ if callback_steps is not None:
635
+ deprecate(
636
+ "callback_steps",
637
+ "1.0.0",
638
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
639
+ )
640
+
641
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
642
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
643
+
644
+ # 0. Default height and width to unet
645
+ height = height or self.default_sample_size * self.vae_scale_factor
646
+ width = width or self.default_sample_size * self.vae_scale_factor
647
+
648
+ original_size = original_size or (height, width)
649
+ target_size = target_size or (height, width)
650
+
651
+ # 1. Check inputs. Raise error if not correct
652
+ self.check_inputs(
653
+ height,
654
+ width,
655
+ callback_steps,
656
+ callback_on_step_end_tensor_inputs,
657
+ )
658
+
659
+ self._guidance_scale = guidance_scale
660
+ self._guidance_rescale = guidance_rescale
661
+ self._clip_skip = clip_skip
662
+ self._cross_attention_kwargs = cross_attention_kwargs
663
+ self._denoising_end = denoising_end
664
+ self._interrupt = False
665
+
666
+ # 2. encode vq_embeds
667
+ assert images is not None or vq_indices is not None or vq_embeds is not None
668
+ batch_size = len(images) if images is not None else len(vq_indices[0])
669
+
670
+ if images:
671
+ vq_embeds, indices_semantic, indices_pixel = self.vq_model_embedder(images, return_indices=True)
672
+ elif vq_indices:
673
+ indices_semantic, indices_pixel = vq_indices[0], vq_indices[1]
674
+ vq_embeds = self.vq_model_embedder.indices_to_codes(vq_indices[0], vq_indices[1])
675
+ elif vq_embeds:
676
+ if isinstance(vq_embeds, list):
677
+ vq_embeds = self.vq_model_embedder.merge_quants(vq_embeds)
678
+ indices_semantic, indices_pixel = None, None
679
+ else:
680
+ raise ValueError("No valid input provided")
681
+
682
+ device = self._execution_device
683
+
684
+ # 3. Encode input prompt
685
+ lora_scale = (
686
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
687
+ )
688
+
689
+ prompt_embeds = repeat(self.empty_prompt_embeds, '1 l c -> b l c', b=batch_size)
690
+ pooled_prompt_embeds = repeat(self.empty_pooled_prompt_embeds, '1 c -> b c', b=batch_size)
691
+
692
+ negative_prompt_embeds = prompt_embeds
693
+ negative_pooled_prompt_embeds = pooled_prompt_embeds
694
+
695
+ # 4. Prepare timesteps
696
+ timesteps, num_inference_steps = retrieve_timesteps(
697
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
698
+ )
699
+
700
+ # 5. Prepare latent variables
701
+ # num_channels_latents = self.unet.config.in_channels
702
+ num_channels_latents = 4
703
+ latents = self.prepare_latents(
704
+ batch_size,
705
+ num_channels_latents,
706
+ height,
707
+ width,
708
+ prompt_embeds.dtype,
709
+ device,
710
+ generator,
711
+ latents,
712
+ )
713
+
714
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
715
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
716
+
717
+ # 7. Prepare added time ids & embeddings
718
+ add_text_embeds = pooled_prompt_embeds
719
+ text_encoder_projection_dim = 1280
720
+
721
+ resolution = f'({width}, {height})'
722
+ assert resolution in self.resolution_group, f"resolution are not in resolution group. Got {resolution}. Candidates:{self.resolution_group}"
723
+ resolution_index = self.resolution_group.index(resolution)
724
+ # resolution_index = None
725
+
726
+ add_time_ids = self._get_add_time_ids(
727
+ original_size,
728
+ crops_coords_top_left,
729
+ target_size,
730
+ dtype=prompt_embeds.dtype,
731
+ text_encoder_projection_dim=text_encoder_projection_dim,
732
+ resolution_index=resolution_index,
733
+ )
734
+ if negative_original_size is not None and negative_target_size is not None:
735
+ negative_add_time_ids = self._get_add_time_ids(
736
+ negative_original_size,
737
+ negative_crops_coords_top_left,
738
+ negative_target_size,
739
+ dtype=prompt_embeds.dtype,
740
+ text_encoder_projection_dim=text_encoder_projection_dim,
741
+ )
742
+ else:
743
+ negative_add_time_ids = add_time_ids
744
+
745
+ if self.do_classifier_free_guidance:
746
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
747
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
748
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
749
+
750
+ prompt_embeds = prompt_embeds.to(device)
751
+ add_text_embeds = add_text_embeds.to(device)
752
+ add_time_ids = add_time_ids.to(device).repeat(batch_size, 1)
753
+
754
+ # 8. Denoising loop
755
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
756
+
757
+ # 8.1 Apply denoising_end
758
+ if (
759
+ self.denoising_end is not None
760
+ and isinstance(self.denoising_end, float)
761
+ and self.denoising_end > 0
762
+ and self.denoising_end < 1
763
+ ):
764
+ discrete_timestep_cutoff = int(
765
+ round(
766
+ self.scheduler.config.num_train_timesteps
767
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
768
+ )
769
+ )
770
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
771
+ timesteps = timesteps[:num_inference_steps]
772
+
773
+ # 9. Optionally get Guidance Scale Embedding
774
+ timestep_cond = None
775
+ if self.unet.config.time_cond_proj_dim is not None:
776
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size)
777
+ timestep_cond = self.get_guidance_scale_embedding(
778
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
779
+ ).to(device=device, dtype=latents.dtype)
780
+
781
+ self._num_timesteps = len(timesteps)
782
+ # with self.progress_bar(total=num_inference_steps) as progress_bar:
783
+ for i, t in enumerate(timesteps):
784
+ if self.interrupt:
785
+ continue
786
+
787
+ # expand the latents if we are doing classifier free guidance
788
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
789
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
790
+
791
+ vq_embeds = vq_embeds.to(latent_model_input) if vq_embeds.size(
792
+ -1) == latent_model_input.size(
793
+ -1) else \
794
+ torch.nn.functional.interpolate(vq_embeds.to(latent_model_input),
795
+ size=latent_model_input.shape[-2:])
796
+ vq_embeds_input = torch.cat([torch.zeros_like(vq_embeds),
797
+ vq_embeds]) if self.do_classifier_free_guidance else vq_embeds
798
+
799
+ # predict the noise residual
800
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
801
+
802
+ latent_model_input = torch.cat([latent_model_input, vq_embeds_input], dim=1)
803
+ noise_pred = self.unet(
804
+ latent_model_input,
805
+ t,
806
+ encoder_hidden_states=prompt_embeds,
807
+ timestep_cond=timestep_cond,
808
+ cross_attention_kwargs=self.cross_attention_kwargs,
809
+ added_cond_kwargs=added_cond_kwargs,
810
+ return_dict=False,
811
+ )[0]
812
+
813
+ # perform guidance
814
+ if self.do_classifier_free_guidance:
815
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
816
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
817
+
818
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
819
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
820
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_cond, guidance_rescale=self.guidance_rescale)
821
+
822
+ # compute the previous noisy sample x_t -> x_t-1
823
+ latents_dtype = latents.dtype
824
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
825
+ if latents.dtype != latents_dtype:
826
+ if torch.backends.mps.is_available():
827
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
828
+ latents = latents.to(latents_dtype)
829
+
830
+ if callback_on_step_end is not None:
831
+ callback_kwargs = {}
832
+ for k in callback_on_step_end_tensor_inputs:
833
+ callback_kwargs[k] = locals()[k]
834
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
835
+
836
+ latents = callback_outputs.pop("latents", latents)
837
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
838
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
839
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
840
+
841
+ # call the callback, if provided
842
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
843
+ # progress_bar.update()
844
+ if callback is not None and i % callback_steps == 0:
845
+ step_idx = i // getattr(self.scheduler, "order", 1)
846
+ callback(step_idx, t, latents)
847
+
848
+ if XLA_AVAILABLE:
849
+ xm.mark_step()
850
+
851
+ if not output_type == "latent":
852
+ # make sure the VAE is in float32 mode, as it overflows in float16
853
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
854
+
855
+ if needs_upcasting:
856
+ self.upcast_vae()
857
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
858
+ elif latents.dtype != self.vae.dtype:
859
+ if torch.backends.mps.is_available():
860
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
861
+ self.vae = self.vae.to(latents.dtype)
862
+
863
+ # unscale/denormalize the latents
864
+ # denormalize with the mean and std if available and not None
865
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
866
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
867
+ if has_latents_mean and has_latents_std:
868
+ latents_mean = (
869
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
870
+ )
871
+ latents_std = (
872
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
873
+ )
874
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
875
+ else:
876
+ latents = latents / self.vae.config.scaling_factor
877
+
878
+ image = self.vae.decode(latents, return_dict=False)[0]
879
+
880
+ # cast back to fp16 if needed
881
+ if needs_upcasting:
882
+ self.vae.to(dtype=torch.float16)
883
+ else:
884
+ image = latents
885
+
886
+ if not output_type == "latent":
887
+ # apply watermark if available
888
+ if self.watermark is not None:
889
+ image = self.watermark.apply_watermark(image)
890
+
891
+ image = self.image_processor.postprocess(image, output_type=output_type)
892
+
893
+ # Offload all models
894
+ self.maybe_free_model_hooks()
895
+
896
+ if not return_dict:
897
+ return (image,)
898
+
899
+ return StableDiffusionXLDecoderPipelineOutput(images=image,
900
+ indices_semantic=indices_semantic,
901
+ indices_pixel=indices_pixel)
special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efb35e4dde47fc87a13f21a10fc3a0ac50340bc53ad9d88999616403d8498216
3
+ size 33100531
tokenizer_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db9da21d556f5a362c7d22bfe8bf4d1bd734a303db356521eeb4c96a11881970
3
+ size 25814172
vocab.json ADDED
The diff for this file is too large to render. See raw diff