Upload folder using huggingface_hub
Browse files- .gitattributes +2 -0
- added_tokens.json +0 -0
- aspect_ratio_utils.py +196 -0
- config.json +191 -0
- configuration_dualvitok.py +154 -0
- configuration_illume.py +140 -0
- configuration_movqgan.py +92 -0
- configuration_qwen2vit.py +249 -0
- generation_config.json +6 -0
- image_processing_dualvitok.py +24 -0
- image_processing_movqgan.py +429 -0
- image_utils.py +812 -0
- inference_utils.py +412 -0
- merges.txt +0 -0
- modeling_dualvitok.py +653 -0
- modeling_illume.py +883 -0
- modeling_movqgan.py +828 -0
- modeling_qwen2vit.py +841 -0
- modeling_rope_utils.py +561 -0
- processing_illume.py +329 -0
- pytorch_model-00001-of-00004.bin +3 -0
- pytorch_model-00002-of-00004.bin +3 -0
- pytorch_model-00003-of-00004.bin +3 -0
- pytorch_model-00004-of-00004.bin +3 -0
- pytorch_model.bin.index.json +889 -0
- sdxl_decoder_pipe.py +901 -0
- special_tokens_map.json +31 -0
- tokenizer.json +3 -0
- tokenizer_config.json +3 -0
- vocab.json +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
tokenizer_config.json filter=lfs diff=lfs merge=lfs -text
|
added_tokens.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
aspect_ratio_utils.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchvision import transforms
|
| 2 |
+
import numpy as np
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
from PIL import Image, ImageOps
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
RATIOS = [
|
| 9 |
+
(512, 512),
|
| 10 |
+
(384, 512),
|
| 11 |
+
(512, 384),
|
| 12 |
+
(384, 768),
|
| 13 |
+
(768, 384),
|
| 14 |
+
(384, 576),
|
| 15 |
+
(576, 384),
|
| 16 |
+
(320, 960),
|
| 17 |
+
(960, 320),
|
| 18 |
+
(256, 1024),
|
| 19 |
+
(1024, 256),
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
RATIO_TYPES = [
|
| 23 |
+
'ratio_h512_w512',
|
| 24 |
+
'ratio_h384_w512',
|
| 25 |
+
'ratio_h512_w384',
|
| 26 |
+
'ratio_h384_w768',
|
| 27 |
+
'ratio_h768_w384',
|
| 28 |
+
'ratio_h384_w576',
|
| 29 |
+
'ratio_h576_w384',
|
| 30 |
+
'ratio_h320_w960',
|
| 31 |
+
'ratio_h960_w320',
|
| 32 |
+
'ratio_h256_w1024',
|
| 33 |
+
'ratio_h1024_w256',
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def center_crop_and_resize(img, output_size=(256, 256)):
|
| 38 |
+
target_h, target_w = output_size
|
| 39 |
+
img_w, img_h = img.size
|
| 40 |
+
|
| 41 |
+
scale_w, scale_h = img_w / target_w, img_h / target_h
|
| 42 |
+
if scale_h > scale_w:
|
| 43 |
+
new_w, new_h = target_w, int(target_w / img_w * img_h)
|
| 44 |
+
else:
|
| 45 |
+
new_w, new_h = int(target_h / img_h * img_w), target_h
|
| 46 |
+
|
| 47 |
+
# Resize the image, keeping the aspect ratio
|
| 48 |
+
img = img.resize((new_w, new_h), Image.LANCZOS)
|
| 49 |
+
|
| 50 |
+
# Calculate the center cropping area
|
| 51 |
+
left = (new_w - target_w) // 2
|
| 52 |
+
top = (new_h - target_h) // 2
|
| 53 |
+
right = left + target_w
|
| 54 |
+
bottom = top + target_h
|
| 55 |
+
|
| 56 |
+
# Crop the extra part
|
| 57 |
+
img = img.crop((left, top, right, bottom))
|
| 58 |
+
|
| 59 |
+
return img
|
| 60 |
+
|
| 61 |
+
def resize_with_padding(img, output_size=(256, 256), fill_color=(0, 0, 0)):
|
| 62 |
+
target_height, target_width = output_size
|
| 63 |
+
|
| 64 |
+
# Step 1: Resize with aspect ratio preserved
|
| 65 |
+
original_width, original_height = img.size
|
| 66 |
+
ratio = min(target_width / original_width, target_height / original_height)
|
| 67 |
+
new_size = (int(original_width * ratio), int(original_height * ratio))
|
| 68 |
+
resized_image = img.resize(new_size, Image.LANCZOS)
|
| 69 |
+
|
| 70 |
+
# Step 2: Add padding to reach target size
|
| 71 |
+
delta_w = target_width - new_size[0]
|
| 72 |
+
delta_h = target_height - new_size[1]
|
| 73 |
+
padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
|
| 74 |
+
padded_image = ImageOps.expand(resized_image, padding, fill=fill_color)
|
| 75 |
+
|
| 76 |
+
return padded_image
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def unpad_and_resize_back(padded_image, original_width, original_height):
|
| 80 |
+
"""
|
| 81 |
+
Revert the padded+resized image back to original size.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
padded_image (PIL.Image): Image after padding.
|
| 85 |
+
original_width (int): Original image width before resize & pad.
|
| 86 |
+
original_height (int): Original image height before resize & pad.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
PIL.Image: Image resized back to original resolution.
|
| 90 |
+
"""
|
| 91 |
+
# Compute the scale factor used during the first resize
|
| 92 |
+
target_width, target_height = padded_image.size
|
| 93 |
+
ratio = min(target_width / original_width, target_height / original_height)
|
| 94 |
+
resized_w = int(original_width * ratio)
|
| 95 |
+
resized_h = int(original_height * ratio)
|
| 96 |
+
|
| 97 |
+
# Compute cropping box on padded image
|
| 98 |
+
left = (target_width - resized_w) // 2
|
| 99 |
+
upper = (target_height - resized_h) // 2
|
| 100 |
+
right = left + resized_w
|
| 101 |
+
lower = upper + resized_h
|
| 102 |
+
|
| 103 |
+
# Crop out the resized region (before padding)
|
| 104 |
+
cropped_image = padded_image.crop((left, upper, right, lower))
|
| 105 |
+
|
| 106 |
+
# Resize back to original resolution
|
| 107 |
+
recovered_image = cropped_image.resize((original_width, original_height), Image.LANCZOS)
|
| 108 |
+
return recovered_image
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def calculate_ratio():
|
| 112 |
+
max_area = 512 * 512
|
| 113 |
+
ratios = [(2, 2), (3, 4), (4, 3), (2, 4), (4, 2), (1, 4), (4, 1), (2, 3), (3, 2), (1, 3), (3, 1)]
|
| 114 |
+
ratio_candicates = []
|
| 115 |
+
for ratio in ratios:
|
| 116 |
+
x = math.sqrt(max_area / ratio[0] / ratio[1])
|
| 117 |
+
x = round(x / 64) * 64
|
| 118 |
+
tmp = (x*ratio[0], x*ratio[1])
|
| 119 |
+
# print(ratio, x, tmp)
|
| 120 |
+
ratio_candicates.append(tmp)
|
| 121 |
+
|
| 122 |
+
print("ratio_candicates", ratio_candicates)
|
| 123 |
+
return ratio_candicates
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class AspectRatioCrop(object):
|
| 127 |
+
"""
|
| 128 |
+
Aspect Ratio Crop transform.
|
| 129 |
+
For a given image, find the corresponding aspect ratio and
|
| 130 |
+
resize / resize + crop to the corresponding base sizes
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
base_sizes: list[tuple], the base sizes of final output.
|
| 134 |
+
For example, [(512, 512), (512, 768), (768, 512)]
|
| 135 |
+
|
| 136 |
+
resize_and_crop: bool .If False, find the matched aspect ratio and resize to base size.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
def __init__(self, base_sizes, crop_percent_thresh=0.2):
|
| 140 |
+
self.base_sizes = [(math.floor(h), math.floor(w)) for (h, w) in base_sizes]
|
| 141 |
+
self.aspect_ratios = [x[1] / x[0] for x in self.base_sizes] # w / h
|
| 142 |
+
self.crop_percent_thresh = crop_percent_thresh
|
| 143 |
+
|
| 144 |
+
def _find_size(self, w, h):
|
| 145 |
+
base_size_indexes = list(range(len(self.base_sizes)))
|
| 146 |
+
aspect_ratios = [self.aspect_ratios[i] for i in base_size_indexes]
|
| 147 |
+
aspect_ratio = w / h
|
| 148 |
+
ratio_diff = [abs(ratio - aspect_ratio) for ratio in aspect_ratios]
|
| 149 |
+
min_diff = np.min(ratio_diff)
|
| 150 |
+
match_diff_indexes = [j for j in range(len(ratio_diff)) if ratio_diff[j] == min_diff]
|
| 151 |
+
match_diff_indexes = sorted(match_diff_indexes, key=lambda x: (h-self.base_sizes[base_size_indexes[x]][0])**2
|
| 152 |
+
+ (w-self.base_sizes[base_size_indexes[x]][1])**2) # pick the area most match one
|
| 153 |
+
corr_index = base_size_indexes[match_diff_indexes[0]]
|
| 154 |
+
return corr_index
|
| 155 |
+
|
| 156 |
+
def get_pred_target_w_h(self, w, h):
|
| 157 |
+
aspect_ratio = w / h
|
| 158 |
+
aspect_index = self._find_size(w, h)
|
| 159 |
+
pred_h, pred_w = self.base_sizes[aspect_index]
|
| 160 |
+
|
| 161 |
+
solutions = [
|
| 162 |
+
(pred_w, int(pred_w / aspect_ratio)),
|
| 163 |
+
(int(pred_h * aspect_ratio), pred_h),
|
| 164 |
+
]
|
| 165 |
+
w_tar = None
|
| 166 |
+
h_tar = None
|
| 167 |
+
for solution in solutions:
|
| 168 |
+
w_s, h_s = solution
|
| 169 |
+
if w_s >= pred_w and h_s >= pred_h:
|
| 170 |
+
w_tar = w_s
|
| 171 |
+
h_tar = h_s
|
| 172 |
+
|
| 173 |
+
return pred_w, pred_h, w_tar, h_tar, aspect_index
|
| 174 |
+
|
| 175 |
+
def __call__(self, image, is_inference=False):
|
| 176 |
+
## step 1: find the cloest aspect ratios
|
| 177 |
+
flag_matched = True
|
| 178 |
+
w, h = image.size
|
| 179 |
+
pred_w, pred_h, w_tar, h_tar, aspect_index = self.get_pred_target_w_h(w, h)
|
| 180 |
+
|
| 181 |
+
crop_percent = 1 - pred_w * pred_h / (w_tar * h_tar)
|
| 182 |
+
if self.crop_percent_thresh > 0 and crop_percent > self.crop_percent_thresh:
|
| 183 |
+
flag_matched = False # filter data
|
| 184 |
+
|
| 185 |
+
if not is_inference:
|
| 186 |
+
## step 2: train: crop and resize
|
| 187 |
+
image = center_crop_and_resize(image, output_size=(pred_h, pred_w))
|
| 188 |
+
else:
|
| 189 |
+
## step 2: inference: resize and padding
|
| 190 |
+
image = resize_with_padding(image, output_size=(pred_h, pred_w))
|
| 191 |
+
|
| 192 |
+
original_size = [h, w]
|
| 193 |
+
target_size = [pred_h, pred_w]
|
| 194 |
+
|
| 195 |
+
return image, original_size, target_size, flag_matched
|
| 196 |
+
|
config.json
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"ILLUMEForConditionalGeneration"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_illume.ILLUMEConfig",
|
| 7 |
+
"AutoModel": "modeling_illume.ILLUMEForConditionalGeneration",
|
| 8 |
+
"AutoModelForCausalLM": "modeling_illume.ILLUMEForConditionalGeneration"
|
| 9 |
+
},
|
| 10 |
+
"ignore_index": -100,
|
| 11 |
+
"image_out_token_index": 282777,
|
| 12 |
+
"image_token_index": 282776,
|
| 13 |
+
"mm_projector_config": {
|
| 14 |
+
"hidden_size": 3584,
|
| 15 |
+
"mm_hidden_size": [
|
| 16 |
+
3584,
|
| 17 |
+
32
|
| 18 |
+
],
|
| 19 |
+
"projector_cfg1": {
|
| 20 |
+
"mlp_depth": 2,
|
| 21 |
+
"type": "MLPProjector"
|
| 22 |
+
},
|
| 23 |
+
"projector_cfg2": {
|
| 24 |
+
"mlp_depth": 2,
|
| 25 |
+
"type": "MLPProjector"
|
| 26 |
+
},
|
| 27 |
+
"trainable": true,
|
| 28 |
+
"type": "MixedProjector"
|
| 29 |
+
},
|
| 30 |
+
"model_type": "illume",
|
| 31 |
+
"special_tokens_ids": {
|
| 32 |
+
"<end_of_image>": 151666,
|
| 33 |
+
"<end_of_level0>": 151669,
|
| 34 |
+
"<end_of_level1>": 151671,
|
| 35 |
+
"<end_of_line>": 151667,
|
| 36 |
+
"<start_of_image>": 151665,
|
| 37 |
+
"<start_of_level0>": 151668,
|
| 38 |
+
"<start_of_level1>": 151670
|
| 39 |
+
},
|
| 40 |
+
"text_config": {
|
| 41 |
+
"_name_or_path": "./Qwen2.5-7B-Instruct-with-vision-tokenizer-32k-96k-level2",
|
| 42 |
+
"architectures": [
|
| 43 |
+
"Qwen2ForCausalLM"
|
| 44 |
+
],
|
| 45 |
+
"bos_token_id": 151643,
|
| 46 |
+
"eos_token_id": 151645,
|
| 47 |
+
"hidden_size": 3584,
|
| 48 |
+
"intermediate_size": 18944,
|
| 49 |
+
"model_type": "qwen2",
|
| 50 |
+
"num_attention_heads": 28,
|
| 51 |
+
"num_hidden_layers": 28,
|
| 52 |
+
"num_key_value_heads": 4,
|
| 53 |
+
"rope_theta": 1000000.0,
|
| 54 |
+
"torch_dtype": "bfloat16",
|
| 55 |
+
"vocab_size": 283175
|
| 56 |
+
},
|
| 57 |
+
"tie_word_embeddings": false,
|
| 58 |
+
"torch_dtype": "bfloat16",
|
| 59 |
+
"transformers_version": "4.44.2",
|
| 60 |
+
"vision_config": {
|
| 61 |
+
"_name_or_path": "./dualvitok/",
|
| 62 |
+
"architectures": [
|
| 63 |
+
"DualViTok"
|
| 64 |
+
],
|
| 65 |
+
"auto_map": {
|
| 66 |
+
"AutoConfig": "configuration_dualvitok.DualViTokConfig",
|
| 67 |
+
"AutoModel": "modeling_dualvitok.DualViTok"
|
| 68 |
+
},
|
| 69 |
+
"model_type": "DualViTok",
|
| 70 |
+
"pixel_decoder": {
|
| 71 |
+
"attn_resolutions": [
|
| 72 |
+
4
|
| 73 |
+
],
|
| 74 |
+
"ch": 384,
|
| 75 |
+
"ch_mult": [
|
| 76 |
+
1,
|
| 77 |
+
1,
|
| 78 |
+
2,
|
| 79 |
+
2,
|
| 80 |
+
4
|
| 81 |
+
],
|
| 82 |
+
"codebook_size": 98304,
|
| 83 |
+
"embed_dim": 64,
|
| 84 |
+
"use_dc_up_down_blocks": true,
|
| 85 |
+
"z_channels": 64
|
| 86 |
+
},
|
| 87 |
+
"pixel_encoder": {
|
| 88 |
+
"attn_resolutions": [
|
| 89 |
+
4
|
| 90 |
+
],
|
| 91 |
+
"ch": 128,
|
| 92 |
+
"ch_mult": [
|
| 93 |
+
1,
|
| 94 |
+
1,
|
| 95 |
+
2,
|
| 96 |
+
2,
|
| 97 |
+
4
|
| 98 |
+
],
|
| 99 |
+
"codebook_size": 98304,
|
| 100 |
+
"embed_dim": 32,
|
| 101 |
+
"use_dc_up_down_blocks": true,
|
| 102 |
+
"z_channels": 32
|
| 103 |
+
},
|
| 104 |
+
"semantic_decoder": {
|
| 105 |
+
"pretrained_semantic_encoder": "Emova-ollm/qwen2vit600m",
|
| 106 |
+
"target_mlp": "norm"
|
| 107 |
+
},
|
| 108 |
+
"semantic_encoder": {
|
| 109 |
+
"pretrained_semantic_encoder": {
|
| 110 |
+
"_name_or_path": "Emova-ollm/qwen2vit600m",
|
| 111 |
+
"add_cross_attention": false,
|
| 112 |
+
"architectures": [
|
| 113 |
+
"Qwen2VisionTransformerPretrainedModel"
|
| 114 |
+
],
|
| 115 |
+
"bad_words_ids": null,
|
| 116 |
+
"begin_suppress_tokens": null,
|
| 117 |
+
"bos_token_id": null,
|
| 118 |
+
"chunk_size_feed_forward": 0,
|
| 119 |
+
"cross_attention_hidden_size": null,
|
| 120 |
+
"decoder_start_token_id": null,
|
| 121 |
+
"depth": 32,
|
| 122 |
+
"diversity_penalty": 0.0,
|
| 123 |
+
"do_sample": false,
|
| 124 |
+
"early_stopping": false,
|
| 125 |
+
"embed_dim": 1280,
|
| 126 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 127 |
+
"eos_token_id": null,
|
| 128 |
+
"exponential_decay_length_penalty": null,
|
| 129 |
+
"finetuning_task": null,
|
| 130 |
+
"forced_bos_token_id": null,
|
| 131 |
+
"forced_eos_token_id": null,
|
| 132 |
+
"hidden_act": "quick_gelu",
|
| 133 |
+
"hidden_size": 3584,
|
| 134 |
+
"id2label": {
|
| 135 |
+
"0": "LABEL_0",
|
| 136 |
+
"1": "LABEL_1"
|
| 137 |
+
},
|
| 138 |
+
"in_channels": 3,
|
| 139 |
+
"in_chans": 3,
|
| 140 |
+
"initializer_range": 0.02,
|
| 141 |
+
"is_decoder": false,
|
| 142 |
+
"is_encoder_decoder": false,
|
| 143 |
+
"label2id": {
|
| 144 |
+
"LABEL_0": 0,
|
| 145 |
+
"LABEL_1": 1
|
| 146 |
+
},
|
| 147 |
+
"length_penalty": 1.0,
|
| 148 |
+
"max_length": 20,
|
| 149 |
+
"min_length": 0,
|
| 150 |
+
"mlp_ratio": 4,
|
| 151 |
+
"model_type": "qwen2_vl",
|
| 152 |
+
"no_repeat_ngram_size": 0,
|
| 153 |
+
"num_beam_groups": 1,
|
| 154 |
+
"num_beams": 1,
|
| 155 |
+
"num_heads": 16,
|
| 156 |
+
"num_return_sequences": 1,
|
| 157 |
+
"output_attentions": false,
|
| 158 |
+
"output_hidden_states": false,
|
| 159 |
+
"output_scores": false,
|
| 160 |
+
"pad_token_id": null,
|
| 161 |
+
"patch_size": 14,
|
| 162 |
+
"prefix": null,
|
| 163 |
+
"problem_type": null,
|
| 164 |
+
"pruned_heads": {},
|
| 165 |
+
"remove_invalid_values": false,
|
| 166 |
+
"repetition_penalty": 1.0,
|
| 167 |
+
"return_dict": true,
|
| 168 |
+
"return_dict_in_generate": false,
|
| 169 |
+
"sep_token_id": null,
|
| 170 |
+
"spatial_merge_size": 2,
|
| 171 |
+
"spatial_patch_size": 14,
|
| 172 |
+
"suppress_tokens": null,
|
| 173 |
+
"task_specific_params": null,
|
| 174 |
+
"temperature": 1.0,
|
| 175 |
+
"temporal_patch_size": 2,
|
| 176 |
+
"tf_legacy_loss": false,
|
| 177 |
+
"tie_encoder_decoder": false,
|
| 178 |
+
"tie_word_embeddings": true,
|
| 179 |
+
"tokenizer_class": null,
|
| 180 |
+
"top_k": 50,
|
| 181 |
+
"top_p": 1.0,
|
| 182 |
+
"torch_dtype": "float32",
|
| 183 |
+
"torchscript": false,
|
| 184 |
+
"transformers_version": "4.44.2",
|
| 185 |
+
"typical_p": 1.0,
|
| 186 |
+
"use_bfloat16": false
|
| 187 |
+
}
|
| 188 |
+
},
|
| 189 |
+
"torch_dtype": "float16"
|
| 190 |
+
}
|
| 191 |
+
}
|
configuration_dualvitok.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" DualViTok model configuration """
|
| 2 |
+
import os
|
| 3 |
+
from typing import List, Union
|
| 4 |
+
|
| 5 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 6 |
+
from transformers.utils import logging
|
| 7 |
+
|
| 8 |
+
from .configuration_movqgan import MoVQConfig
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
logger = logging.get_logger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SemanticEncoderConfig(PretrainedConfig):
|
| 15 |
+
model_type = "DualViTokSemanticEncoder"
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
pretrained_semantic_encoder='Emova-ollm/qwen2vit600m',
|
| 20 |
+
z_channels=32,
|
| 21 |
+
num_blocks=4,
|
| 22 |
+
embed_dim=1280,
|
| 23 |
+
out_layer='linear',
|
| 24 |
+
target_mlp='norm',
|
| 25 |
+
**kwargs
|
| 26 |
+
):
|
| 27 |
+
super().__init__(**kwargs)
|
| 28 |
+
self.pretrained_semantic_encoder = pretrained_semantic_encoder
|
| 29 |
+
self.z_channels = z_channels
|
| 30 |
+
self.num_blocks = num_blocks
|
| 31 |
+
self.out_layer = out_layer
|
| 32 |
+
self.embed_dim = embed_dim
|
| 33 |
+
self.target_mlp = target_mlp
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SemanticDecoderConfig(PretrainedConfig):
|
| 37 |
+
model_type = "DualViTokSemanticDecoder"
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
z_channels=32,
|
| 42 |
+
num_blocks=4,
|
| 43 |
+
embed_dim=1280,
|
| 44 |
+
out_layer='linear_norm',
|
| 45 |
+
out_channels=3584,
|
| 46 |
+
**kwargs
|
| 47 |
+
):
|
| 48 |
+
super().__init__(**kwargs)
|
| 49 |
+
self.z_channels = z_channels
|
| 50 |
+
self.num_blocks = num_blocks
|
| 51 |
+
self.embed_dim = embed_dim
|
| 52 |
+
self.out_layer = out_layer
|
| 53 |
+
self.out_channels = out_channels
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class DualViTokConfig(PretrainedConfig):
|
| 57 |
+
r"""
|
| 58 |
+
This is the configuration class to store the configuration of a [`DualViTok`]. It is used to instantiate an video movq
|
| 59 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 60 |
+
defaults will yield a configuration to the VQ model presented in paper.
|
| 61 |
+
|
| 62 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 63 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
semantic_encoder (`dict`, *optional*):
|
| 68 |
+
Configuration dictionary for the semantic encoder. If `None`, defaults to `SemanticEncoderConfig()`.
|
| 69 |
+
The provided dictionary will be unpacked to initialize a `SemanticEncoderConfig` instance.
|
| 70 |
+
semantic_decoder (`dict`, *optional*):
|
| 71 |
+
Configuration dictionary for the semantic decoder. If `None`, defaults to `SemanticEncoderConfig()`.
|
| 72 |
+
The provided dictionary will be unpacked to initialize a `SemanticEncoderConfig` instance (note: uses `SemanticEncoderConfig` as per current implementation).
|
| 73 |
+
pixel_encoder (`dict`, *optional*):
|
| 74 |
+
Configuration dictionary for the pixel pathway's VQ-VAE model (e.g., `MoVQConfig`). If `None`, defaults to `MoVQConfig()`.
|
| 75 |
+
The provided dictionary will be unpacked to initialize a `MoVQConfig` instance, which defines both encoder and decoder for pixel-level features.
|
| 76 |
+
semantic_quantizer_type (`str`, *optional*, defaults to `'simvq'`):
|
| 77 |
+
Type of the quantizer for semantic tokens (e.g., `'simvq'`, `'ema_simvq'`).
|
| 78 |
+
pixel_quantizer_type (`str`, *optional*, defaults to `'simvq'`):
|
| 79 |
+
Type of the quantizer for pixel tokens (e.g., `'simvq'`, `'ema_simvq'`).
|
| 80 |
+
semantic_quantizer_codebook_size (`int`, *optional*, defaults to 32768):
|
| 81 |
+
Number of entries in the codebook for the semantic quantizer.
|
| 82 |
+
pixel_quantizer_codebook_size (`int`, *optional*, defaults to 98304):
|
| 83 |
+
Number of entries in the codebook for the pixel quantizer.
|
| 84 |
+
attn_implementation (`str`, *optional*, defaults to `'sdpa'`):
|
| 85 |
+
The attention implementation to use (e.g., `'sdpa'`, `'flash_attention_2'`, `'eager'`).
|
| 86 |
+
Can be `'sdpa'` (scaled dot product attention), `'flash_attention_2'` (if available and installed),
|
| 87 |
+
or `'eager'` (the default PyTorch attention implementation).
|
| 88 |
+
|
| 89 |
+
```python
|
| 90 |
+
>>> from transformers import DualViTok, DualViTokConfig
|
| 91 |
+
|
| 92 |
+
>>> # Initializing a video VQ model of configuration
|
| 93 |
+
>>> configuration = DualViTokConfig()
|
| 94 |
+
|
| 95 |
+
>>> # Initializing a model from the VQ model style configuration
|
| 96 |
+
>>> model = DualViTok(configuration)
|
| 97 |
+
|
| 98 |
+
>>> # Accessing the model configuration
|
| 99 |
+
>>> configuration = model.config
|
| 100 |
+
```"""
|
| 101 |
+
|
| 102 |
+
model_type = "DualViTok"
|
| 103 |
+
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
semantic_encoder=None,
|
| 107 |
+
semantic_decoder=None,
|
| 108 |
+
pixel_encoder=None,
|
| 109 |
+
pixel_decoder=None,
|
| 110 |
+
semantic_quantizer_type='simvq',
|
| 111 |
+
pixel_quantizer_type='simvq',
|
| 112 |
+
semantic_quantizer_codebook_size=32768,
|
| 113 |
+
pixel_quantizer_codebook_size=98304,
|
| 114 |
+
attn_implementation='sdpa',
|
| 115 |
+
**kwargs,
|
| 116 |
+
):
|
| 117 |
+
super().__init__(**kwargs)
|
| 118 |
+
if semantic_encoder is None:
|
| 119 |
+
self.semantic_encoder = SemanticEncoderConfig()
|
| 120 |
+
else:
|
| 121 |
+
self.semantic_encoder = SemanticEncoderConfig(**semantic_encoder)
|
| 122 |
+
|
| 123 |
+
if semantic_decoder is None:
|
| 124 |
+
self.semantic_decoder = SemanticDecoderConfig()
|
| 125 |
+
else:
|
| 126 |
+
self.semantic_decoder = SemanticDecoderConfig(**semantic_decoder)
|
| 127 |
+
|
| 128 |
+
self.semantic_quantizer_type = semantic_quantizer_type
|
| 129 |
+
self.pixel_quantizer_type = pixel_quantizer_type
|
| 130 |
+
self.semantic_quantizer_codebook_size = semantic_quantizer_codebook_size
|
| 131 |
+
self.pixel_quantizer_codebook_size = pixel_quantizer_codebook_size
|
| 132 |
+
|
| 133 |
+
if pixel_encoder is None:
|
| 134 |
+
self.pixel_encoder = MoVQConfig()
|
| 135 |
+
else:
|
| 136 |
+
self.pixel_encoder = MoVQConfig(**pixel_encoder)
|
| 137 |
+
|
| 138 |
+
self.pixel_decoder = self.pixel_encoder if pixel_decoder is None else MoVQConfig(**pixel_decoder)
|
| 139 |
+
|
| 140 |
+
self.attn_implementation = attn_implementation
|
| 141 |
+
|
| 142 |
+
@classmethod
|
| 143 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], attn_implementation='sdpa', **kwargs) -> "PretrainedConfig":
|
| 144 |
+
cls._set_token_in_kwargs(kwargs)
|
| 145 |
+
|
| 146 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 147 |
+
|
| 148 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
| 149 |
+
logger.warning(
|
| 150 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 151 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
return cls.from_dict(config_dict, attn_implementation=attn_implementation, **kwargs)
|
configuration_illume.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
"""ILLUME model configuration"""
|
| 3 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 4 |
+
from transformers.utils import logging
|
| 5 |
+
from transformers.models.auto import CONFIG_MAPPING
|
| 6 |
+
|
| 7 |
+
# import the first three to make sure the last one recognize them.
|
| 8 |
+
from .modeling_rope_utils import rope_config_validation
|
| 9 |
+
from .configuration_movqgan import MoVQConfig
|
| 10 |
+
from .configuration_qwen2vit import Qwen2VLVisionConfig
|
| 11 |
+
from .configuration_dualvitok import DualViTokConfig
|
| 12 |
+
|
| 13 |
+
logger = logging.get_logger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ILLUMEConfig(PretrainedConfig):
|
| 17 |
+
r"""
|
| 18 |
+
This is the configuration class to store the configuration of a [`ILLUMEForConditionalGeneration`]. It is used to instantiate an
|
| 19 |
+
ILLUME model according to the specified arguments, defining the model architecture.
|
| 20 |
+
|
| 21 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 22 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `DualViTokConfig`):
|
| 26 |
+
The config object or dictionary of the vision backbone.
|
| 27 |
+
mm_projector_config (`dict`, *optional*, defaults to `None`):
|
| 28 |
+
Configuration for the multimodal projector.
|
| 29 |
+
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Qwen2Config`):
|
| 30 |
+
The config object or dictionary of the text backbone.
|
| 31 |
+
ignore_index (`int`, *optional*, defaults to -100):
|
| 32 |
+
The ignore index for the loss function.
|
| 33 |
+
image_token_index (`int`, *optional*, defaults to 32000):
|
| 34 |
+
The image token index to encode the image prompt.
|
| 35 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 36 |
+
Whether the model's input and output word embeddings should be tied.
|
| 37 |
+
|
| 38 |
+
Example:
|
| 39 |
+
|
| 40 |
+
```python
|
| 41 |
+
>>> from transformers import ILLUMEForConditionalGeneration, ILLUMEConfig, CLIPVisionConfig, LlamaConfig
|
| 42 |
+
|
| 43 |
+
>>> # Initializing a CLIP-vision config
|
| 44 |
+
>>> vision_config = CLIPVisionConfig()
|
| 45 |
+
|
| 46 |
+
>>> # Initializing a Llama config
|
| 47 |
+
>>> text_config = LlamaConfig()
|
| 48 |
+
|
| 49 |
+
>>> # Initializing a ILLUME style configuration
|
| 50 |
+
>>> configuration = ILLUMEConfig(vision_config, text_config)
|
| 51 |
+
|
| 52 |
+
>>> # Initializing a model from the style configuration
|
| 53 |
+
>>> model = ILLUMEForConditionalGeneration(configuration)
|
| 54 |
+
|
| 55 |
+
>>> # Accessing the model configuration
|
| 56 |
+
>>> configuration = model.config
|
| 57 |
+
```"""
|
| 58 |
+
|
| 59 |
+
model_type = "illume"
|
| 60 |
+
is_composition = False
|
| 61 |
+
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
vision_config=None,
|
| 65 |
+
mm_projector_config=None,
|
| 66 |
+
text_config=None,
|
| 67 |
+
ignore_index=-100,
|
| 68 |
+
image_token_index=32000,
|
| 69 |
+
tie_word_embeddings=False,
|
| 70 |
+
**kwargs,
|
| 71 |
+
):
|
| 72 |
+
self.ignore_index = ignore_index
|
| 73 |
+
self.image_token_index = image_token_index
|
| 74 |
+
|
| 75 |
+
if isinstance(vision_config, dict):
|
| 76 |
+
vision_config = DualViTokConfig(**vision_config)
|
| 77 |
+
elif vision_config is None:
|
| 78 |
+
vision_config = DualViTokConfig({
|
| 79 |
+
"semantic_encoder": {
|
| 80 |
+
"pretrained_semantic_encoder":
|
| 81 |
+
"Emova-ollm/qwen2vit600m",
|
| 82 |
+
"z_channels": 32,
|
| 83 |
+
"num_blocks": 4,
|
| 84 |
+
"out_layer": "linear",
|
| 85 |
+
"embed_dim": 1280,
|
| 86 |
+
"target_mlp": "norm"
|
| 87 |
+
},
|
| 88 |
+
"semantic_decoder": {
|
| 89 |
+
"z_channels": 32,
|
| 90 |
+
"num_blocks": 4,
|
| 91 |
+
"embed_dim": 1280,
|
| 92 |
+
"out_layer": "linear_norm",
|
| 93 |
+
"out_channels": 3584
|
| 94 |
+
},
|
| 95 |
+
"semantic_quantizer_type": "simvq",
|
| 96 |
+
"pixel_quantizer_type": "simvq",
|
| 97 |
+
"semantic_quantizer_codebook_size": 32768,
|
| 98 |
+
"pixel_quantizer_codebook_size": 98304,
|
| 99 |
+
"attn_implementation": "sdpa",
|
| 100 |
+
"pixel_encoder": {
|
| 101 |
+
"codebook_size": 98304,
|
| 102 |
+
"embed_dim": 32,
|
| 103 |
+
"z_channels": 32,
|
| 104 |
+
"double_z": False,
|
| 105 |
+
"in_channels": 3,
|
| 106 |
+
"out_channels": 3,
|
| 107 |
+
"ch": 128,
|
| 108 |
+
"ch_mult": [ 1, 1, 2, 2, 4 ],
|
| 109 |
+
"num_res_blocks": 2,
|
| 110 |
+
"attn_resolutions": [ 4 ],
|
| 111 |
+
"dropout": 0.0,
|
| 112 |
+
"use_dc_up_down_blocks": True
|
| 113 |
+
},
|
| 114 |
+
"pixel_decoder": {
|
| 115 |
+
"codebook_size": 98304,
|
| 116 |
+
"embed_dim": 64,
|
| 117 |
+
"z_channels": 64,
|
| 118 |
+
"double_z": False,
|
| 119 |
+
"in_channels": 3,
|
| 120 |
+
"out_channels": 3,
|
| 121 |
+
"ch": 384,
|
| 122 |
+
"ch_mult": [ 1, 1, 2, 2, 4 ],
|
| 123 |
+
"num_res_blocks": 2,
|
| 124 |
+
"attn_resolutions": [4],
|
| 125 |
+
"dropout": 0.0,
|
| 126 |
+
"use_dc_up_down_blocks": True
|
| 127 |
+
},
|
| 128 |
+
}
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
self.vision_config = vision_config
|
| 132 |
+
self.mm_projector_config = mm_projector_config
|
| 133 |
+
if isinstance(text_config, dict):
|
| 134 |
+
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "qwen2"
|
| 135 |
+
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
| 136 |
+
elif text_config is None:
|
| 137 |
+
text_config = CONFIG_MAPPING["qwen2"]()
|
| 138 |
+
|
| 139 |
+
self.text_config = text_config
|
| 140 |
+
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
configuration_movqgan.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" MoVQ model configuration """
|
| 2 |
+
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 6 |
+
from transformers.utils import logging
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
logger = logging.get_logger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MoVQConfig(PretrainedConfig):
|
| 13 |
+
r"""
|
| 14 |
+
This is the configuration class to store the configuration of a [`MoVQ`]. It is used to instantiate an video movq
|
| 15 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 16 |
+
defaults will yield a configuration to the VQ model presented in paper.
|
| 17 |
+
|
| 18 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 19 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
codebook_size (`int`, *optional*, defaults to 32768):
|
| 24 |
+
Codebook size of the VQ model.
|
| 25 |
+
embed_dim (`int`, *optional*, defaults to 4):
|
| 26 |
+
Dimension of the quantized vector in codebook.
|
| 27 |
+
z_channels (`int`, *optional*, defaults to 4):
|
| 28 |
+
Dimension of the output channel of encoder and the input channel of decoder
|
| 29 |
+
double_z (`bool`, *optional*, defaults to False):
|
| 30 |
+
Whether double the output dim of the encoder.
|
| 31 |
+
in_channels (`int`, *optional*, defaults to 3):
|
| 32 |
+
Input channel of encoder.
|
| 33 |
+
out_channels (`int`, *optional*, defaults to 3):
|
| 34 |
+
Output channel of decoder.
|
| 35 |
+
ch (`int`, *optional*, defaults to 256):
|
| 36 |
+
Basic channel number of the intermediate blocks.
|
| 37 |
+
ch_mult (`List[int]`, *optional*, defaults to `[1, 2, 2, 4]`):
|
| 38 |
+
Channel scaling factor of the intermediate blocks.
|
| 39 |
+
num_res_blocks (`int`, *optional*, defaults to 2):
|
| 40 |
+
Residual block number in each stage.
|
| 41 |
+
attn_resolutions (`List[int]`, *optional*, defaults to 3):
|
| 42 |
+
Stage indices to apply attention.
|
| 43 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
| 44 |
+
Dropout probability.
|
| 45 |
+
use_dc_up_down_blocks (`bool`, *optional*, defaults to `False`):
|
| 46 |
+
Whether to use the DC up-down blocks.
|
| 47 |
+
|
| 48 |
+
```python
|
| 49 |
+
>>> from transformers import MoVQ, MoVQConfig
|
| 50 |
+
|
| 51 |
+
>>> # Initializing a video VQ model of configuration
|
| 52 |
+
>>> configuration = MoVQConfig()
|
| 53 |
+
|
| 54 |
+
>>> # Initializing a model from the VQ model style configuration
|
| 55 |
+
>>> model = MoVQModel(configuration)
|
| 56 |
+
|
| 57 |
+
>>> # Accessing the model configuration
|
| 58 |
+
>>> configuration = model.config
|
| 59 |
+
```"""
|
| 60 |
+
|
| 61 |
+
model_type = "MoVQ"
|
| 62 |
+
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
codebook_size: int = 32768,
|
| 66 |
+
embed_dim: int = 4,
|
| 67 |
+
z_channels: int = 4,
|
| 68 |
+
double_z: bool = False,
|
| 69 |
+
in_channels: int = 3,
|
| 70 |
+
out_channels: int = 3,
|
| 71 |
+
ch: int = 256,
|
| 72 |
+
ch_mult: List[int] = [1, 2, 2, 4],
|
| 73 |
+
num_res_blocks: int = 2,
|
| 74 |
+
attn_resolutions: List[int] = [3],
|
| 75 |
+
dropout: float = 0.0,
|
| 76 |
+
use_dc_up_down_blocks=False,
|
| 77 |
+
**kwargs,
|
| 78 |
+
):
|
| 79 |
+
super().__init__(**kwargs)
|
| 80 |
+
|
| 81 |
+
self.codebook_size = codebook_size
|
| 82 |
+
self.embed_dim = embed_dim
|
| 83 |
+
self.z_channels = z_channels
|
| 84 |
+
self.double_z = double_z
|
| 85 |
+
self.in_channels = in_channels
|
| 86 |
+
self.out_channels = out_channels
|
| 87 |
+
self.ch = ch
|
| 88 |
+
self.ch_mult = ch_mult
|
| 89 |
+
self.num_res_blocks = num_res_blocks
|
| 90 |
+
self.attn_resolutions = attn_resolutions
|
| 91 |
+
self.dropout = dropout
|
| 92 |
+
self.use_dc_up_down_blocks = use_dc_up_down_blocks
|
configuration_qwen2vit.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Qwen2VL model configuration"""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
from typing import Union
|
| 19 |
+
|
| 20 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 21 |
+
from transformers.utils import logging
|
| 22 |
+
from .modeling_rope_utils import rope_config_validation
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Qwen2VLVisionConfig(PretrainedConfig):
|
| 28 |
+
model_type = "qwen2_vl"
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
depth=32,
|
| 33 |
+
embed_dim=1280,
|
| 34 |
+
hidden_size=3584,
|
| 35 |
+
hidden_act="quick_gelu",
|
| 36 |
+
mlp_ratio=4,
|
| 37 |
+
num_heads=16,
|
| 38 |
+
in_channels=3,
|
| 39 |
+
patch_size=14,
|
| 40 |
+
spatial_merge_size=2,
|
| 41 |
+
temporal_patch_size=2,
|
| 42 |
+
attn_implementation='eager',
|
| 43 |
+
init_weights=False,
|
| 44 |
+
**kwargs,
|
| 45 |
+
):
|
| 46 |
+
super().__init__(**kwargs)
|
| 47 |
+
|
| 48 |
+
self.depth = depth
|
| 49 |
+
self.embed_dim = embed_dim
|
| 50 |
+
self.hidden_size = hidden_size
|
| 51 |
+
self.hidden_act = hidden_act
|
| 52 |
+
self.mlp_ratio = mlp_ratio
|
| 53 |
+
self.num_heads = num_heads
|
| 54 |
+
self.in_channels = in_channels
|
| 55 |
+
self.patch_size = patch_size
|
| 56 |
+
self.spatial_merge_size = spatial_merge_size
|
| 57 |
+
self.temporal_patch_size = temporal_patch_size
|
| 58 |
+
self.attn_implementation = attn_implementation if attn_implementation else 'eager'
|
| 59 |
+
|
| 60 |
+
self.init_weights = init_weights
|
| 61 |
+
|
| 62 |
+
@classmethod
|
| 63 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
| 64 |
+
cls._set_token_in_kwargs(kwargs)
|
| 65 |
+
|
| 66 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 67 |
+
|
| 68 |
+
# if config_dict.get("model_type") == "qwen2_vl":
|
| 69 |
+
# config_dict = config_dict["vision_config"]
|
| 70 |
+
|
| 71 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
| 72 |
+
logger.warning(
|
| 73 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 74 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
return cls.from_dict(config_dict, **kwargs)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class Qwen2VLConfig(PretrainedConfig):
|
| 81 |
+
r"""
|
| 82 |
+
This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a
|
| 83 |
+
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 84 |
+
with the defaults will yield a similar configuration to that of
|
| 85 |
+
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
|
| 86 |
+
|
| 87 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 88 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
vocab_size (`int`, *optional*, defaults to 152064):
|
| 93 |
+
Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the
|
| 94 |
+
`inputs_ids` passed when calling [`Qwen2VLModel`]
|
| 95 |
+
hidden_size (`int`, *optional*, defaults to 8192):
|
| 96 |
+
Dimension of the hidden representations.
|
| 97 |
+
intermediate_size (`int`, *optional*, defaults to 29568):
|
| 98 |
+
Dimension of the MLP representations.
|
| 99 |
+
num_hidden_layers (`int`, *optional*, defaults to 80):
|
| 100 |
+
Number of hidden layers in the Transformer encoder.
|
| 101 |
+
num_attention_heads (`int`, *optional*, defaults to 64):
|
| 102 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 103 |
+
num_key_value_heads (`int`, *optional*, defaults to 8):
|
| 104 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 105 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 106 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 107 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 108 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
| 109 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
| 110 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 111 |
+
The non-linear activation function (function or string) in the decoder.
|
| 112 |
+
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
| 113 |
+
The maximum sequence length that this model might ever be used with.
|
| 114 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 115 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 116 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 117 |
+
The epsilon used by the rms normalization layers.
|
| 118 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 119 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 120 |
+
relevant if `config.is_decoder=True`.
|
| 121 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 122 |
+
Whether the model's input and output word embeddings should be tied.
|
| 123 |
+
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
| 124 |
+
The base period of the RoPE embeddings.
|
| 125 |
+
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
| 126 |
+
Whether to use sliding window attention.
|
| 127 |
+
sliding_window (`int`, *optional*, defaults to 4096):
|
| 128 |
+
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
| 129 |
+
max_window_layers (`int`, *optional*, defaults to 80):
|
| 130 |
+
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
| 131 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 132 |
+
The dropout ratio for the attention probabilities.
|
| 133 |
+
vision_config (`Dict`, *optional*):
|
| 134 |
+
The config for the visual encoder initialization.
|
| 135 |
+
rope_scaling (`Dict`, *optional*):
|
| 136 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 137 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 138 |
+
accordingly.
|
| 139 |
+
Expected contents:
|
| 140 |
+
`rope_type` (`str`):
|
| 141 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
| 142 |
+
'llama3'], with 'default' being the original RoPE implementation.
|
| 143 |
+
`factor` (`float`, *optional*):
|
| 144 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 145 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 146 |
+
original maximum pre-trained length.
|
| 147 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
| 148 |
+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
| 149 |
+
pretraining.
|
| 150 |
+
`attention_factor` (`float`, *optional*):
|
| 151 |
+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
| 152 |
+
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
| 153 |
+
`factor` field to infer the suggested value.
|
| 154 |
+
`beta_fast` (`float`, *optional*):
|
| 155 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 156 |
+
ramp function. If unspecified, it defaults to 32.
|
| 157 |
+
`beta_slow` (`float`, *optional*):
|
| 158 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 159 |
+
ramp function. If unspecified, it defaults to 1.
|
| 160 |
+
`short_factor` (`List[float]`, *optional*):
|
| 161 |
+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
| 162 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 163 |
+
size divided by the number of attention heads divided by 2
|
| 164 |
+
`long_factor` (`List[float]`, *optional*):
|
| 165 |
+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
| 166 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 167 |
+
size divided by the number of attention heads divided by 2
|
| 168 |
+
`low_freq_factor` (`float`, *optional*):
|
| 169 |
+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
| 170 |
+
`high_freq_factor` (`float`, *optional*):
|
| 171 |
+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
| 172 |
+
|
| 173 |
+
```python
|
| 174 |
+
>>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig
|
| 175 |
+
|
| 176 |
+
>>> # Initializing a Qwen2VL style configuration
|
| 177 |
+
>>> configuration = Qwen2VLConfig()
|
| 178 |
+
|
| 179 |
+
>>> # Initializing a model from the Qwen2-VL-7B style configuration
|
| 180 |
+
>>> model = Qwen2VLForConditionalGeneration(configuration)
|
| 181 |
+
|
| 182 |
+
>>> # Accessing the model configuration
|
| 183 |
+
>>> configuration = model.config
|
| 184 |
+
```"""
|
| 185 |
+
|
| 186 |
+
model_type = "qwen2_vl"
|
| 187 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 188 |
+
|
| 189 |
+
def __init__(
|
| 190 |
+
self,
|
| 191 |
+
vocab_size=152064,
|
| 192 |
+
hidden_size=8192,
|
| 193 |
+
intermediate_size=29568,
|
| 194 |
+
num_hidden_layers=80,
|
| 195 |
+
num_attention_heads=64,
|
| 196 |
+
num_key_value_heads=8,
|
| 197 |
+
hidden_act="silu",
|
| 198 |
+
max_position_embeddings=32768,
|
| 199 |
+
initializer_range=0.02,
|
| 200 |
+
rms_norm_eps=1e-05,
|
| 201 |
+
use_cache=True,
|
| 202 |
+
tie_word_embeddings=False,
|
| 203 |
+
rope_theta=1000000.0,
|
| 204 |
+
use_sliding_window=False,
|
| 205 |
+
sliding_window=4096,
|
| 206 |
+
max_window_layers=80,
|
| 207 |
+
attention_dropout=0.0,
|
| 208 |
+
vision_config=None,
|
| 209 |
+
rope_scaling=None,
|
| 210 |
+
**kwargs,
|
| 211 |
+
):
|
| 212 |
+
if isinstance(vision_config, dict):
|
| 213 |
+
self.vision_config = Qwen2VLVisionConfig(**vision_config)
|
| 214 |
+
elif vision_config is None:
|
| 215 |
+
self.vision_config = Qwen2VLVisionConfig()
|
| 216 |
+
|
| 217 |
+
self.vocab_size = vocab_size
|
| 218 |
+
self.max_position_embeddings = max_position_embeddings
|
| 219 |
+
self.hidden_size = hidden_size
|
| 220 |
+
self.intermediate_size = intermediate_size
|
| 221 |
+
self.num_hidden_layers = num_hidden_layers
|
| 222 |
+
self.num_attention_heads = num_attention_heads
|
| 223 |
+
self.use_sliding_window = use_sliding_window
|
| 224 |
+
self.sliding_window = sliding_window
|
| 225 |
+
self.max_window_layers = max_window_layers
|
| 226 |
+
|
| 227 |
+
# for backward compatibility
|
| 228 |
+
if num_key_value_heads is None:
|
| 229 |
+
num_key_value_heads = num_attention_heads
|
| 230 |
+
|
| 231 |
+
self.num_key_value_heads = num_key_value_heads
|
| 232 |
+
self.hidden_act = hidden_act
|
| 233 |
+
self.initializer_range = initializer_range
|
| 234 |
+
self.rms_norm_eps = rms_norm_eps
|
| 235 |
+
self.use_cache = use_cache
|
| 236 |
+
self.rope_theta = rope_theta
|
| 237 |
+
self.attention_dropout = attention_dropout
|
| 238 |
+
self.rope_scaling = rope_scaling
|
| 239 |
+
|
| 240 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 241 |
+
# BC: if there is a 'type' field, move it to 'rope_type'.
|
| 242 |
+
# and change type from 'mrope' to 'default'
|
| 243 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 244 |
+
if self.rope_scaling["type"] == "mrope":
|
| 245 |
+
self.rope_scaling["type"] = "default"
|
| 246 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 247 |
+
rope_config_validation(self)
|
| 248 |
+
|
| 249 |
+
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
generation_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 151643,
|
| 4 |
+
"eos_token_id": 151645,
|
| 5 |
+
"transformers_version": "4.44.2"
|
| 6 |
+
}
|
image_processing_dualvitok.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Image processor class for DualViTok."""
|
| 2 |
+
|
| 3 |
+
from transformers.utils import TensorType, is_vision_available, logging
|
| 4 |
+
|
| 5 |
+
from .image_processing_movqgan import MoVQImageProcessor
|
| 6 |
+
|
| 7 |
+
logger = logging.get_logger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DualViTokImageProcessor(MoVQImageProcessor):
|
| 11 |
+
r"""
|
| 12 |
+
Constructs a DualViTok image processor that dynamically resizes images based on the original images.
|
| 13 |
+
This image processor is based on MoVQImageProcessor with spatial_factor of 16.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
model_input_names = ["pixel_values"]
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
*args,
|
| 21 |
+
spatial_factor: int = 16,
|
| 22 |
+
**kwargs,
|
| 23 |
+
) -> None:
|
| 24 |
+
super().__init__(*args, spatial_factor=spatial_factor, **kwargs)
|
image_processing_movqgan.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Image processor class for MoVQ."""
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
from typing import Dict, List, Optional, Union
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
| 10 |
+
from transformers.image_transforms import (
|
| 11 |
+
convert_to_rgb,
|
| 12 |
+
resize,
|
| 13 |
+
to_channel_dimension_format,
|
| 14 |
+
)
|
| 15 |
+
from transformers.image_utils import (
|
| 16 |
+
IMAGENET_STANDARD_MEAN,
|
| 17 |
+
IMAGENET_STANDARD_STD,
|
| 18 |
+
ChannelDimension,
|
| 19 |
+
ImageInput,
|
| 20 |
+
PILImageResampling,
|
| 21 |
+
get_image_size,
|
| 22 |
+
infer_channel_dimension_format,
|
| 23 |
+
is_scaled_image,
|
| 24 |
+
make_list_of_images,
|
| 25 |
+
to_numpy_array,
|
| 26 |
+
valid_images,
|
| 27 |
+
validate_preprocess_arguments,
|
| 28 |
+
)
|
| 29 |
+
from transformers.utils import TensorType, is_vision_available, logging
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if is_vision_available():
|
| 36 |
+
from PIL import Image
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def smart_resize(
|
| 40 |
+
height: int, width: int, factor: int = 8, min_pixels: int = 512 * 512, max_pixels: int = 1024 * 1024
|
| 41 |
+
):
|
| 42 |
+
"""Rescales the image so that the following conditions are met:
|
| 43 |
+
|
| 44 |
+
1. Both dimensions (height and width) are divisible by 'factor'.
|
| 45 |
+
|
| 46 |
+
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
| 47 |
+
|
| 48 |
+
3. The aspect ratio of the image is maintained as closely as possible.
|
| 49 |
+
|
| 50 |
+
"""
|
| 51 |
+
# if height < factor or width < factor:
|
| 52 |
+
# raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
|
| 53 |
+
# elif max(height, width) / min(height, width) > 5:
|
| 54 |
+
# raise ValueError(
|
| 55 |
+
# f"absolute aspect ratio must be smaller than 5, got {max(height, width) / min(height, width)}"
|
| 56 |
+
# )
|
| 57 |
+
|
| 58 |
+
h_bar = round(height / factor) * factor
|
| 59 |
+
w_bar = round(width / factor) * factor
|
| 60 |
+
if h_bar * w_bar > max_pixels:
|
| 61 |
+
beta = math.sqrt((height * width) / max_pixels)
|
| 62 |
+
h_bar = math.floor(height / beta / factor) * factor
|
| 63 |
+
w_bar = math.floor(width / beta / factor) * factor
|
| 64 |
+
elif h_bar * w_bar < min_pixels:
|
| 65 |
+
beta = math.sqrt(min_pixels / (height * width))
|
| 66 |
+
h_bar = math.ceil(height * beta / factor) * factor
|
| 67 |
+
w_bar = math.ceil(width * beta / factor) * factor
|
| 68 |
+
|
| 69 |
+
return max(h_bar, factor), max(w_bar, factor)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class MoVQImageProcessor(BaseImageProcessor):
|
| 73 |
+
r"""
|
| 74 |
+
Constructs a MoVQ image processor that dynamically resizes images based on the original images.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 78 |
+
Whether to resize the image's (height, width) dimensions.
|
| 79 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
| 80 |
+
Resampling filter to use when resizing the image.
|
| 81 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 82 |
+
Whether to rescale the image by the specified scale `rescale_factor`.
|
| 83 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 84 |
+
Scale factor to use if rescaling the image.
|
| 85 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 86 |
+
Whether to normalize the image.
|
| 87 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
| 88 |
+
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
| 89 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
|
| 90 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
| 91 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
| 92 |
+
Whether to convert the image to RGB.
|
| 93 |
+
min_pixels (`int`, *optional*, defaults to `512 * 512`):
|
| 94 |
+
The min pixels of the image to resize the image.
|
| 95 |
+
max_pixels (`int`, *optional*, defaults to `1024 * 1024`):
|
| 96 |
+
The max pixels of the image to resize the image.
|
| 97 |
+
spatial_factor (`int`, *optional*, defautls to 8):
|
| 98 |
+
The spatial downsample factor the image will be downsampled in feature extracting phase
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
model_input_names = ["pixel_values"]
|
| 102 |
+
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
do_resize: bool = True,
|
| 106 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 107 |
+
do_rescale: bool = True,
|
| 108 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 109 |
+
do_normalize: bool = True,
|
| 110 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 111 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 112 |
+
do_convert_rgb: bool = True,
|
| 113 |
+
min_pixels: int = 32 * 32,
|
| 114 |
+
max_pixels: int = 1024 * 1024,
|
| 115 |
+
spatial_factor: int = 8,
|
| 116 |
+
**kwargs,
|
| 117 |
+
) -> None:
|
| 118 |
+
super().__init__(**kwargs)
|
| 119 |
+
self.do_resize = do_resize
|
| 120 |
+
self.resample = resample
|
| 121 |
+
self.do_rescale = do_rescale
|
| 122 |
+
self.rescale_factor = rescale_factor
|
| 123 |
+
self.do_normalize = do_normalize
|
| 124 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
| 125 |
+
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
| 126 |
+
self.min_pixels = min_pixels
|
| 127 |
+
self.max_pixels = max_pixels
|
| 128 |
+
self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
|
| 129 |
+
self.do_convert_rgb = do_convert_rgb
|
| 130 |
+
self.spatial_factor = spatial_factor
|
| 131 |
+
|
| 132 |
+
def _preprocess(
|
| 133 |
+
self,
|
| 134 |
+
images: ImageInput,
|
| 135 |
+
do_resize: Optional[bool] = None,
|
| 136 |
+
resample: PILImageResampling = None,
|
| 137 |
+
do_rescale: Optional[bool] = None,
|
| 138 |
+
rescale_factor: Optional[float] = None,
|
| 139 |
+
do_normalize: Optional[bool] = None,
|
| 140 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 141 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 142 |
+
do_convert_rgb: Optional[bool] = None,
|
| 143 |
+
spatial_factor: Optional[int] = None,
|
| 144 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 145 |
+
output_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
|
| 146 |
+
):
|
| 147 |
+
"""
|
| 148 |
+
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
images (`ImageInput`):
|
| 152 |
+
Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
|
| 153 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 154 |
+
Whether to resize the image.
|
| 155 |
+
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
| 156 |
+
Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
|
| 157 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 158 |
+
Whether to rescale the image.
|
| 159 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 160 |
+
Scale factor to use if rescaling the image.
|
| 161 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 162 |
+
Whether to normalize the image.
|
| 163 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 164 |
+
Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
| 165 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 166 |
+
Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
| 167 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 168 |
+
Whether to convert the image to RGB.
|
| 169 |
+
spatial_factor (`int`, *optional*, defaults to `self.spatial_factor`):
|
| 170 |
+
The spatial downsample factor the image will be downsampled in feature extracting phase
|
| 171 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 172 |
+
The channel dimension format for the input image. Can be one of:
|
| 173 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 174 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 175 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 176 |
+
output_data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 177 |
+
The channel dimension format for the output image. Can be one of:
|
| 178 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 179 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 180 |
+
- Unset: Use the channel dimension format of the input image.
|
| 181 |
+
"""
|
| 182 |
+
spatial_factor = spatial_factor if spatial_factor is not None else self.spatial_factor
|
| 183 |
+
|
| 184 |
+
images = make_list_of_images(images)
|
| 185 |
+
if do_convert_rgb:
|
| 186 |
+
images = [convert_to_rgb(image) for image in images]
|
| 187 |
+
|
| 188 |
+
# All transformations expect numpy arrays.
|
| 189 |
+
images = [to_numpy_array(image) for image in images]
|
| 190 |
+
|
| 191 |
+
if is_scaled_image(images[0]) and do_rescale:
|
| 192 |
+
logger.warning_once(
|
| 193 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 194 |
+
"pixel_values.append()images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
if input_data_format is None:
|
| 198 |
+
# We assume that all images have the same channel dimension format.
|
| 199 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 200 |
+
|
| 201 |
+
height, width = get_image_size(images[0], channel_dim=input_data_format)
|
| 202 |
+
resized_height, resized_width = height, width
|
| 203 |
+
processed_images = []
|
| 204 |
+
for image in images:
|
| 205 |
+
if do_resize:
|
| 206 |
+
resized_height, resized_width = smart_resize(
|
| 207 |
+
height,
|
| 208 |
+
width,
|
| 209 |
+
factor=spatial_factor,
|
| 210 |
+
min_pixels=self.min_pixels,
|
| 211 |
+
max_pixels=self.max_pixels,
|
| 212 |
+
)
|
| 213 |
+
image = resize(
|
| 214 |
+
image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
if do_rescale:
|
| 218 |
+
image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
|
| 219 |
+
|
| 220 |
+
if do_normalize:
|
| 221 |
+
image = self.normalize(
|
| 222 |
+
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
image = to_channel_dimension_format(image, output_data_format, input_channel_dim=input_data_format)
|
| 226 |
+
processed_images.append(image)
|
| 227 |
+
|
| 228 |
+
image = np.array(processed_images)
|
| 229 |
+
return image
|
| 230 |
+
|
| 231 |
+
def preprocess(
|
| 232 |
+
self,
|
| 233 |
+
images: ImageInput,
|
| 234 |
+
do_resize: Optional[bool] = None,
|
| 235 |
+
resample: PILImageResampling = None,
|
| 236 |
+
do_rescale: Optional[bool] = None,
|
| 237 |
+
rescale_factor: Optional[float] = None,
|
| 238 |
+
do_normalize: Optional[bool] = None,
|
| 239 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 240 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 241 |
+
do_convert_rgb: Optional[bool] = None,
|
| 242 |
+
spatial_factor: Optional[int] = None,
|
| 243 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 244 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 245 |
+
output_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
|
| 246 |
+
):
|
| 247 |
+
"""
|
| 248 |
+
Args:
|
| 249 |
+
images (`ImageInput`):
|
| 250 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 251 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 252 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 253 |
+
Whether to resize the image.
|
| 254 |
+
resample (`int`, *optional*, defaults to `self.resample`):
|
| 255 |
+
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
| 256 |
+
has an effect if `do_resize` is set to `True`.
|
| 257 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 258 |
+
Whether to rescale the image.
|
| 259 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 260 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 261 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 262 |
+
Whether to normalize the image.
|
| 263 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 264 |
+
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
| 265 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 266 |
+
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
| 267 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 268 |
+
Whether to convert the image to RGB.
|
| 269 |
+
spatial_factor (`int`, *optional*, defaults to `self.spatial_factor`):
|
| 270 |
+
The spatial downsample factor the image will be downsampled in feature extracting phase
|
| 271 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 272 |
+
The type of tensors to return. Can be one of:
|
| 273 |
+
- Unset: Return a list of `np.ndarray`.
|
| 274 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 275 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 276 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 277 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 278 |
+
from the input image. Can be one of:
|
| 279 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 280 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 281 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 282 |
+
output_data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 283 |
+
The channel dimension format for the output image. Can be one of:
|
| 284 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 285 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 286 |
+
- Unset: Use the channel dimension format of the input image.
|
| 287 |
+
"""
|
| 288 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 289 |
+
resample = resample if resample is not None else self.resample
|
| 290 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 291 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 292 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 293 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 294 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 295 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
| 296 |
+
spatial_factor = spatial_factor if spatial_factor is not None else self.spatial_factor
|
| 297 |
+
|
| 298 |
+
images = make_list_of_images(images)
|
| 299 |
+
if images is None or not valid_images(images):
|
| 300 |
+
raise ValueError(
|
| 301 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 302 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
validate_preprocess_arguments(
|
| 306 |
+
rescale_factor=rescale_factor,
|
| 307 |
+
do_normalize=do_normalize,
|
| 308 |
+
image_mean=image_mean,
|
| 309 |
+
image_std=image_std,
|
| 310 |
+
do_resize=do_resize,
|
| 311 |
+
size=self.size,
|
| 312 |
+
resample=resample,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
pixel_values = []
|
| 316 |
+
for image in images:
|
| 317 |
+
norm_image = self._preprocess(
|
| 318 |
+
image,
|
| 319 |
+
do_resize=do_resize,
|
| 320 |
+
resample=resample,
|
| 321 |
+
do_rescale=do_rescale,
|
| 322 |
+
rescale_factor=rescale_factor,
|
| 323 |
+
do_normalize=do_normalize,
|
| 324 |
+
image_mean=image_mean,
|
| 325 |
+
image_std=image_std,
|
| 326 |
+
do_convert_rgb=do_convert_rgb,
|
| 327 |
+
spatial_factor=spatial_factor,
|
| 328 |
+
input_data_format=input_data_format,
|
| 329 |
+
output_data_format=output_data_format,
|
| 330 |
+
)
|
| 331 |
+
pixel_values.extend(norm_image)
|
| 332 |
+
|
| 333 |
+
pixel_values = np.array(pixel_values)
|
| 334 |
+
data = {"pixel_values": pixel_values}
|
| 335 |
+
|
| 336 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 337 |
+
|
| 338 |
+
def postprocess(
|
| 339 |
+
self,
|
| 340 |
+
images: ImageInput,
|
| 341 |
+
do_rescale: Optional[bool] = None,
|
| 342 |
+
rescale_factor: Optional[float] = None,
|
| 343 |
+
do_normalize: Optional[bool] = None,
|
| 344 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 345 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 346 |
+
return_tensors: Optional[Union[str, TensorType]] = "PIL.Image.Image",
|
| 347 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 348 |
+
):
|
| 349 |
+
"""
|
| 350 |
+
Postprocess an image or batch of images tensor. Postprocess is the reverse process of preprocess.
|
| 351 |
+
The parameters should be same as in preprocess.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
images (`ImageInput`):
|
| 355 |
+
Image to postprocess. Expects a single or batch of images with pixel values ranging from -1 to 1.
|
| 356 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 357 |
+
Whether to rescale the image.
|
| 358 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 359 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 360 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 361 |
+
Whether to normalize the image.
|
| 362 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 363 |
+
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
| 364 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 365 |
+
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
| 366 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 367 |
+
The type of tensors to return. Can be one of:
|
| 368 |
+
- Unset: Return a list of `np.ndarray`.
|
| 369 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 370 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 371 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 372 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 373 |
+
from the input image. Can be one of:
|
| 374 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 375 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 376 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 377 |
+
"""
|
| 378 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 379 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 380 |
+
rescale_factor = 1 / rescale_factor
|
| 381 |
+
|
| 382 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 383 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 384 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 385 |
+
image_mean, image_std = self.inverse_meanstd(image_mean, image_std)
|
| 386 |
+
|
| 387 |
+
images = make_list_of_images(images)
|
| 388 |
+
if isinstance(images[0], Image.Image):
|
| 389 |
+
return images if len(images) > 1 else images[0]
|
| 390 |
+
|
| 391 |
+
if input_data_format is None:
|
| 392 |
+
# We assume that all images have the same channel dimension format.
|
| 393 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 394 |
+
|
| 395 |
+
pixel_values = []
|
| 396 |
+
for image in images:
|
| 397 |
+
image = to_numpy_array(image)
|
| 398 |
+
if do_normalize:
|
| 399 |
+
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
| 400 |
+
|
| 401 |
+
if do_rescale:
|
| 402 |
+
image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
|
| 403 |
+
image = image.clip(0, 255).astype(np.uint8)
|
| 404 |
+
|
| 405 |
+
if do_normalize and do_rescale and return_tensors == "PIL.Image.Image":
|
| 406 |
+
image = to_channel_dimension_format(image, ChannelDimension.LAST, input_channel_dim=input_data_format)
|
| 407 |
+
pixel_values.append(Image.fromarray(image))
|
| 408 |
+
else:
|
| 409 |
+
pixel_values.extend(image)
|
| 410 |
+
|
| 411 |
+
data = {"pixel_values": pixel_values}
|
| 412 |
+
return_tensors = return_tensors if return_tensors != "PIL.Image.Image" else None
|
| 413 |
+
|
| 414 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 415 |
+
|
| 416 |
+
def inverse_meanstd(self, image_mean, image_std):
|
| 417 |
+
image_mean = self.to_tuple(image_mean)
|
| 418 |
+
image_std = self.to_tuple(image_std)
|
| 419 |
+
|
| 420 |
+
rev_image_mean = tuple(-m / s for m, s in zip(image_mean, image_std))
|
| 421 |
+
rev_image_std = tuple(1 / s for s in image_std)
|
| 422 |
+
|
| 423 |
+
return rev_image_mean, rev_image_std
|
| 424 |
+
|
| 425 |
+
def to_tuple(self, value, dim=3):
|
| 426 |
+
if isinstance(value, (int, float)):
|
| 427 |
+
return (value,) * dim
|
| 428 |
+
|
| 429 |
+
return tuple(value)
|
image_utils.py
ADDED
|
@@ -0,0 +1,812 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import base64
|
| 17 |
+
import os
|
| 18 |
+
from io import BytesIO
|
| 19 |
+
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import requests
|
| 23 |
+
from packaging import version
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
from transformers.utils import (
|
| 27 |
+
ExplicitEnum,
|
| 28 |
+
is_jax_tensor,
|
| 29 |
+
is_numpy_array,
|
| 30 |
+
is_tf_tensor,
|
| 31 |
+
is_torch_available,
|
| 32 |
+
is_torch_tensor,
|
| 33 |
+
is_torchvision_available,
|
| 34 |
+
is_vision_available,
|
| 35 |
+
logging,
|
| 36 |
+
requires_backends,
|
| 37 |
+
to_numpy,
|
| 38 |
+
)
|
| 39 |
+
from transformers.utils.constants import ( # noqa: F401
|
| 40 |
+
IMAGENET_DEFAULT_MEAN,
|
| 41 |
+
IMAGENET_DEFAULT_STD,
|
| 42 |
+
IMAGENET_STANDARD_MEAN,
|
| 43 |
+
IMAGENET_STANDARD_STD,
|
| 44 |
+
OPENAI_CLIP_MEAN,
|
| 45 |
+
OPENAI_CLIP_STD,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if is_vision_available():
|
| 50 |
+
import PIL.Image
|
| 51 |
+
import PIL.ImageOps
|
| 52 |
+
|
| 53 |
+
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
| 54 |
+
PILImageResampling = PIL.Image.Resampling
|
| 55 |
+
else:
|
| 56 |
+
PILImageResampling = PIL.Image
|
| 57 |
+
|
| 58 |
+
if is_torchvision_available():
|
| 59 |
+
from torchvision.transforms import InterpolationMode
|
| 60 |
+
|
| 61 |
+
pil_torch_interpolation_mapping = {
|
| 62 |
+
PILImageResampling.NEAREST: InterpolationMode.NEAREST,
|
| 63 |
+
PILImageResampling.BOX: InterpolationMode.BOX,
|
| 64 |
+
PILImageResampling.BILINEAR: InterpolationMode.BILINEAR,
|
| 65 |
+
PILImageResampling.HAMMING: InterpolationMode.HAMMING,
|
| 66 |
+
PILImageResampling.BICUBIC: InterpolationMode.BICUBIC,
|
| 67 |
+
PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
if TYPE_CHECKING:
|
| 72 |
+
if is_torch_available():
|
| 73 |
+
import torch
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
logger = logging.get_logger(__name__)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
ImageInput = Union[
|
| 80 |
+
"PIL.Image.Image", np.ndarray, "torch.Tensor", List["PIL.Image.Image"], List[np.ndarray], List["torch.Tensor"]
|
| 81 |
+
] # noqa
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
VideoInput = Union[
|
| 85 |
+
List["PIL.Image.Image"],
|
| 86 |
+
"np.ndarray",
|
| 87 |
+
"torch.Tensor",
|
| 88 |
+
List["np.ndarray"],
|
| 89 |
+
List["torch.Tensor"],
|
| 90 |
+
List[List["PIL.Image.Image"]],
|
| 91 |
+
List[List["np.ndarrray"]],
|
| 92 |
+
List[List["torch.Tensor"]],
|
| 93 |
+
] # noqa
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class ChannelDimension(ExplicitEnum):
|
| 97 |
+
FIRST = "channels_first"
|
| 98 |
+
LAST = "channels_last"
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class AnnotationFormat(ExplicitEnum):
|
| 102 |
+
COCO_DETECTION = "coco_detection"
|
| 103 |
+
COCO_PANOPTIC = "coco_panoptic"
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class AnnotionFormat(ExplicitEnum):
|
| 107 |
+
COCO_DETECTION = AnnotationFormat.COCO_DETECTION.value
|
| 108 |
+
COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
AnnotationType = Dict[str, Union[int, str, List[Dict]]]
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def is_pil_image(img):
|
| 115 |
+
return is_vision_available() and isinstance(img, PIL.Image.Image)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class ImageType(ExplicitEnum):
|
| 119 |
+
PIL = "pillow"
|
| 120 |
+
TORCH = "torch"
|
| 121 |
+
NUMPY = "numpy"
|
| 122 |
+
TENSORFLOW = "tensorflow"
|
| 123 |
+
JAX = "jax"
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def get_image_type(image):
|
| 127 |
+
if is_pil_image(image):
|
| 128 |
+
return ImageType.PIL
|
| 129 |
+
if is_torch_tensor(image):
|
| 130 |
+
return ImageType.TORCH
|
| 131 |
+
if is_numpy_array(image):
|
| 132 |
+
return ImageType.NUMPY
|
| 133 |
+
if is_tf_tensor(image):
|
| 134 |
+
return ImageType.TENSORFLOW
|
| 135 |
+
if is_jax_tensor(image):
|
| 136 |
+
return ImageType.JAX
|
| 137 |
+
raise ValueError(f"Unrecognised image type {type(image)}")
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def is_valid_image(img):
|
| 141 |
+
return is_pil_image(img) or is_numpy_array(img) or is_torch_tensor(img) or is_tf_tensor(img) or is_jax_tensor(img)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def valid_images(imgs):
|
| 145 |
+
# If we have an list of images, make sure every image is valid
|
| 146 |
+
if isinstance(imgs, (list, tuple)):
|
| 147 |
+
for img in imgs:
|
| 148 |
+
if not valid_images(img):
|
| 149 |
+
return False
|
| 150 |
+
# If not a list of tuple, we have been given a single image or batched tensor of images
|
| 151 |
+
elif not is_valid_image(imgs):
|
| 152 |
+
return False
|
| 153 |
+
return True
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def is_batched(img):
|
| 157 |
+
if isinstance(img, (list, tuple)):
|
| 158 |
+
return is_valid_image(img[0])
|
| 159 |
+
return False
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def is_scaled_image(image: np.ndarray) -> bool:
|
| 163 |
+
"""
|
| 164 |
+
Checks to see whether the pixel values have already been rescaled to [0, 1].
|
| 165 |
+
"""
|
| 166 |
+
if image.dtype == np.uint8:
|
| 167 |
+
return False
|
| 168 |
+
|
| 169 |
+
# It's possible the image has pixel values in [0, 255] but is of floating type
|
| 170 |
+
return np.min(image) >= 0 and np.max(image) <= 1
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]:
|
| 174 |
+
"""
|
| 175 |
+
Ensure that the input is a list of images. If the input is a single image, it is converted to a list of length 1.
|
| 176 |
+
If the input is a batch of images, it is converted to a list of images.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
images (`ImageInput`):
|
| 180 |
+
Image of images to turn into a list of images.
|
| 181 |
+
expected_ndims (`int`, *optional*, defaults to 3):
|
| 182 |
+
Expected number of dimensions for a single input image. If the input image has a different number of
|
| 183 |
+
dimensions, an error is raised.
|
| 184 |
+
"""
|
| 185 |
+
if is_batched(images):
|
| 186 |
+
return images
|
| 187 |
+
|
| 188 |
+
# Either the input is a single image, in which case we create a list of length 1
|
| 189 |
+
if isinstance(images, PIL.Image.Image):
|
| 190 |
+
# PIL images are never batched
|
| 191 |
+
return [images]
|
| 192 |
+
|
| 193 |
+
if is_valid_image(images):
|
| 194 |
+
if images.ndim == expected_ndims + 1:
|
| 195 |
+
# Batch of images
|
| 196 |
+
images = list(images)
|
| 197 |
+
elif images.ndim == expected_ndims:
|
| 198 |
+
# Single image
|
| 199 |
+
images = [images]
|
| 200 |
+
else:
|
| 201 |
+
raise ValueError(
|
| 202 |
+
f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got"
|
| 203 |
+
f" {images.ndim} dimensions."
|
| 204 |
+
)
|
| 205 |
+
return images
|
| 206 |
+
raise ValueError(
|
| 207 |
+
"Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or "
|
| 208 |
+
f"jax.ndarray, but got {type(images)}."
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def to_numpy_array(img) -> np.ndarray:
|
| 213 |
+
if not is_valid_image(img):
|
| 214 |
+
raise ValueError(f"Invalid image type: {type(img)}")
|
| 215 |
+
|
| 216 |
+
if is_vision_available() and isinstance(img, PIL.Image.Image):
|
| 217 |
+
return np.array(img)
|
| 218 |
+
return to_numpy(img)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def infer_channel_dimension_format(
|
| 222 |
+
image: np.ndarray, num_channels: Optional[Union[int, Tuple[int, ...]]] = None
|
| 223 |
+
) -> ChannelDimension:
|
| 224 |
+
"""
|
| 225 |
+
Infers the channel dimension format of `image`.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
image (`np.ndarray`):
|
| 229 |
+
The image to infer the channel dimension of.
|
| 230 |
+
num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
|
| 231 |
+
The number of channels of the image.
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
The channel dimension of the image.
|
| 235 |
+
"""
|
| 236 |
+
num_channels = num_channels if num_channels is not None else (1, 3)
|
| 237 |
+
num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels
|
| 238 |
+
|
| 239 |
+
if image.ndim == 3:
|
| 240 |
+
first_dim, last_dim = 0, 2
|
| 241 |
+
elif image.ndim == 4:
|
| 242 |
+
first_dim, last_dim = 1, 3
|
| 243 |
+
else:
|
| 244 |
+
raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
|
| 245 |
+
|
| 246 |
+
if image.shape[first_dim] in num_channels and image.shape[last_dim] in num_channels:
|
| 247 |
+
logger.warning(
|
| 248 |
+
f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension."
|
| 249 |
+
)
|
| 250 |
+
return ChannelDimension.FIRST
|
| 251 |
+
elif image.shape[first_dim] in num_channels:
|
| 252 |
+
return ChannelDimension.FIRST
|
| 253 |
+
elif image.shape[last_dim] in num_channels:
|
| 254 |
+
return ChannelDimension.LAST
|
| 255 |
+
raise ValueError("Unable to infer channel dimension format")
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def get_channel_dimension_axis(
|
| 259 |
+
image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None
|
| 260 |
+
) -> int:
|
| 261 |
+
"""
|
| 262 |
+
Returns the channel dimension axis of the image.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
image (`np.ndarray`):
|
| 266 |
+
The image to get the channel dimension axis of.
|
| 267 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 268 |
+
The channel dimension format of the image. If `None`, will infer the channel dimension from the image.
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
The channel dimension axis of the image.
|
| 272 |
+
"""
|
| 273 |
+
if input_data_format is None:
|
| 274 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 275 |
+
if input_data_format == ChannelDimension.FIRST:
|
| 276 |
+
return image.ndim - 3
|
| 277 |
+
elif input_data_format == ChannelDimension.LAST:
|
| 278 |
+
return image.ndim - 1
|
| 279 |
+
raise ValueError(f"Unsupported data format: {input_data_format}")
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]:
|
| 283 |
+
"""
|
| 284 |
+
Returns the (height, width) dimensions of the image.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
image (`np.ndarray`):
|
| 288 |
+
The image to get the dimensions of.
|
| 289 |
+
channel_dim (`ChannelDimension`, *optional*):
|
| 290 |
+
Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image.
|
| 291 |
+
|
| 292 |
+
Returns:
|
| 293 |
+
A tuple of the image's height and width.
|
| 294 |
+
"""
|
| 295 |
+
if channel_dim is None:
|
| 296 |
+
channel_dim = infer_channel_dimension_format(image)
|
| 297 |
+
|
| 298 |
+
if channel_dim == ChannelDimension.FIRST:
|
| 299 |
+
return image.shape[-2], image.shape[-1]
|
| 300 |
+
elif channel_dim == ChannelDimension.LAST:
|
| 301 |
+
return image.shape[-3], image.shape[-2]
|
| 302 |
+
else:
|
| 303 |
+
raise ValueError(f"Unsupported data format: {channel_dim}")
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def is_valid_annotation_coco_detection(annotation: Dict[str, Union[List, Tuple]]) -> bool:
|
| 307 |
+
if (
|
| 308 |
+
isinstance(annotation, dict)
|
| 309 |
+
and "image_id" in annotation
|
| 310 |
+
and "annotations" in annotation
|
| 311 |
+
and isinstance(annotation["annotations"], (list, tuple))
|
| 312 |
+
and (
|
| 313 |
+
# an image can have no annotations
|
| 314 |
+
len(annotation["annotations"]) == 0 or isinstance(annotation["annotations"][0], dict)
|
| 315 |
+
)
|
| 316 |
+
):
|
| 317 |
+
return True
|
| 318 |
+
return False
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def is_valid_annotation_coco_panoptic(annotation: Dict[str, Union[List, Tuple]]) -> bool:
|
| 322 |
+
if (
|
| 323 |
+
isinstance(annotation, dict)
|
| 324 |
+
and "image_id" in annotation
|
| 325 |
+
and "segments_info" in annotation
|
| 326 |
+
and "file_name" in annotation
|
| 327 |
+
and isinstance(annotation["segments_info"], (list, tuple))
|
| 328 |
+
and (
|
| 329 |
+
# an image can have no segments
|
| 330 |
+
len(annotation["segments_info"]) == 0 or isinstance(annotation["segments_info"][0], dict)
|
| 331 |
+
)
|
| 332 |
+
):
|
| 333 |
+
return True
|
| 334 |
+
return False
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def valid_coco_detection_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool:
|
| 338 |
+
return all(is_valid_annotation_coco_detection(ann) for ann in annotations)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def valid_coco_panoptic_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool:
|
| 342 |
+
return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = None) -> "PIL.Image.Image":
|
| 346 |
+
"""
|
| 347 |
+
Loads `image` to a PIL Image.
|
| 348 |
+
|
| 349 |
+
Args:
|
| 350 |
+
image (`str` or `PIL.Image.Image`):
|
| 351 |
+
The image to convert to the PIL Image format.
|
| 352 |
+
timeout (`float`, *optional*):
|
| 353 |
+
The timeout value in seconds for the URL request.
|
| 354 |
+
|
| 355 |
+
Returns:
|
| 356 |
+
`PIL.Image.Image`: A PIL Image.
|
| 357 |
+
"""
|
| 358 |
+
requires_backends(load_image, ["vision"])
|
| 359 |
+
if isinstance(image, str):
|
| 360 |
+
if image.startswith("http://") or image.startswith("https://"):
|
| 361 |
+
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
| 362 |
+
# like http_huggingface_co.png
|
| 363 |
+
image = PIL.Image.open(BytesIO(requests.get(image, timeout=timeout).content))
|
| 364 |
+
elif os.path.isfile(image):
|
| 365 |
+
image = PIL.Image.open(image)
|
| 366 |
+
else:
|
| 367 |
+
if image.startswith("data:image/"):
|
| 368 |
+
image = image.split(",")[1]
|
| 369 |
+
|
| 370 |
+
# Try to load as base64
|
| 371 |
+
try:
|
| 372 |
+
b64 = base64.decodebytes(image.encode())
|
| 373 |
+
image = PIL.Image.open(BytesIO(b64))
|
| 374 |
+
except Exception as e:
|
| 375 |
+
raise ValueError(
|
| 376 |
+
f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
|
| 377 |
+
)
|
| 378 |
+
elif isinstance(image, PIL.Image.Image):
|
| 379 |
+
image = image
|
| 380 |
+
else:
|
| 381 |
+
raise TypeError(
|
| 382 |
+
"Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image."
|
| 383 |
+
)
|
| 384 |
+
image = PIL.ImageOps.exif_transpose(image)
|
| 385 |
+
image = image.convert("RGB")
|
| 386 |
+
return image
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def validate_preprocess_arguments(
|
| 390 |
+
do_rescale: Optional[bool] = None,
|
| 391 |
+
rescale_factor: Optional[float] = None,
|
| 392 |
+
do_normalize: Optional[bool] = None,
|
| 393 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 394 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 395 |
+
do_pad: Optional[bool] = None,
|
| 396 |
+
size_divisibility: Optional[int] = None,
|
| 397 |
+
do_center_crop: Optional[bool] = None,
|
| 398 |
+
crop_size: Optional[Dict[str, int]] = None,
|
| 399 |
+
do_resize: Optional[bool] = None,
|
| 400 |
+
size: Optional[Dict[str, int]] = None,
|
| 401 |
+
resample: Optional["PILImageResampling"] = None,
|
| 402 |
+
):
|
| 403 |
+
"""
|
| 404 |
+
Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method.
|
| 405 |
+
Raises `ValueError` if arguments incompatibility is caught.
|
| 406 |
+
Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`,
|
| 407 |
+
sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow
|
| 408 |
+
existing arguments when possible.
|
| 409 |
+
|
| 410 |
+
"""
|
| 411 |
+
if do_rescale and rescale_factor is None:
|
| 412 |
+
raise ValueError("`rescale_factor` must be specified if `do_rescale` is `True`.")
|
| 413 |
+
|
| 414 |
+
if do_pad and size_divisibility is None:
|
| 415 |
+
# Here, size_divisor might be passed as the value of size
|
| 416 |
+
raise ValueError(
|
| 417 |
+
"Depending on the model, `size_divisibility`, `size_divisor`, `pad_size` or `size` must be specified if `do_pad` is `True`."
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
if do_normalize and (image_mean is None or image_std is None):
|
| 421 |
+
raise ValueError("`image_mean` and `image_std` must both be specified if `do_normalize` is `True`.")
|
| 422 |
+
|
| 423 |
+
if do_center_crop and crop_size is None:
|
| 424 |
+
raise ValueError("`crop_size` must be specified if `do_center_crop` is `True`.")
|
| 425 |
+
|
| 426 |
+
if do_resize and (size is None or resample is None):
|
| 427 |
+
raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
# In the future we can add a TF implementation here when we have TF models.
|
| 431 |
+
class ImageFeatureExtractionMixin:
|
| 432 |
+
"""
|
| 433 |
+
Mixin that contain utilities for preparing image features.
|
| 434 |
+
"""
|
| 435 |
+
|
| 436 |
+
def _ensure_format_supported(self, image):
|
| 437 |
+
if not isinstance(image, (PIL.Image.Image, np.ndarray)) and not is_torch_tensor(image):
|
| 438 |
+
raise ValueError(
|
| 439 |
+
f"Got type {type(image)} which is not supported, only `PIL.Image.Image`, `np.array` and "
|
| 440 |
+
"`torch.Tensor` are."
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
def to_pil_image(self, image, rescale=None):
|
| 444 |
+
"""
|
| 445 |
+
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
|
| 446 |
+
needed.
|
| 447 |
+
|
| 448 |
+
Args:
|
| 449 |
+
image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
|
| 450 |
+
The image to convert to the PIL Image format.
|
| 451 |
+
rescale (`bool`, *optional*):
|
| 452 |
+
Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will
|
| 453 |
+
default to `True` if the image type is a floating type, `False` otherwise.
|
| 454 |
+
"""
|
| 455 |
+
self._ensure_format_supported(image)
|
| 456 |
+
|
| 457 |
+
if is_torch_tensor(image):
|
| 458 |
+
image = image.numpy()
|
| 459 |
+
|
| 460 |
+
if isinstance(image, np.ndarray):
|
| 461 |
+
if rescale is None:
|
| 462 |
+
# rescale default to the array being of floating type.
|
| 463 |
+
rescale = isinstance(image.flat[0], np.floating)
|
| 464 |
+
# If the channel as been moved to first dim, we put it back at the end.
|
| 465 |
+
if image.ndim == 3 and image.shape[0] in [1, 3]:
|
| 466 |
+
image = image.transpose(1, 2, 0)
|
| 467 |
+
if rescale:
|
| 468 |
+
image = image * 255
|
| 469 |
+
image = image.astype(np.uint8)
|
| 470 |
+
return PIL.Image.fromarray(image)
|
| 471 |
+
return image
|
| 472 |
+
|
| 473 |
+
def convert_rgb(self, image):
|
| 474 |
+
"""
|
| 475 |
+
Converts `PIL.Image.Image` to RGB format.
|
| 476 |
+
|
| 477 |
+
Args:
|
| 478 |
+
image (`PIL.Image.Image`):
|
| 479 |
+
The image to convert.
|
| 480 |
+
"""
|
| 481 |
+
self._ensure_format_supported(image)
|
| 482 |
+
if not isinstance(image, PIL.Image.Image):
|
| 483 |
+
return image
|
| 484 |
+
|
| 485 |
+
return image.convert("RGB")
|
| 486 |
+
|
| 487 |
+
def rescale(self, image: np.ndarray, scale: Union[float, int]) -> np.ndarray:
|
| 488 |
+
"""
|
| 489 |
+
Rescale a numpy image by scale amount
|
| 490 |
+
"""
|
| 491 |
+
self._ensure_format_supported(image)
|
| 492 |
+
return image * scale
|
| 493 |
+
|
| 494 |
+
def to_numpy_array(self, image, rescale=None, channel_first=True):
|
| 495 |
+
"""
|
| 496 |
+
Converts `image` to a numpy array. Optionally rescales it and puts the channel dimension as the first
|
| 497 |
+
dimension.
|
| 498 |
+
|
| 499 |
+
Args:
|
| 500 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
| 501 |
+
The image to convert to a NumPy array.
|
| 502 |
+
rescale (`bool`, *optional*):
|
| 503 |
+
Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Will
|
| 504 |
+
default to `True` if the image is a PIL Image or an array/tensor of integers, `False` otherwise.
|
| 505 |
+
channel_first (`bool`, *optional*, defaults to `True`):
|
| 506 |
+
Whether or not to permute the dimensions of the image to put the channel dimension first.
|
| 507 |
+
"""
|
| 508 |
+
self._ensure_format_supported(image)
|
| 509 |
+
|
| 510 |
+
if isinstance(image, PIL.Image.Image):
|
| 511 |
+
image = np.array(image)
|
| 512 |
+
|
| 513 |
+
if is_torch_tensor(image):
|
| 514 |
+
image = image.numpy()
|
| 515 |
+
|
| 516 |
+
rescale = isinstance(image.flat[0], np.integer) if rescale is None else rescale
|
| 517 |
+
|
| 518 |
+
if rescale:
|
| 519 |
+
image = self.rescale(image.astype(np.float32), 1 / 255.0)
|
| 520 |
+
|
| 521 |
+
if channel_first and image.ndim == 3:
|
| 522 |
+
image = image.transpose(2, 0, 1)
|
| 523 |
+
|
| 524 |
+
return image
|
| 525 |
+
|
| 526 |
+
def expand_dims(self, image):
|
| 527 |
+
"""
|
| 528 |
+
Expands 2-dimensional `image` to 3 dimensions.
|
| 529 |
+
|
| 530 |
+
Args:
|
| 531 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
| 532 |
+
The image to expand.
|
| 533 |
+
"""
|
| 534 |
+
self._ensure_format_supported(image)
|
| 535 |
+
|
| 536 |
+
# Do nothing if PIL image
|
| 537 |
+
if isinstance(image, PIL.Image.Image):
|
| 538 |
+
return image
|
| 539 |
+
|
| 540 |
+
if is_torch_tensor(image):
|
| 541 |
+
image = image.unsqueeze(0)
|
| 542 |
+
else:
|
| 543 |
+
image = np.expand_dims(image, axis=0)
|
| 544 |
+
return image
|
| 545 |
+
|
| 546 |
+
def normalize(self, image, mean, std, rescale=False):
|
| 547 |
+
"""
|
| 548 |
+
Normalizes `image` with `mean` and `std`. Note that this will trigger a conversion of `image` to a NumPy array
|
| 549 |
+
if it's a PIL Image.
|
| 550 |
+
|
| 551 |
+
Args:
|
| 552 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
| 553 |
+
The image to normalize.
|
| 554 |
+
mean (`List[float]` or `np.ndarray` or `torch.Tensor`):
|
| 555 |
+
The mean (per channel) to use for normalization.
|
| 556 |
+
std (`List[float]` or `np.ndarray` or `torch.Tensor`):
|
| 557 |
+
The standard deviation (per channel) to use for normalization.
|
| 558 |
+
rescale (`bool`, *optional*, defaults to `False`):
|
| 559 |
+
Whether or not to rescale the image to be between 0 and 1. If a PIL image is provided, scaling will
|
| 560 |
+
happen automatically.
|
| 561 |
+
"""
|
| 562 |
+
self._ensure_format_supported(image)
|
| 563 |
+
|
| 564 |
+
if isinstance(image, PIL.Image.Image):
|
| 565 |
+
image = self.to_numpy_array(image, rescale=True)
|
| 566 |
+
# If the input image is a PIL image, it automatically gets rescaled. If it's another
|
| 567 |
+
# type it may need rescaling.
|
| 568 |
+
elif rescale:
|
| 569 |
+
if isinstance(image, np.ndarray):
|
| 570 |
+
image = self.rescale(image.astype(np.float32), 1 / 255.0)
|
| 571 |
+
elif is_torch_tensor(image):
|
| 572 |
+
image = self.rescale(image.float(), 1 / 255.0)
|
| 573 |
+
|
| 574 |
+
if isinstance(image, np.ndarray):
|
| 575 |
+
if not isinstance(mean, np.ndarray):
|
| 576 |
+
mean = np.array(mean).astype(image.dtype)
|
| 577 |
+
if not isinstance(std, np.ndarray):
|
| 578 |
+
std = np.array(std).astype(image.dtype)
|
| 579 |
+
elif is_torch_tensor(image):
|
| 580 |
+
import torch
|
| 581 |
+
|
| 582 |
+
if not isinstance(mean, torch.Tensor):
|
| 583 |
+
if isinstance(mean, np.ndarray):
|
| 584 |
+
mean = torch.from_numpy(mean)
|
| 585 |
+
else:
|
| 586 |
+
mean = torch.tensor(mean)
|
| 587 |
+
if not isinstance(std, torch.Tensor):
|
| 588 |
+
if isinstance(std, np.ndarray):
|
| 589 |
+
std = torch.from_numpy(std)
|
| 590 |
+
else:
|
| 591 |
+
std = torch.tensor(std)
|
| 592 |
+
|
| 593 |
+
if image.ndim == 3 and image.shape[0] in [1, 3]:
|
| 594 |
+
return (image - mean[:, None, None]) / std[:, None, None]
|
| 595 |
+
else:
|
| 596 |
+
return (image - mean) / std
|
| 597 |
+
|
| 598 |
+
def resize(self, image, size, resample=None, default_to_square=True, max_size=None):
|
| 599 |
+
"""
|
| 600 |
+
Resizes `image`. Enforces conversion of input to PIL.Image.
|
| 601 |
+
|
| 602 |
+
Args:
|
| 603 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
| 604 |
+
The image to resize.
|
| 605 |
+
size (`int` or `Tuple[int, int]`):
|
| 606 |
+
The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be
|
| 607 |
+
matched to this.
|
| 608 |
+
|
| 609 |
+
If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
|
| 610 |
+
`size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to
|
| 611 |
+
this number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
|
| 612 |
+
resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
| 613 |
+
The filter to user for resampling.
|
| 614 |
+
default_to_square (`bool`, *optional*, defaults to `True`):
|
| 615 |
+
How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a
|
| 616 |
+
square (`size`,`size`). If set to `False`, will replicate
|
| 617 |
+
[`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
|
| 618 |
+
with support for resizing only the smallest edge and providing an optional `max_size`.
|
| 619 |
+
max_size (`int`, *optional*, defaults to `None`):
|
| 620 |
+
The maximum allowed for the longer edge of the resized image: if the longer edge of the image is
|
| 621 |
+
greater than `max_size` after being resized according to `size`, then the image is resized again so
|
| 622 |
+
that the longer edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller
|
| 623 |
+
edge may be shorter than `size`. Only used if `default_to_square` is `False`.
|
| 624 |
+
|
| 625 |
+
Returns:
|
| 626 |
+
image: A resized `PIL.Image.Image`.
|
| 627 |
+
"""
|
| 628 |
+
resample = resample if resample is not None else PILImageResampling.BILINEAR
|
| 629 |
+
|
| 630 |
+
self._ensure_format_supported(image)
|
| 631 |
+
|
| 632 |
+
if not isinstance(image, PIL.Image.Image):
|
| 633 |
+
image = self.to_pil_image(image)
|
| 634 |
+
|
| 635 |
+
if isinstance(size, list):
|
| 636 |
+
size = tuple(size)
|
| 637 |
+
|
| 638 |
+
if isinstance(size, int) or len(size) == 1:
|
| 639 |
+
if default_to_square:
|
| 640 |
+
size = (size, size) if isinstance(size, int) else (size[0], size[0])
|
| 641 |
+
else:
|
| 642 |
+
width, height = image.size
|
| 643 |
+
# specified size only for the smallest edge
|
| 644 |
+
short, long = (width, height) if width <= height else (height, width)
|
| 645 |
+
requested_new_short = size if isinstance(size, int) else size[0]
|
| 646 |
+
|
| 647 |
+
if short == requested_new_short:
|
| 648 |
+
return image
|
| 649 |
+
|
| 650 |
+
new_short, new_long = requested_new_short, int(requested_new_short * long / short)
|
| 651 |
+
|
| 652 |
+
if max_size is not None:
|
| 653 |
+
if max_size <= requested_new_short:
|
| 654 |
+
raise ValueError(
|
| 655 |
+
f"max_size = {max_size} must be strictly greater than the requested "
|
| 656 |
+
f"size for the smaller edge size = {size}"
|
| 657 |
+
)
|
| 658 |
+
if new_long > max_size:
|
| 659 |
+
new_short, new_long = int(max_size * new_short / new_long), max_size
|
| 660 |
+
|
| 661 |
+
size = (new_short, new_long) if width <= height else (new_long, new_short)
|
| 662 |
+
|
| 663 |
+
return image.resize(size, resample=resample)
|
| 664 |
+
|
| 665 |
+
def center_crop(self, image, size):
|
| 666 |
+
"""
|
| 667 |
+
Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the
|
| 668 |
+
size given, it will be padded (so the returned result has the size asked).
|
| 669 |
+
|
| 670 |
+
Args:
|
| 671 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape (n_channels, height, width) or (height, width, n_channels)):
|
| 672 |
+
The image to resize.
|
| 673 |
+
size (`int` or `Tuple[int, int]`):
|
| 674 |
+
The size to which crop the image.
|
| 675 |
+
|
| 676 |
+
Returns:
|
| 677 |
+
new_image: A center cropped `PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape: (n_channels,
|
| 678 |
+
height, width).
|
| 679 |
+
"""
|
| 680 |
+
self._ensure_format_supported(image)
|
| 681 |
+
|
| 682 |
+
if not isinstance(size, tuple):
|
| 683 |
+
size = (size, size)
|
| 684 |
+
|
| 685 |
+
# PIL Image.size is (width, height) but NumPy array and torch Tensors have (height, width)
|
| 686 |
+
if is_torch_tensor(image) or isinstance(image, np.ndarray):
|
| 687 |
+
if image.ndim == 2:
|
| 688 |
+
image = self.expand_dims(image)
|
| 689 |
+
image_shape = image.shape[1:] if image.shape[0] in [1, 3] else image.shape[:2]
|
| 690 |
+
else:
|
| 691 |
+
image_shape = (image.size[1], image.size[0])
|
| 692 |
+
|
| 693 |
+
top = (image_shape[0] - size[0]) // 2
|
| 694 |
+
bottom = top + size[0] # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
|
| 695 |
+
left = (image_shape[1] - size[1]) // 2
|
| 696 |
+
right = left + size[1] # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
|
| 697 |
+
|
| 698 |
+
# For PIL Images we have a method to crop directly.
|
| 699 |
+
if isinstance(image, PIL.Image.Image):
|
| 700 |
+
return image.crop((left, top, right, bottom))
|
| 701 |
+
|
| 702 |
+
# Check if image is in (n_channels, height, width) or (height, width, n_channels) format
|
| 703 |
+
channel_first = True if image.shape[0] in [1, 3] else False
|
| 704 |
+
|
| 705 |
+
# Transpose (height, width, n_channels) format images
|
| 706 |
+
if not channel_first:
|
| 707 |
+
if isinstance(image, np.ndarray):
|
| 708 |
+
image = image.transpose(2, 0, 1)
|
| 709 |
+
if is_torch_tensor(image):
|
| 710 |
+
image = image.permute(2, 0, 1)
|
| 711 |
+
|
| 712 |
+
# Check if cropped area is within image boundaries
|
| 713 |
+
if top >= 0 and bottom <= image_shape[0] and left >= 0 and right <= image_shape[1]:
|
| 714 |
+
return image[..., top:bottom, left:right]
|
| 715 |
+
|
| 716 |
+
# Otherwise, we may need to pad if the image is too small. Oh joy...
|
| 717 |
+
new_shape = image.shape[:-2] + (max(size[0], image_shape[0]), max(size[1], image_shape[1]))
|
| 718 |
+
if isinstance(image, np.ndarray):
|
| 719 |
+
new_image = np.zeros_like(image, shape=new_shape)
|
| 720 |
+
elif is_torch_tensor(image):
|
| 721 |
+
new_image = image.new_zeros(new_shape)
|
| 722 |
+
|
| 723 |
+
top_pad = (new_shape[-2] - image_shape[0]) // 2
|
| 724 |
+
bottom_pad = top_pad + image_shape[0]
|
| 725 |
+
left_pad = (new_shape[-1] - image_shape[1]) // 2
|
| 726 |
+
right_pad = left_pad + image_shape[1]
|
| 727 |
+
new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
|
| 728 |
+
|
| 729 |
+
top += top_pad
|
| 730 |
+
bottom += top_pad
|
| 731 |
+
left += left_pad
|
| 732 |
+
right += left_pad
|
| 733 |
+
|
| 734 |
+
new_image = new_image[
|
| 735 |
+
..., max(0, top) : min(new_image.shape[-2], bottom), max(0, left) : min(new_image.shape[-1], right)
|
| 736 |
+
]
|
| 737 |
+
|
| 738 |
+
return new_image
|
| 739 |
+
|
| 740 |
+
def flip_channel_order(self, image):
|
| 741 |
+
"""
|
| 742 |
+
Flips the channel order of `image` from RGB to BGR, or vice versa. Note that this will trigger a conversion of
|
| 743 |
+
`image` to a NumPy array if it's a PIL Image.
|
| 744 |
+
|
| 745 |
+
Args:
|
| 746 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
| 747 |
+
The image whose color channels to flip. If `np.ndarray` or `torch.Tensor`, the channel dimension should
|
| 748 |
+
be first.
|
| 749 |
+
"""
|
| 750 |
+
self._ensure_format_supported(image)
|
| 751 |
+
|
| 752 |
+
if isinstance(image, PIL.Image.Image):
|
| 753 |
+
image = self.to_numpy_array(image)
|
| 754 |
+
|
| 755 |
+
return image[::-1, :, :]
|
| 756 |
+
|
| 757 |
+
def rotate(self, image, angle, resample=None, expand=0, center=None, translate=None, fillcolor=None):
|
| 758 |
+
"""
|
| 759 |
+
Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees
|
| 760 |
+
counter clockwise around its centre.
|
| 761 |
+
|
| 762 |
+
Args:
|
| 763 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
| 764 |
+
The image to rotate. If `np.ndarray` or `torch.Tensor`, will be converted to `PIL.Image.Image` before
|
| 765 |
+
rotating.
|
| 766 |
+
|
| 767 |
+
Returns:
|
| 768 |
+
image: A rotated `PIL.Image.Image`.
|
| 769 |
+
"""
|
| 770 |
+
resample = resample if resample is not None else PIL.Image.NEAREST
|
| 771 |
+
|
| 772 |
+
self._ensure_format_supported(image)
|
| 773 |
+
|
| 774 |
+
if not isinstance(image, PIL.Image.Image):
|
| 775 |
+
image = self.to_pil_image(image)
|
| 776 |
+
|
| 777 |
+
return image.rotate(
|
| 778 |
+
angle, resample=resample, expand=expand, center=center, translate=translate, fillcolor=fillcolor
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
def validate_annotations(
|
| 783 |
+
annotation_format: AnnotationFormat,
|
| 784 |
+
supported_annotation_formats: Tuple[AnnotationFormat, ...],
|
| 785 |
+
annotations: List[Dict],
|
| 786 |
+
) -> None:
|
| 787 |
+
if annotation_format not in supported_annotation_formats:
|
| 788 |
+
raise ValueError(f"Unsupported annotation format: {format} must be one of {supported_annotation_formats}")
|
| 789 |
+
|
| 790 |
+
if annotation_format is AnnotationFormat.COCO_DETECTION:
|
| 791 |
+
if not valid_coco_detection_annotations(annotations):
|
| 792 |
+
raise ValueError(
|
| 793 |
+
"Invalid COCO detection annotations. Annotations must a dict (single image) or list of dicts "
|
| 794 |
+
"(batch of images) with the following keys: `image_id` and `annotations`, with the latter "
|
| 795 |
+
"being a list of annotations in the COCO format."
|
| 796 |
+
)
|
| 797 |
+
|
| 798 |
+
if annotation_format is AnnotationFormat.COCO_PANOPTIC:
|
| 799 |
+
if not valid_coco_panoptic_annotations(annotations):
|
| 800 |
+
raise ValueError(
|
| 801 |
+
"Invalid COCO panoptic annotations. Annotations must a dict (single image) or list of dicts "
|
| 802 |
+
"(batch of images) with the following keys: `image_id`, `file_name` and `segments_info`, with "
|
| 803 |
+
"the latter being a list of annotations in the COCO format."
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
def validate_kwargs(valid_processor_keys: List[str], captured_kwargs: List[str]):
|
| 808 |
+
unused_keys = set(captured_kwargs).difference(set(valid_processor_keys))
|
| 809 |
+
if unused_keys:
|
| 810 |
+
unused_key_str = ", ".join(unused_keys)
|
| 811 |
+
# TODO raise a warning here instead of simply logging?
|
| 812 |
+
logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.")
|
inference_utils.py
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import random
|
| 6 |
+
import numpy as np
|
| 7 |
+
from transformers import LogitsProcessor, LogitsProcessorList, logging
|
| 8 |
+
from typing import Optional, List, Dict, Tuple, Union, Set
|
| 9 |
+
|
| 10 |
+
logger = logging.get_logger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def parse_interleaved_text_image(
|
| 14 |
+
full_output_text: str,
|
| 15 |
+
num_levels: int = 2,
|
| 16 |
+
image_placeholder: str = "<image>",
|
| 17 |
+
start_tag: str = "<start_of_image>",
|
| 18 |
+
end_tag: str = "<end_of_image>"
|
| 19 |
+
) -> Tuple[str, List[List[List[int]]]]:
|
| 20 |
+
"""
|
| 21 |
+
Parses text containing interleaved image token blocks.
|
| 22 |
+
|
| 23 |
+
Identifies blocks enclosed by start_tag and end_tag, extracts image tokens
|
| 24 |
+
(<|image_levelX_Y|>) within them, and replaces the blocks with a placeholder
|
| 25 |
+
in the output text.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
full_output_text: The raw input string containing text and image blocks.
|
| 29 |
+
num_levels: The expected number of levels for image tokens (e.g., 2).
|
| 30 |
+
image_placeholder: The string to replace image blocks with in the output text.
|
| 31 |
+
start_tag: The exact string marking the beginning of an image block.
|
| 32 |
+
end_tag: The exact string marking the end of an image block.
|
| 33 |
+
eos_token: If provided, this token will be removed from the final text.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
A tuple containing:
|
| 37 |
+
- generated_text (str): The text with image blocks replaced by placeholders.
|
| 38 |
+
- all_image_indices (List[List[List[int]]]): A list where each element
|
| 39 |
+
represents one image. Each image element is a list containing lists
|
| 40 |
+
of token indices for each level.
|
| 41 |
+
Example for 2 images, 2 levels:
|
| 42 |
+
[
|
| 43 |
+
[[level0_indices_img1], [level1_indices_img1]], # Image 1
|
| 44 |
+
[[level0_indices_img2], [level1_indices_img2]] # Image 2
|
| 45 |
+
]
|
| 46 |
+
"""
|
| 47 |
+
all_image_indices: List[List[List[int]]] = []
|
| 48 |
+
processed_text_parts: List[str] = []
|
| 49 |
+
list_image_token_parts: List[str] = []
|
| 50 |
+
last_end: int = 0
|
| 51 |
+
|
| 52 |
+
# Escape start/end tags for regex safety if they contain special characters
|
| 53 |
+
escaped_start_tag = re.escape(start_tag)
|
| 54 |
+
escaped_end_tag = re.escape(end_tag)
|
| 55 |
+
|
| 56 |
+
# Pattern to find image blocks: start_tag ... end_tag (non-greedy)
|
| 57 |
+
image_block_pattern = rf'{escaped_start_tag}(.*?){escaped_end_tag}'
|
| 58 |
+
# Pattern to find individual image tokens within a block
|
| 59 |
+
token_pattern = r'<\|image_level(\d+)_(\d+)\|>'
|
| 60 |
+
|
| 61 |
+
# Find all image blocks (re.DOTALL allows '.' to match newlines)
|
| 62 |
+
for match in re.finditer(image_block_pattern, full_output_text, re.DOTALL):
|
| 63 |
+
# 1. Add text preceding this image block
|
| 64 |
+
processed_text_parts.append(full_output_text[last_end:match.start()])
|
| 65 |
+
|
| 66 |
+
# collect the image token ids.
|
| 67 |
+
list_image_token_parts.append(full_output_text[match.start(): match.end()])
|
| 68 |
+
|
| 69 |
+
# 2. Add the placeholder for the image
|
| 70 |
+
processed_text_parts.append(image_placeholder)
|
| 71 |
+
|
| 72 |
+
# 3. Process the content *within* the current image block
|
| 73 |
+
image_token_content = match.group(1) # Content between tags
|
| 74 |
+
parsed_level_indices = {} # {level: [indices]} for *this* image
|
| 75 |
+
|
| 76 |
+
# Find all image tokens within this block
|
| 77 |
+
for token_match in re.finditer(token_pattern, image_token_content):
|
| 78 |
+
try:
|
| 79 |
+
level = int(token_match.group(1))
|
| 80 |
+
index = int(token_match.group(2))
|
| 81 |
+
if level >= num_levels:
|
| 82 |
+
logger.warning(f"Parsed token level {level} >= num_levels {num_levels}. Ignoring token.")
|
| 83 |
+
continue
|
| 84 |
+
if level not in parsed_level_indices:
|
| 85 |
+
parsed_level_indices[level] = []
|
| 86 |
+
parsed_level_indices[level].append(index)
|
| 87 |
+
except (ValueError, IndexError):
|
| 88 |
+
logger.warning(f"Could not parse token: {token_match.group(0)}")
|
| 89 |
+
continue # Skip malformed tokens
|
| 90 |
+
|
| 91 |
+
# Structure the indices for the current image based on expected levels
|
| 92 |
+
current_image_indices = []
|
| 93 |
+
logger.debug(f"Processing Image Block. Found levels: {parsed_level_indices.keys()}")
|
| 94 |
+
for level in range(num_levels):
|
| 95 |
+
# Get indices for the level, default to empty list if level not found
|
| 96 |
+
indices = parsed_level_indices.get(level, [])
|
| 97 |
+
# Optional: Sort indices if order isn't guaranteed (usually is by finditer)
|
| 98 |
+
# indices.sort()
|
| 99 |
+
current_image_indices.append(indices)
|
| 100 |
+
logger.debug(f" Level {level} indices count: {len(indices)}")
|
| 101 |
+
|
| 102 |
+
all_image_indices.append(current_image_indices)
|
| 103 |
+
logger.info(f"Parsed Image {len(all_image_indices)}: Found indices for {len(current_image_indices)} levels.")
|
| 104 |
+
|
| 105 |
+
# 4. Update position for the next iteration
|
| 106 |
+
last_end = match.end()
|
| 107 |
+
|
| 108 |
+
# Add any remaining text after the last image block
|
| 109 |
+
processed_text_parts.append(full_output_text[last_end:])
|
| 110 |
+
|
| 111 |
+
# Join the text parts to form the final generated text
|
| 112 |
+
generated_text = "".join(processed_text_parts)
|
| 113 |
+
|
| 114 |
+
return generated_text, all_image_indices, list_image_token_parts
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def calculate_image_token_num(h, w, downsample_rate_per_level=[28, 16]):
|
| 118 |
+
# Assuming RESOLUTION_MAPPING is accessible or hardcoded if needed
|
| 119 |
+
# For simplicity, let's assume direct calculation based on downsampling
|
| 120 |
+
# Replace with actual RESOLUTION_MAPPING logic if necessary
|
| 121 |
+
# Example: w1, h1 = RESOLUTION_MAPPING.get((w, h), (w, h)) # Get from mapping
|
| 122 |
+
w1, h1 = w, h # Placeholder if mapping not available/needed here
|
| 123 |
+
w1, h1 = w1 // downsample_rate_per_level[0], h1 // downsample_rate_per_level[0]
|
| 124 |
+
semantic_token_num = w1 * h1
|
| 125 |
+
|
| 126 |
+
w2, h2 = w // downsample_rate_per_level[1], h // downsample_rate_per_level[1]
|
| 127 |
+
pixel_token_num = w2 * h2
|
| 128 |
+
logger.info(f"Calculated token nums: semantic={semantic_token_num}, pixel={pixel_token_num} for target ({h},{w})")
|
| 129 |
+
# Estimate max_token_length (adjust based on special tokens in your format)
|
| 130 |
+
max_token_length = (h1 * (w1 + 1) + 2) + (h2 * (w2 + 1) + 2) + 2 + 2 + 1 + 1 + 50 # Add buffer
|
| 131 |
+
return [semantic_token_num, pixel_token_num], max_token_length, h1, w1, h2, w2
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class InterleavedLogitsProcessor(LogitsProcessor):
|
| 135 |
+
"""
|
| 136 |
+
Combines CFG, Dual VQ Image Token Structure Enforcement, and Dynamic Sampling
|
| 137 |
+
for interleaved text and image generation.
|
| 138 |
+
|
| 139 |
+
Includes refined masking during text generation to only allow text,
|
| 140 |
+
a specific resolution tag, and the start_of_image token.
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
def __init__(self,
|
| 144 |
+
# CFG parameters
|
| 145 |
+
guidance_scale=1.0,
|
| 146 |
+
uncond=None,
|
| 147 |
+
attention_mask=None,
|
| 148 |
+
model=None,
|
| 149 |
+
# DualVQ parameters
|
| 150 |
+
level0_range=None, level1_range=None,
|
| 151 |
+
num_level0_rows=None, num_level0_tokens=None,
|
| 152 |
+
num_level1_rows=None, num_level1_tokens=None,
|
| 153 |
+
special_tokens=None,
|
| 154 |
+
*,
|
| 155 |
+
# Dynamic Sampling parameters
|
| 156 |
+
default_temp=1.0, level0_temp=1.0, level1_temp=2.0,
|
| 157 |
+
default_top_k=2048, level0_top_k=2048, level1_top_k=2048 * 3,
|
| 158 |
+
default_top_p=0.8, level0_top_p=0.8, level1_top_p=1.0,
|
| 159 |
+
# General
|
| 160 |
+
images=None,
|
| 161 |
+
):
|
| 162 |
+
|
| 163 |
+
# --- CFG ---
|
| 164 |
+
self.guidance_scale = guidance_scale
|
| 165 |
+
self.uncond = uncond
|
| 166 |
+
self.attention_mask = attention_mask
|
| 167 |
+
self.images = images
|
| 168 |
+
self.model = model
|
| 169 |
+
self.out = None
|
| 170 |
+
|
| 171 |
+
# --- DualVQ ---
|
| 172 |
+
self.level0_range = level0_range
|
| 173 |
+
self.level1_range = level1_range
|
| 174 |
+
self.num_level0_rows = num_level0_rows
|
| 175 |
+
self.num_level0_tokens = num_level0_tokens
|
| 176 |
+
self.num_level1_rows = num_level1_rows
|
| 177 |
+
self.num_level1_tokens = num_level1_tokens
|
| 178 |
+
self.special_tokens = special_tokens
|
| 179 |
+
|
| 180 |
+
# DualVQ State
|
| 181 |
+
self.generating_image = False
|
| 182 |
+
self.current_level = None
|
| 183 |
+
self.tokens_in_row = 0
|
| 184 |
+
self.rows_in_level = 0
|
| 185 |
+
|
| 186 |
+
# --- Dynamic Sampling ---
|
| 187 |
+
self.start_of_level0_token_id = special_tokens["start_of_level0"]
|
| 188 |
+
self.end_of_level0_token_id = special_tokens["end_of_level0"]
|
| 189 |
+
self.start_of_level1_token_id = special_tokens["start_of_level1"]
|
| 190 |
+
self.end_of_level1_token_id = special_tokens["end_of_level1"]
|
| 191 |
+
self.start_of_image_token_id = special_tokens["start_of_image"]
|
| 192 |
+
self.end_of_image_token_id = special_tokens["end_of_image"]
|
| 193 |
+
|
| 194 |
+
self.default_temp = default_temp
|
| 195 |
+
self.default_top_k = default_top_k
|
| 196 |
+
self.default_top_p = default_top_p
|
| 197 |
+
self.level0_temp = level0_temp
|
| 198 |
+
self.level0_top_k = level0_top_k
|
| 199 |
+
self.level0_top_p = level0_top_p
|
| 200 |
+
self.level1_temp = level1_temp
|
| 201 |
+
self.level1_top_k = level1_top_k
|
| 202 |
+
self.level1_top_p = level1_top_p
|
| 203 |
+
|
| 204 |
+
# Dynamic Sampling State
|
| 205 |
+
self.in_level0_mode = False
|
| 206 |
+
self.in_level1_mode = False
|
| 207 |
+
|
| 208 |
+
# --- Validation ---
|
| 209 |
+
if not self.special_tokens:
|
| 210 |
+
raise ValueError("special_tokens dictionary cannot be empty.")
|
| 211 |
+
# *** Updated required keys ***
|
| 212 |
+
required_keys = ["start_of_image", "end_of_image", "start_of_level0",
|
| 213 |
+
"end_of_level0", "start_of_level1", "end_of_level1",
|
| 214 |
+
"end_of_line", "end_of_text"]
|
| 215 |
+
for key in required_keys:
|
| 216 |
+
if key not in self.special_tokens:
|
| 217 |
+
raise ValueError(f"Missing required key in special_tokens: {key}")
|
| 218 |
+
|
| 219 |
+
def _apply_cfg(self, input_ids, scores):
|
| 220 |
+
"""Applies Classifier-Free Guidance."""
|
| 221 |
+
scores = F.log_softmax(scores, dim=-1)
|
| 222 |
+
if self.guidance_scale == 1:
|
| 223 |
+
return scores
|
| 224 |
+
|
| 225 |
+
if self.out is None:
|
| 226 |
+
self.out = self.model(self.uncond,
|
| 227 |
+
attention_mask=self.attention_mask,
|
| 228 |
+
pixel_values=self.images)
|
| 229 |
+
else:
|
| 230 |
+
self.out = self.model(
|
| 231 |
+
input_ids[:, -1:],
|
| 232 |
+
use_cache=True,
|
| 233 |
+
past_key_values=self.out.past_key_values,
|
| 234 |
+
)
|
| 235 |
+
unconditional_logits = F.log_softmax(self.out.logits[:, -1, :], dim=-1)
|
| 236 |
+
out = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits
|
| 237 |
+
return out
|
| 238 |
+
|
| 239 |
+
def _apply_sampling(self, scores, temp, top_k, top_p):
|
| 240 |
+
""" Apply top-k, top-p, and temperature """
|
| 241 |
+
if temp > 0.0:
|
| 242 |
+
scores = scores / temp # Adjust temperature
|
| 243 |
+
|
| 244 |
+
# Top-K filtering
|
| 245 |
+
if top_k > 0:
|
| 246 |
+
top_k_values, _ = torch.topk(scores, min(top_k, scores.size(-1)))
|
| 247 |
+
scores[scores < top_k_values[:, -1].unsqueeze(-1)] = -float("Inf")
|
| 248 |
+
|
| 249 |
+
# Top-P filtering
|
| 250 |
+
if top_p < 1.0:
|
| 251 |
+
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
| 252 |
+
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
| 253 |
+
|
| 254 |
+
# Only keep tokens with cumulative probabilities within top_p
|
| 255 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 256 |
+
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
| 257 |
+
sorted_indices_to_remove[:, 0] = False
|
| 258 |
+
|
| 259 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 260 |
+
scores[indices_to_remove] = -float("Inf")
|
| 261 |
+
|
| 262 |
+
return scores
|
| 263 |
+
|
| 264 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 265 |
+
|
| 266 |
+
# --- Step 0: Get last token and Vocab Size ---
|
| 267 |
+
last_token = None
|
| 268 |
+
if input_ids.shape[1] > 0:
|
| 269 |
+
last_token = input_ids[0, -1].item() # Assuming batch size 1
|
| 270 |
+
|
| 271 |
+
# --- Step 1: Update State & Apply Constraints ---
|
| 272 |
+
# State updates based on the *last generated* token
|
| 273 |
+
if last_token == self.start_of_image_token_id:
|
| 274 |
+
self.generating_image = True
|
| 275 |
+
self.current_level = None
|
| 276 |
+
self.tokens_in_row = 0
|
| 277 |
+
self.rows_in_level = 0
|
| 278 |
+
self.in_level0_mode = False
|
| 279 |
+
self.in_level1_mode = False
|
| 280 |
+
elif last_token == self.start_of_level0_token_id:
|
| 281 |
+
self.current_level = "level0"
|
| 282 |
+
self.tokens_in_row = 0
|
| 283 |
+
self.rows_in_level = 0
|
| 284 |
+
self.in_level0_mode = True
|
| 285 |
+
self.in_level1_mode = False
|
| 286 |
+
elif last_token == self.end_of_level0_token_id:
|
| 287 |
+
self.current_level = None
|
| 288 |
+
self.in_level0_mode = False
|
| 289 |
+
elif last_token == self.start_of_level1_token_id:
|
| 290 |
+
self.current_level = "level1"
|
| 291 |
+
self.tokens_in_row = 0
|
| 292 |
+
self.rows_in_level = 0
|
| 293 |
+
self.in_level0_mode = False
|
| 294 |
+
self.in_level1_mode = True
|
| 295 |
+
elif last_token == self.end_of_level1_token_id:
|
| 296 |
+
self.current_level = None
|
| 297 |
+
self.in_level1_mode = False
|
| 298 |
+
elif last_token == self.end_of_image_token_id:
|
| 299 |
+
self.generating_image = False
|
| 300 |
+
self.current_level = None
|
| 301 |
+
self.tokens_in_row = 0
|
| 302 |
+
self.rows_in_level = 0
|
| 303 |
+
self.in_level0_mode = False
|
| 304 |
+
self.in_level1_mode = False
|
| 305 |
+
elif last_token == self.special_tokens["end_of_line"] and self.generating_image:
|
| 306 |
+
self.tokens_in_row = 0
|
| 307 |
+
self.rows_in_level += 1
|
| 308 |
+
elif self.generating_image and self.current_level is not None:
|
| 309 |
+
if (self.current_level == "level0" and self.level0_range[0] <= last_token < self.level0_range[1]) or \
|
| 310 |
+
(self.current_level == "level1" and self.level1_range[0] <= last_token < self.level1_range[1]):
|
| 311 |
+
self.tokens_in_row += 1
|
| 312 |
+
|
| 313 |
+
# --- Step 2: Apply CFG ---
|
| 314 |
+
if self.generating_image:
|
| 315 |
+
scores = self._apply_cfg(input_ids, scores)
|
| 316 |
+
else:
|
| 317 |
+
if self.out:
|
| 318 |
+
self.out = None
|
| 319 |
+
|
| 320 |
+
# Apply constraints based on the *current* state (determining the *next* token)
|
| 321 |
+
mask = torch.zeros_like(scores, dtype=torch.bool) # True means ALLOWED
|
| 322 |
+
|
| 323 |
+
if self.generating_image:
|
| 324 |
+
# --- Image Generation Masking ---
|
| 325 |
+
if self.current_level == "level0":
|
| 326 |
+
if self.rows_in_level == self.num_level0_rows:
|
| 327 |
+
mask[:, self.special_tokens["end_of_level0"]] = True
|
| 328 |
+
elif self.tokens_in_row == self.num_level0_tokens:
|
| 329 |
+
mask[:, self.special_tokens["end_of_line"]] = True
|
| 330 |
+
else:
|
| 331 |
+
mask[:, self.level0_range[0]:self.level0_range[1]] = True
|
| 332 |
+
elif self.current_level == "level1":
|
| 333 |
+
if self.rows_in_level == self.num_level1_rows:
|
| 334 |
+
mask[:, self.special_tokens["end_of_level1"]] = True
|
| 335 |
+
elif self.tokens_in_row == self.num_level1_tokens:
|
| 336 |
+
mask[:, self.special_tokens["end_of_line"]] = True
|
| 337 |
+
else:
|
| 338 |
+
mask[:, self.level1_range[0]:self.level1_range[1]] = True
|
| 339 |
+
else: # Between structure tokens
|
| 340 |
+
if last_token == self.start_of_image_token_id:
|
| 341 |
+
mask[:, self.special_tokens["start_of_level0"]] = True
|
| 342 |
+
elif last_token == self.end_of_level0_token_id:
|
| 343 |
+
mask[:, self.special_tokens["start_of_level1"]] = True
|
| 344 |
+
elif last_token == self.end_of_level1_token_id:
|
| 345 |
+
mask[:, self.special_tokens["end_of_image"]] = True
|
| 346 |
+
elif last_token is None and input_ids.shape[1] == 0: # Very first token is image?
|
| 347 |
+
mask[:, self.start_of_image_token_id] = True
|
| 348 |
+
else: # Allow relevant structural tokens if needed
|
| 349 |
+
mask[:, self.special_tokens["start_of_level0"]] = True
|
| 350 |
+
mask[:, self.special_tokens["start_of_level1"]] = True
|
| 351 |
+
mask[:, self.special_tokens["end_of_image"]] = True
|
| 352 |
+
|
| 353 |
+
else:
|
| 354 |
+
# Allow *all* tokens by default...
|
| 355 |
+
mask[:, :] = True
|
| 356 |
+
# ...then specifically *disallow* image content and intermediate structure tokens
|
| 357 |
+
mask[:, self.level0_range[0]:self.level0_range[1]] = False
|
| 358 |
+
mask[:, self.level1_range[0]:self.level1_range[1]] = False
|
| 359 |
+
mask[:, self.special_tokens["start_of_level0"]] = False
|
| 360 |
+
mask[:, self.special_tokens["end_of_level0"]] = False
|
| 361 |
+
mask[:, self.special_tokens["start_of_level1"]] = False
|
| 362 |
+
mask[:, self.special_tokens["end_of_level1"]] = False
|
| 363 |
+
mask[:, self.special_tokens["end_of_line"]] = False # EOL only allowed within image context
|
| 364 |
+
|
| 365 |
+
# Ensure the specific allowed tokens for text phase are indeed allowed
|
| 366 |
+
# (This overrides any potential disallowing above if IDs overlap, e.g., if EOS was in image range)
|
| 367 |
+
mask[:, self.special_tokens["end_of_text"]] = True
|
| 368 |
+
mask[:, self.special_tokens["start_of_image"]] = True
|
| 369 |
+
|
| 370 |
+
# Apply the mask
|
| 371 |
+
scores[~mask] = -float("Inf")
|
| 372 |
+
|
| 373 |
+
# Handle edge case: If all tokens are masked
|
| 374 |
+
if not torch.any(scores > -float("Inf"), dim=-1).all():
|
| 375 |
+
print("WARN: All tokens masked, allowing EOS.")
|
| 376 |
+
# Allow EOS and potentially other safe tokens if needed
|
| 377 |
+
scores[:] = -float("Inf") # Reset all to -inf first
|
| 378 |
+
scores[:, self.special_tokens["end_of_text"]] = 0
|
| 379 |
+
|
| 380 |
+
# --- Step 3: Apply Dynamic Sampling ---
|
| 381 |
+
current_temp, current_top_k, current_top_p = self.default_temp, self.default_top_k, self.default_top_p
|
| 382 |
+
if self.in_level0_mode:
|
| 383 |
+
current_temp, current_top_k, current_top_p = self.level0_temp, self.level0_top_k, self.level0_top_p
|
| 384 |
+
elif self.in_level1_mode:
|
| 385 |
+
current_temp, current_top_k, current_top_p = self.level1_temp, self.level1_top_k, self.level1_top_p
|
| 386 |
+
|
| 387 |
+
scores = self._apply_sampling(scores, current_temp, current_top_k, current_top_p)
|
| 388 |
+
|
| 389 |
+
return scores
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def replace_placeholder_with_list(
|
| 393 |
+
tensor_a: torch.Tensor,
|
| 394 |
+
tensor_b: torch.Tensor,
|
| 395 |
+
placeholder_value: Union[int, float]
|
| 396 |
+
) -> torch.Tensor:
|
| 397 |
+
if tensor_a.dim() != 1:
|
| 398 |
+
raise ValueError("Input tensor_a must be 1-dimensional.")
|
| 399 |
+
|
| 400 |
+
indices = torch.where(tensor_a == placeholder_value)[0]
|
| 401 |
+
|
| 402 |
+
if len(indices) == 0:
|
| 403 |
+
# Placeholder not found, return the original tensor
|
| 404 |
+
print(
|
| 405 |
+
f"Warning: Placeholder value {placeholder_value} not found in the tensor. Returning original tensor.")
|
| 406 |
+
return tensor_a
|
| 407 |
+
|
| 408 |
+
# Get the index of the *first* occurrence
|
| 409 |
+
idx = indices[0].item()
|
| 410 |
+
|
| 411 |
+
result_tensor = torch.cat((tensor_a[:idx], tensor_b.to(tensor_a), tensor_a[idx + 1:]), dim=0)
|
| 412 |
+
return result_tensor
|
merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_dualvitok.py
ADDED
|
@@ -0,0 +1,653 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import math
|
| 6 |
+
from typing import Optional, Tuple, Union, List, Callable
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch.nn import Module
|
| 12 |
+
|
| 13 |
+
from einops import rearrange, repeat, pack, unpack
|
| 14 |
+
from einx import get_at
|
| 15 |
+
|
| 16 |
+
from torch.utils.checkpoint import checkpoint
|
| 17 |
+
from transformers import AutoImageProcessor
|
| 18 |
+
from transformers.modeling_utils import PreTrainedModel, get_parameter_device, get_parameter_dtype
|
| 19 |
+
|
| 20 |
+
from .configuration_dualvitok import DualViTokConfig
|
| 21 |
+
from .modeling_movqgan import MoVQModel, MoVQEncoder, MoVQDecoder, Decoder
|
| 22 |
+
|
| 23 |
+
from .configuration_qwen2vit import Qwen2VLVisionConfig
|
| 24 |
+
from .modeling_qwen2vit import Qwen2VisionTransformerPretrainedModel, \
|
| 25 |
+
VisionRotaryEmbedding, Qwen2VLBatchVisionBlock
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
import xformers.ops as xops
|
| 29 |
+
|
| 30 |
+
is_xformers_available = True
|
| 31 |
+
except Exception as e:
|
| 32 |
+
is_xformers_available = False
|
| 33 |
+
|
| 34 |
+
if torch.__version__ > "2.1.2":
|
| 35 |
+
IS_SDPA_AVAILABLE = True
|
| 36 |
+
else:
|
| 37 |
+
IS_SDPA_AVAILABLE = False
|
| 38 |
+
|
| 39 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
| 40 |
+
sys.path.append(cur_dir)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# helper functions
|
| 44 |
+
|
| 45 |
+
def exists(v):
|
| 46 |
+
return v is not None
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def identity(t):
|
| 50 |
+
return t
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def default(v, d):
|
| 54 |
+
return v if exists(v) else d
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def pack_one(t, pattern):
|
| 58 |
+
packed, packed_shape = pack([t], pattern)
|
| 59 |
+
|
| 60 |
+
def inverse(out, inv_pattern=None):
|
| 61 |
+
inv_pattern = default(inv_pattern, pattern)
|
| 62 |
+
out, = unpack(out, packed_shape, inv_pattern)
|
| 63 |
+
return out
|
| 64 |
+
|
| 65 |
+
return packed, inverse
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# class
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class SimVQ(Module):
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
dim,
|
| 75 |
+
codebook_size,
|
| 76 |
+
codebook_transform: Module | None = None,
|
| 77 |
+
init_fn: Callable = identity,
|
| 78 |
+
channel_first=True,
|
| 79 |
+
input_to_quantize_commit_loss_weight=0.25,
|
| 80 |
+
commitment_weight=1.,
|
| 81 |
+
frozen_codebook_dim=None # frozen codebook dim could have different dimensions than projection
|
| 82 |
+
):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.codebook_size = codebook_size
|
| 85 |
+
self.channel_first = channel_first
|
| 86 |
+
|
| 87 |
+
frozen_codebook_dim = default(frozen_codebook_dim, dim)
|
| 88 |
+
codebook = torch.randn(codebook_size, frozen_codebook_dim) * (frozen_codebook_dim ** -0.5)
|
| 89 |
+
codebook = init_fn(codebook)
|
| 90 |
+
|
| 91 |
+
# the codebook is actually implicit from a linear layer from frozen gaussian or uniform
|
| 92 |
+
|
| 93 |
+
if not exists(codebook_transform):
|
| 94 |
+
codebook_transform = nn.Linear(frozen_codebook_dim, dim, bias=False)
|
| 95 |
+
|
| 96 |
+
self.code_transform = codebook_transform
|
| 97 |
+
|
| 98 |
+
self.register_buffer('frozen_codebook', codebook)
|
| 99 |
+
|
| 100 |
+
# commit loss weighting - weighing input to quantize a bit less is crucial for it to work
|
| 101 |
+
self.input_to_quantize_commit_loss_weight = input_to_quantize_commit_loss_weight
|
| 102 |
+
|
| 103 |
+
# total commitment loss weight
|
| 104 |
+
self.commitment_weight = commitment_weight
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def codebook(self):
|
| 108 |
+
return self.code_transform(self.frozen_codebook)
|
| 109 |
+
|
| 110 |
+
def indices_to_codes(
|
| 111 |
+
self,
|
| 112 |
+
indices
|
| 113 |
+
):
|
| 114 |
+
implicit_codebook = self.codebook
|
| 115 |
+
|
| 116 |
+
frozen_codes = get_at('[c] d, b ... -> b ... d', self.frozen_codebook, indices)
|
| 117 |
+
quantized = self.code_transform(frozen_codes)
|
| 118 |
+
|
| 119 |
+
if self.channel_first:
|
| 120 |
+
quantized = rearrange(quantized, 'b ... d -> b d ...')
|
| 121 |
+
|
| 122 |
+
return quantized
|
| 123 |
+
|
| 124 |
+
def forward(
|
| 125 |
+
self,
|
| 126 |
+
x
|
| 127 |
+
):
|
| 128 |
+
if self.channel_first:
|
| 129 |
+
x = rearrange(x, 'b d ... -> b ... d')
|
| 130 |
+
|
| 131 |
+
x, inverse_pack = pack_one(x, 'b * d')
|
| 132 |
+
|
| 133 |
+
implicit_codebook = self.codebook
|
| 134 |
+
|
| 135 |
+
with torch.no_grad():
|
| 136 |
+
dist = torch.cdist(x, implicit_codebook)
|
| 137 |
+
indices = dist.argmin(dim=-1)
|
| 138 |
+
|
| 139 |
+
# select codes
|
| 140 |
+
|
| 141 |
+
quantized = get_at('[c] d, b n -> b n d', implicit_codebook, indices)
|
| 142 |
+
|
| 143 |
+
# commit loss and straight through, as was done in the paper
|
| 144 |
+
|
| 145 |
+
commit_loss = (
|
| 146 |
+
F.mse_loss(x.detach(), quantized) +
|
| 147 |
+
F.mse_loss(x, quantized.detach()) * self.input_to_quantize_commit_loss_weight
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
quantized = (quantized - x).detach() + x
|
| 151 |
+
|
| 152 |
+
quantized = inverse_pack(quantized)
|
| 153 |
+
indices = inverse_pack(indices, 'b *')
|
| 154 |
+
|
| 155 |
+
if self.channel_first:
|
| 156 |
+
quantized = rearrange(quantized, 'b ... d-> b d ...')
|
| 157 |
+
|
| 158 |
+
return quantized, commit_loss * self.commitment_weight, indices
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def init_weights(m):
|
| 162 |
+
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.LayerNorm):
|
| 163 |
+
if m.weight is not None:
|
| 164 |
+
nn.init.constant_(m.weight, 1)
|
| 165 |
+
if m.bias is not None:
|
| 166 |
+
nn.init.constant_(m.bias, 0)
|
| 167 |
+
elif isinstance(m, nn.Linear):
|
| 168 |
+
nn.init.xavier_uniform_(m.weight)
|
| 169 |
+
if m.bias is not None:
|
| 170 |
+
nn.init.constant_(m.bias, 0)
|
| 171 |
+
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) \
|
| 172 |
+
or isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
|
| 173 |
+
w = m.weight.data
|
| 174 |
+
nn.init.xavier_uniform_(w)
|
| 175 |
+
if m.bias is not None:
|
| 176 |
+
nn.init.constant_(m.bias, 0)
|
| 177 |
+
elif isinstance(m, nn.Embedding):
|
| 178 |
+
nn.init.normal_(m.weight, mean=0, std=1)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class ScalingLayerForQwen2ViT:
|
| 182 |
+
def __init__(
|
| 183 |
+
self,
|
| 184 |
+
min_pixels: int = 56 * 56,
|
| 185 |
+
max_pixels: int = 28 * 28 * 1280,
|
| 186 |
+
patch_size: int = 14,
|
| 187 |
+
temporal_patch_size: int = 2,
|
| 188 |
+
merge_size: int = 2,
|
| 189 |
+
**kwargs,
|
| 190 |
+
) -> None:
|
| 191 |
+
super().__init__(**kwargs)
|
| 192 |
+
OPENAI_CLIP_MEAN = torch.as_tensor([0.48145466, 0.4578275, 0.40821073])[None, :, None, None]
|
| 193 |
+
OPENAI_CLIP_STD = torch.as_tensor([0.26862954, 0.26130258, 0.27577711])[None, :, None, None]
|
| 194 |
+
|
| 195 |
+
self.image_mean = OPENAI_CLIP_MEAN
|
| 196 |
+
self.image_std = OPENAI_CLIP_STD
|
| 197 |
+
self.min_pixels = min_pixels
|
| 198 |
+
self.max_pixels = max_pixels
|
| 199 |
+
self.patch_size = patch_size
|
| 200 |
+
self.temporal_patch_size = temporal_patch_size
|
| 201 |
+
self.merge_size = merge_size
|
| 202 |
+
self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
|
| 203 |
+
|
| 204 |
+
def __call__(self, images):
|
| 205 |
+
if images.ndim == 4:
|
| 206 |
+
images = images.unsqueeze(1)
|
| 207 |
+
batch_size, temporal, channel, height, width = images.shape
|
| 208 |
+
|
| 209 |
+
factor = self.patch_size * self.merge_size
|
| 210 |
+
|
| 211 |
+
resized_height, resized_width = height // factor * factor, width // factor * factor
|
| 212 |
+
|
| 213 |
+
images = (images + 1) / 2 # rescale to [0, 1.]
|
| 214 |
+
|
| 215 |
+
images = torch.nn.functional.interpolate(
|
| 216 |
+
images.flatten(0, 1).float(),
|
| 217 |
+
size=(resized_height, resized_width),
|
| 218 |
+
mode='bicubic',
|
| 219 |
+
align_corners=False,
|
| 220 |
+
antialias=True
|
| 221 |
+
).to(images.dtype)
|
| 222 |
+
|
| 223 |
+
images = images.clamp(0, 1) # rescale to [0, 1.]
|
| 224 |
+
images = ((images - self.image_mean.to(images)) / self.image_std.to(images))
|
| 225 |
+
|
| 226 |
+
images = rearrange(images, '(b t) c h w -> b t c h w', b=batch_size, t=temporal)
|
| 227 |
+
if temporal == 1:
|
| 228 |
+
images = images.repeat_interleave(self.temporal_patch_size, dim=1)
|
| 229 |
+
temporal = self.temporal_patch_size
|
| 230 |
+
|
| 231 |
+
grid_t = temporal // self.temporal_patch_size
|
| 232 |
+
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
|
| 233 |
+
|
| 234 |
+
images = images.reshape(
|
| 235 |
+
batch_size * grid_t,
|
| 236 |
+
self.temporal_patch_size,
|
| 237 |
+
channel,
|
| 238 |
+
-1
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
images = rearrange(images, 'b p c n -> b n (c p)')
|
| 242 |
+
images = images.reshape(
|
| 243 |
+
batch_size * grid_t,
|
| 244 |
+
grid_h // self.merge_size,
|
| 245 |
+
self.merge_size,
|
| 246 |
+
self.patch_size,
|
| 247 |
+
grid_w // self.merge_size,
|
| 248 |
+
self.merge_size,
|
| 249 |
+
self.patch_size,
|
| 250 |
+
-1
|
| 251 |
+
)
|
| 252 |
+
images = rearrange(images, 'b h k s1 w l s2 n -> (b h w k l) (n s1 s2)')
|
| 253 |
+
|
| 254 |
+
return dict(image=images, image_grid_thw=torch.as_tensor([[grid_t, grid_h, grid_w] for _ in range(batch_size)]))
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class SemanticEncoder(nn.Module):
|
| 258 |
+
def __init__(self,
|
| 259 |
+
semantic_encoder,
|
| 260 |
+
z_channels=4,
|
| 261 |
+
num_blocks=2,
|
| 262 |
+
embed_dim=1280,
|
| 263 |
+
proj_layer='linear',
|
| 264 |
+
attn_implementation='xformers',
|
| 265 |
+
target_mlp='identity',
|
| 266 |
+
):
|
| 267 |
+
super().__init__()
|
| 268 |
+
self.embed_dim = embed_dim
|
| 269 |
+
|
| 270 |
+
if isinstance(semantic_encoder, str):
|
| 271 |
+
self.model = Qwen2VisionTransformerPretrainedModel.from_pretrained(
|
| 272 |
+
semantic_encoder,
|
| 273 |
+
attn_implementation=attn_implementation
|
| 274 |
+
)
|
| 275 |
+
elif isinstance(semantic_encoder, dict):
|
| 276 |
+
config = Qwen2VLVisionConfig(**semantic_encoder, attn_implementation=attn_implementation)
|
| 277 |
+
self.model = Qwen2VisionTransformerPretrainedModel(config)
|
| 278 |
+
else:
|
| 279 |
+
raise ValueError(f"Invalid semantic_encoder: {semantic_encoder}")
|
| 280 |
+
input_channels = self.model.config.hidden_size
|
| 281 |
+
|
| 282 |
+
for p in self.model.parameters():
|
| 283 |
+
p.requires_grad = False
|
| 284 |
+
|
| 285 |
+
self.proj_in = nn.Conv2d(input_channels, embed_dim, 1, 1) if input_channels != embed_dim else nn.Identity()
|
| 286 |
+
|
| 287 |
+
config = Qwen2VLVisionConfig(depth=num_blocks,
|
| 288 |
+
embed_dim=embed_dim, )
|
| 289 |
+
head_dim = config.embed_dim // config.num_heads
|
| 290 |
+
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
| 291 |
+
|
| 292 |
+
self.blocks = nn.ModuleList(
|
| 293 |
+
[Qwen2VLBatchVisionBlock(config, attn_implementation) for _ in range(num_blocks)]
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
if proj_layer == 'norm_linear':
|
| 297 |
+
self.proj_out = nn.Sequential(
|
| 298 |
+
nn.LayerNorm(embed_dim),
|
| 299 |
+
nn.Linear(
|
| 300 |
+
embed_dim,
|
| 301 |
+
z_channels,
|
| 302 |
+
)
|
| 303 |
+
)
|
| 304 |
+
elif proj_layer == 'linear':
|
| 305 |
+
self.proj_out = nn.Sequential(
|
| 306 |
+
nn.Linear(
|
| 307 |
+
embed_dim,
|
| 308 |
+
z_channels,
|
| 309 |
+
)
|
| 310 |
+
)
|
| 311 |
+
elif proj_layer == 'mlp':
|
| 312 |
+
self.proj_out = nn.Sequential(
|
| 313 |
+
nn.Linear(embed_dim, embed_dim),
|
| 314 |
+
nn.Tanh(),
|
| 315 |
+
nn.Linear(embed_dim, z_channels),
|
| 316 |
+
)
|
| 317 |
+
else:
|
| 318 |
+
raise RuntimeError(f"Wrong proj layer. Got {proj_layer}")
|
| 319 |
+
|
| 320 |
+
if target_mlp == 'identity':
|
| 321 |
+
self.target_mlp = nn.Sequential(
|
| 322 |
+
nn.Identity(),
|
| 323 |
+
)
|
| 324 |
+
elif target_mlp == 'norm':
|
| 325 |
+
self.target_mlp = nn.Sequential(
|
| 326 |
+
nn.LayerNorm(input_channels, eps=1e-6, elementwise_affine=False),
|
| 327 |
+
)
|
| 328 |
+
self.init_weight()
|
| 329 |
+
|
| 330 |
+
def init_weight(self):
|
| 331 |
+
self.proj_in.apply(init_weights)
|
| 332 |
+
self.blocks.apply(init_weights)
|
| 333 |
+
self.proj_out.apply(init_weights)
|
| 334 |
+
self.target_mlp.apply(init_weights)
|
| 335 |
+
|
| 336 |
+
def rot_pos_emb(self, grid_thw, max_seq_len):
|
| 337 |
+
pos_ids = torch.zeros((len(grid_thw), max_seq_len, 2), dtype=torch.long)
|
| 338 |
+
for idx, (t, h, w) in enumerate(grid_thw):
|
| 339 |
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
| 340 |
+
hpos_ids = hpos_ids.flatten()
|
| 341 |
+
|
| 342 |
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
| 343 |
+
wpos_ids = wpos_ids.flatten()
|
| 344 |
+
|
| 345 |
+
current_pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
|
| 346 |
+
pos_ids[idx, :current_pos_ids.shape[0]] = current_pos_ids
|
| 347 |
+
max_grid_size = grid_thw[:, 1:].max()
|
| 348 |
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
| 349 |
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(2)
|
| 350 |
+
return rotary_pos_emb
|
| 351 |
+
|
| 352 |
+
def forward(self, x, grid_thw):
|
| 353 |
+
x = self.model(x, grid_thw=grid_thw)
|
| 354 |
+
|
| 355 |
+
x = x_target = self.target_mlp(x)
|
| 356 |
+
|
| 357 |
+
x = F.linear(x,
|
| 358 |
+
self.proj_in.weight.view(self.proj_in.weight.shape[0], -1),
|
| 359 |
+
self.proj_in.bias)
|
| 360 |
+
|
| 361 |
+
new_grid_thw = torch.as_tensor([[t, h // 2, w // 2] for t, h, w in grid_thw])
|
| 362 |
+
|
| 363 |
+
seq_lens = [t_i * h_i * w_i for t_i, h_i, w_i in new_grid_thw]
|
| 364 |
+
max_seq_len = max(seq_lens)
|
| 365 |
+
|
| 366 |
+
x = rearrange(x, '(b h w) c -> b (h w) c', h=new_grid_thw[0, 1], w=new_grid_thw[0, 2])
|
| 367 |
+
|
| 368 |
+
rotary_pos_emb = self.rot_pos_emb(new_grid_thw, max_seq_len)
|
| 369 |
+
|
| 370 |
+
for blk in self.blocks:
|
| 371 |
+
x = blk(x, rotary_pos_emb=rotary_pos_emb)
|
| 372 |
+
|
| 373 |
+
x = self.proj_out(x) # [b, max_length, d]
|
| 374 |
+
|
| 375 |
+
t, h, w = new_grid_thw[0]
|
| 376 |
+
b = len(grid_thw)
|
| 377 |
+
x = rearrange(x, 'b (h w) c ->b c h w', b=b, h=h, w=w)
|
| 378 |
+
x_target = rearrange(x_target, '(b h w) c ->b c h w', b=b, h=h, w=w)
|
| 379 |
+
return x, x_target
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
class SemanticDecoder(nn.Module):
|
| 383 |
+
def __init__(self,
|
| 384 |
+
z_channels=4,
|
| 385 |
+
embed_dim=1280,
|
| 386 |
+
num_blocks=2,
|
| 387 |
+
output_channels=1280,
|
| 388 |
+
attn_implementation='xformers',
|
| 389 |
+
proj_layer='linear_norm'):
|
| 390 |
+
super().__init__()
|
| 391 |
+
self.proj_in = nn.Linear(z_channels, embed_dim)
|
| 392 |
+
|
| 393 |
+
self.output_channels = output_channels
|
| 394 |
+
config = Qwen2VLVisionConfig(depth=num_blocks, embed_dim=embed_dim)
|
| 395 |
+
|
| 396 |
+
self.blocks = nn.ModuleList(
|
| 397 |
+
[Qwen2VLBatchVisionBlock(config, attn_implementation) for _ in range(num_blocks)]
|
| 398 |
+
)
|
| 399 |
+
head_dim = config.embed_dim // config.num_heads
|
| 400 |
+
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
| 401 |
+
|
| 402 |
+
if proj_layer == 'norm_linear':
|
| 403 |
+
self.proj_out = nn.Sequential(
|
| 404 |
+
nn.LayerNorm(embed_dim),
|
| 405 |
+
nn.Linear(embed_dim, output_channels),
|
| 406 |
+
)
|
| 407 |
+
elif proj_layer == 'linear':
|
| 408 |
+
self.proj_out = nn.Sequential(
|
| 409 |
+
nn.Linear(embed_dim, output_channels)
|
| 410 |
+
)
|
| 411 |
+
elif proj_layer == 'mlp':
|
| 412 |
+
self.proj_out = nn.Sequential(
|
| 413 |
+
nn.Linear(embed_dim, embed_dim),
|
| 414 |
+
nn.Tanh(),
|
| 415 |
+
nn.Linear(embed_dim, output_channels),
|
| 416 |
+
)
|
| 417 |
+
elif proj_layer == 'linear_norm':
|
| 418 |
+
self.proj_out = nn.Sequential(
|
| 419 |
+
nn.Linear(embed_dim, output_channels),
|
| 420 |
+
nn.LayerNorm(output_channels),
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
self.apply(init_weights)
|
| 424 |
+
|
| 425 |
+
@property
|
| 426 |
+
def last_layer(self):
|
| 427 |
+
return self.proj_out[-1].weight
|
| 428 |
+
|
| 429 |
+
def rot_pos_emb(self, grid_thw, max_seq_len):
|
| 430 |
+
pos_ids = torch.zeros((len(grid_thw), max_seq_len, 2), dtype=torch.long)
|
| 431 |
+
for idx, (t, h, w) in enumerate(grid_thw):
|
| 432 |
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
| 433 |
+
hpos_ids = hpos_ids.flatten()
|
| 434 |
+
|
| 435 |
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
| 436 |
+
wpos_ids = wpos_ids.flatten()
|
| 437 |
+
|
| 438 |
+
current_pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
|
| 439 |
+
pos_ids[idx, :current_pos_ids.shape[0]] = current_pos_ids
|
| 440 |
+
max_grid_size = grid_thw[:, 1:].max()
|
| 441 |
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
| 442 |
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(2)
|
| 443 |
+
return rotary_pos_emb
|
| 444 |
+
|
| 445 |
+
def forward(self, z: torch.Tensor):
|
| 446 |
+
x = z
|
| 447 |
+
|
| 448 |
+
b, c, h, w = x.shape
|
| 449 |
+
|
| 450 |
+
x = rearrange(x, 'b c h w -> b (h w) c')
|
| 451 |
+
|
| 452 |
+
grid_thw = torch.as_tensor([[1, h, w] for _ in range(b)])
|
| 453 |
+
seq_lens = [t * h * w for t, h, w in grid_thw]
|
| 454 |
+
max_seq_len = max(seq_lens)
|
| 455 |
+
|
| 456 |
+
x = self.proj_in(x)
|
| 457 |
+
|
| 458 |
+
rotary_pos_emb = self.rot_pos_emb(grid_thw, max_seq_len)
|
| 459 |
+
|
| 460 |
+
for blk in self.blocks:
|
| 461 |
+
x = blk(x, rotary_pos_emb=rotary_pos_emb)
|
| 462 |
+
|
| 463 |
+
x = self.proj_out(x)
|
| 464 |
+
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
| 465 |
+
return x
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
class DualViTokPretrainModel(PreTrainedModel):
|
| 469 |
+
"""
|
| 470 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 471 |
+
models.
|
| 472 |
+
"""
|
| 473 |
+
|
| 474 |
+
config_class = DualViTokConfig
|
| 475 |
+
base_model_prefix = "dualvitok"
|
| 476 |
+
main_input_name = "pixel_values"
|
| 477 |
+
_no_split_modules = ["BatchQwen2VLVisionBlock", "MoVQResnetBlock", "MoVQAttnBlock", "MoVQResnetTemporalBlock"]
|
| 478 |
+
_supports_flash_attn_2 = True
|
| 479 |
+
_supports_sdpa = True
|
| 480 |
+
_supports_cache_class = True
|
| 481 |
+
_supports_static_cache = True
|
| 482 |
+
|
| 483 |
+
def _init_weights(self, module):
|
| 484 |
+
if isinstance(module, (nn.Conv2d, nn.Conv3d)):
|
| 485 |
+
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
| 486 |
+
# copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
|
| 487 |
+
elif isinstance(module, nn.Linear):
|
| 488 |
+
nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
|
| 489 |
+
if module.bias is not None:
|
| 490 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
|
| 491 |
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
| 492 |
+
nn.init.uniform_(module.bias, -bound, bound)
|
| 493 |
+
elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
|
| 494 |
+
nn.init.constant_(module.weight, 1)
|
| 495 |
+
nn.init.constant_(module.bias, 0)
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
class DualViTok(DualViTokPretrainModel):
|
| 499 |
+
def __init__(self, config: DualViTokConfig):
|
| 500 |
+
super().__init__(config)
|
| 501 |
+
self.config = config
|
| 502 |
+
|
| 503 |
+
self._semantic_channel = config.semantic_encoder.z_channels
|
| 504 |
+
self._pixel_channel = config.pixel_encoder.z_channels
|
| 505 |
+
|
| 506 |
+
self.semantic_encoder = SemanticEncoder(
|
| 507 |
+
semantic_encoder=config.semantic_encoder.pretrained_semantic_encoder,
|
| 508 |
+
z_channels=config.semantic_encoder.z_channels,
|
| 509 |
+
num_blocks=config.semantic_encoder.num_blocks,
|
| 510 |
+
embed_dim=config.semantic_encoder.embed_dim,
|
| 511 |
+
proj_layer=config.semantic_encoder.out_layer,
|
| 512 |
+
attn_implementation=config.attn_implementation,
|
| 513 |
+
target_mlp=config.semantic_encoder.target_mlp, )
|
| 514 |
+
self.semantic_decoder = SemanticDecoder(
|
| 515 |
+
z_channels=config.semantic_decoder.z_channels,
|
| 516 |
+
embed_dim=config.semantic_decoder.embed_dim,
|
| 517 |
+
num_blocks=config.semantic_decoder.num_blocks,
|
| 518 |
+
output_channels=config.semantic_decoder.out_channels,
|
| 519 |
+
attn_implementation=config.attn_implementation,
|
| 520 |
+
proj_layer=config.semantic_decoder.out_layer,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
if config.semantic_quantizer_type.lower() == 'simvq':
|
| 524 |
+
self.semantic_quantizer = SimVQ(
|
| 525 |
+
dim=config.semantic_encoder.z_channels,
|
| 526 |
+
codebook_size=config.semantic_quantizer_codebook_size,
|
| 527 |
+
)
|
| 528 |
+
elif config.semantic_quantizer_type.lower() == 'vq':
|
| 529 |
+
raise NotImplementedError
|
| 530 |
+
self.semantic_quantizer = VQ(
|
| 531 |
+
dim=config.semantic_encoder.z_channels,
|
| 532 |
+
codebook_size=config.semantic_quantizer_codebook_size,
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
self.pixel_encoder = MoVQEncoder(config.pixel_encoder)
|
| 536 |
+
self.pixel_quant_conv = nn.Conv2d(config.pixel_encoder.z_channels, config.pixel_encoder.embed_dim, 1)
|
| 537 |
+
|
| 538 |
+
if config.pixel_quantizer_type.lower() == 'simvq':
|
| 539 |
+
self.pixel_quantizer = SimVQ(
|
| 540 |
+
dim=config.pixel_encoder.z_channels,
|
| 541 |
+
codebook_size=config.pixel_quantizer_codebook_size,
|
| 542 |
+
)
|
| 543 |
+
elif config.pixel_quantizer_type.lower() == 'vq':
|
| 544 |
+
raise NotImplementedError
|
| 545 |
+
self.pixel_quantizer = VQ(
|
| 546 |
+
dim=config.pixel_encoder.z_channels,
|
| 547 |
+
codebook_size=config.pixel_quantizer_codebook_size,
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
self.pixel_post_quant_conv = nn.Conv2d(config.pixel_decoder.embed_dim,
|
| 551 |
+
config.pixel_decoder.z_channels, 1)
|
| 552 |
+
|
| 553 |
+
self.pixel_decoder = MoVQDecoder(config.pixel_decoder)
|
| 554 |
+
|
| 555 |
+
self.scaling_layer = ScalingLayerForQwen2ViT()
|
| 556 |
+
|
| 557 |
+
@property
|
| 558 |
+
def device(self):
|
| 559 |
+
return get_parameter_device(self)
|
| 560 |
+
|
| 561 |
+
@property
|
| 562 |
+
def dtype(self):
|
| 563 |
+
return get_parameter_dtype(self)
|
| 564 |
+
|
| 565 |
+
@property
|
| 566 |
+
def pixel_channel(self):
|
| 567 |
+
return self._pixel_channel
|
| 568 |
+
|
| 569 |
+
@property
|
| 570 |
+
def semantic_channel(self):
|
| 571 |
+
return self._semantic_channel
|
| 572 |
+
|
| 573 |
+
def encode(self, image: torch.FloatTensor):
|
| 574 |
+
scale_output = self.scaling_layer(image)
|
| 575 |
+
image, image_grid_thw, image_gen = scale_output['image'], scale_output['image_grid_thw'], image
|
| 576 |
+
|
| 577 |
+
h_semantic, target_semantic = self.semantic_encoder(image, image_grid_thw)
|
| 578 |
+
quant_semantic, emb_loss_semantic, info_semantic = self.semantic_quantizer(h_semantic.float())
|
| 579 |
+
|
| 580 |
+
h_pixel = self.pixel_encoder(image_gen)
|
| 581 |
+
h_pixel = self.pixel_quant_conv(h_pixel)
|
| 582 |
+
|
| 583 |
+
quant_pixel, emb_loss_pixel, info_pixel = self.pixel_quantizer(h_pixel.float())
|
| 584 |
+
|
| 585 |
+
return (quant_semantic, emb_loss_semantic, info_semantic, target_semantic), \
|
| 586 |
+
(quant_pixel, emb_loss_pixel, info_pixel)
|
| 587 |
+
|
| 588 |
+
def encode_code(self, *args, **kwargs):
|
| 589 |
+
(_, _, semantic_indices, _), \
|
| 590 |
+
(_, _, pixel_indices) = self.encode(*args, **kwargs)
|
| 591 |
+
return semantic_indices, pixel_indices
|
| 592 |
+
|
| 593 |
+
def indices_to_codes(self, semantic_indices, pixel_indices):
|
| 594 |
+
quant_semantic = self.semantic_quantizer.indices_to_codes(semantic_indices)
|
| 595 |
+
quant_pixel = self.pixel_quantizer.indices_to_codes(pixel_indices)
|
| 596 |
+
return quant_semantic, quant_pixel
|
| 597 |
+
|
| 598 |
+
def encode_semantic(self, image: torch.FloatTensor):
|
| 599 |
+
scale_output = self.scaling_layer(image)
|
| 600 |
+
image, image_grid_thw, image_gen = scale_output['image'], scale_output['image_grid_thw'], image
|
| 601 |
+
|
| 602 |
+
h_semantic, target_semantic = self.semantic_encoder(image, image_grid_thw)
|
| 603 |
+
quant_semantic, emb_loss_semantic, info_semantic = self.semantic_quantizer(h_semantic.float())
|
| 604 |
+
return quant_semantic, emb_loss_semantic, info_semantic, target_semantic
|
| 605 |
+
|
| 606 |
+
def merge_quants(self, quant_semantic: torch.Tensor, quant_pixel: torch.Tensor):
|
| 607 |
+
quant_semantic_resized = F.interpolate(
|
| 608 |
+
quant_semantic, quant_pixel.shape[-2:], mode='bicubic'
|
| 609 |
+
).to(quant_semantic.dtype)
|
| 610 |
+
quant_semantic = quant_semantic_resized
|
| 611 |
+
|
| 612 |
+
quant = torch.cat([quant_semantic, quant_pixel], dim=1)
|
| 613 |
+
|
| 614 |
+
return quant
|
| 615 |
+
|
| 616 |
+
def decode(self, quant_semantic: torch.Tensor, quant_pixel: torch.Tensor, ):
|
| 617 |
+
quant = self.merge_quants(quant_semantic, quant_pixel)
|
| 618 |
+
quant2 = self.pixel_post_quant_conv(quant)
|
| 619 |
+
x = self.pixel_decoder(quant2, quant)
|
| 620 |
+
return x
|
| 621 |
+
|
| 622 |
+
def decode_code(self, semantic_indices, pixel_indices):
|
| 623 |
+
quant_semantic = self.semantic_quantizer.indices_to_codes(semantic_indices)
|
| 624 |
+
quant_pixel = self.pixel_quantizer.indices_to_codes(pixel_indices)
|
| 625 |
+
return self.decode(quant_semantic, quant_pixel)
|
| 626 |
+
|
| 627 |
+
def decode_semantic(self, x: List[torch.Tensor]):
|
| 628 |
+
return self.semantic_decoder(x)
|
| 629 |
+
|
| 630 |
+
def forward(self, pixel_values: torch.FloatTensor):
|
| 631 |
+
(quant_semantic, diff_semantic, _, target_semantic), \
|
| 632 |
+
(quant_pixel, diff_pixel, _) = self.encode(pixel_values)
|
| 633 |
+
dec = self.decode(quant_semantic, quant_pixel)
|
| 634 |
+
dec_semantic = self.decode_semantic(quant_semantic)
|
| 635 |
+
return (dec_semantic, diff_semantic, target_semantic), (dec, diff_pixel)
|
| 636 |
+
|
| 637 |
+
def build_sdxl_decoder(self, path='ILLUME-MLLM/dualvitok-sdxl-decoder',
|
| 638 |
+
image_processor=None,
|
| 639 |
+
torch_dtype=torch.float16,
|
| 640 |
+
add_watermarker=False,
|
| 641 |
+
device='cuda',
|
| 642 |
+
):
|
| 643 |
+
from .sdxl_decoder_pipe import StableDiffusionXLDecoderPipeline
|
| 644 |
+
|
| 645 |
+
if image_processor is None:
|
| 646 |
+
image_processor = AutoImageProcessor.from_pretrained('ILLUME-MLLM/dualvitok', trust_remote_code=True)
|
| 647 |
+
|
| 648 |
+
return StableDiffusionXLDecoderPipeline.from_pretrained(path,
|
| 649 |
+
torch_dtype=torch_dtype,
|
| 650 |
+
add_watermarker=add_watermarker,
|
| 651 |
+
vq_model=self,
|
| 652 |
+
vq_image_processor=image_processor).to(device)
|
| 653 |
+
|
modeling_illume.py
ADDED
|
@@ -0,0 +1,883 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PyTorch ILLUME model."""
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from functools import partial
|
| 6 |
+
from typing import List, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.utils.checkpoint
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
from transformers import PreTrainedModel
|
| 13 |
+
from transformers.activations import ACT2FN
|
| 14 |
+
from transformers.cache_utils import Cache
|
| 15 |
+
from transformers.modeling_outputs import ModelOutput
|
| 16 |
+
from transformers.utils import (
|
| 17 |
+
add_start_docstrings,
|
| 18 |
+
add_start_docstrings_to_model_forward,
|
| 19 |
+
logging,
|
| 20 |
+
replace_return_docstrings,
|
| 21 |
+
)
|
| 22 |
+
from transformers.models.auto import AutoModel, AutoModelForCausalLM
|
| 23 |
+
from transformers import LogitsProcessorList
|
| 24 |
+
|
| 25 |
+
from .configuration_illume import ILLUMEConfig
|
| 26 |
+
from .modeling_qwen2vit import Qwen2VisionTransformerPretrainedModel
|
| 27 |
+
from .modeling_dualvitok import ScalingLayerForQwen2ViT, SemanticEncoder
|
| 28 |
+
from .modeling_movqgan import MoVQEncoder
|
| 29 |
+
from .inference_utils import InterleavedLogitsProcessor, \
|
| 30 |
+
parse_interleaved_text_image, calculate_image_token_num
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
from einops import rearrange
|
| 34 |
+
|
| 35 |
+
logger = logging.get_logger(__name__)
|
| 36 |
+
|
| 37 |
+
_CONFIG_FOR_DOC = "ILLUMEConfig"
|
| 38 |
+
|
| 39 |
+
# Define common resolutions
|
| 40 |
+
DEFAULT_RESOLUTIONS = [
|
| 41 |
+
(256, 256), (512, 512), (384, 640), (640, 384), (512, 384),
|
| 42 |
+
(384, 512), (256, 384), (384, 256), (256, 512), (512, 256)
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
# qwen2.5
|
| 46 |
+
special_tokens_ids = [151665, 151666, 151667, 151668, 151669, 151670, 151671]
|
| 47 |
+
start_token = 151672 + 32
|
| 48 |
+
level0_range = (start_token, start_token + 32768) # Level 0 token ID 范围
|
| 49 |
+
level1_range = (start_token + 32768, start_token + 32768 * 4) # Level 1 token ID 范围
|
| 50 |
+
|
| 51 |
+
special_tokens_dict = {
|
| 52 |
+
"start_of_image": 151665,
|
| 53 |
+
"end_of_image": 151666,
|
| 54 |
+
"start_of_level0": 151668,
|
| 55 |
+
"end_of_level0": 151669,
|
| 56 |
+
"start_of_level1": 151670,
|
| 57 |
+
"end_of_level1": 151671,
|
| 58 |
+
"end_of_line": 151667,
|
| 59 |
+
"end_of_text": 151645,
|
| 60 |
+
#
|
| 61 |
+
"level0_range": level0_range,
|
| 62 |
+
"level1_range": level1_range,
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class ILLUMECausalLMOutputWithPast(ModelOutput):
|
| 68 |
+
"""
|
| 69 |
+
Base class for ILLUME causal language model (or autoregressive) outputs.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 73 |
+
Language modeling loss (for next-token prediction).
|
| 74 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 75 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 76 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 77 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
| 78 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
| 79 |
+
|
| 80 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
| 81 |
+
`past_key_values` input) to speed up sequential decoding.
|
| 82 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 83 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 84 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 85 |
+
|
| 86 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| 87 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 88 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 89 |
+
sequence_length)`.
|
| 90 |
+
|
| 91 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 92 |
+
heads.
|
| 93 |
+
image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
|
| 94 |
+
Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
|
| 95 |
+
sequence_length, hidden_size)`.
|
| 96 |
+
|
| 97 |
+
image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
loss: Optional[torch.FloatTensor] = None
|
| 101 |
+
logits: torch.FloatTensor = None
|
| 102 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None
|
| 103 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 104 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 105 |
+
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class MLPProjector(nn.Sequential):
|
| 109 |
+
# CAbstractor
|
| 110 |
+
def __init__(self, mlp_depth, hidden_size, mm_hidden_size):
|
| 111 |
+
super(MLPProjector, self).__init__()
|
| 112 |
+
modules = [nn.Linear(mm_hidden_size, hidden_size)]
|
| 113 |
+
for _ in range(1, mlp_depth):
|
| 114 |
+
modules.append(nn.GELU())
|
| 115 |
+
modules.append(nn.Linear(hidden_size, hidden_size))
|
| 116 |
+
super(MLPProjector, self).__init__(*modules)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class ILLUMEMultiModalProjector(nn.Sequential):
|
| 120 |
+
# CAbstractor
|
| 121 |
+
def __init__(self, config):
|
| 122 |
+
super(ILLUMEMultiModalProjector, self).__init__()
|
| 123 |
+
hidden_size = config.text_config.hidden_size
|
| 124 |
+
mm_hidden_size1, mm_hidden_size2 = config.mm_projector_config['mm_hidden_size']
|
| 125 |
+
self.projector_1 = MLPProjector(mlp_depth=config.mm_projector_config['projector_cfg1']['mlp_depth'],
|
| 126 |
+
mm_hidden_size=mm_hidden_size1, hidden_size=hidden_size)
|
| 127 |
+
self.projector_2 = MLPProjector(mlp_depth=config.mm_projector_config['projector_cfg2']['mlp_depth'],
|
| 128 |
+
mm_hidden_size=mm_hidden_size2, hidden_size=hidden_size)
|
| 129 |
+
|
| 130 |
+
def forward(self, image_features):
|
| 131 |
+
image_feature_1, image_feature_2 = image_features
|
| 132 |
+
image_feature_1 = self.projector_1(image_feature_1)
|
| 133 |
+
image_feature_2 = self.projector_2(image_feature_2)
|
| 134 |
+
image_features = torch.concat([image_feature_1, image_feature_2], dim=1)
|
| 135 |
+
return image_features
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class ILLUMEDualVisionTower(nn.Module):
|
| 139 |
+
def __init__(self,
|
| 140 |
+
vision_config,
|
| 141 |
+
attn_implementation='sdpa',
|
| 142 |
+
):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self._config = vision_config
|
| 145 |
+
self.semantic_encoder = SemanticEncoder(
|
| 146 |
+
semantic_encoder=vision_config.semantic_encoder.pretrained_semantic_encoder,
|
| 147 |
+
z_channels=vision_config.semantic_encoder.z_channels,
|
| 148 |
+
num_blocks=vision_config.semantic_encoder.num_blocks,
|
| 149 |
+
embed_dim=vision_config.semantic_encoder.embed_dim,
|
| 150 |
+
proj_layer=vision_config.semantic_encoder.out_layer,
|
| 151 |
+
attn_implementation=attn_implementation,
|
| 152 |
+
target_mlp=vision_config.semantic_encoder.target_mlp, ).model
|
| 153 |
+
self.pixel_encoder = MoVQEncoder(vision_config.pixel_encoder)
|
| 154 |
+
self.scaling_layer = ScalingLayerForQwen2ViT()
|
| 155 |
+
|
| 156 |
+
def forward(self, images):
|
| 157 |
+
if isinstance(images, list) and all(x is not None and x.shape == images[0].shape for x in images):
|
| 158 |
+
images = torch.concat(images, dim=0)
|
| 159 |
+
images = images.to(device=self.device, dtype=self.dtype)
|
| 160 |
+
else:
|
| 161 |
+
images = [image.to(device=self.device, dtype=self.dtype) for image in images]
|
| 162 |
+
|
| 163 |
+
image_feature_shape_pixels, image_feature_shape_semantics = [], []
|
| 164 |
+
if isinstance(images, list): # anyres setting
|
| 165 |
+
h_pixels = []
|
| 166 |
+
for image in images:
|
| 167 |
+
if image.ndim == 3:
|
| 168 |
+
image = image.unsqueeze(0)
|
| 169 |
+
h_pixel = self.pixel_encoder(image)
|
| 170 |
+
b, c, h, w = h_pixel.shape
|
| 171 |
+
image_feature_shape_pixels.append((h, w))
|
| 172 |
+
h_pixel = rearrange(h_pixel, 'b c h w -> b (h w) c')
|
| 173 |
+
h_pixels.append(h_pixel)
|
| 174 |
+
h_pixels = torch.cat(h_pixels, dim=1)
|
| 175 |
+
|
| 176 |
+
h_semantics = []
|
| 177 |
+
for image in images:
|
| 178 |
+
if image.ndim == 3:
|
| 179 |
+
image = image.unsqueeze(0)
|
| 180 |
+
image = image.unsqueeze(dim=1)
|
| 181 |
+
scale_output = self.scaling_layer(image.clone())
|
| 182 |
+
image_2, image_grid_thw = scale_output['image'], scale_output['image_grid_thw']
|
| 183 |
+
image_feature_shape_semantics.append((int(image_grid_thw[0][1]) // 2, int(image_grid_thw[0][2] // 2)))
|
| 184 |
+
h_semantic = self.semantic_encoder(image_2, image_grid_thw)
|
| 185 |
+
h_semantics.append(h_semantic)
|
| 186 |
+
h_semantics = torch.cat(h_semantics, dim=0)
|
| 187 |
+
h_semantics = h_semantics.unsqueeze(dim=0)
|
| 188 |
+
|
| 189 |
+
image_feature_shapes = [[shape_semantic, shape_pixel] for shape_semantic, shape_pixel in
|
| 190 |
+
zip(image_feature_shape_semantics, image_feature_shape_pixels)]
|
| 191 |
+
|
| 192 |
+
else: # fixed res setting
|
| 193 |
+
assert images.ndim == 4
|
| 194 |
+
h_pixels = self.pixel_encoder(images)
|
| 195 |
+
b, c, h, w = h_pixels.shape
|
| 196 |
+
h_pixels = rearrange(h_pixels, 'b c h w -> (b h w) c')
|
| 197 |
+
h_pixels = h_pixels.unsqueeze(dim=0)
|
| 198 |
+
|
| 199 |
+
images = images.unsqueeze(dim=1)
|
| 200 |
+
scale_output = self.scaling_layer(images.clone())
|
| 201 |
+
images_2, images_grid_thw = scale_output['image'], scale_output['image_grid_thw']
|
| 202 |
+
|
| 203 |
+
h_semantics = self.semantic_encoder(images_2, images_grid_thw)
|
| 204 |
+
h_semantics = h_semantics.unsqueeze(dim=0)
|
| 205 |
+
|
| 206 |
+
shape_semantic = (int(images_grid_thw[0][1]) // 2, int(images_grid_thw[0][2] // 2))
|
| 207 |
+
shape_pixel = (h, w)
|
| 208 |
+
image_feature_shapes = [[shape_semantic, shape_pixel] for i in range(b)]
|
| 209 |
+
|
| 210 |
+
return [h_semantics, h_pixels], image_feature_shapes
|
| 211 |
+
|
| 212 |
+
@property
|
| 213 |
+
def dtype(self):
|
| 214 |
+
return self.semantic_encoder.dtype
|
| 215 |
+
|
| 216 |
+
@property
|
| 217 |
+
def device(self):
|
| 218 |
+
return self.semantic_encoder.device
|
| 219 |
+
|
| 220 |
+
@property
|
| 221 |
+
def config(self):
|
| 222 |
+
return self._config
|
| 223 |
+
|
| 224 |
+
@property
|
| 225 |
+
def hidden_size(self):
|
| 226 |
+
return self.config.hidden_size
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class ILLUMEPreTrainedModel(PreTrainedModel):
|
| 230 |
+
config_class = ILLUMEConfig
|
| 231 |
+
base_model_prefix = "model"
|
| 232 |
+
supports_gradient_checkpointing = True
|
| 233 |
+
_no_split_modules = ["ILLUMEVisionAttention"]
|
| 234 |
+
_skip_keys_device_placement = "past_key_values"
|
| 235 |
+
_supports_flash_attn_2 = True
|
| 236 |
+
_supports_cache_class = True
|
| 237 |
+
|
| 238 |
+
def _init_weights(self, module):
|
| 239 |
+
std = (
|
| 240 |
+
self.config.initializer_range
|
| 241 |
+
if hasattr(self.config, "initializer_range")
|
| 242 |
+
else self.config.text_config.initializer_range
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
if hasattr(module, "class_embedding"):
|
| 246 |
+
module.class_embedding.data.normal_(mean=0.0, std=std)
|
| 247 |
+
|
| 248 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 249 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 250 |
+
if module.bias is not None:
|
| 251 |
+
module.bias.data.zero_()
|
| 252 |
+
elif isinstance(module, nn.Embedding):
|
| 253 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 254 |
+
if module.padding_idx is not None:
|
| 255 |
+
module.weight.data[module.padding_idx].zero_()
|
| 256 |
+
|
| 257 |
+
@property
|
| 258 |
+
def _supports_sdpa(self):
|
| 259 |
+
"""
|
| 260 |
+
Retrieve language_model's attribute to check whether the model supports
|
| 261 |
+
SDPA or not.
|
| 262 |
+
"""
|
| 263 |
+
return self.language_model._supports_sdpa
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class ILLUMEForConditionalGeneration(ILLUMEPreTrainedModel):
|
| 267 |
+
def __init__(self, config: ILLUMEConfig, **kwargs):
|
| 268 |
+
super().__init__(config)
|
| 269 |
+
self.vision_tower = ILLUMEDualVisionTower(config.vision_config)
|
| 270 |
+
self.mm_projector = ILLUMEMultiModalProjector(config)
|
| 271 |
+
|
| 272 |
+
self.vocab_size = config.text_config.vocab_size
|
| 273 |
+
self.language_model = AutoModelForCausalLM.from_config(
|
| 274 |
+
config.text_config, attn_implementation=config._attn_implementation
|
| 275 |
+
)
|
| 276 |
+
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
| 277 |
+
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
|
| 278 |
+
self.post_init()
|
| 279 |
+
|
| 280 |
+
@property
|
| 281 |
+
def padding_side(self):
|
| 282 |
+
return self._padding_side
|
| 283 |
+
|
| 284 |
+
@padding_side.setter
|
| 285 |
+
def padding_side(self, padding_side: str):
|
| 286 |
+
if padding_side not in ["left", "right"]:
|
| 287 |
+
raise ValueError(f"{padding_side} is not `left` or `right`.")
|
| 288 |
+
self._padding_side = padding_side
|
| 289 |
+
|
| 290 |
+
def get_input_embeddings(self):
|
| 291 |
+
return self.language_model.get_input_embeddings()
|
| 292 |
+
|
| 293 |
+
def set_input_embeddings(self, value):
|
| 294 |
+
self.language_model.set_input_embeddings(value)
|
| 295 |
+
|
| 296 |
+
def get_output_embeddings(self):
|
| 297 |
+
return self.language_model.get_output_embeddings()
|
| 298 |
+
|
| 299 |
+
def set_output_embeddings(self, new_embeddings):
|
| 300 |
+
self.language_model.set_output_embeddings(new_embeddings)
|
| 301 |
+
|
| 302 |
+
def set_decoder(self, decoder):
|
| 303 |
+
self.language_model.set_decoder(decoder)
|
| 304 |
+
|
| 305 |
+
def get_decoder(self):
|
| 306 |
+
return self.language_model.get_decoder()
|
| 307 |
+
|
| 308 |
+
def tie_weights(self):
|
| 309 |
+
return self.language_model.tie_weights()
|
| 310 |
+
|
| 311 |
+
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
|
| 312 |
+
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
| 313 |
+
# update vocab size
|
| 314 |
+
self.config.text_config.vocab_size = model_embeds.num_embeddings
|
| 315 |
+
self.vocab_size = model_embeds.num_embeddings
|
| 316 |
+
return model_embeds
|
| 317 |
+
|
| 318 |
+
def _add_eol(self, x_feat, eol_feature):
|
| 319 |
+
h, w, C = x_feat.shape
|
| 320 |
+
eol_feature = eol_feature.unsqueeze(0).unsqueeze(0).expand(h, 1, C)
|
| 321 |
+
x_feat = torch.cat([x_feat, eol_feature], dim=1)
|
| 322 |
+
x_feat = x_feat.view(-1, C)
|
| 323 |
+
return x_feat
|
| 324 |
+
|
| 325 |
+
def _reformat_image_sequence(self, x, special_tokens_features, level):
|
| 326 |
+
# add end_of_line
|
| 327 |
+
x = self._add_eol(x, special_tokens_features[2])
|
| 328 |
+
|
| 329 |
+
# add soi, eoi, sol, eol
|
| 330 |
+
x = torch.cat([
|
| 331 |
+
special_tokens_features[3 + level * 2].unsqueeze(0),
|
| 332 |
+
x,
|
| 333 |
+
special_tokens_features[3 + level * 2 + 1].unsqueeze(0),
|
| 334 |
+
], dim=0)
|
| 335 |
+
return x
|
| 336 |
+
|
| 337 |
+
def _merge_input_ids_with_image_features(
|
| 338 |
+
self,
|
| 339 |
+
image_features,
|
| 340 |
+
feature_lens,
|
| 341 |
+
inputs_embeds,
|
| 342 |
+
input_ids,
|
| 343 |
+
attention_mask,
|
| 344 |
+
position_ids=None,
|
| 345 |
+
labels=None,
|
| 346 |
+
image_token_index=None,
|
| 347 |
+
ignore_index=-100,
|
| 348 |
+
):
|
| 349 |
+
image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index
|
| 350 |
+
ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index
|
| 351 |
+
|
| 352 |
+
with torch.no_grad():
|
| 353 |
+
num_images = feature_lens.size(0)
|
| 354 |
+
num_image_features, embed_dim = image_features.shape
|
| 355 |
+
if feature_lens.sum() != num_image_features:
|
| 356 |
+
raise ValueError(f"{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}")
|
| 357 |
+
batch_size = input_ids.shape[0]
|
| 358 |
+
_left_padding = torch.any(attention_mask[:, 0] == 0)
|
| 359 |
+
_right_padding = torch.any(attention_mask[:, -1] == 0)
|
| 360 |
+
|
| 361 |
+
left_padding = True if not self.training else False
|
| 362 |
+
if batch_size > 1 and not self.training:
|
| 363 |
+
if _left_padding and not _right_padding:
|
| 364 |
+
left_padding = True
|
| 365 |
+
elif not _left_padding and _right_padding:
|
| 366 |
+
left_padding = False
|
| 367 |
+
elif not _left_padding and not _right_padding:
|
| 368 |
+
# both side is 1, so cannot tell
|
| 369 |
+
left_padding = self.padding_side == "left"
|
| 370 |
+
else:
|
| 371 |
+
# invalid attention_mask
|
| 372 |
+
raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
|
| 373 |
+
|
| 374 |
+
# Whether to turn off right padding
|
| 375 |
+
# 1. Create a mask to know where special image tokens are
|
| 376 |
+
special_image_token_mask = input_ids == image_token_index
|
| 377 |
+
# special_image_token_mask: [bsz, seqlen]
|
| 378 |
+
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
| 379 |
+
# num_special_image_tokens: [bsz]
|
| 380 |
+
# Reserve for padding of num_images
|
| 381 |
+
total_num_special_image_tokens = torch.sum(special_image_token_mask)
|
| 382 |
+
if total_num_special_image_tokens != num_images:
|
| 383 |
+
raise ValueError(
|
| 384 |
+
f"Number of image tokens in input_ids ({total_num_special_image_tokens}) different from num_images ({num_images})."
|
| 385 |
+
)
|
| 386 |
+
# Compute the maximum embed dimension
|
| 387 |
+
# max_image_feature_lens is max_feature_lens per batch
|
| 388 |
+
feature_lens = feature_lens.to(input_ids.device)
|
| 389 |
+
feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0)
|
| 390 |
+
feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=input_ids.device)
|
| 391 |
+
embed_sequence_lengths = (
|
| 392 |
+
(attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum
|
| 393 |
+
)
|
| 394 |
+
max_embed_dim = embed_sequence_lengths.max()
|
| 395 |
+
|
| 396 |
+
batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1))
|
| 397 |
+
# 2. Compute the positions where text should be written
|
| 398 |
+
# Calculate new positions for text tokens in merged image-text sequence.
|
| 399 |
+
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images` text tokens.
|
| 400 |
+
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
|
| 401 |
+
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
| 402 |
+
# ! instead of special_image_token_mask * (num_image_patches - 1)
|
| 403 |
+
# special_image_token_mask * (num_feature_len - 1)
|
| 404 |
+
special_image_token_mask = special_image_token_mask.long()
|
| 405 |
+
special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1
|
| 406 |
+
new_token_positions = torch.cumsum((special_image_token_mask + 1), -1) - 1
|
| 407 |
+
if left_padding:
|
| 408 |
+
# shift right token positions so that they are ending at the same number
|
| 409 |
+
# the below here was incorrect? new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:]
|
| 410 |
+
new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:]
|
| 411 |
+
|
| 412 |
+
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
| 413 |
+
|
| 414 |
+
# 3. Create the full embedding, already padded to the maximum position
|
| 415 |
+
final_embedding = torch.zeros(
|
| 416 |
+
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
| 417 |
+
)
|
| 418 |
+
final_attention_mask = torch.zeros(
|
| 419 |
+
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
|
| 420 |
+
)
|
| 421 |
+
final_input_ids = torch.full(
|
| 422 |
+
(batch_size, max_embed_dim), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device
|
| 423 |
+
)
|
| 424 |
+
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
| 425 |
+
# set the corresponding tensors into their correct target device.
|
| 426 |
+
target_device = inputs_embeds.device
|
| 427 |
+
batch_indices, non_image_indices, text_to_overwrite = (
|
| 428 |
+
batch_indices.to(target_device),
|
| 429 |
+
non_image_indices.to(target_device),
|
| 430 |
+
text_to_overwrite.to(target_device),
|
| 431 |
+
)
|
| 432 |
+
attention_mask = attention_mask.to(target_device)
|
| 433 |
+
input_ids = input_ids.to(target_device)
|
| 434 |
+
|
| 435 |
+
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
|
| 436 |
+
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
| 437 |
+
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
|
| 438 |
+
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
|
| 439 |
+
final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_image_indices]
|
| 440 |
+
final_labels = None
|
| 441 |
+
if labels is not None:
|
| 442 |
+
labels = labels.to(target_device)
|
| 443 |
+
final_labels = torch.full_like(final_attention_mask, ignore_index).to(torch.long)
|
| 444 |
+
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
| 445 |
+
|
| 446 |
+
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
|
| 447 |
+
with torch.no_grad():
|
| 448 |
+
image_to_overwrite = torch.full(
|
| 449 |
+
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
|
| 450 |
+
)
|
| 451 |
+
image_to_overwrite[batch_indices, text_to_overwrite] = False
|
| 452 |
+
embed_indices = torch.arange(max_embed_dim).unsqueeze(0).to(target_device)
|
| 453 |
+
embed_indices = embed_indices.expand(batch_size, max_embed_dim)
|
| 454 |
+
embed_seq_lens = embed_sequence_lengths[:, None].to(target_device)
|
| 455 |
+
|
| 456 |
+
if left_padding:
|
| 457 |
+
# exclude padding on the left
|
| 458 |
+
max_embed_dim = max_embed_dim.to(target_device)
|
| 459 |
+
val = (max_embed_dim - embed_indices) <= embed_seq_lens
|
| 460 |
+
else:
|
| 461 |
+
# exclude padding on the right
|
| 462 |
+
val = embed_indices < embed_seq_lens
|
| 463 |
+
image_to_overwrite &= val
|
| 464 |
+
|
| 465 |
+
if image_to_overwrite.sum() != num_image_features:
|
| 466 |
+
raise ValueError(
|
| 467 |
+
f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. "
|
| 468 |
+
f"The number of image tokens is {torch.sum(special_image_token_mask)} while"
|
| 469 |
+
f" the number of image given to the model is {num_images}. "
|
| 470 |
+
f"This prevents correct indexing and breaks batch generation."
|
| 471 |
+
)
|
| 472 |
+
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
| 473 |
+
final_attention_mask |= image_to_overwrite
|
| 474 |
+
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
| 475 |
+
|
| 476 |
+
return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids
|
| 477 |
+
|
| 478 |
+
def forward(
|
| 479 |
+
self,
|
| 480 |
+
input_ids: torch.LongTensor = None,
|
| 481 |
+
pixel_values: torch.FloatTensor = None,
|
| 482 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 483 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 484 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 485 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 486 |
+
labels: Optional[torch.LongTensor] = None,
|
| 487 |
+
use_cache: Optional[bool] = None,
|
| 488 |
+
output_attentions: Optional[bool] = None,
|
| 489 |
+
output_hidden_states: Optional[bool] = None,
|
| 490 |
+
return_dict: Optional[bool] = None,
|
| 491 |
+
) -> Union[Tuple, ILLUMECausalLMOutputWithPast]:
|
| 492 |
+
r"""
|
| 493 |
+
Args:
|
| 494 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 495 |
+
Indices of input sequence tokens in the vocabulary.
|
| 496 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 497 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 498 |
+
|
| 499 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 500 |
+
Pixel values of the image to be generated.
|
| 501 |
+
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 502 |
+
Mask to avoid performing attention on padding token indices.
|
| 503 |
+
Mask values selected in `[0, 1]`:
|
| 504 |
+
- 1 for tokens that are not masked,
|
| 505 |
+
- 0 for tokens that are masked.
|
| 506 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 507 |
+
Indices of positions of each input sequence tokens in the position embeddings.
|
| 508 |
+
Selected in the range `[0, config.max_position_embeddings - 1]`.
|
| 509 |
+
|
| 510 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 511 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
| 512 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
| 513 |
+
use_cache (`bool`, *optional*):
|
| 514 |
+
If `True`, past_key_values are returned and can be used to speed up decoding.
|
| 515 |
+
output_attentions (`bool`, *optional*):
|
| 516 |
+
Whether to return the attentions tensors of all attention layers.
|
| 517 |
+
output_hidden_states (`bool`, *optional*):
|
| 518 |
+
Whether to return the hidden states of all layers.
|
| 519 |
+
return_dict (`bool`, *optional*):
|
| 520 |
+
Whether to return a dictionary of outputs instead of a plain tuple.
|
| 521 |
+
|
| 522 |
+
Returns:
|
| 523 |
+
An instance of [`ILLUMECausalLMOutputWithPast`] if `return_dict=True`. Otherwise, it returns a tuple of tensors.
|
| 524 |
+
|
| 525 |
+
Example:
|
| 526 |
+
|
| 527 |
+
```python
|
| 528 |
+
>>> import torch
|
| 529 |
+
>>> from PIL import Image
|
| 530 |
+
>>> import requests
|
| 531 |
+
>>> from transformers import AutoProcessor, ILLUMEForConditionalGeneration
|
| 532 |
+
|
| 533 |
+
>>> # Load model and processor
|
| 534 |
+
>>> # Specify `torch_dtype` for mixed-precision inference, e.g., torch.bfloat16 or torch.float16
|
| 535 |
+
>>> # Specify `attn_implementation="flash_attention_2"` if Flash Attention 2 is installed and supported, or "sdpa" for PyTorch SDPA.
|
| 536 |
+
>>> # `low_cpu_mem_usage=True` can help reduce CPU memory for large models.
|
| 537 |
+
>>> model = ILLUMEForConditionalGeneration.from_pretrained(
|
| 538 |
+
... "illume-unified-mllm/illume_plus-qwen2_5-3b-hf",
|
| 539 |
+
... torch_dtype=torch.bfloat16, # Optional: Or torch.float16. Adjust based on your hardware.
|
| 540 |
+
... low_cpu_mem_usage=True, # Optional: Reduces CPU RAM during model loading.
|
| 541 |
+
... attn_implementation="sdpa", # Optional: Use "flash_attention_2" if available for better performance.
|
| 542 |
+
... trust_remote_code=True
|
| 543 |
+
... ).eval()
|
| 544 |
+
>>> # To use GPU: model = model.to("cuda") # Ensure the model is on the correct device
|
| 545 |
+
|
| 546 |
+
>>> processor = AutoProcessor.from_pretrained(
|
| 547 |
+
... "illume-unified-mllm/illume_plus-qwen2_5-3b-hf",
|
| 548 |
+
... trust_remote_code=True
|
| 549 |
+
... )
|
| 550 |
+
|
| 551 |
+
>>> # Prepare inputs: a text prompt and an image
|
| 552 |
+
>>> # The processor formats the input for the model, including applying the chat template.
|
| 553 |
+
>>> messages = [
|
| 554 |
+
... {"role": "user", "content": [
|
| 555 |
+
... {"type": "image"},
|
| 556 |
+
... {"type": "text", "text": "What is shown in this image?"}
|
| 557 |
+
... ]}
|
| 558 |
+
... ]
|
| 559 |
+
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" # An example image URL
|
| 560 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 561 |
+
|
| 562 |
+
>>> inputs = processor(text=messages, images=[image], return_tensors="pt")
|
| 563 |
+
>>> # To use GPU: inputs = {k: v.to("cuda") for k, v in inputs.items()} # Move inputs to the same device as the model
|
| 564 |
+
|
| 565 |
+
>>> # Generate text based on the input
|
| 566 |
+
>>> gen_kwargs = {"max_new_tokens": 100, "do_sample": False} # Generation parameters
|
| 567 |
+
>>> with torch.no_grad(): # Disable gradient calculations for inference
|
| 568 |
+
... outputs = model.generate(**inputs, **gen_kwargs)
|
| 569 |
+
|
| 570 |
+
>>> # Decode the generated tokens, removing the prompt
|
| 571 |
+
>>> input_token_len = inputs["input_ids"].shape[1]
|
| 572 |
+
>>> generated_ids = outputs[:, input_token_len:]
|
| 573 |
+
>>> response = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 574 |
+
|
| 575 |
+
>>> print(response)
|
| 576 |
+
```"""
|
| 577 |
+
|
| 578 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 579 |
+
output_hidden_states = (
|
| 580 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 581 |
+
)
|
| 582 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 583 |
+
|
| 584 |
+
if inputs_embeds is None:
|
| 585 |
+
# 1. Extract the input embeddings
|
| 586 |
+
# In case image_token_index is not in the embeddings (extra token but embedding don't have it)
|
| 587 |
+
for_inputs_embeds_ids = input_ids.clone()
|
| 588 |
+
for_inputs_embeds_ids[(input_ids == self.config.image_token_index)] = 0
|
| 589 |
+
inputs_embeds = self.get_input_embeddings()(for_inputs_embeds_ids)
|
| 590 |
+
|
| 591 |
+
# 2. Merge text and images
|
| 592 |
+
if pixel_values is not None and input_ids.shape[1] != 1 and len(pixel_values) > 0:
|
| 593 |
+
image_features, image_feature_shapes = self.vision_tower(pixel_values)
|
| 594 |
+
image_features = self.mm_projector(image_features)
|
| 595 |
+
|
| 596 |
+
# reformat image sequence
|
| 597 |
+
# special_tokens_ids: <start_of_image>, <end_of_image>, <end_of_line>, <start_of_level0>, <end_of_level0>, <start_of_level1>, <end_of_level1>
|
| 598 |
+
special_tokens_ids = torch.Tensor([self.config.special_tokens_ids[key] for key in
|
| 599 |
+
["<start_of_image>", "<end_of_image>", "<end_of_line>",
|
| 600 |
+
"<start_of_level0>",
|
| 601 |
+
"<end_of_level0>", "<start_of_level1>",
|
| 602 |
+
"<end_of_level1>"]]).long().to(image_features.device)
|
| 603 |
+
|
| 604 |
+
semantic_sizes = [h * w for (h, w), _ in image_feature_shapes]
|
| 605 |
+
pixel_sizes = [h * w for _, (h, w) in image_feature_shapes]
|
| 606 |
+
|
| 607 |
+
# English: Split the image features into semantic features and reshape them to (h, w, -1)
|
| 608 |
+
semantic_features = torch.split(image_features[:, :sum(semantic_sizes), :], semantic_sizes, dim=1)
|
| 609 |
+
h_semantics = [feat.view(h, w, -1) for feat, ((h, w), _) in
|
| 610 |
+
zip(semantic_features, image_feature_shapes)]
|
| 611 |
+
|
| 612 |
+
# English: Split the image features into pixel features and reshape them to (h, w, -1)
|
| 613 |
+
det_features = torch.split(
|
| 614 |
+
image_features[:, sum(semantic_sizes): sum(semantic_sizes) + sum(pixel_sizes), :], pixel_sizes,
|
| 615 |
+
dim=1)
|
| 616 |
+
h_pixels = [feat.view(h, w, -1) for feat, (_, (h, w)) in zip(det_features, image_feature_shapes)]
|
| 617 |
+
|
| 618 |
+
special_tokens_features = self.language_model.model.embed_tokens(special_tokens_ids)
|
| 619 |
+
|
| 620 |
+
image_features = []
|
| 621 |
+
feature_lens = []
|
| 622 |
+
for h_semantic, h_pixel in zip(h_semantics, h_pixels):
|
| 623 |
+
h_semantic = self._reformat_image_sequence(h_semantic, special_tokens_features.clone(), level=0)
|
| 624 |
+
h_pixel = self._reformat_image_sequence(h_pixel, special_tokens_features.clone(), level=1)
|
| 625 |
+
|
| 626 |
+
image_feature = torch.cat([special_tokens_features[0].unsqueeze(0), h_semantic, h_pixel,
|
| 627 |
+
special_tokens_features[1].unsqueeze(0)], dim=0)
|
| 628 |
+
image_features.append(image_feature)
|
| 629 |
+
feature_lens.append(image_feature.shape[0])
|
| 630 |
+
|
| 631 |
+
feature_lens = torch.as_tensor(feature_lens)
|
| 632 |
+
image_features = torch.cat(image_features, dim=0)
|
| 633 |
+
|
| 634 |
+
inputs_embeds = inputs_embeds.to(self.dtype)
|
| 635 |
+
inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features(
|
| 636 |
+
image_features,
|
| 637 |
+
feature_lens,
|
| 638 |
+
inputs_embeds,
|
| 639 |
+
input_ids,
|
| 640 |
+
attention_mask,
|
| 641 |
+
position_ids,
|
| 642 |
+
labels=labels,
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
# pixel_values is not None but is empty ---> text only cases
|
| 646 |
+
elif pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) == 0:
|
| 647 |
+
# there are no images
|
| 648 |
+
pass
|
| 649 |
+
|
| 650 |
+
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
| 651 |
+
# generation with cache
|
| 652 |
+
elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
|
| 653 |
+
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
| 654 |
+
# that are set to 0
|
| 655 |
+
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
| 656 |
+
|
| 657 |
+
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
| 658 |
+
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
| 659 |
+
|
| 660 |
+
# Get the target length
|
| 661 |
+
target_length = input_ids.shape[1]
|
| 662 |
+
past_length = first_layer_past_key_value.shape[-1]
|
| 663 |
+
|
| 664 |
+
extended_attention_mask = torch.ones(
|
| 665 |
+
(attention_mask.shape[0], past_length),
|
| 666 |
+
dtype=attention_mask.dtype,
|
| 667 |
+
device=attention_mask.device,
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
# Filter out only the tokens that can be un-attended, this can happen
|
| 671 |
+
# if one uses ILLUME + Fused modules where the cache on the
|
| 672 |
+
# first iteration is already big enough, or if one passes custom cache
|
| 673 |
+
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
| 674 |
+
new_batch_index = batch_index[valid_indices]
|
| 675 |
+
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
| 676 |
+
|
| 677 |
+
# Zero-out the places where we don't need to attend
|
| 678 |
+
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
| 679 |
+
|
| 680 |
+
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
| 681 |
+
|
| 682 |
+
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
| 683 |
+
|
| 684 |
+
outputs = self.language_model(
|
| 685 |
+
attention_mask=attention_mask.to(inputs_embeds.device) if attention_mask is not None else attention_mask,
|
| 686 |
+
position_ids=position_ids,
|
| 687 |
+
past_key_values=past_key_values,
|
| 688 |
+
inputs_embeds=inputs_embeds,
|
| 689 |
+
use_cache=use_cache,
|
| 690 |
+
output_attentions=output_attentions,
|
| 691 |
+
output_hidden_states=output_hidden_states,
|
| 692 |
+
return_dict=return_dict,
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
logits = outputs[0]
|
| 696 |
+
|
| 697 |
+
loss = None
|
| 698 |
+
if labels is not None:
|
| 699 |
+
# Shift so that tokens < n predict n
|
| 700 |
+
if attention_mask is not None:
|
| 701 |
+
shift_attention_mask = attention_mask[..., 1:]
|
| 702 |
+
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
|
| 703 |
+
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
|
| 704 |
+
else:
|
| 705 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 706 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 707 |
+
# Flatten the tokens
|
| 708 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 709 |
+
loss = loss_fct(
|
| 710 |
+
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
if not return_dict:
|
| 714 |
+
output = (logits,) + outputs[1:]
|
| 715 |
+
return (loss,) + output if loss is not None else output
|
| 716 |
+
|
| 717 |
+
return ILLUMECausalLMOutputWithPast(
|
| 718 |
+
loss=loss,
|
| 719 |
+
logits=logits,
|
| 720 |
+
past_key_values=outputs.past_key_values,
|
| 721 |
+
hidden_states=outputs.hidden_states,
|
| 722 |
+
attentions=outputs.attentions,
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
def prepare_inputs_for_generation(
|
| 726 |
+
self,
|
| 727 |
+
input_ids,
|
| 728 |
+
past_key_values=None,
|
| 729 |
+
inputs_embeds=None,
|
| 730 |
+
pixel_values=None,
|
| 731 |
+
attention_mask=None,
|
| 732 |
+
**kwargs,
|
| 733 |
+
):
|
| 734 |
+
if past_key_values is not None:
|
| 735 |
+
if isinstance(past_key_values, Cache):
|
| 736 |
+
cache_length = past_key_values.get_seq_length()
|
| 737 |
+
past_length = past_key_values.seen_tokens
|
| 738 |
+
else:
|
| 739 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 740 |
+
|
| 741 |
+
# Keep only the unprocessed tokens:
|
| 742 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
| 743 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
| 744 |
+
# input)
|
| 745 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
| 746 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
|
| 747 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
| 748 |
+
# input_ids based on the past_length.
|
| 749 |
+
elif past_length < input_ids.shape[1]:
|
| 750 |
+
input_ids = input_ids[:, past_length:]
|
| 751 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
| 752 |
+
elif self.config.image_token_index in input_ids:
|
| 753 |
+
input_ids = input_ids[:, input_ids.shape[1] - 1:]
|
| 754 |
+
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
| 755 |
+
# older attention values, as their corresponding values are not part of the input.
|
| 756 |
+
if cache_length < past_length and attention_mask is not None:
|
| 757 |
+
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]):]
|
| 758 |
+
|
| 759 |
+
position_ids = kwargs.get("position_ids", None)
|
| 760 |
+
if attention_mask is not None and position_ids is None:
|
| 761 |
+
# create position_ids on the fly for batch generation
|
| 762 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 763 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 764 |
+
if past_key_values:
|
| 765 |
+
position_ids = position_ids[:, -input_ids.shape[1]:]
|
| 766 |
+
|
| 767 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 768 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 769 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 770 |
+
else:
|
| 771 |
+
model_inputs = {"input_ids": input_ids}
|
| 772 |
+
|
| 773 |
+
model_inputs.update(
|
| 774 |
+
{
|
| 775 |
+
"position_ids": position_ids,
|
| 776 |
+
"past_key_values": past_key_values,
|
| 777 |
+
"use_cache": kwargs.get("use_cache"),
|
| 778 |
+
"attention_mask": attention_mask,
|
| 779 |
+
"pixel_values": pixel_values,
|
| 780 |
+
}
|
| 781 |
+
)
|
| 782 |
+
return model_inputs
|
| 783 |
+
|
| 784 |
+
def _reorder_cache(self, *args, **kwargs):
|
| 785 |
+
return self.language_model._reorder_cache(*args, **kwargs)
|
| 786 |
+
|
| 787 |
+
def prepare_logit_processor(self,
|
| 788 |
+
guidance_scale=2.0,
|
| 789 |
+
negative_prompt_ids=None,
|
| 790 |
+
negative_prompt_attention_mask=None,
|
| 791 |
+
resolution=None,
|
| 792 |
+
temperature=1.0, image_semantic_temperature=1.0, image_pixel_temperature=None,
|
| 793 |
+
top_k=128, image_semantic_top_k=2048, image_pixel_top_k=None,
|
| 794 |
+
top_p=1.0, image_semantic_top_p=1.0, image_pixel_top_p=None,
|
| 795 |
+
images=None,
|
| 796 |
+
):
|
| 797 |
+
if resolution is not None:
|
| 798 |
+
token_nums, _, h1, w1, h2, w2 = calculate_image_token_num(*resolution)
|
| 799 |
+
else:
|
| 800 |
+
h1, w1, h2, w2 = 0, 0, 0, 0
|
| 801 |
+
|
| 802 |
+
if image_pixel_temperature is None:
|
| 803 |
+
image_pixel_temperature = image_semantic_temperature
|
| 804 |
+
if image_pixel_top_k is None:
|
| 805 |
+
image_pixel_top_k = image_semantic_top_k * 3
|
| 806 |
+
if image_pixel_top_p is None:
|
| 807 |
+
image_pixel_top_p = image_semantic_top_p
|
| 808 |
+
|
| 809 |
+
return InterleavedLogitsProcessor(
|
| 810 |
+
guidance_scale=guidance_scale,
|
| 811 |
+
uncond=negative_prompt_ids,
|
| 812 |
+
attention_mask=negative_prompt_attention_mask,
|
| 813 |
+
model=self,
|
| 814 |
+
# DualVQ parameters
|
| 815 |
+
level0_range=level0_range,
|
| 816 |
+
level1_range=level1_range,
|
| 817 |
+
num_level0_rows=h1, num_level0_tokens=w1,
|
| 818 |
+
num_level1_rows=h2, num_level1_tokens=w2,
|
| 819 |
+
special_tokens=special_tokens_dict,
|
| 820 |
+
# Dynamic Sampling parameters
|
| 821 |
+
default_temp=temperature, level0_temp=image_semantic_temperature, level1_temp=image_pixel_temperature,
|
| 822 |
+
default_top_k=top_k, level0_top_k=image_semantic_top_k, level1_top_k=image_pixel_top_k,
|
| 823 |
+
default_top_p=top_p, level0_top_p=image_semantic_top_p, level1_top_p=image_pixel_top_p,
|
| 824 |
+
images=images
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
def generate(
|
| 828 |
+
self,
|
| 829 |
+
*args,
|
| 830 |
+
temperature: float = 1.0, top_k: int = 128, top_p: float = 1.0,
|
| 831 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 832 |
+
# image generation or image editing hyperparameters.
|
| 833 |
+
guidance_scale=1.0, target_image_resolution=None,
|
| 834 |
+
image_semantic_temperature: float = 1.0,
|
| 835 |
+
image_semantic_top_k: int = 2048,
|
| 836 |
+
image_semantic_top_p: float = 1.0,
|
| 837 |
+
image_pixel_temperature: float = 1.0,
|
| 838 |
+
image_pixel_top_k: int = 2048 * 3,
|
| 839 |
+
image_pixel_top_p: float = 1.0,
|
| 840 |
+
negative_image_prompt_ids: Optional[torch.Tensor] = None,
|
| 841 |
+
negative_image_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 842 |
+
disable_logit_processor=False,
|
| 843 |
+
logits_processor=None,
|
| 844 |
+
**kwargs,
|
| 845 |
+
):
|
| 846 |
+
if target_image_resolution is not None:
|
| 847 |
+
# check if target_image_resolution valied.
|
| 848 |
+
if not isinstance(target_image_resolution, tuple) or len(target_image_resolution) != 2:
|
| 849 |
+
raise ValueError("target_image_resolution must be a tuple of two integers.")
|
| 850 |
+
if not all(isinstance(dim, int) and dim > 0 for dim in target_image_resolution):
|
| 851 |
+
raise ValueError("target_image_resolution must contain positive integers.")
|
| 852 |
+
|
| 853 |
+
if target_image_resolution not in DEFAULT_RESOLUTIONS:
|
| 854 |
+
raise ValueError(
|
| 855 |
+
"target_image_resolution must be in one of the following ratios: " + str(DEFAULT_RESOLUTIONS))
|
| 856 |
+
|
| 857 |
+
if logits_processor is None:
|
| 858 |
+
logits_processor = LogitsProcessorList([])
|
| 859 |
+
|
| 860 |
+
if not disable_logit_processor:
|
| 861 |
+
illume_logit_processor = self.prepare_logit_processor(
|
| 862 |
+
negative_prompt_ids=negative_image_prompt_ids,
|
| 863 |
+
negative_prompt_attention_mask=negative_image_prompt_attention_mask,
|
| 864 |
+
temperature=temperature,
|
| 865 |
+
top_k=top_k,
|
| 866 |
+
top_p=top_p,
|
| 867 |
+
guidance_scale=guidance_scale,
|
| 868 |
+
resolution=target_image_resolution,
|
| 869 |
+
image_semantic_temperature=image_semantic_temperature,
|
| 870 |
+
image_pixel_temperature=image_pixel_temperature,
|
| 871 |
+
image_semantic_top_k=image_semantic_top_k,
|
| 872 |
+
image_pixel_top_k=image_pixel_top_k,
|
| 873 |
+
image_semantic_top_p=image_semantic_top_p,
|
| 874 |
+
image_pixel_top_p=image_pixel_top_p,
|
| 875 |
+
images=pixel_values,
|
| 876 |
+
)
|
| 877 |
+
logits_processor.append(illume_logit_processor)
|
| 878 |
+
|
| 879 |
+
return super(ILLUMEForConditionalGeneration, self).generate(
|
| 880 |
+
*args,
|
| 881 |
+
pixel_values=pixel_values,
|
| 882 |
+
logits_processor=logits_processor,
|
| 883 |
+
**kwargs)
|
modeling_movqgan.py
ADDED
|
@@ -0,0 +1,828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" MoVQ model """
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from einops import rearrange, repeat
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
from torch.utils.checkpoint import checkpoint
|
| 11 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 12 |
+
|
| 13 |
+
from .configuration_movqgan import MoVQConfig
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
import xformers.ops as xops
|
| 17 |
+
|
| 18 |
+
is_xformers_available = True
|
| 19 |
+
except Exception as e:
|
| 20 |
+
is_xformers_available = False
|
| 21 |
+
|
| 22 |
+
if torch.__version__ > "2.1.2":
|
| 23 |
+
IS_SDPA_AVAILABLE = True
|
| 24 |
+
else:
|
| 25 |
+
IS_SDPA_AVAILABLE = False
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class MoVQActivation(nn.Module):
|
| 29 |
+
|
| 30 |
+
def __init__(self):
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
def __call__(self, x: torch.Tensor):
|
| 34 |
+
return x * torch.sigmoid(x)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class MoVQUpsample(nn.Module):
|
| 38 |
+
|
| 39 |
+
def __init__(self, in_channels: int):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.conv = nn.Conv2d(
|
| 42 |
+
in_channels,
|
| 43 |
+
in_channels,
|
| 44 |
+
kernel_size=3,
|
| 45 |
+
stride=1,
|
| 46 |
+
padding=1,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def forward(self, x: torch.Tensor):
|
| 50 |
+
x = F.interpolate(x.float(), scale_factor=2.0, mode="nearest").to(x.dtype)
|
| 51 |
+
x = self.conv(x)
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class DCDownBlock2d(nn.Module):
|
| 56 |
+
def __init__(self, in_channels: int, out_channels: int = None, downsample: bool = True,
|
| 57 |
+
shortcut: bool = True) -> None:
|
| 58 |
+
super().__init__()
|
| 59 |
+
out_channels = out_channels if out_channels else in_channels
|
| 60 |
+
|
| 61 |
+
self.downsample = downsample
|
| 62 |
+
self.factor = 2
|
| 63 |
+
self.stride = 1 if downsample else 2
|
| 64 |
+
self.group_size = in_channels * self.factor ** 2 // out_channels
|
| 65 |
+
self.shortcut = shortcut
|
| 66 |
+
|
| 67 |
+
out_ratio = self.factor ** 2
|
| 68 |
+
if downsample:
|
| 69 |
+
assert out_channels % out_ratio == 0
|
| 70 |
+
out_channels = out_channels // out_ratio
|
| 71 |
+
|
| 72 |
+
self.conv = nn.Conv2d(
|
| 73 |
+
in_channels,
|
| 74 |
+
out_channels,
|
| 75 |
+
kernel_size=3,
|
| 76 |
+
stride=self.stride,
|
| 77 |
+
padding=1,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 81 |
+
x = self.conv(hidden_states)
|
| 82 |
+
if self.downsample:
|
| 83 |
+
x = F.pixel_unshuffle(x, self.factor)
|
| 84 |
+
|
| 85 |
+
if self.shortcut:
|
| 86 |
+
y = F.pixel_unshuffle(hidden_states, self.factor)
|
| 87 |
+
y = y.unflatten(1, (-1, self.group_size))
|
| 88 |
+
y = y.mean(dim=2)
|
| 89 |
+
hidden_states = x + y
|
| 90 |
+
else:
|
| 91 |
+
hidden_states = x
|
| 92 |
+
|
| 93 |
+
return hidden_states # x + y
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class DCUpBlock2d(nn.Module):
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
in_channels: int,
|
| 100 |
+
out_channels: int = None,
|
| 101 |
+
interpolate: bool = False,
|
| 102 |
+
shortcut: bool = True,
|
| 103 |
+
interpolation_mode: str = "nearest",
|
| 104 |
+
) -> None:
|
| 105 |
+
super().__init__()
|
| 106 |
+
out_channels = out_channels if out_channels else in_channels
|
| 107 |
+
|
| 108 |
+
self.interpolate = interpolate
|
| 109 |
+
self.interpolation_mode = interpolation_mode
|
| 110 |
+
self.shortcut = shortcut
|
| 111 |
+
self.factor = 2
|
| 112 |
+
self.repeats = out_channels * self.factor ** 2 // in_channels
|
| 113 |
+
|
| 114 |
+
out_ratio = self.factor ** 2
|
| 115 |
+
|
| 116 |
+
if not interpolate:
|
| 117 |
+
out_channels = out_channels * out_ratio
|
| 118 |
+
|
| 119 |
+
self.conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
|
| 120 |
+
|
| 121 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 122 |
+
if self.interpolate:
|
| 123 |
+
x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode)
|
| 124 |
+
x = self.conv(x)
|
| 125 |
+
else:
|
| 126 |
+
x = self.conv(hidden_states)
|
| 127 |
+
x = F.pixel_shuffle(x, self.factor)
|
| 128 |
+
|
| 129 |
+
if self.shortcut:
|
| 130 |
+
y = hidden_states.repeat_interleave(self.repeats, dim=1)
|
| 131 |
+
y = F.pixel_shuffle(y, self.factor)
|
| 132 |
+
hidden_states = x + y
|
| 133 |
+
else:
|
| 134 |
+
hidden_states = x
|
| 135 |
+
|
| 136 |
+
return hidden_states
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class MoVQDownsample(nn.Module):
|
| 140 |
+
|
| 141 |
+
def __init__(self, in_channels: int):
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.conv = nn.Conv2d(
|
| 144 |
+
in_channels,
|
| 145 |
+
in_channels,
|
| 146 |
+
kernel_size=3,
|
| 147 |
+
stride=2,
|
| 148 |
+
padding=0,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
def forward(self, x: torch.Tensor):
|
| 152 |
+
pad = (0, 1, 0, 1)
|
| 153 |
+
x = F.pad(x, pad, mode="constant", value=0)
|
| 154 |
+
x = self.conv(x)
|
| 155 |
+
return x
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class MoVQSpatialNorm(nn.Module):
|
| 159 |
+
|
| 160 |
+
def __init__(
|
| 161 |
+
self,
|
| 162 |
+
f_channels: int,
|
| 163 |
+
zq_channels: int,
|
| 164 |
+
norm_layer: nn.Module = nn.GroupNorm,
|
| 165 |
+
add_conv: bool = False,
|
| 166 |
+
num_groups: int = 32,
|
| 167 |
+
eps: float = 1e-6,
|
| 168 |
+
affine: bool = True,
|
| 169 |
+
):
|
| 170 |
+
super().__init__()
|
| 171 |
+
self.norm_layer = norm_layer(
|
| 172 |
+
num_channels=f_channels,
|
| 173 |
+
num_groups=num_groups,
|
| 174 |
+
eps=eps,
|
| 175 |
+
affine=affine,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
self.add_conv = add_conv
|
| 179 |
+
if self.add_conv:
|
| 180 |
+
self.conv = nn.Conv2d(
|
| 181 |
+
zq_channels,
|
| 182 |
+
zq_channels,
|
| 183 |
+
kernel_size=3,
|
| 184 |
+
stride=1,
|
| 185 |
+
padding=1,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
self.conv_y = nn.Conv2d(
|
| 189 |
+
zq_channels,
|
| 190 |
+
f_channels,
|
| 191 |
+
kernel_size=1,
|
| 192 |
+
stride=1,
|
| 193 |
+
padding=0,
|
| 194 |
+
)
|
| 195 |
+
self.conv_b = nn.Conv2d(
|
| 196 |
+
zq_channels,
|
| 197 |
+
f_channels,
|
| 198 |
+
kernel_size=1,
|
| 199 |
+
stride=1,
|
| 200 |
+
padding=0,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
def forward(self, x: torch.Tensor, zq: torch.Tensor):
|
| 204 |
+
zq = F.interpolate(zq.float(), size=x.shape[-2:], mode="nearest").to(zq.dtype)
|
| 205 |
+
|
| 206 |
+
if self.add_conv:
|
| 207 |
+
zq = self.conv(zq)
|
| 208 |
+
|
| 209 |
+
x = self.norm_layer(x)
|
| 210 |
+
x = x * self.conv_y(zq) + self.conv_b(zq)
|
| 211 |
+
return x
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class MoVQResnetBlock(nn.Module):
|
| 215 |
+
|
| 216 |
+
def __init__(
|
| 217 |
+
self,
|
| 218 |
+
in_channels: int,
|
| 219 |
+
out_channels: Optional[int] = None,
|
| 220 |
+
conv_shortcut: bool = False,
|
| 221 |
+
dropout: float = 0.0,
|
| 222 |
+
zq_ch: Optional[int] = None,
|
| 223 |
+
add_conv: bool = False,
|
| 224 |
+
):
|
| 225 |
+
super().__init__()
|
| 226 |
+
self.in_channels = in_channels
|
| 227 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 228 |
+
self.out_channels = out_channels
|
| 229 |
+
self.use_conv_shortcut = conv_shortcut
|
| 230 |
+
self.zq_ch = zq_ch
|
| 231 |
+
|
| 232 |
+
if zq_ch is None:
|
| 233 |
+
norm_kwargs = dict(num_groups=32, eps=1e-6, affine=True)
|
| 234 |
+
self.norm1 = nn.GroupNorm(num_channels=in_channels, **norm_kwargs)
|
| 235 |
+
self.norm2 = nn.GroupNorm(num_channels=out_channels, **norm_kwargs)
|
| 236 |
+
else:
|
| 237 |
+
self.norm1 = MoVQSpatialNorm(in_channels, zq_ch, add_conv=add_conv)
|
| 238 |
+
self.norm2 = MoVQSpatialNorm(out_channels, zq_ch, add_conv=add_conv)
|
| 239 |
+
|
| 240 |
+
self.conv1 = nn.Conv2d(
|
| 241 |
+
in_channels,
|
| 242 |
+
out_channels,
|
| 243 |
+
kernel_size=3,
|
| 244 |
+
stride=1,
|
| 245 |
+
padding=1,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
self.dropout = nn.Dropout(dropout)
|
| 249 |
+
self.conv2 = nn.Conv2d(
|
| 250 |
+
out_channels,
|
| 251 |
+
out_channels,
|
| 252 |
+
kernel_size=3,
|
| 253 |
+
stride=1,
|
| 254 |
+
padding=1,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
self.act = MoVQActivation()
|
| 258 |
+
|
| 259 |
+
if self.in_channels != self.out_channels:
|
| 260 |
+
if self.use_conv_shortcut:
|
| 261 |
+
self.conv_shortcut = nn.Conv2d(
|
| 262 |
+
in_channels,
|
| 263 |
+
out_channels,
|
| 264 |
+
kernel_size=3,
|
| 265 |
+
stride=1,
|
| 266 |
+
padding=1,
|
| 267 |
+
)
|
| 268 |
+
else:
|
| 269 |
+
self.nin_shortcut = nn.Conv2d(
|
| 270 |
+
in_channels,
|
| 271 |
+
out_channels,
|
| 272 |
+
kernel_size=1,
|
| 273 |
+
stride=1,
|
| 274 |
+
padding=0,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
def forward(self, x: torch.Tensor, zq: Optional[torch.Tensor] = None):
|
| 278 |
+
norm_args = tuple() if self.zq_ch is None else (zq,)
|
| 279 |
+
|
| 280 |
+
h = self.norm1(x, *norm_args)
|
| 281 |
+
h = self.act(h)
|
| 282 |
+
h = self.conv1(h)
|
| 283 |
+
|
| 284 |
+
h = self.norm2(h, *norm_args)
|
| 285 |
+
h = self.act(h)
|
| 286 |
+
h = self.dropout(h)
|
| 287 |
+
h = self.conv2(h)
|
| 288 |
+
|
| 289 |
+
if self.in_channels != self.out_channels:
|
| 290 |
+
if self.use_conv_shortcut:
|
| 291 |
+
x = self.conv_shortcut(x)
|
| 292 |
+
else:
|
| 293 |
+
x = self.nin_shortcut(x)
|
| 294 |
+
|
| 295 |
+
return x + h
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class MoVQAttnBlock(nn.Module):
|
| 299 |
+
|
| 300 |
+
def __init__(
|
| 301 |
+
self,
|
| 302 |
+
in_channels: int,
|
| 303 |
+
zq_ch: Optional[int] = None,
|
| 304 |
+
add_conv: bool = False,
|
| 305 |
+
num_heads=1,
|
| 306 |
+
):
|
| 307 |
+
super().__init__()
|
| 308 |
+
self.in_channels = in_channels
|
| 309 |
+
self.zq_ch = zq_ch
|
| 310 |
+
self.num_heads = num_heads
|
| 311 |
+
|
| 312 |
+
if zq_ch is None:
|
| 313 |
+
norm_kwargs = dict(num_groups=32, eps=1e-6, affine=True)
|
| 314 |
+
self.norm = nn.GroupNorm(num_channels=in_channels, **norm_kwargs)
|
| 315 |
+
else:
|
| 316 |
+
self.norm = MoVQSpatialNorm(in_channels, zq_ch, add_conv=add_conv)
|
| 317 |
+
|
| 318 |
+
self.q = nn.Conv2d(
|
| 319 |
+
in_channels,
|
| 320 |
+
in_channels,
|
| 321 |
+
kernel_size=1,
|
| 322 |
+
stride=1,
|
| 323 |
+
padding=0,
|
| 324 |
+
)
|
| 325 |
+
self.k = nn.Conv2d(
|
| 326 |
+
in_channels,
|
| 327 |
+
in_channels,
|
| 328 |
+
kernel_size=1,
|
| 329 |
+
stride=1,
|
| 330 |
+
padding=0,
|
| 331 |
+
)
|
| 332 |
+
self.v = nn.Conv2d(
|
| 333 |
+
in_channels,
|
| 334 |
+
in_channels,
|
| 335 |
+
kernel_size=1,
|
| 336 |
+
stride=1,
|
| 337 |
+
padding=0,
|
| 338 |
+
)
|
| 339 |
+
self.proj_out = nn.Conv2d(
|
| 340 |
+
in_channels,
|
| 341 |
+
in_channels,
|
| 342 |
+
kernel_size=1,
|
| 343 |
+
stride=1,
|
| 344 |
+
padding=0,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
def forward(self, x: torch.Tensor, zq: Optional[torch.Tensor] = None):
|
| 348 |
+
# x: [b, c1, h1, w1]
|
| 349 |
+
# zq: [b, c2, h2, w2]
|
| 350 |
+
# attention_mask: [b, 1, h3, w3]
|
| 351 |
+
norm_args = tuple() if self.zq_ch is None else (zq,)
|
| 352 |
+
|
| 353 |
+
# if context is not None:
|
| 354 |
+
# context = F.interpolate(context.float(), size=x.shape[-2:], mode="nearest").to(context.dtype)
|
| 355 |
+
# x = x + self.conv_context(context)
|
| 356 |
+
|
| 357 |
+
nx = self.norm(x, *norm_args)
|
| 358 |
+
q = self.q(nx)
|
| 359 |
+
k = self.k(nx)
|
| 360 |
+
v = self.v(nx)
|
| 361 |
+
|
| 362 |
+
b, c, h, w = q.shape
|
| 363 |
+
if is_xformers_available:
|
| 364 |
+
# If xformers is available, create attn_bias for xops.memory_efficient_attention.
|
| 365 |
+
attn_bias = None
|
| 366 |
+
|
| 367 |
+
v = xops.memory_efficient_attention(
|
| 368 |
+
rearrange(q, 'b (n c) h w -> b (h w) n c', n=self.num_heads).contiguous(),
|
| 369 |
+
rearrange(k, 'b (n c) h w -> b (h w) n c', n=self.num_heads).contiguous(),
|
| 370 |
+
rearrange(v, 'b (n c) h w -> b (h w) n c', n=self.num_heads).contiguous(),
|
| 371 |
+
scale=1.0 / math.sqrt(c // self.num_heads),
|
| 372 |
+
attn_bias=attn_bias,
|
| 373 |
+
)
|
| 374 |
+
v = rearrange(v, 'b (h w) n c -> b (n c) h w', h=h, w=w).contiguous()
|
| 375 |
+
elif IS_SDPA_AVAILABLE:
|
| 376 |
+
# compute attention
|
| 377 |
+
q = rearrange(q, 'b (n c) h w -> b n (h w) c', n=self.num_heads).contiguous()
|
| 378 |
+
k = rearrange(k, 'b (n c) h w -> b n (h w) c', n=self.num_heads).contiguous()
|
| 379 |
+
v = rearrange(v, 'b (n c) h w -> b n (h w) c', n=self.num_heads).contiguous()
|
| 380 |
+
|
| 381 |
+
attn_bias = None
|
| 382 |
+
|
| 383 |
+
v = F.scaled_dot_product_attention(q, k, v, attn_bias, dropout_p=0.0)
|
| 384 |
+
v = v.transpose(1, 2)
|
| 385 |
+
v = rearrange(v, 'b (h w) n c -> b (n c) h w', h=h, w=w)
|
| 386 |
+
else:
|
| 387 |
+
# compute attention
|
| 388 |
+
q = rearrange(q, 'b (n c) h w -> b n c (h w)', n=self.num_heads).contiguous()
|
| 389 |
+
k = rearrange(k, 'b (n c) h w -> b n c (h w)', n=self.num_heads).contiguous()
|
| 390 |
+
v = rearrange(v, 'b (n c) h w -> b n c (h w)', n=self.num_heads).contiguous()
|
| 391 |
+
|
| 392 |
+
# score = torch.bmm(q.permute(0, 2, 1), k)
|
| 393 |
+
score = torch.einsum('b n c k, b n c l -> b n k l', q, k)
|
| 394 |
+
score = score / math.sqrt(c // self.num_heads)
|
| 395 |
+
|
| 396 |
+
score = F.softmax(score, dim=2)
|
| 397 |
+
|
| 398 |
+
# attend to values
|
| 399 |
+
# v = v.reshape(b, c, h * w)
|
| 400 |
+
# v = torch.bmm(v, score.permute(0, 2, 1))
|
| 401 |
+
v = torch.einsum('b n c l, b n k l -> b n c k', v, score)
|
| 402 |
+
v = v.reshape(b, c, h, w)
|
| 403 |
+
|
| 404 |
+
v = self.proj_out(v)
|
| 405 |
+
|
| 406 |
+
return x + v
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
class MoVQVectorQuantizer(nn.Module):
|
| 410 |
+
|
| 411 |
+
def __init__(self, config: MoVQConfig):
|
| 412 |
+
super().__init__()
|
| 413 |
+
self.embedding = nn.Embedding(config.codebook_size, config.embed_dim)
|
| 414 |
+
self.embedding.weight.data.uniform_(-1.0 / config.codebook_size, 1.0 / config.codebook_size)
|
| 415 |
+
|
| 416 |
+
def forward(self, x: torch.Tensor):
|
| 417 |
+
# b t c h w -> b t h w c
|
| 418 |
+
b, t, c, h, w = x.shape
|
| 419 |
+
x = x.permute(0, 1, 3, 4, 2).contiguous()
|
| 420 |
+
x_flattened = x.view(-1, c)
|
| 421 |
+
|
| 422 |
+
codebook = self.embedding.weight
|
| 423 |
+
|
| 424 |
+
d = torch.sum(x_flattened ** 2, dim=1, keepdim=True) + \
|
| 425 |
+
torch.sum(codebook ** 2, dim=1) - 2 * \
|
| 426 |
+
torch.einsum('bd,dn->bn', x_flattened, codebook.permute(1, 0))
|
| 427 |
+
|
| 428 |
+
indices = torch.argmin(d, dim=1)
|
| 429 |
+
indices = indices.view(b, t, h, w)
|
| 430 |
+
return indices
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
class MoVQPretrainedModel(PreTrainedModel):
|
| 434 |
+
"""
|
| 435 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 436 |
+
models.
|
| 437 |
+
"""
|
| 438 |
+
|
| 439 |
+
config_class = MoVQConfig
|
| 440 |
+
base_model_prefix = "movq"
|
| 441 |
+
main_input_name = "pixel_values"
|
| 442 |
+
_no_split_modules = ["MoVQResnetBlock", "MoVQAttnBlock"]
|
| 443 |
+
|
| 444 |
+
def _init_weights(self, module):
|
| 445 |
+
if isinstance(module, (nn.Conv2d, nn.Conv3d)):
|
| 446 |
+
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
| 447 |
+
# copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
|
| 448 |
+
elif isinstance(module, nn.Linear):
|
| 449 |
+
nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
|
| 450 |
+
if module.bias is not None:
|
| 451 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
|
| 452 |
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
| 453 |
+
nn.init.uniform_(module.bias, -bound, bound)
|
| 454 |
+
elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
|
| 455 |
+
nn.init.constant_(module.weight, 1)
|
| 456 |
+
nn.init.constant_(module.bias, 0)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
class MoVQEncoder(nn.Module):
|
| 460 |
+
def __init__(self, config: MoVQConfig):
|
| 461 |
+
super().__init__()
|
| 462 |
+
self.config = config
|
| 463 |
+
self.ch = config.ch
|
| 464 |
+
self.num_resolutions = len(config.ch_mult)
|
| 465 |
+
self.num_res_blocks = config.num_res_blocks
|
| 466 |
+
self.in_channels = config.in_channels
|
| 467 |
+
|
| 468 |
+
# downsampling
|
| 469 |
+
self.conv_in = nn.Conv2d(
|
| 470 |
+
self.in_channels,
|
| 471 |
+
self.ch,
|
| 472 |
+
kernel_size=3,
|
| 473 |
+
stride=1,
|
| 474 |
+
padding=1
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
in_ch_mult = (1,) + tuple(config.ch_mult)
|
| 478 |
+
self.down = nn.ModuleList()
|
| 479 |
+
for i_level in range(self.num_resolutions):
|
| 480 |
+
block = nn.ModuleList()
|
| 481 |
+
attn = nn.ModuleList()
|
| 482 |
+
block_in = config.ch * in_ch_mult[i_level]
|
| 483 |
+
block_out = config.ch * config.ch_mult[i_level]
|
| 484 |
+
for i_block in range(self.num_res_blocks):
|
| 485 |
+
block.append(
|
| 486 |
+
MoVQResnetBlock(
|
| 487 |
+
in_channels=block_in,
|
| 488 |
+
out_channels=block_out,
|
| 489 |
+
dropout=config.dropout,
|
| 490 |
+
)
|
| 491 |
+
)
|
| 492 |
+
block_in = block_out
|
| 493 |
+
if i_level in config.attn_resolutions:
|
| 494 |
+
attn.append(MoVQAttnBlock(block_in))
|
| 495 |
+
|
| 496 |
+
down = nn.Module()
|
| 497 |
+
down.block = block
|
| 498 |
+
down.attn = attn
|
| 499 |
+
if i_level != self.num_resolutions - 1:
|
| 500 |
+
if config.use_dc_up_down_blocks:
|
| 501 |
+
down.downsample = DCDownBlock2d(block_in)
|
| 502 |
+
else:
|
| 503 |
+
down.downsample = MoVQDownsample(block_in)
|
| 504 |
+
|
| 505 |
+
self.down.append(down)
|
| 506 |
+
|
| 507 |
+
# middle
|
| 508 |
+
self.mid = nn.Module()
|
| 509 |
+
self.mid.block_1 = MoVQResnetBlock(
|
| 510 |
+
in_channels=block_in,
|
| 511 |
+
out_channels=block_in,
|
| 512 |
+
dropout=config.dropout,
|
| 513 |
+
)
|
| 514 |
+
self.mid.attn_1 = MoVQAttnBlock(block_in)
|
| 515 |
+
self.mid.block_2 = MoVQResnetBlock(
|
| 516 |
+
in_channels=block_in,
|
| 517 |
+
out_channels=block_in,
|
| 518 |
+
dropout=config.dropout,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
# end
|
| 522 |
+
|
| 523 |
+
self.norm_out = nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True)
|
| 524 |
+
|
| 525 |
+
self.act = MoVQActivation()
|
| 526 |
+
|
| 527 |
+
out_z_channels = 2 * config.z_channels if config.double_z else config.z_channels
|
| 528 |
+
self.conv_out = nn.Conv2d(
|
| 529 |
+
block_in,
|
| 530 |
+
out_z_channels,
|
| 531 |
+
kernel_size=3,
|
| 532 |
+
stride=1,
|
| 533 |
+
padding=1,
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
self.out_shortcut_average_group_size = block_in // out_z_channels
|
| 537 |
+
|
| 538 |
+
def forward(self, x: torch.Tensor):
|
| 539 |
+
|
| 540 |
+
# downsampling
|
| 541 |
+
h = self.conv_in(x)
|
| 542 |
+
for i_level in range(self.num_resolutions):
|
| 543 |
+
for i_block in range(self.num_res_blocks):
|
| 544 |
+
h = self.down[i_level].block[i_block](h)
|
| 545 |
+
if len(self.down[i_level].attn) > 0:
|
| 546 |
+
h = self.down[i_level].attn[i_block](h)
|
| 547 |
+
|
| 548 |
+
if i_level != self.num_resolutions - 1:
|
| 549 |
+
h = self.down[i_level].downsample(h)
|
| 550 |
+
|
| 551 |
+
h = self.mid.block_1(h)
|
| 552 |
+
h = self.mid.attn_1(h)
|
| 553 |
+
h = self.mid.block_2(h)
|
| 554 |
+
|
| 555 |
+
# end
|
| 556 |
+
h = self.norm_out(h)
|
| 557 |
+
h = self.act(h)
|
| 558 |
+
|
| 559 |
+
if self.config.use_dc_up_down_blocks:
|
| 560 |
+
x = h.unflatten(1, (-1, self.out_shortcut_average_group_size))
|
| 561 |
+
x = x.mean(dim=2)
|
| 562 |
+
h = self.conv_out(h) + x
|
| 563 |
+
else:
|
| 564 |
+
h = self.conv_out(h)
|
| 565 |
+
return h
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
class MoVQDecoder(nn.Module):
|
| 569 |
+
def __init__(self, config: MoVQConfig):
|
| 570 |
+
super().__init__()
|
| 571 |
+
self.config = config
|
| 572 |
+
self.ch = config.ch
|
| 573 |
+
self.num_resolutions = len(config.ch_mult)
|
| 574 |
+
self.num_res_blocks = config.num_res_blocks
|
| 575 |
+
|
| 576 |
+
in_ch_mult = (1,) + tuple(config.ch_mult)
|
| 577 |
+
zq_ch = config.embed_dim
|
| 578 |
+
|
| 579 |
+
block_in = config.ch * config.ch_mult[-1]
|
| 580 |
+
|
| 581 |
+
self.in_shortcut_repeats = block_in // config.embed_dim
|
| 582 |
+
|
| 583 |
+
self.conv_in = nn.Conv2d(
|
| 584 |
+
config.z_channels,
|
| 585 |
+
block_in,
|
| 586 |
+
kernel_size=3,
|
| 587 |
+
stride=1,
|
| 588 |
+
padding=1,
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
# middle
|
| 592 |
+
self.mid = nn.Module()
|
| 593 |
+
self.mid.block_1 = MoVQResnetBlock(
|
| 594 |
+
in_channels=block_in,
|
| 595 |
+
out_channels=block_in,
|
| 596 |
+
dropout=config.dropout,
|
| 597 |
+
zq_ch=zq_ch,
|
| 598 |
+
)
|
| 599 |
+
self.mid.attn_1 = MoVQAttnBlock(block_in, zq_ch)
|
| 600 |
+
self.mid.block_2 = MoVQResnetBlock(
|
| 601 |
+
in_channels=block_in,
|
| 602 |
+
out_channels=block_in,
|
| 603 |
+
dropout=config.dropout,
|
| 604 |
+
zq_ch=zq_ch,
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
# upsampling
|
| 608 |
+
self.up = nn.ModuleList()
|
| 609 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 610 |
+
block = nn.ModuleList()
|
| 611 |
+
attn = nn.ModuleList()
|
| 612 |
+
block_out = config.ch * config.ch_mult[i_level]
|
| 613 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 614 |
+
block.append(
|
| 615 |
+
MoVQResnetBlock(
|
| 616 |
+
in_channels=block_in,
|
| 617 |
+
out_channels=block_out,
|
| 618 |
+
dropout=config.dropout,
|
| 619 |
+
zq_ch=zq_ch,
|
| 620 |
+
)
|
| 621 |
+
)
|
| 622 |
+
block_in = block_out
|
| 623 |
+
if i_level in config.attn_resolutions:
|
| 624 |
+
attn.append(MoVQAttnBlock(block_in, zq_ch))
|
| 625 |
+
|
| 626 |
+
up = nn.Module()
|
| 627 |
+
up.block = block
|
| 628 |
+
up.attn = attn
|
| 629 |
+
if i_level != 0:
|
| 630 |
+
if config.use_dc_up_down_blocks:
|
| 631 |
+
up.upsample = DCUpBlock2d(block_in)
|
| 632 |
+
else:
|
| 633 |
+
up.upsample = MoVQUpsample(block_in)
|
| 634 |
+
|
| 635 |
+
self.up.insert(0, up)
|
| 636 |
+
|
| 637 |
+
self.act = MoVQActivation()
|
| 638 |
+
|
| 639 |
+
self.norm_out = MoVQSpatialNorm(block_in, zq_ch)
|
| 640 |
+
self.conv_out = nn.Conv2d(
|
| 641 |
+
block_in,
|
| 642 |
+
config.out_channels,
|
| 643 |
+
kernel_size=3,
|
| 644 |
+
stride=1,
|
| 645 |
+
padding=1,
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
@property
|
| 649 |
+
def last_layer(self):
|
| 650 |
+
return self.conv_out.weight
|
| 651 |
+
|
| 652 |
+
def forward(self, z: torch.Tensor, zq: torch.Tensor):
|
| 653 |
+
h = z
|
| 654 |
+
|
| 655 |
+
if self.config.use_dc_up_down_blocks:
|
| 656 |
+
h = h.repeat_interleave(self.in_shortcut_repeats, dim=1)
|
| 657 |
+
h = self.conv_in(z) + h
|
| 658 |
+
else:
|
| 659 |
+
h = self.conv_in(h)
|
| 660 |
+
|
| 661 |
+
# middle
|
| 662 |
+
h = self.mid.block_1(h, zq)
|
| 663 |
+
h = self.mid.attn_1(h, zq)
|
| 664 |
+
h = self.mid.block_2(h, zq)
|
| 665 |
+
|
| 666 |
+
# upsampling
|
| 667 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 668 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 669 |
+
h = self.up[i_level].block[i_block](h, zq)
|
| 670 |
+
if len(self.up[i_level].attn) > 0:
|
| 671 |
+
h = self.up[i_level].attn[i_block](h, zq)
|
| 672 |
+
|
| 673 |
+
if i_level != 0:
|
| 674 |
+
h = self.up[i_level].upsample(h)
|
| 675 |
+
|
| 676 |
+
h = self.norm_out(h, zq)
|
| 677 |
+
h = self.act(h)
|
| 678 |
+
h = self.conv_out(h)
|
| 679 |
+
|
| 680 |
+
return h
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
class Decoder(nn.Module):
|
| 684 |
+
def __init__(self, config: MoVQConfig):
|
| 685 |
+
super().__init__()
|
| 686 |
+
self.config = config
|
| 687 |
+
self.ch = config.ch
|
| 688 |
+
self.num_resolutions = len(config.ch_mult)
|
| 689 |
+
self.num_res_blocks = config.num_res_blocks
|
| 690 |
+
|
| 691 |
+
in_ch_mult = (1,) + tuple(config.ch_mult)
|
| 692 |
+
|
| 693 |
+
block_in = config.ch * config.ch_mult[-1]
|
| 694 |
+
|
| 695 |
+
self.conv_in = nn.Conv2d(
|
| 696 |
+
config.z_channels,
|
| 697 |
+
block_in,
|
| 698 |
+
kernel_size=3,
|
| 699 |
+
stride=1,
|
| 700 |
+
padding=1,
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
# middle
|
| 704 |
+
self.mid = nn.Module()
|
| 705 |
+
self.mid.block_1 = MoVQResnetBlock(
|
| 706 |
+
in_channels=block_in,
|
| 707 |
+
out_channels=block_in,
|
| 708 |
+
dropout=config.dropout,
|
| 709 |
+
)
|
| 710 |
+
self.mid.attn_1 = MoVQAttnBlock(block_in)
|
| 711 |
+
self.mid.block_2 = MoVQResnetBlock(
|
| 712 |
+
in_channels=block_in,
|
| 713 |
+
out_channels=block_in,
|
| 714 |
+
dropout=config.dropout,
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
# upsampling
|
| 718 |
+
self.up = nn.ModuleList()
|
| 719 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 720 |
+
block = nn.ModuleList()
|
| 721 |
+
attn = nn.ModuleList()
|
| 722 |
+
block_out = config.ch * config.ch_mult[i_level]
|
| 723 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 724 |
+
block.append(
|
| 725 |
+
MoVQResnetBlock(
|
| 726 |
+
in_channels=block_in,
|
| 727 |
+
out_channels=block_out,
|
| 728 |
+
dropout=config.dropout,
|
| 729 |
+
)
|
| 730 |
+
)
|
| 731 |
+
block_in = block_out
|
| 732 |
+
if i_level in config.attn_resolutions:
|
| 733 |
+
attn.append(MoVQAttnBlock(block_in))
|
| 734 |
+
|
| 735 |
+
up = nn.Module()
|
| 736 |
+
up.block = block
|
| 737 |
+
up.attn = attn
|
| 738 |
+
if i_level != 0:
|
| 739 |
+
up.upsample = MoVQUpsample(block_in)
|
| 740 |
+
|
| 741 |
+
self.up.insert(0, up)
|
| 742 |
+
|
| 743 |
+
self.act = MoVQActivation()
|
| 744 |
+
|
| 745 |
+
norm_kwargs = dict(num_groups=32, eps=1e-6, affine=True)
|
| 746 |
+
self.norm_out = nn.GroupNorm(num_channels=block_in, **norm_kwargs)
|
| 747 |
+
self.conv_out = nn.Conv2d(
|
| 748 |
+
block_in,
|
| 749 |
+
config.out_channels,
|
| 750 |
+
kernel_size=3,
|
| 751 |
+
stride=1,
|
| 752 |
+
padding=1,
|
| 753 |
+
)
|
| 754 |
+
|
| 755 |
+
@property
|
| 756 |
+
def last_layer(self):
|
| 757 |
+
return self.conv_out.weight
|
| 758 |
+
|
| 759 |
+
def forward(self, z: torch.Tensor, zq: torch.Tensor):
|
| 760 |
+
h = z
|
| 761 |
+
h = self.conv_in(h)
|
| 762 |
+
|
| 763 |
+
# middle
|
| 764 |
+
h = self.mid.block_1(h)
|
| 765 |
+
h = self.mid.attn_1(h)
|
| 766 |
+
h = self.mid.block_2(h)
|
| 767 |
+
|
| 768 |
+
# upsampling
|
| 769 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 770 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 771 |
+
h = self.up[i_level].block[i_block](h)
|
| 772 |
+
if len(self.up[i_level].attn) > 0:
|
| 773 |
+
h = self.up[i_level].attn[i_block](h)
|
| 774 |
+
|
| 775 |
+
if i_level != 0:
|
| 776 |
+
h = self.up[i_level].upsample(h)
|
| 777 |
+
|
| 778 |
+
h = self.norm_out(h)
|
| 779 |
+
h = self.act(h)
|
| 780 |
+
h = self.conv_out(h)
|
| 781 |
+
|
| 782 |
+
return h
|
| 783 |
+
|
| 784 |
+
|
| 785 |
+
class MoVQModel(MoVQPretrainedModel):
|
| 786 |
+
|
| 787 |
+
def __init__(self, config):
|
| 788 |
+
super().__init__(config)
|
| 789 |
+
self.config = config
|
| 790 |
+
|
| 791 |
+
self.encoder = MoVQEncoder(config)
|
| 792 |
+
self.decoder = MoVQDecoder(config)
|
| 793 |
+
self.quantize = MoVQVectorQuantizer(config)
|
| 794 |
+
|
| 795 |
+
self.quant_conv = nn.Conv2d(config.z_channels, config.embed_dim, 1)
|
| 796 |
+
self.post_quant_conv = nn.Conv2d(config.embed_dim, config.z_channels, 1)
|
| 797 |
+
|
| 798 |
+
self.spatial_scale_factor = 2 ** (len(config.ch_mult) - 1)
|
| 799 |
+
|
| 800 |
+
self.post_init()
|
| 801 |
+
|
| 802 |
+
def encode(self, x: torch.Tensor):
|
| 803 |
+
h = self.encoder(x)
|
| 804 |
+
h = self.quant_conv(h)
|
| 805 |
+
codes = self.quantize(h)
|
| 806 |
+
return codes
|
| 807 |
+
|
| 808 |
+
def decode(self, x: torch.Tensor):
|
| 809 |
+
quant = self.quantize.embedding(x.flatten())
|
| 810 |
+
b, h, w, c = quant.shape
|
| 811 |
+
quant = quant.view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
|
| 812 |
+
quant2 = self.post_quant_conv(quant)
|
| 813 |
+
image = self.decoder(quant2, quant)
|
| 814 |
+
image = image.reshape(
|
| 815 |
+
b,
|
| 816 |
+
self.config.out_channels,
|
| 817 |
+
h * self.spatial_scale_factor,
|
| 818 |
+
w * self.spatial_scale_factor,
|
| 819 |
+
)
|
| 820 |
+
return image
|
| 821 |
+
|
| 822 |
+
@property
|
| 823 |
+
def device(self):
|
| 824 |
+
return next(self.parameters()).device
|
| 825 |
+
|
| 826 |
+
@property
|
| 827 |
+
def dtype(self):
|
| 828 |
+
return next(self.parameters()).dtype
|
modeling_qwen2vit.py
ADDED
|
@@ -0,0 +1,841 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 5 |
+
# and OPT implementations in this library. It has been modified from its
|
| 6 |
+
# original forms to accommodate minor architectural differences compared
|
| 7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
"""PyTorch Qwen2-VL model."""
|
| 21 |
+
|
| 22 |
+
import math
|
| 23 |
+
from dataclasses import dataclass
|
| 24 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 25 |
+
|
| 26 |
+
from torch import Tensor
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
import torch.nn as nn
|
| 30 |
+
import torch.nn.functional as F
|
| 31 |
+
import torch.utils.checkpoint
|
| 32 |
+
|
| 33 |
+
from transformers.activations import ACT2FN
|
| 34 |
+
from transformers.cache_utils import Cache, StaticCache
|
| 35 |
+
from transformers.modeling_attn_mask_utils import (
|
| 36 |
+
AttentionMaskConverter,
|
| 37 |
+
)
|
| 38 |
+
from transformers.modeling_outputs import (
|
| 39 |
+
BaseModelOutputWithPast,
|
| 40 |
+
ModelOutput,
|
| 41 |
+
)
|
| 42 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 43 |
+
from transformers.utils import (
|
| 44 |
+
add_start_docstrings,
|
| 45 |
+
add_start_docstrings_to_model_forward,
|
| 46 |
+
is_torch_npu_available,
|
| 47 |
+
is_flash_attn_2_available,
|
| 48 |
+
is_flash_attn_greater_or_equal_2_10,
|
| 49 |
+
logging,
|
| 50 |
+
replace_return_docstrings,
|
| 51 |
+
)
|
| 52 |
+
from .configuration_qwen2vit import Qwen2VLConfig, Qwen2VLVisionConfig
|
| 53 |
+
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
| 54 |
+
|
| 55 |
+
from einops import rearrange
|
| 56 |
+
|
| 57 |
+
logger = logging.get_logger(__name__)
|
| 58 |
+
|
| 59 |
+
_CONFIG_FOR_DOC = "Qwen2VLConfig"
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
import xformers.ops as xops
|
| 63 |
+
|
| 64 |
+
is_xformers_available = True
|
| 65 |
+
except Exception as e:
|
| 66 |
+
is_xformers_available = False
|
| 67 |
+
|
| 68 |
+
if is_flash_attn_2_available():
|
| 69 |
+
from flash_attn import flash_attn_varlen_func
|
| 70 |
+
|
| 71 |
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
| 72 |
+
else:
|
| 73 |
+
flash_attn_varlen_func = None
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def init_weights(m):
|
| 77 |
+
if isinstance(m, nn.Linear):
|
| 78 |
+
# we use xavier_uniform following official JAX ViT:
|
| 79 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 80 |
+
if m.bias is not None:
|
| 81 |
+
nn.init.constant_(m.bias, 0)
|
| 82 |
+
elif isinstance(m, nn.nn.LayerNorm):
|
| 83 |
+
nn.init.constant_(m.bias, 0)
|
| 84 |
+
nn.init.constant_(m.weight, 1.0)
|
| 85 |
+
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
| 86 |
+
w = m.weight.data
|
| 87 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@dataclass
|
| 91 |
+
class Qwen2VLCausalLMOutputWithPast(ModelOutput):
|
| 92 |
+
"""
|
| 93 |
+
Base class for Qwen2VL causal language model (or autoregressive) outputs.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 97 |
+
Language modeling loss (for next-token prediction).
|
| 98 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 99 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 100 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 101 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
| 102 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
| 103 |
+
|
| 104 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
| 105 |
+
`past_key_values` input) to speed up sequential decoding.
|
| 106 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 107 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 108 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 109 |
+
|
| 110 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| 111 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 112 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 113 |
+
sequence_length)`.
|
| 114 |
+
|
| 115 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 116 |
+
heads.
|
| 117 |
+
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
| 118 |
+
The rope index difference between sequence length and multimodal rope.
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
loss: Optional[torch.FloatTensor] = None
|
| 122 |
+
logits: torch.FloatTensor = None
|
| 123 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None
|
| 124 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 125 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 126 |
+
rope_deltas: Optional[torch.LongTensor] = None
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class Qwen2VLRotaryEmbedding(nn.Module):
|
| 130 |
+
def __init__(
|
| 131 |
+
self,
|
| 132 |
+
dim=None,
|
| 133 |
+
max_position_embeddings=2048,
|
| 134 |
+
base=10000,
|
| 135 |
+
device=None,
|
| 136 |
+
scaling_factor=1.0,
|
| 137 |
+
rope_type="default",
|
| 138 |
+
config: Optional[Qwen2VLConfig] = None,
|
| 139 |
+
):
|
| 140 |
+
super().__init__()
|
| 141 |
+
# TODO (joao): remove the `if` below, only used for BC
|
| 142 |
+
self.rope_kwargs = {}
|
| 143 |
+
if config is None:
|
| 144 |
+
logger.warning_once(
|
| 145 |
+
"`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
| 146 |
+
"`config` argument. All other arguments will be removed in v4.46"
|
| 147 |
+
)
|
| 148 |
+
self.rope_kwargs = {
|
| 149 |
+
"rope_type": rope_type,
|
| 150 |
+
"factor": scaling_factor,
|
| 151 |
+
"dim": dim,
|
| 152 |
+
"base": base,
|
| 153 |
+
"max_position_embeddings": max_position_embeddings,
|
| 154 |
+
}
|
| 155 |
+
self.rope_type = rope_type
|
| 156 |
+
self.max_seq_len_cached = max_position_embeddings
|
| 157 |
+
self.original_max_seq_len = max_position_embeddings
|
| 158 |
+
else:
|
| 159 |
+
# BC: "rope_type" was originally "type"
|
| 160 |
+
if config.rope_scaling is not None:
|
| 161 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 162 |
+
else:
|
| 163 |
+
self.rope_type = "default"
|
| 164 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 165 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 166 |
+
|
| 167 |
+
self.config = config
|
| 168 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 169 |
+
|
| 170 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
| 171 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 172 |
+
self.original_inv_freq = self.inv_freq
|
| 173 |
+
|
| 174 |
+
def _dynamic_frequency_update(self, position_ids, device):
|
| 175 |
+
"""
|
| 176 |
+
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
| 177 |
+
1 - growing beyond the cached sequence length (allow scaling)
|
| 178 |
+
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
| 179 |
+
"""
|
| 180 |
+
seq_len = torch.max(position_ids) + 1
|
| 181 |
+
if seq_len > self.max_seq_len_cached: # growth
|
| 182 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(
|
| 183 |
+
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
| 184 |
+
)
|
| 185 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
| 186 |
+
self.max_seq_len_cached = seq_len
|
| 187 |
+
|
| 188 |
+
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
| 189 |
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
| 190 |
+
self.max_seq_len_cached = self.original_max_seq_len
|
| 191 |
+
|
| 192 |
+
@torch.no_grad()
|
| 193 |
+
def forward(self, x, position_ids):
|
| 194 |
+
if "dynamic" in self.rope_type:
|
| 195 |
+
self._dynamic_frequency_update(position_ids, device=x.device)
|
| 196 |
+
|
| 197 |
+
# Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for thw grids
|
| 198 |
+
# So we expand the inv_freq to shape (3, ...)
|
| 199 |
+
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
|
| 200 |
+
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
| 201 |
+
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
| 202 |
+
device_type = x.device.type
|
| 203 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
| 204 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 205 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
|
| 206 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 207 |
+
cos = emb.cos()
|
| 208 |
+
sin = emb.sin()
|
| 209 |
+
|
| 210 |
+
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
| 211 |
+
cos = cos * self.attention_scaling
|
| 212 |
+
sin = sin * self.attention_scaling
|
| 213 |
+
|
| 214 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
| 218 |
+
def rotate_half(x):
|
| 219 |
+
"""Rotates half the hidden dims of the input."""
|
| 220 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 221 |
+
x2 = x[..., x.shape[-1] // 2:]
|
| 222 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
| 226 |
+
orig_dtype = tensor.dtype
|
| 227 |
+
tensor = tensor.float()
|
| 228 |
+
cos = freqs.cos()
|
| 229 |
+
sin = freqs.sin()
|
| 230 |
+
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
| 231 |
+
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
| 232 |
+
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
| 233 |
+
output = output.to(orig_dtype)
|
| 234 |
+
return output
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def apply_rotary_pos_emb_vision_batch(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
| 238 |
+
orig_dtype = tensor.dtype
|
| 239 |
+
tensor = tensor.float()
|
| 240 |
+
cos = freqs.cos()
|
| 241 |
+
sin = freqs.sin()
|
| 242 |
+
cos = cos.repeat(1, 1, 1, 2).float()
|
| 243 |
+
sin = sin.repeat(1, 1, 1, 2).float()
|
| 244 |
+
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
| 245 |
+
output = output.to(orig_dtype)
|
| 246 |
+
return output
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class VisionRotaryEmbedding(nn.Module):
|
| 250 |
+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
| 251 |
+
super().__init__()
|
| 252 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
| 253 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 254 |
+
|
| 255 |
+
def forward(self, seqlen: int, scale_factor: float = 1.0) -> torch.Tensor:
|
| 256 |
+
# 使用 scale_factor 动态调整 inv_freq
|
| 257 |
+
scaled_inv_freq = self.inv_freq * scale_factor
|
| 258 |
+
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
| 259 |
+
freqs = torch.outer(seq, scaled_inv_freq)
|
| 260 |
+
return freqs
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class PatchEmbed(nn.Module):
|
| 264 |
+
def __init__(
|
| 265 |
+
self,
|
| 266 |
+
patch_size: int = 14,
|
| 267 |
+
temporal_patch_size: int = 2,
|
| 268 |
+
in_channels: int = 3,
|
| 269 |
+
embed_dim: int = 1152,
|
| 270 |
+
) -> None:
|
| 271 |
+
super().__init__()
|
| 272 |
+
self.patch_size = patch_size
|
| 273 |
+
self.temporal_patch_size = temporal_patch_size
|
| 274 |
+
self.in_channels = in_channels
|
| 275 |
+
self.embed_dim = embed_dim
|
| 276 |
+
|
| 277 |
+
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
| 278 |
+
self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
|
| 279 |
+
|
| 280 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 281 |
+
target_dtype = self.proj.weight.dtype
|
| 282 |
+
if is_torch_npu_available():
|
| 283 |
+
# if True:
|
| 284 |
+
hidden_states = F.linear(hidden_states, self.proj.weight.view(self.embed_dim, -1))
|
| 285 |
+
else:
|
| 286 |
+
hidden_states = hidden_states.view(
|
| 287 |
+
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
|
| 288 |
+
)
|
| 289 |
+
hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
|
| 290 |
+
return hidden_states
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class PatchMerger(nn.Module):
|
| 294 |
+
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
|
| 295 |
+
super().__init__()
|
| 296 |
+
self.hidden_size = context_dim * (spatial_merge_size ** 2)
|
| 297 |
+
self.ln_q = nn.LayerNorm(context_dim, eps=1e-6)
|
| 298 |
+
self.mlp = nn.Sequential(
|
| 299 |
+
nn.Linear(self.hidden_size, self.hidden_size),
|
| 300 |
+
nn.GELU(),
|
| 301 |
+
nn.Linear(self.hidden_size, dim),
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
def forward(self, x: torch.Tensor, grid_thw) -> torch.Tensor:
|
| 305 |
+
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
|
| 306 |
+
return x
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class VisionMlp(nn.Module):
|
| 310 |
+
def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
|
| 311 |
+
super().__init__()
|
| 312 |
+
self.fc1 = nn.Linear(dim, hidden_dim)
|
| 313 |
+
self.act = ACT2FN[hidden_act]
|
| 314 |
+
self.fc2 = nn.Linear(hidden_dim, dim)
|
| 315 |
+
|
| 316 |
+
def forward(self, x) -> torch.Tensor:
|
| 317 |
+
return self.fc2(self.act(self.fc1(x)))
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class VisionAttention(nn.Module):
|
| 321 |
+
def __init__(self, dim: int, num_heads: int = 16, ) -> None:
|
| 322 |
+
super().__init__()
|
| 323 |
+
self.num_heads = num_heads
|
| 324 |
+
self.head_dim = dim // num_heads
|
| 325 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
| 326 |
+
self.proj = nn.Linear(dim, dim)
|
| 327 |
+
|
| 328 |
+
def forward(
|
| 329 |
+
self,
|
| 330 |
+
hidden_states: torch.Tensor,
|
| 331 |
+
cu_seqlens: torch.Tensor,
|
| 332 |
+
rotary_pos_emb: torch.Tensor = None
|
| 333 |
+
) -> torch.Tensor:
|
| 334 |
+
seq_length = hidden_states.shape[0]
|
| 335 |
+
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
| 336 |
+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
| 337 |
+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
| 338 |
+
|
| 339 |
+
attention_mask = torch.full(
|
| 340 |
+
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
|
| 341 |
+
)
|
| 342 |
+
for i in range(1, len(cu_seqlens)):
|
| 343 |
+
attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = 0
|
| 344 |
+
|
| 345 |
+
q = q.transpose(0, 1)
|
| 346 |
+
k = k.transpose(0, 1)
|
| 347 |
+
v = v.transpose(0, 1)
|
| 348 |
+
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
|
| 349 |
+
attn_weights = attn_weights + attention_mask
|
| 350 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
| 351 |
+
attn_output = torch.matmul(attn_weights, v)
|
| 352 |
+
attn_output = attn_output.transpose(0, 1)
|
| 353 |
+
attn_output = attn_output.reshape(seq_length, -1)
|
| 354 |
+
attn_output = self.proj(attn_output)
|
| 355 |
+
return attn_output
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
class BatchVisionAttention(nn.Module):
|
| 359 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
| 360 |
+
super().__init__()
|
| 361 |
+
self.num_heads = num_heads
|
| 362 |
+
self.head_dim = dim // num_heads
|
| 363 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
| 364 |
+
self.proj = nn.Linear(dim, dim)
|
| 365 |
+
|
| 366 |
+
def forward(
|
| 367 |
+
self,
|
| 368 |
+
hidden_states: torch.Tensor, # [batch_size, seq_len, dim]
|
| 369 |
+
attention_mask: torch.Tensor, # [batch_size, 1, 1, seq_len]
|
| 370 |
+
rotary_pos_emb: torch.Tensor = None # [batch_size, seq_len, head_dim//2]
|
| 371 |
+
) -> torch.Tensor:
|
| 372 |
+
batch_size, seq_len, _ = hidden_states.shape
|
| 373 |
+
|
| 374 |
+
q, k, v = self.qkv(hidden_states).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim).permute(2, 0,
|
| 375 |
+
3, 1,
|
| 376 |
+
4).unbind(
|
| 377 |
+
0)
|
| 378 |
+
# [batch_size, num_heads, seq_len, head_dim]
|
| 379 |
+
|
| 380 |
+
if rotary_pos_emb is not None:
|
| 381 |
+
rotary_pos_emb = rotary_pos_emb.unsqueeze(1) # [batch_size, 1, seq_len, head_dim//2]
|
| 382 |
+
q = apply_rotary_pos_emb_vision_batch(q, rotary_pos_emb)
|
| 383 |
+
k = apply_rotary_pos_emb_vision_batch(k, rotary_pos_emb)
|
| 384 |
+
|
| 385 |
+
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 386 |
+
if attention_mask is not None:
|
| 387 |
+
attn_weights = attn_weights + attention_mask
|
| 388 |
+
|
| 389 |
+
# Softmax
|
| 390 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
| 391 |
+
|
| 392 |
+
attn_output = torch.matmul(attn_weights, v) # [batch_size, num_heads, seq_len, head_dim]
|
| 393 |
+
attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, -1)
|
| 394 |
+
return self.proj(attn_output)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
class VisionXformerAttention(nn.Module):
|
| 398 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
| 399 |
+
super().__init__()
|
| 400 |
+
self.num_heads = num_heads
|
| 401 |
+
self.head_dim = dim // num_heads
|
| 402 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
| 403 |
+
self.proj = nn.Linear(dim, dim)
|
| 404 |
+
|
| 405 |
+
def forward(
|
| 406 |
+
self,
|
| 407 |
+
hidden_states: torch.Tensor,
|
| 408 |
+
cu_seqlens: torch.Tensor,
|
| 409 |
+
rotary_pos_emb: torch.Tensor = None
|
| 410 |
+
) -> torch.Tensor:
|
| 411 |
+
seq_length = hidden_states.shape[0]
|
| 412 |
+
|
| 413 |
+
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
| 414 |
+
|
| 415 |
+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb)
|
| 416 |
+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb)
|
| 417 |
+
|
| 418 |
+
seqlens = [cu_seqlens[0]] + [cu_seqlens[i] - cu_seqlens[i - 1] for i in range(1, len(cu_seqlens))]
|
| 419 |
+
attn_bias = xops.fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| 420 |
+
|
| 421 |
+
attn_output = xops.memory_efficient_attention(
|
| 422 |
+
q, k, v.unsqueeze(0),
|
| 423 |
+
attn_bias=attn_bias,
|
| 424 |
+
scale=1.0 / math.sqrt(self.head_dim)
|
| 425 |
+
)
|
| 426 |
+
attn_output = attn_output.reshape(seq_length, -1)
|
| 427 |
+
attn_output = self.proj(attn_output)
|
| 428 |
+
return attn_output
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
class BatchVisionXformerAttention(nn.Module):
|
| 432 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
| 433 |
+
super().__init__()
|
| 434 |
+
self.num_heads = num_heads
|
| 435 |
+
self.head_dim = dim // num_heads
|
| 436 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
| 437 |
+
self.proj = nn.Linear(dim, dim)
|
| 438 |
+
|
| 439 |
+
def forward(
|
| 440 |
+
self,
|
| 441 |
+
hidden_states: torch.Tensor,
|
| 442 |
+
attention_mask: torch.Tensor, # [batch_size, 1, 1, seq_len]
|
| 443 |
+
rotary_pos_emb: torch.Tensor = None
|
| 444 |
+
) -> torch.Tensor:
|
| 445 |
+
seq_length = hidden_states.shape[0]
|
| 446 |
+
batch_size, seq_len = hidden_states.shape
|
| 447 |
+
|
| 448 |
+
q, k, v = self.qkv(hidden_states).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim).permute(2, 0,
|
| 449 |
+
3, 1,
|
| 450 |
+
4).unbind(
|
| 451 |
+
0)
|
| 452 |
+
# [batch_size, num_heads, seq_len, head_dim]
|
| 453 |
+
|
| 454 |
+
if rotary_pos_emb is not None:
|
| 455 |
+
rotary_pos_emb = rotary_pos_emb.unsqueeze(1) # [batch_size, 1, seq_len, head_dim//2]
|
| 456 |
+
q = apply_rotary_pos_emb_vision_batch(q, rotary_pos_emb)
|
| 457 |
+
k = apply_rotary_pos_emb_vision_batch(k, rotary_pos_emb)
|
| 458 |
+
|
| 459 |
+
attn_output = xops.memory_efficient_attention(
|
| 460 |
+
q, k, v,
|
| 461 |
+
attn_bias=attention_mask,
|
| 462 |
+
scale=1.0 / math.sqrt(self.head_dim)
|
| 463 |
+
)
|
| 464 |
+
attn_output = attn_output.reshape(batch_size, seq_len, -1)
|
| 465 |
+
return self.proj(attn_output)
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
class VisionFlashAttention2(nn.Module):
|
| 469 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
| 470 |
+
super().__init__()
|
| 471 |
+
self.num_heads = num_heads
|
| 472 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
| 473 |
+
self.proj = nn.Linear(dim, dim)
|
| 474 |
+
|
| 475 |
+
def forward(
|
| 476 |
+
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
| 477 |
+
) -> torch.Tensor:
|
| 478 |
+
seq_length = hidden_states.shape[0]
|
| 479 |
+
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
| 480 |
+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
| 481 |
+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
| 482 |
+
|
| 483 |
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
| 484 |
+
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
| 485 |
+
seq_length, -1
|
| 486 |
+
)
|
| 487 |
+
attn_output = self.proj(attn_output)
|
| 488 |
+
return attn_output
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
class BatchVisionFlashAttention2(nn.Module):
|
| 492 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
| 493 |
+
super().__init__()
|
| 494 |
+
self.num_heads = num_heads
|
| 495 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
| 496 |
+
self.proj = nn.Linear(dim, dim)
|
| 497 |
+
|
| 498 |
+
def forward(
|
| 499 |
+
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
| 500 |
+
) -> torch.Tensor:
|
| 501 |
+
batch_size, seq_len, _ = hidden_states.shape
|
| 502 |
+
|
| 503 |
+
q, k, v = self.qkv(hidden_states).reshape(batch_size, seq_len, 3, self.num_heads, -1).permute(2, 0, 3, 1,
|
| 504 |
+
4).unbind(0)
|
| 505 |
+
|
| 506 |
+
if rotary_pos_emb is not None:
|
| 507 |
+
rotary_pos_emb = rotary_pos_emb.unsqueeze(1) # [batch_size, 1, seq_len, head_dim//2]
|
| 508 |
+
q = apply_rotary_pos_emb_vision_batch(q, rotary_pos_emb)
|
| 509 |
+
k = apply_rotary_pos_emb_vision_batch(k, rotary_pos_emb)
|
| 510 |
+
|
| 511 |
+
q = rearrange(q, 'b h l d -> b l h d')
|
| 512 |
+
k = rearrange(k, 'b h l d -> b l h d')
|
| 513 |
+
v = rearrange(v, 'b h l d -> b l h d')
|
| 514 |
+
|
| 515 |
+
attn_output = _flash_attention_forward(q, k, v).reshape(batch_size, seq_len, -1)
|
| 516 |
+
attn_output = self.proj(attn_output)
|
| 517 |
+
return attn_output
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
class VisionSdpaAttention(nn.Module):
|
| 521 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
| 522 |
+
super().__init__()
|
| 523 |
+
self.num_heads = num_heads
|
| 524 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
| 525 |
+
self.proj = nn.Linear(dim, dim)
|
| 526 |
+
|
| 527 |
+
def forward(
|
| 528 |
+
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
| 529 |
+
) -> torch.Tensor:
|
| 530 |
+
seq_length = hidden_states.shape[0]
|
| 531 |
+
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
| 532 |
+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
| 533 |
+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
| 534 |
+
|
| 535 |
+
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
|
| 536 |
+
for i in range(1, len(cu_seqlens)):
|
| 537 |
+
attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = True
|
| 538 |
+
q = q.transpose(0, 1)
|
| 539 |
+
k = k.transpose(0, 1)
|
| 540 |
+
v = v.transpose(0, 1)
|
| 541 |
+
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
| 542 |
+
attn_output = attn_output.transpose(0, 1)
|
| 543 |
+
attn_output = attn_output.reshape(seq_length, -1)
|
| 544 |
+
attn_output = self.proj(attn_output)
|
| 545 |
+
return attn_output
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
class BatchVisionSdpaAttention(nn.Module):
|
| 549 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
| 550 |
+
super().__init__()
|
| 551 |
+
self.num_heads = num_heads
|
| 552 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
| 553 |
+
self.proj = nn.Linear(dim, dim)
|
| 554 |
+
|
| 555 |
+
def forward(
|
| 556 |
+
self,
|
| 557 |
+
hidden_states: torch.Tensor, # [batch_size, seq_len, dim]
|
| 558 |
+
attention_mask: torch.Tensor = None, # [batch_size, 1, 1, seq_len]
|
| 559 |
+
rotary_pos_emb: torch.Tensor = None # [batch_size, seq_len, head_dim//2]
|
| 560 |
+
) -> torch.Tensor:
|
| 561 |
+
batch_size, seq_len, _ = hidden_states.shape
|
| 562 |
+
q, k, v = self.qkv(hidden_states).reshape(batch_size, seq_len, 3, self.num_heads, -1).permute(2, 0, 3, 1,
|
| 563 |
+
4).unbind(0)
|
| 564 |
+
# [batch_size, num_heads, seq_len, head_dim]
|
| 565 |
+
|
| 566 |
+
if rotary_pos_emb is not None:
|
| 567 |
+
rotary_pos_emb = rotary_pos_emb.unsqueeze(1) # [batch_size, 1, seq_len, head_dim//2]
|
| 568 |
+
q = apply_rotary_pos_emb_vision_batch(q, rotary_pos_emb)
|
| 569 |
+
k = apply_rotary_pos_emb_vision_batch(k, rotary_pos_emb)
|
| 570 |
+
|
| 571 |
+
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
| 572 |
+
attn_output = attn_output.transpose(1, 2)
|
| 573 |
+
attn_output = attn_output.reshape(batch_size, seq_len, -1)
|
| 574 |
+
attn_output = self.proj(attn_output)
|
| 575 |
+
return attn_output
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
QWEN2_VL_VISION_ATTENTION_CLASSES = {
|
| 579 |
+
"eager": VisionAttention,
|
| 580 |
+
"flash_attention_2": VisionFlashAttention2,
|
| 581 |
+
"sdpa": VisionSdpaAttention,
|
| 582 |
+
"xformers": VisionXformerAttention,
|
| 583 |
+
}
|
| 584 |
+
|
| 585 |
+
QWEN2_VL_VISION_BATCH_ATTENTION_CLASSES = {
|
| 586 |
+
"eager": BatchVisionAttention,
|
| 587 |
+
"flash_attention_2": VisionFlashAttention2,
|
| 588 |
+
"sdpa": BatchVisionSdpaAttention,
|
| 589 |
+
}
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
class Qwen2VLVisionBlock(nn.Module):
|
| 593 |
+
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
|
| 594 |
+
super().__init__()
|
| 595 |
+
|
| 596 |
+
self.norm1 = nn.LayerNorm(config.embed_dim, eps=1e-6)
|
| 597 |
+
self.norm2 = nn.LayerNorm(config.embed_dim, eps=1e-6)
|
| 598 |
+
mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
|
| 599 |
+
|
| 600 |
+
self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation](
|
| 601 |
+
config.embed_dim, num_heads=config.num_heads,
|
| 602 |
+
)
|
| 603 |
+
self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
|
| 604 |
+
|
| 605 |
+
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb, grid_thw) -> torch.Tensor:
|
| 606 |
+
hidden_states = hidden_states + self.attn(
|
| 607 |
+
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
| 608 |
+
)
|
| 609 |
+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
| 610 |
+
return hidden_states
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
class Qwen2VLBatchVisionBlock(nn.Module):
|
| 614 |
+
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
|
| 615 |
+
super().__init__()
|
| 616 |
+
self.norm1 = nn.LayerNorm(config.embed_dim, eps=1e-6)
|
| 617 |
+
self.norm2 = nn.LayerNorm(config.embed_dim, eps=1e-6)
|
| 618 |
+
mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
|
| 619 |
+
|
| 620 |
+
self.attn = QWEN2_VL_VISION_BATCH_ATTENTION_CLASSES[attn_implementation](
|
| 621 |
+
config.embed_dim, num_heads=config.num_heads,
|
| 622 |
+
)
|
| 623 |
+
self.mlp = VisionMlp(config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
|
| 624 |
+
|
| 625 |
+
def forward(
|
| 626 |
+
self,
|
| 627 |
+
hidden_states: torch.Tensor, # [batch_size, seq_len, dim]
|
| 628 |
+
attention_mask: torch.Tensor = None, # [batch_size, 1, 1, seq_len]
|
| 629 |
+
rotary_pos_emb: torch.Tensor = None # [batch_size, seq_len, head_dim//2]
|
| 630 |
+
) -> torch.Tensor:
|
| 631 |
+
# Attention
|
| 632 |
+
hidden_states = hidden_states + self.attn(
|
| 633 |
+
self.norm1(hidden_states), attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb
|
| 634 |
+
)
|
| 635 |
+
# MLP
|
| 636 |
+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
| 637 |
+
return hidden_states
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
| 641 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
| 642 |
+
attention_mask: torch.Tensor,
|
| 643 |
+
sequence_length: int,
|
| 644 |
+
target_length: int,
|
| 645 |
+
dtype: torch.dtype,
|
| 646 |
+
device: torch.device,
|
| 647 |
+
min_dtype: float,
|
| 648 |
+
cache_position: torch.Tensor,
|
| 649 |
+
batch_size: int,
|
| 650 |
+
):
|
| 651 |
+
"""
|
| 652 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 653 |
+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
| 654 |
+
|
| 655 |
+
Args:
|
| 656 |
+
attention_mask (`torch.Tensor`):
|
| 657 |
+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
| 658 |
+
sequence_length (`int`):
|
| 659 |
+
The sequence length being processed.
|
| 660 |
+
target_length (`int`):
|
| 661 |
+
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
| 662 |
+
dtype (`torch.dtype`):
|
| 663 |
+
The dtype to use for the 4D attention mask.
|
| 664 |
+
device (`torch.device`):
|
| 665 |
+
The device to plcae the 4D attention mask on.
|
| 666 |
+
min_dtype (`float`):
|
| 667 |
+
The minimum value representable with the dtype `dtype`.
|
| 668 |
+
cache_position (`torch.Tensor`):
|
| 669 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
| 670 |
+
batch_size (`torch.Tensor`):
|
| 671 |
+
Batch size.
|
| 672 |
+
"""
|
| 673 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 674 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
| 675 |
+
causal_mask = attention_mask
|
| 676 |
+
else:
|
| 677 |
+
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
| 678 |
+
if sequence_length != 1:
|
| 679 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 680 |
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
| 681 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 682 |
+
if attention_mask is not None:
|
| 683 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 684 |
+
mask_length = attention_mask.shape[-1]
|
| 685 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
| 686 |
+
padding_mask = padding_mask == 0
|
| 687 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
| 688 |
+
padding_mask, min_dtype
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
return causal_mask
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm
|
| 695 |
+
class Qwen2RMSNorm(nn.Module):
|
| 696 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 697 |
+
"""
|
| 698 |
+
Qwen2RMSNorm is equivalent to T5nn.LayerNorm
|
| 699 |
+
"""
|
| 700 |
+
super().__init__()
|
| 701 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 702 |
+
self.variance_epsilon = eps
|
| 703 |
+
|
| 704 |
+
def forward(self, hidden_states):
|
| 705 |
+
input_dtype = hidden_states.dtype
|
| 706 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 707 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 708 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 709 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 710 |
+
|
| 711 |
+
def extra_repr(self):
|
| 712 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2MLP
|
| 716 |
+
class Qwen2MLP(nn.Module):
|
| 717 |
+
def __init__(self, config):
|
| 718 |
+
super().__init__()
|
| 719 |
+
self.hidden_size = config.hidden_size
|
| 720 |
+
self.intermediate_size = config.intermediate_size
|
| 721 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 722 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 723 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 724 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 725 |
+
|
| 726 |
+
def forward(self, hidden_state):
|
| 727 |
+
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
| 728 |
+
|
| 729 |
+
|
| 730 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
| 731 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 732 |
+
"""
|
| 733 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 734 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 735 |
+
"""
|
| 736 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 737 |
+
if n_rep == 1:
|
| 738 |
+
return hidden_states
|
| 739 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 740 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
class Qwen2VLPreTrainedModel(PreTrainedModel):
|
| 744 |
+
config_class = Qwen2VLConfig
|
| 745 |
+
base_model_prefix = "model"
|
| 746 |
+
supports_gradient_checkpointing = True
|
| 747 |
+
_no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"]
|
| 748 |
+
_skip_keys_device_placement = "past_key_values"
|
| 749 |
+
_supports_flash_attn_2 = True
|
| 750 |
+
_supports_sdpa = True
|
| 751 |
+
_supports_cache_class = True
|
| 752 |
+
_supports_static_cache = True
|
| 753 |
+
|
| 754 |
+
def _init_weights(self, module):
|
| 755 |
+
std = self.config.initializer_range
|
| 756 |
+
if isinstance(module, (nn.Linear, nn.Conv3d)):
|
| 757 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 758 |
+
if module.bias is not None:
|
| 759 |
+
module.bias.data.zero_()
|
| 760 |
+
elif isinstance(module, nn.Embedding):
|
| 761 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 762 |
+
if module.padding_idx is not None:
|
| 763 |
+
module.weight.data[module.padding_idx].zero_()
|
| 764 |
+
|
| 765 |
+
|
| 766 |
+
class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
| 767 |
+
config_class = Qwen2VLVisionConfig
|
| 768 |
+
_no_split_modules = ["Qwen2VLVisionBlock"]
|
| 769 |
+
|
| 770 |
+
def __init__(self, config) -> None:
|
| 771 |
+
super().__init__(config)
|
| 772 |
+
self.spatial_merge_size = config.spatial_merge_size
|
| 773 |
+
|
| 774 |
+
self.patch_embed = PatchEmbed(
|
| 775 |
+
patch_size=config.patch_size,
|
| 776 |
+
temporal_patch_size=config.temporal_patch_size,
|
| 777 |
+
in_channels=config.in_channels,
|
| 778 |
+
embed_dim=config.embed_dim,
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
head_dim = config.embed_dim // config.num_heads
|
| 782 |
+
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
| 783 |
+
|
| 784 |
+
self.blocks = nn.ModuleList(
|
| 785 |
+
[Qwen2VLVisionBlock(config, config.attn_implementation) for _ in range(config.depth)]
|
| 786 |
+
)
|
| 787 |
+
self.merger = PatchMerger(
|
| 788 |
+
dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size
|
| 789 |
+
)
|
| 790 |
+
|
| 791 |
+
def get_dtype(self) -> torch.dtype:
|
| 792 |
+
return self.blocks[0].mlp.fc2.weight.dtype
|
| 793 |
+
|
| 794 |
+
def get_device(self) -> torch.device:
|
| 795 |
+
return self.blocks[0].mlp.fc2.weight.device
|
| 796 |
+
|
| 797 |
+
def rot_pos_emb(self, grid_thw):
|
| 798 |
+
pos_ids = []
|
| 799 |
+
for t, h, w in grid_thw:
|
| 800 |
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
| 801 |
+
hpos_ids = hpos_ids.reshape(
|
| 802 |
+
h // self.spatial_merge_size,
|
| 803 |
+
self.spatial_merge_size,
|
| 804 |
+
w // self.spatial_merge_size,
|
| 805 |
+
self.spatial_merge_size,
|
| 806 |
+
)
|
| 807 |
+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
| 808 |
+
hpos_ids = hpos_ids.flatten()
|
| 809 |
+
|
| 810 |
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
| 811 |
+
wpos_ids = wpos_ids.reshape(
|
| 812 |
+
h // self.spatial_merge_size,
|
| 813 |
+
self.spatial_merge_size,
|
| 814 |
+
w // self.spatial_merge_size,
|
| 815 |
+
self.spatial_merge_size,
|
| 816 |
+
)
|
| 817 |
+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
| 818 |
+
wpos_ids = wpos_ids.flatten()
|
| 819 |
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
| 820 |
+
pos_ids = torch.cat(pos_ids, dim=0)
|
| 821 |
+
max_grid_size = grid_thw[:, 1:].max()
|
| 822 |
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
| 823 |
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
| 824 |
+
return rotary_pos_emb
|
| 825 |
+
|
| 826 |
+
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
| 827 |
+
hidden_states = self.patch_embed(hidden_states)
|
| 828 |
+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
| 829 |
+
|
| 830 |
+
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
| 831 |
+
dim=0, dtype=torch.int32
|
| 832 |
+
)
|
| 833 |
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
| 834 |
+
|
| 835 |
+
for blk in self.blocks:
|
| 836 |
+
hidden_states = blk(hidden_states,
|
| 837 |
+
cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb,
|
| 838 |
+
grid_thw=grid_thw)
|
| 839 |
+
|
| 840 |
+
hidden_states = self.merger(hidden_states, grid_thw)
|
| 841 |
+
return hidden_states
|
modeling_rope_utils.py
ADDED
|
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from typing import Optional, Tuple
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 20 |
+
from transformers.utils import is_torch_available, logging
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.get_logger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if is_torch_available():
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _compute_default_rope_parameters(
|
| 31 |
+
config: Optional[PretrainedConfig] = None,
|
| 32 |
+
device: Optional["torch.device"] = None,
|
| 33 |
+
seq_len: Optional[int] = None,
|
| 34 |
+
**rope_kwargs,
|
| 35 |
+
) -> Tuple["torch.Tensor", float]:
|
| 36 |
+
"""
|
| 37 |
+
Computes the inverse frequencies according to the original RoPE implementation
|
| 38 |
+
Args:
|
| 39 |
+
config ([`~transformers.PretrainedConfig`]):
|
| 40 |
+
The model configuration.
|
| 41 |
+
device (`torch.device`):
|
| 42 |
+
The device to use for initialization of the inverse frequencies.
|
| 43 |
+
seq_len (`int`, *optional*):
|
| 44 |
+
The current sequence length. Unused for this type of RoPE.
|
| 45 |
+
rope_kwargs (`Dict`, *optional*):
|
| 46 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
| 47 |
+
Returns:
|
| 48 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 49 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
| 50 |
+
"""
|
| 51 |
+
if config is not None and len(rope_kwargs) > 0:
|
| 52 |
+
raise ValueError(
|
| 53 |
+
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
|
| 54 |
+
f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
|
| 55 |
+
)
|
| 56 |
+
if len(rope_kwargs) > 0:
|
| 57 |
+
base = rope_kwargs["base"]
|
| 58 |
+
dim = rope_kwargs["dim"]
|
| 59 |
+
elif config is not None:
|
| 60 |
+
base = config.rope_theta
|
| 61 |
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
| 62 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 63 |
+
dim = int(head_dim * partial_rotary_factor)
|
| 64 |
+
|
| 65 |
+
attention_factor = 1.0 # Unused in this type of RoPE
|
| 66 |
+
|
| 67 |
+
# Compute the inverse frequencies
|
| 68 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
|
| 69 |
+
return inv_freq, attention_factor
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _compute_linear_scaling_rope_parameters(
|
| 73 |
+
config: Optional[PretrainedConfig] = None,
|
| 74 |
+
device: Optional["torch.device"] = None,
|
| 75 |
+
seq_len: Optional[int] = None,
|
| 76 |
+
**rope_kwargs,
|
| 77 |
+
) -> Tuple["torch.Tensor", float]:
|
| 78 |
+
"""
|
| 79 |
+
Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
|
| 80 |
+
Args:
|
| 81 |
+
config ([`~transformers.PretrainedConfig`]):
|
| 82 |
+
The model configuration.
|
| 83 |
+
device (`torch.device`):
|
| 84 |
+
The device to use for initialization of the inverse frequencies.
|
| 85 |
+
seq_len (`int`, *optional*):
|
| 86 |
+
The current sequence length. Unused for this type of RoPE.
|
| 87 |
+
rope_kwargs (`Dict`, *optional*):
|
| 88 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
| 89 |
+
Returns:
|
| 90 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 91 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
| 92 |
+
"""
|
| 93 |
+
if config is not None and len(rope_kwargs) > 0:
|
| 94 |
+
raise ValueError(
|
| 95 |
+
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
|
| 96 |
+
f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
|
| 97 |
+
)
|
| 98 |
+
if len(rope_kwargs) > 0:
|
| 99 |
+
factor = rope_kwargs["factor"]
|
| 100 |
+
elif config is not None:
|
| 101 |
+
factor = config.rope_scaling["factor"]
|
| 102 |
+
|
| 103 |
+
# Gets the default RoPE parameters
|
| 104 |
+
inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
|
| 105 |
+
|
| 106 |
+
# Then applies linear scaling to the frequencies.
|
| 107 |
+
# NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
|
| 108 |
+
# applying scaling to the inverse frequencies is equivalent.
|
| 109 |
+
inv_freq /= factor
|
| 110 |
+
return inv_freq, attention_factor
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _compute_dynamic_ntk_parameters(
|
| 114 |
+
config: Optional[PretrainedConfig] = None,
|
| 115 |
+
device: Optional["torch.device"] = None,
|
| 116 |
+
seq_len: Optional[int] = None,
|
| 117 |
+
**rope_kwargs,
|
| 118 |
+
) -> Tuple["torch.Tensor", float]:
|
| 119 |
+
"""
|
| 120 |
+
Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
|
| 121 |
+
Args:
|
| 122 |
+
config ([`~transformers.PretrainedConfig`]):
|
| 123 |
+
The model configuration.
|
| 124 |
+
device (`torch.device`):
|
| 125 |
+
The device to use for initialization of the inverse frequencies.
|
| 126 |
+
seq_len (`int`, *optional*):
|
| 127 |
+
The current sequence length, used to update the dynamic RoPE at inference time.
|
| 128 |
+
rope_kwargs (`Dict`, *optional*):
|
| 129 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
| 130 |
+
Returns:
|
| 131 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 132 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
| 133 |
+
"""
|
| 134 |
+
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
|
| 135 |
+
if config is not None and len(rope_kwargs) > 0:
|
| 136 |
+
raise ValueError(
|
| 137 |
+
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
|
| 138 |
+
f"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
|
| 139 |
+
)
|
| 140 |
+
if len(rope_kwargs) > 0:
|
| 141 |
+
base = rope_kwargs["base"]
|
| 142 |
+
dim = rope_kwargs["dim"]
|
| 143 |
+
max_position_embeddings = rope_kwargs["max_position_embeddings"]
|
| 144 |
+
factor = rope_kwargs["factor"]
|
| 145 |
+
elif config is not None:
|
| 146 |
+
base = config.rope_theta
|
| 147 |
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
| 148 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 149 |
+
dim = int(head_dim * partial_rotary_factor)
|
| 150 |
+
max_position_embeddings = config.max_position_embeddings
|
| 151 |
+
factor = config.rope_scaling["factor"]
|
| 152 |
+
|
| 153 |
+
attention_factor = 1.0 # Unused in this type of RoPE
|
| 154 |
+
|
| 155 |
+
# seq_len: default to max_position_embeddings, e.g. at init time
|
| 156 |
+
seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings
|
| 157 |
+
|
| 158 |
+
# Compute the inverse frequencies
|
| 159 |
+
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
|
| 160 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
|
| 161 |
+
return inv_freq, attention_factor
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _compute_yarn_parameters(
|
| 165 |
+
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
|
| 166 |
+
) -> Tuple["torch.Tensor", float]:
|
| 167 |
+
"""
|
| 168 |
+
Computes the inverse frequencies with NTK scaling. Please refer to the
|
| 169 |
+
[original paper](https://arxiv.org/abs/2309.00071)
|
| 170 |
+
Args:
|
| 171 |
+
config ([`~transformers.PretrainedConfig`]):
|
| 172 |
+
The model configuration.
|
| 173 |
+
device (`torch.device`):
|
| 174 |
+
The device to use for initialization of the inverse frequencies.
|
| 175 |
+
seq_len (`int`, *optional*):
|
| 176 |
+
The current sequence length. Unused for this type of RoPE.
|
| 177 |
+
rope_kwargs (`Dict`, *optional*):
|
| 178 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
| 179 |
+
Returns:
|
| 180 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 181 |
+
post-processing scaling factor applied to the computed cos/sin.
|
| 182 |
+
"""
|
| 183 |
+
# No need to keep BC with yarn, unreleased when this new pattern was created.
|
| 184 |
+
if len(rope_kwargs) > 0:
|
| 185 |
+
raise ValueError(
|
| 186 |
+
f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}"
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
base = config.rope_theta
|
| 190 |
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
| 191 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 192 |
+
dim = int(head_dim * partial_rotary_factor)
|
| 193 |
+
max_position_embeddings = config.max_position_embeddings
|
| 194 |
+
factor = config.rope_scaling["factor"]
|
| 195 |
+
|
| 196 |
+
# Sets the attention factor as suggested in the paper
|
| 197 |
+
attention_factor = config.rope_scaling.get("attention_factor")
|
| 198 |
+
if attention_factor is None:
|
| 199 |
+
attention_factor = 0.1 * math.log(factor) + 1.0
|
| 200 |
+
|
| 201 |
+
# Optional config options
|
| 202 |
+
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
|
| 203 |
+
beta_fast = config.rope_scaling.get("beta_fast") or 32
|
| 204 |
+
beta_slow = config.rope_scaling.get("beta_slow") or 1
|
| 205 |
+
|
| 206 |
+
# Compute the inverse frequencies
|
| 207 |
+
def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
|
| 208 |
+
"""Inverse dimension formula to find the dimension based on the number of rotations"""
|
| 209 |
+
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
|
| 210 |
+
|
| 211 |
+
def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
|
| 212 |
+
"""Find dimension range bounds based on rotations"""
|
| 213 |
+
low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
| 214 |
+
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
| 215 |
+
return max(low, 0), min(high, dim - 1)
|
| 216 |
+
|
| 217 |
+
def linear_ramp_factor(min, max, dim):
|
| 218 |
+
if min == max:
|
| 219 |
+
max += 0.001 # Prevent singularity
|
| 220 |
+
|
| 221 |
+
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
| 222 |
+
ramp_func = torch.clamp(linear_func, 0, 1)
|
| 223 |
+
return ramp_func
|
| 224 |
+
|
| 225 |
+
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
|
| 226 |
+
# to expand the possible context length. In other words, interpolation = apply scaling factor.
|
| 227 |
+
pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
|
| 228 |
+
inv_freq_extrapolation = 1.0 / pos_freqs
|
| 229 |
+
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
| 230 |
+
|
| 231 |
+
low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
|
| 232 |
+
|
| 233 |
+
# Get n-dimensional rotational scaling corrected for extrapolation
|
| 234 |
+
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
|
| 235 |
+
inv_freq = (
|
| 236 |
+
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
|
| 237 |
+
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
return inv_freq, attention_factor
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def _compute_longrope_parameters(
|
| 244 |
+
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
|
| 245 |
+
) -> Tuple["torch.Tensor", float]:
|
| 246 |
+
"""
|
| 247 |
+
Computes the inverse frequencies with LongRoPE scaling. Please refer to the
|
| 248 |
+
[original implementation](https://github.com/microsoft/LongRoPE)
|
| 249 |
+
Args:
|
| 250 |
+
config ([`~transformers.PretrainedConfig`]):
|
| 251 |
+
The model configuration.
|
| 252 |
+
device (`torch.device`):
|
| 253 |
+
The device to use for initialization of the inverse frequencies.
|
| 254 |
+
seq_len (`int`, *optional*):
|
| 255 |
+
The current sequence length. Unused for this type of RoPE.
|
| 256 |
+
rope_kwargs (`Dict`, *optional*):
|
| 257 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
| 258 |
+
Returns:
|
| 259 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 260 |
+
post-processing scaling factor applied to the computed cos/sin.
|
| 261 |
+
"""
|
| 262 |
+
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
|
| 263 |
+
# No need to keep BC with longrope, unreleased when this new pattern was created.
|
| 264 |
+
if len(rope_kwargs) > 0:
|
| 265 |
+
raise ValueError(
|
| 266 |
+
"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got "
|
| 267 |
+
f"{rope_kwargs}"
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
base = config.rope_theta
|
| 271 |
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
| 272 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 273 |
+
dim = int(head_dim * partial_rotary_factor)
|
| 274 |
+
long_factor = config.rope_scaling["long_factor"]
|
| 275 |
+
short_factor = config.rope_scaling["short_factor"]
|
| 276 |
+
factor = config.rope_scaling.get("factor")
|
| 277 |
+
attention_factor = config.rope_scaling.get("attention_factor")
|
| 278 |
+
|
| 279 |
+
# NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
|
| 280 |
+
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
|
| 281 |
+
# values to compute the default attention scaling factor, instead of using `factor`.
|
| 282 |
+
if hasattr(config, "original_max_position_embeddings"):
|
| 283 |
+
max_position_embeddings = config.original_max_position_embeddings
|
| 284 |
+
expanded_max_position_embeddings = config.max_position_embeddings
|
| 285 |
+
factor = expanded_max_position_embeddings / max_position_embeddings
|
| 286 |
+
else:
|
| 287 |
+
max_position_embeddings = config.max_position_embeddings
|
| 288 |
+
expanded_max_position_embeddings = max_position_embeddings * factor
|
| 289 |
+
|
| 290 |
+
# Sets the attention factor as suggested in the paper
|
| 291 |
+
if attention_factor is None:
|
| 292 |
+
if factor <= 1.0:
|
| 293 |
+
attention_factor = 1.0
|
| 294 |
+
else:
|
| 295 |
+
attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
|
| 296 |
+
|
| 297 |
+
# Compute the inverse frequencies -- scaled based on the target sequence length
|
| 298 |
+
if expanded_max_position_embeddings > max_position_embeddings:
|
| 299 |
+
ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
|
| 300 |
+
else:
|
| 301 |
+
ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
|
| 302 |
+
inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim
|
| 303 |
+
inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
|
| 304 |
+
|
| 305 |
+
return inv_freq, attention_factor
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def _compute_llama3_parameters(
|
| 309 |
+
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
|
| 310 |
+
) -> Tuple["torch.Tensor", float]:
|
| 311 |
+
"""
|
| 312 |
+
Computes the inverse frequencies for llama 3.1.
|
| 313 |
+
|
| 314 |
+
Args:
|
| 315 |
+
config ([`~transformers.PretrainedConfig`]):
|
| 316 |
+
The model configuration.
|
| 317 |
+
device (`torch.device`):
|
| 318 |
+
The device to use for initialization of the inverse frequencies.
|
| 319 |
+
seq_len (`int`, *optional*):
|
| 320 |
+
The current sequence length. Unused for this type of RoPE.
|
| 321 |
+
rope_kwargs (`Dict`, *optional*):
|
| 322 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
| 323 |
+
Returns:
|
| 324 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 325 |
+
post-processing scaling factor applied to the computed cos/sin.
|
| 326 |
+
"""
|
| 327 |
+
# Gets the default RoPE parameters
|
| 328 |
+
inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
|
| 329 |
+
|
| 330 |
+
factor = config.rope_scaling["factor"] # `8` in the original implementation
|
| 331 |
+
low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
|
| 332 |
+
high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
|
| 333 |
+
old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
|
| 334 |
+
|
| 335 |
+
low_freq_wavelen = old_context_len / low_freq_factor
|
| 336 |
+
high_freq_wavelen = old_context_len / high_freq_factor
|
| 337 |
+
|
| 338 |
+
wavelen = 2 * math.pi / inv_freq
|
| 339 |
+
# wavelen < high_freq_wavelen: do nothing
|
| 340 |
+
# wavelen > low_freq_wavelen: divide by factor
|
| 341 |
+
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
|
| 342 |
+
# otherwise: interpolate between the two, using a smooth factor
|
| 343 |
+
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
| 344 |
+
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
|
| 345 |
+
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
|
| 346 |
+
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
|
| 347 |
+
|
| 348 |
+
return inv_freq_llama, attention_factor
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
|
| 352 |
+
# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
|
| 353 |
+
# parameterizations, as long as the callable has the same signature.
|
| 354 |
+
ROPE_INIT_FUNCTIONS = {
|
| 355 |
+
"default": _compute_default_rope_parameters,
|
| 356 |
+
"linear": _compute_linear_scaling_rope_parameters,
|
| 357 |
+
"dynamic": _compute_dynamic_ntk_parameters,
|
| 358 |
+
"yarn": _compute_yarn_parameters,
|
| 359 |
+
"longrope": _compute_longrope_parameters,
|
| 360 |
+
"llama3": _compute_llama3_parameters,
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None):
|
| 365 |
+
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
|
| 366 |
+
# BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
|
| 367 |
+
if "type" in received_keys:
|
| 368 |
+
received_keys -= {"type"}
|
| 369 |
+
required_keys.add("rope_type")
|
| 370 |
+
|
| 371 |
+
missing_keys = required_keys - received_keys
|
| 372 |
+
if missing_keys:
|
| 373 |
+
raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}")
|
| 374 |
+
|
| 375 |
+
if optional_keys is not None:
|
| 376 |
+
unused_keys = received_keys - required_keys - optional_keys
|
| 377 |
+
else:
|
| 378 |
+
unused_keys = received_keys - required_keys
|
| 379 |
+
if unused_keys:
|
| 380 |
+
logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def _validate_default_rope_parameters(config: PretrainedConfig):
|
| 384 |
+
rope_scaling = config.rope_scaling
|
| 385 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
| 386 |
+
required_keys = {"rope_type"}
|
| 387 |
+
received_keys = set(rope_scaling.keys())
|
| 388 |
+
_check_received_keys(rope_type, received_keys, required_keys)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
|
| 392 |
+
rope_scaling = config.rope_scaling
|
| 393 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
| 394 |
+
required_keys = {"rope_type", "factor"}
|
| 395 |
+
received_keys = set(rope_scaling.keys())
|
| 396 |
+
_check_received_keys(rope_type, received_keys, required_keys)
|
| 397 |
+
|
| 398 |
+
factor = rope_scaling["factor"]
|
| 399 |
+
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
| 400 |
+
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
|
| 404 |
+
rope_scaling = config.rope_scaling
|
| 405 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
| 406 |
+
required_keys = {"rope_type", "factor"}
|
| 407 |
+
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
| 408 |
+
optional_keys = {"original_max_position_embeddings"}
|
| 409 |
+
received_keys = set(rope_scaling.keys())
|
| 410 |
+
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
| 411 |
+
|
| 412 |
+
factor = rope_scaling["factor"]
|
| 413 |
+
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
| 414 |
+
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def _validate_yarn_parameters(config: PretrainedConfig):
|
| 418 |
+
rope_scaling = config.rope_scaling
|
| 419 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
| 420 |
+
required_keys = {"rope_type", "factor"}
|
| 421 |
+
optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
|
| 422 |
+
received_keys = set(rope_scaling.keys())
|
| 423 |
+
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
| 424 |
+
|
| 425 |
+
factor = rope_scaling["factor"]
|
| 426 |
+
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
| 427 |
+
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
| 428 |
+
|
| 429 |
+
attention_factor = rope_scaling.get("attention_factor")
|
| 430 |
+
if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0):
|
| 431 |
+
logger.warning(
|
| 432 |
+
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
|
| 433 |
+
)
|
| 434 |
+
beta_fast = rope_scaling.get("beta_fast")
|
| 435 |
+
if beta_fast is not None and not isinstance(beta_fast, float):
|
| 436 |
+
logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
|
| 437 |
+
beta_slow = rope_scaling.get("beta_slow")
|
| 438 |
+
if beta_slow is not None and not isinstance(beta_slow, float):
|
| 439 |
+
logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
|
| 440 |
+
|
| 441 |
+
if (beta_fast or 32) < (beta_slow or 1):
|
| 442 |
+
logger.warning(
|
| 443 |
+
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
|
| 444 |
+
f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def _validate_longrope_parameters(config: PretrainedConfig):
|
| 449 |
+
rope_scaling = config.rope_scaling
|
| 450 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
| 451 |
+
required_keys = {"rope_type", "short_factor", "long_factor"}
|
| 452 |
+
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
| 453 |
+
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
|
| 454 |
+
received_keys = set(rope_scaling.keys())
|
| 455 |
+
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
| 456 |
+
|
| 457 |
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
| 458 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 459 |
+
dim = int(head_dim * partial_rotary_factor)
|
| 460 |
+
|
| 461 |
+
short_factor = rope_scaling.get("short_factor")
|
| 462 |
+
if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
|
| 463 |
+
logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
|
| 464 |
+
if not len(short_factor) == dim // 2:
|
| 465 |
+
logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")
|
| 466 |
+
|
| 467 |
+
long_factor = rope_scaling.get("long_factor")
|
| 468 |
+
if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor):
|
| 469 |
+
logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
|
| 470 |
+
if not len(long_factor) == dim // 2:
|
| 471 |
+
logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")
|
| 472 |
+
|
| 473 |
+
# Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
|
| 474 |
+
# `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is
|
| 475 |
+
# unique to longrope (= undesirable)
|
| 476 |
+
if hasattr(config, "original_max_position_embeddings"):
|
| 477 |
+
logger.warning_once(
|
| 478 |
+
"This model has set a `original_max_position_embeddings` field, to be used together with "
|
| 479 |
+
"`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`"
|
| 480 |
+
"with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, "
|
| 481 |
+
"as it is compatible with most model architectures."
|
| 482 |
+
)
|
| 483 |
+
else:
|
| 484 |
+
factor = rope_scaling.get("factor")
|
| 485 |
+
if factor is None:
|
| 486 |
+
logger.warning("Missing required keys in `rope_scaling`: 'factor'")
|
| 487 |
+
elif not isinstance(factor, float) or factor < 1.0:
|
| 488 |
+
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
| 489 |
+
|
| 490 |
+
attention_factor = rope_scaling.get("attention_factor")
|
| 491 |
+
if attention_factor is not None:
|
| 492 |
+
if not isinstance(attention_factor, float) or attention_factor < 0.0:
|
| 493 |
+
logger.warning(
|
| 494 |
+
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def _validate_llama3_parameters(config: PretrainedConfig):
|
| 499 |
+
rope_scaling = config.rope_scaling
|
| 500 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
| 501 |
+
required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
|
| 502 |
+
received_keys = set(rope_scaling.keys())
|
| 503 |
+
_check_received_keys(rope_type, received_keys, required_keys)
|
| 504 |
+
|
| 505 |
+
factor = rope_scaling["factor"]
|
| 506 |
+
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
| 507 |
+
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
| 508 |
+
|
| 509 |
+
low_freq_factor = rope_scaling["low_freq_factor"]
|
| 510 |
+
high_freq_factor = rope_scaling["high_freq_factor"]
|
| 511 |
+
if low_freq_factor is None or not isinstance(low_freq_factor, float):
|
| 512 |
+
logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
|
| 513 |
+
if high_freq_factor is None or not isinstance(high_freq_factor, float):
|
| 514 |
+
logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
|
| 515 |
+
if high_freq_factor <= low_freq_factor:
|
| 516 |
+
logger.warning(
|
| 517 |
+
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
|
| 518 |
+
f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
|
| 522 |
+
if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
|
| 523 |
+
logger.warning(
|
| 524 |
+
"`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
|
| 525 |
+
f"{original_max_position_embeddings}"
|
| 526 |
+
)
|
| 527 |
+
if original_max_position_embeddings >= config.max_position_embeddings:
|
| 528 |
+
logger.warning(
|
| 529 |
+
"`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
|
| 530 |
+
f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
|
| 535 |
+
ROPE_VALIDATION_FUNCTIONS = {
|
| 536 |
+
"default": _validate_default_rope_parameters,
|
| 537 |
+
"linear": _validate_linear_scaling_rope_parameters,
|
| 538 |
+
"dynamic": _validate_dynamic_scaling_rope_parameters,
|
| 539 |
+
"yarn": _validate_yarn_parameters,
|
| 540 |
+
"longrope": _validate_longrope_parameters,
|
| 541 |
+
"llama3": _validate_llama3_parameters,
|
| 542 |
+
}
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def rope_config_validation(config: PretrainedConfig):
|
| 546 |
+
"""
|
| 547 |
+
Validate the RoPE config arguments, given a `PretrainedConfig` object
|
| 548 |
+
"""
|
| 549 |
+
rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig`
|
| 550 |
+
if rope_scaling is None:
|
| 551 |
+
return
|
| 552 |
+
|
| 553 |
+
# BC: "rope_type" was originally "type"
|
| 554 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
|
| 555 |
+
validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
|
| 556 |
+
if validation_fn is not None:
|
| 557 |
+
validation_fn(config)
|
| 558 |
+
else:
|
| 559 |
+
logger.warning(
|
| 560 |
+
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
|
| 561 |
+
)
|
processing_illume.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Processor class for ILLUME_plus with dualvitok and dualvitok-sdxl-decoder.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from typing import List, Union
|
| 7 |
+
|
| 8 |
+
from transformers import AutoProcessor, AutoImageProcessor
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from typing import Unpack
|
| 12 |
+
except ImportError:
|
| 13 |
+
from typing_extensions import Unpack
|
| 14 |
+
|
| 15 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 16 |
+
from .image_utils import ImageInput
|
| 17 |
+
from transformers.processing_utils import (
|
| 18 |
+
ProcessingKwargs,
|
| 19 |
+
ProcessorMixin,
|
| 20 |
+
)
|
| 21 |
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 22 |
+
from transformers.utils import logging
|
| 23 |
+
|
| 24 |
+
from PIL import Image
|
| 25 |
+
import re # Added for parsing image tokens
|
| 26 |
+
from typing import List, Tuple
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
from .configuration_illume import ILLUMEConfig
|
| 30 |
+
from .image_processing_movqgan import MoVQImageProcessor
|
| 31 |
+
from .image_processing_dualvitok import DualViTokImageProcessor
|
| 32 |
+
from .aspect_ratio_utils import AspectRatioCrop, RATIOS, unpad_and_resize_back
|
| 33 |
+
from .inference_utils import parse_interleaved_text_image, calculate_image_token_num
|
| 34 |
+
from .sdxl_decoder_pipe import StableDiffusionXLDecoderPipeline
|
| 35 |
+
|
| 36 |
+
logger = logging.get_logger(__name__)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ILLUMEProcessorKwargs(ProcessingKwargs, total=False):
|
| 40 |
+
_defaults = {
|
| 41 |
+
"text_kwargs": {
|
| 42 |
+
"padding": False,
|
| 43 |
+
},
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class ILLUMEProcessor(ProcessorMixin):
|
| 48 |
+
r"""
|
| 49 |
+
Constructs a Qwen2-VL processor which wraps a Qwen2-VL image processor and a Qwen2 tokenizer into a single processor.
|
| 50 |
+
[`ILLUMEProcessor`] offers all the functionalities of [`ILLUMEImageProcessor`] and [`Qwen2TokenizerFast`]. See the
|
| 51 |
+
[`~ILLUMEProcessor.__call__`] and [`~ILLUMEProcessor.decode`] for more information.
|
| 52 |
+
Args:
|
| 53 |
+
image_processor ([`IllumeImageProcessor`], *optional*):
|
| 54 |
+
The image processor is a required input.
|
| 55 |
+
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
| 56 |
+
The tokenizer is a required input.
|
| 57 |
+
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
| 58 |
+
in a chat into a tokenizable string.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
attributes = ["image_processor", "tokenizer"]
|
| 62 |
+
valid_kwargs = ["chat_template"]
|
| 63 |
+
image_processor_class = "AutoImageProcessor"
|
| 64 |
+
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
| 65 |
+
_re_placeholder = re.compile(r"<image_out>|<image>")
|
| 66 |
+
|
| 67 |
+
def __init__(self, image_processor=None, tokenizer=None, chat_template=None,
|
| 68 |
+
crop_percent_thresh=0.2, anyres_indicator_base=64, **kwargs):
|
| 69 |
+
super().__init__(image_processor=image_processor, tokenizer=tokenizer, chat_template=chat_template)
|
| 70 |
+
self.vision_tokenizer = None
|
| 71 |
+
self.diffusion_vision_detokenizer = None
|
| 72 |
+
self.crop_percent_thresh = crop_percent_thresh
|
| 73 |
+
self.anyres_indicator_base = anyres_indicator_base
|
| 74 |
+
|
| 75 |
+
def set_vision_tokenizer(self, tokenizer):
|
| 76 |
+
if self.vision_tokenizer and tokenizer:
|
| 77 |
+
logger.info('You are resetting vision tokenizer!')
|
| 78 |
+
return
|
| 79 |
+
self.vision_tokenizer = tokenizer
|
| 80 |
+
logger.info('Setting vision tokenizer!')
|
| 81 |
+
|
| 82 |
+
def load_diffusion_vision_detokenizer(self, diffusion_decoder,
|
| 83 |
+
torch_dtype=torch.float16,
|
| 84 |
+
add_watermarker=False,
|
| 85 |
+
device='cuda',
|
| 86 |
+
):
|
| 87 |
+
if self.diffusion_vision_detokenizer:
|
| 88 |
+
logger.info('You are resetting diffusion vision detokenizer!')
|
| 89 |
+
return
|
| 90 |
+
|
| 91 |
+
if self.vision_tokenizer is None:
|
| 92 |
+
raise ValueError("Vision tokenizer is not set. Please set the vision tokenizer by using `processor.set_vision_tokenizer`")
|
| 93 |
+
|
| 94 |
+
self.diffusion_vision_detokenizer = StableDiffusionXLDecoderPipeline.from_pretrained(diffusion_decoder,
|
| 95 |
+
torch_dtype=torch_dtype,
|
| 96 |
+
add_watermarker=add_watermarker,
|
| 97 |
+
vq_model=self.vision_tokenizer,
|
| 98 |
+
vq_image_processor=self.image_processor).to(device)
|
| 99 |
+
logger.info('Setting diffusion vision detokenizer!')
|
| 100 |
+
|
| 101 |
+
def get_ratio_tag_from_ratio(self, ratio):
|
| 102 |
+
h, w = ratio
|
| 103 |
+
h_indicator, w_indicator = h // self.anyres_indicator_base, w // self.anyres_indicator_base
|
| 104 |
+
ratio_tag = f"<height_{h_indicator}><width_{w_indicator}>"
|
| 105 |
+
return ratio_tag
|
| 106 |
+
|
| 107 |
+
@torch.no_grad()
|
| 108 |
+
def _encode_with_dualvitok(self, img):
|
| 109 |
+
# img is a PIL image or np.ndarray
|
| 110 |
+
px = self.image_processor(img, return_tensors="pt")["pixel_values"].to(self.vision_tokenizer.device)
|
| 111 |
+
(_, _, idx_sem, _), (_, _, idx_pix) = self.vision_tokenizer.encode(px)
|
| 112 |
+
return idx_sem[0].cpu().tolist(), idx_pix[0].cpu().tolist()
|
| 113 |
+
|
| 114 |
+
def transform_image_nearest_resolution_ratio(self, image, ratios=RATIOS):
|
| 115 |
+
arc = AspectRatioCrop(ratios, crop_percent_thresh=self.crop_percent_thresh)
|
| 116 |
+
image, original_size, target_size, flag_matched = arc(image, is_inference=True)
|
| 117 |
+
return image
|
| 118 |
+
|
| 119 |
+
def convert_image_to_token_string(self, image, ratios=RATIOS):
|
| 120 |
+
arc = AspectRatioCrop(ratios, crop_percent_thresh=self.crop_percent_thresh)
|
| 121 |
+
image, original_size, target_size, flag_matched = arc(image, is_inference=True)
|
| 122 |
+
ratio_tag = self.get_ratio_tag_from_ratio(target_size)
|
| 123 |
+
|
| 124 |
+
image_embed_inds = self._encode_with_dualvitok(image)
|
| 125 |
+
return ratio_tag + self.encode_image_token_into_code(image_embed_inds)
|
| 126 |
+
|
| 127 |
+
def unpad_and_resize_back(self, padded_image, original_width, original_height):
|
| 128 |
+
return unpad_and_resize_back(padded_image, original_width, original_height)
|
| 129 |
+
|
| 130 |
+
def encode_image_token_into_code(self, image_embed_inds,
|
| 131 |
+
add_token_name="<|image_level{}_{}|>",
|
| 132 |
+
selected_vision_tokenizer_levels=None):
|
| 133 |
+
'''
|
| 134 |
+
Args:
|
| 135 |
+
image_embed_inds: 3D list, vision token ids for each tokenizer level
|
| 136 |
+
add_token_name: tag name for vision tokens
|
| 137 |
+
Returns:
|
| 138 |
+
image_token_return: str
|
| 139 |
+
'''
|
| 140 |
+
|
| 141 |
+
if selected_vision_tokenizer_levels is not None:
|
| 142 |
+
image_embed_inds_new = []
|
| 143 |
+
for level in selected_vision_tokenizer_levels:
|
| 144 |
+
image_embed_inds_new.append(image_embed_inds[level])
|
| 145 |
+
image_embed_inds = image_embed_inds_new
|
| 146 |
+
|
| 147 |
+
image_token_name_list = []
|
| 148 |
+
for level, image_embed_ind in enumerate(image_embed_inds):
|
| 149 |
+
image_token_name = []
|
| 150 |
+
for row in image_embed_ind:
|
| 151 |
+
image_token_name.append([add_token_name.format(level, ind) for ind in row])
|
| 152 |
+
|
| 153 |
+
image_token_name_list.append("<start_of_level{}>".format(level))
|
| 154 |
+
for row in image_token_name:
|
| 155 |
+
row.append("<end_of_line>")
|
| 156 |
+
|
| 157 |
+
for row in image_token_name:
|
| 158 |
+
image_token_name_list.extend(row)
|
| 159 |
+
|
| 160 |
+
image_token_name_list.append("<end_of_level{}>".format(level))
|
| 161 |
+
|
| 162 |
+
image_token_return = "".join(image_token_name_list)
|
| 163 |
+
image_token_return = "<start_of_image>" + image_token_return + "<end_of_image>"
|
| 164 |
+
return image_token_return
|
| 165 |
+
|
| 166 |
+
@torch.no_grad()
|
| 167 |
+
def decode_images(self, image_inds_list, target_resolution=(512, 512), return_type='pil',
|
| 168 |
+
use_diffusion=False, diffusion_cfg_scale=2.0, diffusion_num_inference_steps=20, **kwargs):
|
| 169 |
+
|
| 170 |
+
token_nums, _, h1, w1, h2, w2 = calculate_image_token_num(*target_resolution)
|
| 171 |
+
|
| 172 |
+
decoded_images = []
|
| 173 |
+
for image_inds in image_inds_list:
|
| 174 |
+
semantic_code = torch.as_tensor([image_inds[0]])
|
| 175 |
+
texture_code = torch.as_tensor([image_inds[1]])
|
| 176 |
+
if use_diffusion:
|
| 177 |
+
if self.diffusion_vision_detokenizer is None:
|
| 178 |
+
raise RuntimeError(
|
| 179 |
+
"diffusion_vision_detokenizer is not set. Please set the diffusion decoder by using `pipe.load_diffusion_vision_detokenizer`")
|
| 180 |
+
|
| 181 |
+
semantic_code = semantic_code.view(semantic_code.shape[0], h1, w1)
|
| 182 |
+
texture_code = texture_code.view(texture_code.shape[0], h2, w2)
|
| 183 |
+
|
| 184 |
+
diffusion_outputs = self.diffusion_vision_detokenizer(
|
| 185 |
+
vq_indices=(semantic_code, texture_code),
|
| 186 |
+
height=target_resolution[0] * 2,
|
| 187 |
+
width=target_resolution[1] * 2,
|
| 188 |
+
guidance_scale=diffusion_cfg_scale,
|
| 189 |
+
num_inference_steps=diffusion_num_inference_steps,
|
| 190 |
+
output_type=return_type,
|
| 191 |
+
**kwargs
|
| 192 |
+
)
|
| 193 |
+
samples = diffusion_outputs.images
|
| 194 |
+
image = samples[0]
|
| 195 |
+
else:
|
| 196 |
+
if self.vision_tokenizer is None:
|
| 197 |
+
raise RuntimeError(
|
| 198 |
+
"vision_detokenizer is not set. Please set the vision decoder by using `pipe.set_vision_detokenizer`")
|
| 199 |
+
|
| 200 |
+
semantic_code = semantic_code.view(semantic_code.shape[0], h1, w1)
|
| 201 |
+
texture_code = texture_code.view(texture_code.shape[0], h2, w2)
|
| 202 |
+
|
| 203 |
+
samples = self.vision_tokenizer.decode_code(semantic_code, texture_code)
|
| 204 |
+
|
| 205 |
+
if return_type == 'pil':
|
| 206 |
+
sample = \
|
| 207 |
+
torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).cpu().to(torch.uint8).numpy()[0]
|
| 208 |
+
image = Image.fromarray(sample)
|
| 209 |
+
else: # return numpy range -1 to 1.
|
| 210 |
+
image = samples.permute(0, 2, 3, 1).cpu().numpy()[0]
|
| 211 |
+
decoded_images.append(image)
|
| 212 |
+
|
| 213 |
+
return decoded_images
|
| 214 |
+
|
| 215 |
+
def parse_text_image(self, text, image_placeholder='<image_out>'):
|
| 216 |
+
generated_text, image_embed_inds_list, list_image_token_parts = parse_interleaved_text_image(text, num_levels=2,
|
| 217 |
+
image_placeholder=image_placeholder)
|
| 218 |
+
return generated_text, image_embed_inds_list, list_image_token_parts
|
| 219 |
+
|
| 220 |
+
def _encode_out_placeholder(self, img):
|
| 221 |
+
"""
|
| 222 |
+
Encode one image with DualViTok and return a string
|
| 223 |
+
that can replace the <image_out> marker in the text.
|
| 224 |
+
"""
|
| 225 |
+
return self.convert_image_to_token_string(img)
|
| 226 |
+
|
| 227 |
+
def __call__(
|
| 228 |
+
self,
|
| 229 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
| 230 |
+
images: ImageInput = None,
|
| 231 |
+
**kwargs: Unpack[ILLUMEProcessorKwargs],
|
| 232 |
+
) -> BatchFeature:
|
| 233 |
+
"""
|
| 234 |
+
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
| 235 |
+
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
| 236 |
+
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
|
| 237 |
+
DualViTokImageProcessor's [`~DualViTokImageProcessor.__call__`] if `vision_infos` is not `None`.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
| 241 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
| 242 |
+
tensor. Both channels-first and channels-last formats are supported.
|
| 243 |
+
text (`str`, `List[str]`, `List[List[str]]`):
|
| 244 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
| 245 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
| 246 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 247 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
| 248 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
| 249 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
| 250 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
| 251 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
| 252 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
| 256 |
+
|
| 257 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
| 258 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
| 259 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
| 260 |
+
`None`).
|
| 261 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
| 262 |
+
"""
|
| 263 |
+
output_kwargs = self._merge_kwargs(
|
| 264 |
+
ILLUMEProcessorKwargs,
|
| 265 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 266 |
+
**kwargs,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
if not isinstance(text, list):
|
| 270 |
+
text = [text]
|
| 271 |
+
|
| 272 |
+
if isinstance(images, str):
|
| 273 |
+
images = [images]
|
| 274 |
+
elif images and isinstance(images[0], list):
|
| 275 |
+
# flatten List[List[PIL.Image.Image]]
|
| 276 |
+
images = [item for sublist in images for item in sublist]
|
| 277 |
+
|
| 278 |
+
_ = output_kwargs["text_kwargs"].pop("padding_side", None)
|
| 279 |
+
try:
|
| 280 |
+
text = self.apply_chat_template(text, add_generation_prompt=True, padding=True)
|
| 281 |
+
except Exception as e:
|
| 282 |
+
logger.info('Warning: input texts have been applied chat templates!')
|
| 283 |
+
|
| 284 |
+
if images is None:
|
| 285 |
+
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
| 286 |
+
return BatchFeature(data={**text_inputs})
|
| 287 |
+
else:
|
| 288 |
+
imgs_in, new_text, used = [], [], 0
|
| 289 |
+
|
| 290 |
+
if not isinstance(text, list):
|
| 291 |
+
text = [text]
|
| 292 |
+
|
| 293 |
+
for s in text: # walk each prompt
|
| 294 |
+
out, i = [], 0
|
| 295 |
+
for m in self._re_placeholder.finditer(s): # every placeholder
|
| 296 |
+
out.append(s[i:m.start()])
|
| 297 |
+
if used >= len(images):
|
| 298 |
+
raise ValueError("not enough images for placeholders")
|
| 299 |
+
img = images[used]
|
| 300 |
+
used += 1
|
| 301 |
+
if m.group() == "<image_out>":
|
| 302 |
+
out.append(self.convert_image_to_token_string(img)) # replace
|
| 303 |
+
else: # <image>
|
| 304 |
+
out.append("<image>")
|
| 305 |
+
imgs_in.append(img) # keep for pixel feats
|
| 306 |
+
i = m.end()
|
| 307 |
+
out.append(s[i:])
|
| 308 |
+
new_text.append("".join(out))
|
| 309 |
+
|
| 310 |
+
if used != len(images):
|
| 311 |
+
raise ValueError(f"too many images for placeholders. used {used} vs len(images) {len(images)}. {text}")
|
| 312 |
+
|
| 313 |
+
text_inputs = self.tokenizer(new_text, **output_kwargs["text_kwargs"])
|
| 314 |
+
image_inputs = self.image_processor.preprocess(imgs_in, **output_kwargs["images_kwargs"]) if imgs_in else {}
|
| 315 |
+
|
| 316 |
+
return BatchFeature(data={**text_inputs, **image_inputs})
|
| 317 |
+
|
| 318 |
+
def batch_decode(self, sequences, *args, **kwargs):
|
| 319 |
+
return [self.decode(seq, *args, **kwargs)
|
| 320 |
+
for i, seq in enumerate(sequences)]
|
| 321 |
+
|
| 322 |
+
def decode(self, *args, **kwargs):
|
| 323 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 324 |
+
|
| 325 |
+
@property
|
| 326 |
+
def model_input_names(self):
|
| 327 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
| 328 |
+
image_processor_input_names = self.image_processor.model_input_names
|
| 329 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
pytorch_model-00001-of-00004.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d7c85c4110b8939d12d79ac65e4d0985cec132c92ef481e36b3923e5e090ba21
|
| 3 |
+
size 4970246339
|
pytorch_model-00002-of-00004.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:813720e5102d637c585854db25e80984d03fb73490e72795247d917933d8f821
|
| 3 |
+
size 4932780328
|
pytorch_model-00003-of-00004.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ab7b0e070fc80137d7d9e43951d89a4ea82c961fb0a3dffeba4e73be033d07e1
|
| 3 |
+
size 4991527403
|
pytorch_model-00004-of-00004.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:89e7232078f06cb9df7470b148d01e6de8826c936d49ed56401c7e91cddea297
|
| 3 |
+
size 3699762846
|
pytorch_model.bin.index.json
ADDED
|
@@ -0,0 +1,889 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_size": 18594005440
|
| 4 |
+
},
|
| 5 |
+
"weight_map": {
|
| 6 |
+
"language_model.lm_head.weight": "pytorch_model-00004-of-00004.bin",
|
| 7 |
+
"language_model.model.embed_tokens.weight": "pytorch_model-00001-of-00004.bin",
|
| 8 |
+
"language_model.model.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00004.bin",
|
| 9 |
+
"language_model.model.layers.0.mlp.down_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 10 |
+
"language_model.model.layers.0.mlp.gate_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 11 |
+
"language_model.model.layers.0.mlp.up_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 12 |
+
"language_model.model.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00004.bin",
|
| 13 |
+
"language_model.model.layers.0.self_attn.k_proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 14 |
+
"language_model.model.layers.0.self_attn.k_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 15 |
+
"language_model.model.layers.0.self_attn.o_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 16 |
+
"language_model.model.layers.0.self_attn.q_proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 17 |
+
"language_model.model.layers.0.self_attn.q_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 18 |
+
"language_model.model.layers.0.self_attn.v_proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 19 |
+
"language_model.model.layers.0.self_attn.v_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 20 |
+
"language_model.model.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00004.bin",
|
| 21 |
+
"language_model.model.layers.1.mlp.down_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 22 |
+
"language_model.model.layers.1.mlp.gate_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 23 |
+
"language_model.model.layers.1.mlp.up_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 24 |
+
"language_model.model.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00004.bin",
|
| 25 |
+
"language_model.model.layers.1.self_attn.k_proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 26 |
+
"language_model.model.layers.1.self_attn.k_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 27 |
+
"language_model.model.layers.1.self_attn.o_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 28 |
+
"language_model.model.layers.1.self_attn.q_proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 29 |
+
"language_model.model.layers.1.self_attn.q_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 30 |
+
"language_model.model.layers.1.self_attn.v_proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 31 |
+
"language_model.model.layers.1.self_attn.v_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 32 |
+
"language_model.model.layers.10.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
|
| 33 |
+
"language_model.model.layers.10.mlp.down_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 34 |
+
"language_model.model.layers.10.mlp.gate_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 35 |
+
"language_model.model.layers.10.mlp.up_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 36 |
+
"language_model.model.layers.10.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
|
| 37 |
+
"language_model.model.layers.10.self_attn.k_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 38 |
+
"language_model.model.layers.10.self_attn.k_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 39 |
+
"language_model.model.layers.10.self_attn.o_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 40 |
+
"language_model.model.layers.10.self_attn.q_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 41 |
+
"language_model.model.layers.10.self_attn.q_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 42 |
+
"language_model.model.layers.10.self_attn.v_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 43 |
+
"language_model.model.layers.10.self_attn.v_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 44 |
+
"language_model.model.layers.11.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
|
| 45 |
+
"language_model.model.layers.11.mlp.down_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 46 |
+
"language_model.model.layers.11.mlp.gate_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 47 |
+
"language_model.model.layers.11.mlp.up_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 48 |
+
"language_model.model.layers.11.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
|
| 49 |
+
"language_model.model.layers.11.self_attn.k_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 50 |
+
"language_model.model.layers.11.self_attn.k_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 51 |
+
"language_model.model.layers.11.self_attn.o_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 52 |
+
"language_model.model.layers.11.self_attn.q_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 53 |
+
"language_model.model.layers.11.self_attn.q_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 54 |
+
"language_model.model.layers.11.self_attn.v_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 55 |
+
"language_model.model.layers.11.self_attn.v_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 56 |
+
"language_model.model.layers.12.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
|
| 57 |
+
"language_model.model.layers.12.mlp.down_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 58 |
+
"language_model.model.layers.12.mlp.gate_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 59 |
+
"language_model.model.layers.12.mlp.up_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 60 |
+
"language_model.model.layers.12.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
|
| 61 |
+
"language_model.model.layers.12.self_attn.k_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 62 |
+
"language_model.model.layers.12.self_attn.k_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 63 |
+
"language_model.model.layers.12.self_attn.o_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 64 |
+
"language_model.model.layers.12.self_attn.q_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 65 |
+
"language_model.model.layers.12.self_attn.q_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 66 |
+
"language_model.model.layers.12.self_attn.v_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 67 |
+
"language_model.model.layers.12.self_attn.v_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 68 |
+
"language_model.model.layers.13.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
|
| 69 |
+
"language_model.model.layers.13.mlp.down_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 70 |
+
"language_model.model.layers.13.mlp.gate_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 71 |
+
"language_model.model.layers.13.mlp.up_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 72 |
+
"language_model.model.layers.13.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
|
| 73 |
+
"language_model.model.layers.13.self_attn.k_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 74 |
+
"language_model.model.layers.13.self_attn.k_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 75 |
+
"language_model.model.layers.13.self_attn.o_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 76 |
+
"language_model.model.layers.13.self_attn.q_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 77 |
+
"language_model.model.layers.13.self_attn.q_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 78 |
+
"language_model.model.layers.13.self_attn.v_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 79 |
+
"language_model.model.layers.13.self_attn.v_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 80 |
+
"language_model.model.layers.14.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
|
| 81 |
+
"language_model.model.layers.14.mlp.down_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 82 |
+
"language_model.model.layers.14.mlp.gate_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 83 |
+
"language_model.model.layers.14.mlp.up_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 84 |
+
"language_model.model.layers.14.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
|
| 85 |
+
"language_model.model.layers.14.self_attn.k_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 86 |
+
"language_model.model.layers.14.self_attn.k_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 87 |
+
"language_model.model.layers.14.self_attn.o_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 88 |
+
"language_model.model.layers.14.self_attn.q_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 89 |
+
"language_model.model.layers.14.self_attn.q_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 90 |
+
"language_model.model.layers.14.self_attn.v_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 91 |
+
"language_model.model.layers.14.self_attn.v_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 92 |
+
"language_model.model.layers.15.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
|
| 93 |
+
"language_model.model.layers.15.mlp.down_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 94 |
+
"language_model.model.layers.15.mlp.gate_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 95 |
+
"language_model.model.layers.15.mlp.up_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 96 |
+
"language_model.model.layers.15.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
|
| 97 |
+
"language_model.model.layers.15.self_attn.k_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 98 |
+
"language_model.model.layers.15.self_attn.k_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 99 |
+
"language_model.model.layers.15.self_attn.o_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 100 |
+
"language_model.model.layers.15.self_attn.q_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 101 |
+
"language_model.model.layers.15.self_attn.q_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 102 |
+
"language_model.model.layers.15.self_attn.v_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 103 |
+
"language_model.model.layers.15.self_attn.v_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 104 |
+
"language_model.model.layers.16.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
|
| 105 |
+
"language_model.model.layers.16.mlp.down_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 106 |
+
"language_model.model.layers.16.mlp.gate_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 107 |
+
"language_model.model.layers.16.mlp.up_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 108 |
+
"language_model.model.layers.16.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
|
| 109 |
+
"language_model.model.layers.16.self_attn.k_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 110 |
+
"language_model.model.layers.16.self_attn.k_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 111 |
+
"language_model.model.layers.16.self_attn.o_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 112 |
+
"language_model.model.layers.16.self_attn.q_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 113 |
+
"language_model.model.layers.16.self_attn.q_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 114 |
+
"language_model.model.layers.16.self_attn.v_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 115 |
+
"language_model.model.layers.16.self_attn.v_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 116 |
+
"language_model.model.layers.17.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
|
| 117 |
+
"language_model.model.layers.17.mlp.down_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 118 |
+
"language_model.model.layers.17.mlp.gate_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 119 |
+
"language_model.model.layers.17.mlp.up_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 120 |
+
"language_model.model.layers.17.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
|
| 121 |
+
"language_model.model.layers.17.self_attn.k_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 122 |
+
"language_model.model.layers.17.self_attn.k_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 123 |
+
"language_model.model.layers.17.self_attn.o_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 124 |
+
"language_model.model.layers.17.self_attn.q_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 125 |
+
"language_model.model.layers.17.self_attn.q_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 126 |
+
"language_model.model.layers.17.self_attn.v_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 127 |
+
"language_model.model.layers.17.self_attn.v_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 128 |
+
"language_model.model.layers.18.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
|
| 129 |
+
"language_model.model.layers.18.mlp.down_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 130 |
+
"language_model.model.layers.18.mlp.gate_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 131 |
+
"language_model.model.layers.18.mlp.up_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 132 |
+
"language_model.model.layers.18.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
|
| 133 |
+
"language_model.model.layers.18.self_attn.k_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 134 |
+
"language_model.model.layers.18.self_attn.k_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 135 |
+
"language_model.model.layers.18.self_attn.o_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 136 |
+
"language_model.model.layers.18.self_attn.q_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 137 |
+
"language_model.model.layers.18.self_attn.q_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 138 |
+
"language_model.model.layers.18.self_attn.v_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 139 |
+
"language_model.model.layers.18.self_attn.v_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 140 |
+
"language_model.model.layers.19.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
|
| 141 |
+
"language_model.model.layers.19.mlp.down_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 142 |
+
"language_model.model.layers.19.mlp.gate_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 143 |
+
"language_model.model.layers.19.mlp.up_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 144 |
+
"language_model.model.layers.19.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
|
| 145 |
+
"language_model.model.layers.19.self_attn.k_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 146 |
+
"language_model.model.layers.19.self_attn.k_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 147 |
+
"language_model.model.layers.19.self_attn.o_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 148 |
+
"language_model.model.layers.19.self_attn.q_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 149 |
+
"language_model.model.layers.19.self_attn.q_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 150 |
+
"language_model.model.layers.19.self_attn.v_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 151 |
+
"language_model.model.layers.19.self_attn.v_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 152 |
+
"language_model.model.layers.2.input_layernorm.weight": "pytorch_model-00001-of-00004.bin",
|
| 153 |
+
"language_model.model.layers.2.mlp.down_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 154 |
+
"language_model.model.layers.2.mlp.gate_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 155 |
+
"language_model.model.layers.2.mlp.up_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 156 |
+
"language_model.model.layers.2.post_attention_layernorm.weight": "pytorch_model-00001-of-00004.bin",
|
| 157 |
+
"language_model.model.layers.2.self_attn.k_proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 158 |
+
"language_model.model.layers.2.self_attn.k_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 159 |
+
"language_model.model.layers.2.self_attn.o_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 160 |
+
"language_model.model.layers.2.self_attn.q_proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 161 |
+
"language_model.model.layers.2.self_attn.q_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 162 |
+
"language_model.model.layers.2.self_attn.v_proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 163 |
+
"language_model.model.layers.2.self_attn.v_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 164 |
+
"language_model.model.layers.20.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
|
| 165 |
+
"language_model.model.layers.20.mlp.down_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 166 |
+
"language_model.model.layers.20.mlp.gate_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 167 |
+
"language_model.model.layers.20.mlp.up_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 168 |
+
"language_model.model.layers.20.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
|
| 169 |
+
"language_model.model.layers.20.self_attn.k_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 170 |
+
"language_model.model.layers.20.self_attn.k_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 171 |
+
"language_model.model.layers.20.self_attn.o_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 172 |
+
"language_model.model.layers.20.self_attn.q_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 173 |
+
"language_model.model.layers.20.self_attn.q_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 174 |
+
"language_model.model.layers.20.self_attn.v_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 175 |
+
"language_model.model.layers.20.self_attn.v_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 176 |
+
"language_model.model.layers.21.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
|
| 177 |
+
"language_model.model.layers.21.mlp.down_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 178 |
+
"language_model.model.layers.21.mlp.gate_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 179 |
+
"language_model.model.layers.21.mlp.up_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 180 |
+
"language_model.model.layers.21.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
|
| 181 |
+
"language_model.model.layers.21.self_attn.k_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 182 |
+
"language_model.model.layers.21.self_attn.k_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 183 |
+
"language_model.model.layers.21.self_attn.o_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 184 |
+
"language_model.model.layers.21.self_attn.q_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 185 |
+
"language_model.model.layers.21.self_attn.q_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 186 |
+
"language_model.model.layers.21.self_attn.v_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 187 |
+
"language_model.model.layers.21.self_attn.v_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 188 |
+
"language_model.model.layers.22.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
|
| 189 |
+
"language_model.model.layers.22.mlp.down_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 190 |
+
"language_model.model.layers.22.mlp.gate_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 191 |
+
"language_model.model.layers.22.mlp.up_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 192 |
+
"language_model.model.layers.22.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
|
| 193 |
+
"language_model.model.layers.22.self_attn.k_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 194 |
+
"language_model.model.layers.22.self_attn.k_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 195 |
+
"language_model.model.layers.22.self_attn.o_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 196 |
+
"language_model.model.layers.22.self_attn.q_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 197 |
+
"language_model.model.layers.22.self_attn.q_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 198 |
+
"language_model.model.layers.22.self_attn.v_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 199 |
+
"language_model.model.layers.22.self_attn.v_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 200 |
+
"language_model.model.layers.23.input_layernorm.weight": "pytorch_model-00003-of-00004.bin",
|
| 201 |
+
"language_model.model.layers.23.mlp.down_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 202 |
+
"language_model.model.layers.23.mlp.gate_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 203 |
+
"language_model.model.layers.23.mlp.up_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 204 |
+
"language_model.model.layers.23.post_attention_layernorm.weight": "pytorch_model-00003-of-00004.bin",
|
| 205 |
+
"language_model.model.layers.23.self_attn.k_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 206 |
+
"language_model.model.layers.23.self_attn.k_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 207 |
+
"language_model.model.layers.23.self_attn.o_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 208 |
+
"language_model.model.layers.23.self_attn.q_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 209 |
+
"language_model.model.layers.23.self_attn.q_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 210 |
+
"language_model.model.layers.23.self_attn.v_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 211 |
+
"language_model.model.layers.23.self_attn.v_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 212 |
+
"language_model.model.layers.24.input_layernorm.weight": "pytorch_model-00004-of-00004.bin",
|
| 213 |
+
"language_model.model.layers.24.mlp.down_proj.weight": "pytorch_model-00004-of-00004.bin",
|
| 214 |
+
"language_model.model.layers.24.mlp.gate_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 215 |
+
"language_model.model.layers.24.mlp.up_proj.weight": "pytorch_model-00004-of-00004.bin",
|
| 216 |
+
"language_model.model.layers.24.post_attention_layernorm.weight": "pytorch_model-00004-of-00004.bin",
|
| 217 |
+
"language_model.model.layers.24.self_attn.k_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 218 |
+
"language_model.model.layers.24.self_attn.k_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 219 |
+
"language_model.model.layers.24.self_attn.o_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 220 |
+
"language_model.model.layers.24.self_attn.q_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 221 |
+
"language_model.model.layers.24.self_attn.q_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 222 |
+
"language_model.model.layers.24.self_attn.v_proj.bias": "pytorch_model-00003-of-00004.bin",
|
| 223 |
+
"language_model.model.layers.24.self_attn.v_proj.weight": "pytorch_model-00003-of-00004.bin",
|
| 224 |
+
"language_model.model.layers.25.input_layernorm.weight": "pytorch_model-00004-of-00004.bin",
|
| 225 |
+
"language_model.model.layers.25.mlp.down_proj.weight": "pytorch_model-00004-of-00004.bin",
|
| 226 |
+
"language_model.model.layers.25.mlp.gate_proj.weight": "pytorch_model-00004-of-00004.bin",
|
| 227 |
+
"language_model.model.layers.25.mlp.up_proj.weight": "pytorch_model-00004-of-00004.bin",
|
| 228 |
+
"language_model.model.layers.25.post_attention_layernorm.weight": "pytorch_model-00004-of-00004.bin",
|
| 229 |
+
"language_model.model.layers.25.self_attn.k_proj.bias": "pytorch_model-00004-of-00004.bin",
|
| 230 |
+
"language_model.model.layers.25.self_attn.k_proj.weight": "pytorch_model-00004-of-00004.bin",
|
| 231 |
+
"language_model.model.layers.25.self_attn.o_proj.weight": "pytorch_model-00004-of-00004.bin",
|
| 232 |
+
"language_model.model.layers.25.self_attn.q_proj.bias": "pytorch_model-00004-of-00004.bin",
|
| 233 |
+
"language_model.model.layers.25.self_attn.q_proj.weight": "pytorch_model-00004-of-00004.bin",
|
| 234 |
+
"language_model.model.layers.25.self_attn.v_proj.bias": "pytorch_model-00004-of-00004.bin",
|
| 235 |
+
"language_model.model.layers.25.self_attn.v_proj.weight": "pytorch_model-00004-of-00004.bin",
|
| 236 |
+
"language_model.model.layers.26.input_layernorm.weight": "pytorch_model-00004-of-00004.bin",
|
| 237 |
+
"language_model.model.layers.26.mlp.down_proj.weight": "pytorch_model-00004-of-00004.bin",
|
| 238 |
+
"language_model.model.layers.26.mlp.gate_proj.weight": "pytorch_model-00004-of-00004.bin",
|
| 239 |
+
"language_model.model.layers.26.mlp.up_proj.weight": "pytorch_model-00004-of-00004.bin",
|
| 240 |
+
"language_model.model.layers.26.post_attention_layernorm.weight": "pytorch_model-00004-of-00004.bin",
|
| 241 |
+
"language_model.model.layers.26.self_attn.k_proj.bias": "pytorch_model-00004-of-00004.bin",
|
| 242 |
+
"language_model.model.layers.26.self_attn.k_proj.weight": "pytorch_model-00004-of-00004.bin",
|
| 243 |
+
"language_model.model.layers.26.self_attn.o_proj.weight": "pytorch_model-00004-of-00004.bin",
|
| 244 |
+
"language_model.model.layers.26.self_attn.q_proj.bias": "pytorch_model-00004-of-00004.bin",
|
| 245 |
+
"language_model.model.layers.26.self_attn.q_proj.weight": "pytorch_model-00004-of-00004.bin",
|
| 246 |
+
"language_model.model.layers.26.self_attn.v_proj.bias": "pytorch_model-00004-of-00004.bin",
|
| 247 |
+
"language_model.model.layers.26.self_attn.v_proj.weight": "pytorch_model-00004-of-00004.bin",
|
| 248 |
+
"language_model.model.layers.27.input_layernorm.weight": "pytorch_model-00004-of-00004.bin",
|
| 249 |
+
"language_model.model.layers.27.mlp.down_proj.weight": "pytorch_model-00004-of-00004.bin",
|
| 250 |
+
"language_model.model.layers.27.mlp.gate_proj.weight": "pytorch_model-00004-of-00004.bin",
|
| 251 |
+
"language_model.model.layers.27.mlp.up_proj.weight": "pytorch_model-00004-of-00004.bin",
|
| 252 |
+
"language_model.model.layers.27.post_attention_layernorm.weight": "pytorch_model-00004-of-00004.bin",
|
| 253 |
+
"language_model.model.layers.27.self_attn.k_proj.bias": "pytorch_model-00004-of-00004.bin",
|
| 254 |
+
"language_model.model.layers.27.self_attn.k_proj.weight": "pytorch_model-00004-of-00004.bin",
|
| 255 |
+
"language_model.model.layers.27.self_attn.o_proj.weight": "pytorch_model-00004-of-00004.bin",
|
| 256 |
+
"language_model.model.layers.27.self_attn.q_proj.bias": "pytorch_model-00004-of-00004.bin",
|
| 257 |
+
"language_model.model.layers.27.self_attn.q_proj.weight": "pytorch_model-00004-of-00004.bin",
|
| 258 |
+
"language_model.model.layers.27.self_attn.v_proj.bias": "pytorch_model-00004-of-00004.bin",
|
| 259 |
+
"language_model.model.layers.27.self_attn.v_proj.weight": "pytorch_model-00004-of-00004.bin",
|
| 260 |
+
"language_model.model.layers.3.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
|
| 261 |
+
"language_model.model.layers.3.mlp.down_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 262 |
+
"language_model.model.layers.3.mlp.gate_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 263 |
+
"language_model.model.layers.3.mlp.up_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 264 |
+
"language_model.model.layers.3.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
|
| 265 |
+
"language_model.model.layers.3.self_attn.k_proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 266 |
+
"language_model.model.layers.3.self_attn.k_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 267 |
+
"language_model.model.layers.3.self_attn.o_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 268 |
+
"language_model.model.layers.3.self_attn.q_proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 269 |
+
"language_model.model.layers.3.self_attn.q_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 270 |
+
"language_model.model.layers.3.self_attn.v_proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 271 |
+
"language_model.model.layers.3.self_attn.v_proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 272 |
+
"language_model.model.layers.4.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
|
| 273 |
+
"language_model.model.layers.4.mlp.down_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 274 |
+
"language_model.model.layers.4.mlp.gate_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 275 |
+
"language_model.model.layers.4.mlp.up_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 276 |
+
"language_model.model.layers.4.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
|
| 277 |
+
"language_model.model.layers.4.self_attn.k_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 278 |
+
"language_model.model.layers.4.self_attn.k_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 279 |
+
"language_model.model.layers.4.self_attn.o_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 280 |
+
"language_model.model.layers.4.self_attn.q_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 281 |
+
"language_model.model.layers.4.self_attn.q_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 282 |
+
"language_model.model.layers.4.self_attn.v_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 283 |
+
"language_model.model.layers.4.self_attn.v_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 284 |
+
"language_model.model.layers.5.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
|
| 285 |
+
"language_model.model.layers.5.mlp.down_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 286 |
+
"language_model.model.layers.5.mlp.gate_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 287 |
+
"language_model.model.layers.5.mlp.up_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 288 |
+
"language_model.model.layers.5.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
|
| 289 |
+
"language_model.model.layers.5.self_attn.k_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 290 |
+
"language_model.model.layers.5.self_attn.k_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 291 |
+
"language_model.model.layers.5.self_attn.o_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 292 |
+
"language_model.model.layers.5.self_attn.q_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 293 |
+
"language_model.model.layers.5.self_attn.q_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 294 |
+
"language_model.model.layers.5.self_attn.v_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 295 |
+
"language_model.model.layers.5.self_attn.v_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 296 |
+
"language_model.model.layers.6.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
|
| 297 |
+
"language_model.model.layers.6.mlp.down_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 298 |
+
"language_model.model.layers.6.mlp.gate_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 299 |
+
"language_model.model.layers.6.mlp.up_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 300 |
+
"language_model.model.layers.6.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
|
| 301 |
+
"language_model.model.layers.6.self_attn.k_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 302 |
+
"language_model.model.layers.6.self_attn.k_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 303 |
+
"language_model.model.layers.6.self_attn.o_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 304 |
+
"language_model.model.layers.6.self_attn.q_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 305 |
+
"language_model.model.layers.6.self_attn.q_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 306 |
+
"language_model.model.layers.6.self_attn.v_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 307 |
+
"language_model.model.layers.6.self_attn.v_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 308 |
+
"language_model.model.layers.7.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
|
| 309 |
+
"language_model.model.layers.7.mlp.down_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 310 |
+
"language_model.model.layers.7.mlp.gate_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 311 |
+
"language_model.model.layers.7.mlp.up_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 312 |
+
"language_model.model.layers.7.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
|
| 313 |
+
"language_model.model.layers.7.self_attn.k_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 314 |
+
"language_model.model.layers.7.self_attn.k_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 315 |
+
"language_model.model.layers.7.self_attn.o_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 316 |
+
"language_model.model.layers.7.self_attn.q_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 317 |
+
"language_model.model.layers.7.self_attn.q_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 318 |
+
"language_model.model.layers.7.self_attn.v_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 319 |
+
"language_model.model.layers.7.self_attn.v_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 320 |
+
"language_model.model.layers.8.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
|
| 321 |
+
"language_model.model.layers.8.mlp.down_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 322 |
+
"language_model.model.layers.8.mlp.gate_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 323 |
+
"language_model.model.layers.8.mlp.up_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 324 |
+
"language_model.model.layers.8.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
|
| 325 |
+
"language_model.model.layers.8.self_attn.k_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 326 |
+
"language_model.model.layers.8.self_attn.k_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 327 |
+
"language_model.model.layers.8.self_attn.o_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 328 |
+
"language_model.model.layers.8.self_attn.q_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 329 |
+
"language_model.model.layers.8.self_attn.q_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 330 |
+
"language_model.model.layers.8.self_attn.v_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 331 |
+
"language_model.model.layers.8.self_attn.v_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 332 |
+
"language_model.model.layers.9.input_layernorm.weight": "pytorch_model-00002-of-00004.bin",
|
| 333 |
+
"language_model.model.layers.9.mlp.down_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 334 |
+
"language_model.model.layers.9.mlp.gate_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 335 |
+
"language_model.model.layers.9.mlp.up_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 336 |
+
"language_model.model.layers.9.post_attention_layernorm.weight": "pytorch_model-00002-of-00004.bin",
|
| 337 |
+
"language_model.model.layers.9.self_attn.k_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 338 |
+
"language_model.model.layers.9.self_attn.k_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 339 |
+
"language_model.model.layers.9.self_attn.o_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 340 |
+
"language_model.model.layers.9.self_attn.q_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 341 |
+
"language_model.model.layers.9.self_attn.q_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 342 |
+
"language_model.model.layers.9.self_attn.v_proj.bias": "pytorch_model-00002-of-00004.bin",
|
| 343 |
+
"language_model.model.layers.9.self_attn.v_proj.weight": "pytorch_model-00002-of-00004.bin",
|
| 344 |
+
"language_model.model.norm.weight": "pytorch_model-00004-of-00004.bin",
|
| 345 |
+
"mm_projector.projector_1.0.bias": "pytorch_model-00001-of-00004.bin",
|
| 346 |
+
"mm_projector.projector_1.0.weight": "pytorch_model-00001-of-00004.bin",
|
| 347 |
+
"mm_projector.projector_1.2.bias": "pytorch_model-00001-of-00004.bin",
|
| 348 |
+
"mm_projector.projector_1.2.weight": "pytorch_model-00001-of-00004.bin",
|
| 349 |
+
"mm_projector.projector_2.0.bias": "pytorch_model-00001-of-00004.bin",
|
| 350 |
+
"mm_projector.projector_2.0.weight": "pytorch_model-00001-of-00004.bin",
|
| 351 |
+
"mm_projector.projector_2.2.bias": "pytorch_model-00001-of-00004.bin",
|
| 352 |
+
"mm_projector.projector_2.2.weight": "pytorch_model-00001-of-00004.bin",
|
| 353 |
+
"vision_tower.pixel_encoder.conv_in.bias": "pytorch_model-00001-of-00004.bin",
|
| 354 |
+
"vision_tower.pixel_encoder.conv_in.weight": "pytorch_model-00001-of-00004.bin",
|
| 355 |
+
"vision_tower.pixel_encoder.conv_out.bias": "pytorch_model-00001-of-00004.bin",
|
| 356 |
+
"vision_tower.pixel_encoder.conv_out.weight": "pytorch_model-00001-of-00004.bin",
|
| 357 |
+
"vision_tower.pixel_encoder.down.0.block.0.conv1.bias": "pytorch_model-00001-of-00004.bin",
|
| 358 |
+
"vision_tower.pixel_encoder.down.0.block.0.conv1.weight": "pytorch_model-00001-of-00004.bin",
|
| 359 |
+
"vision_tower.pixel_encoder.down.0.block.0.conv2.bias": "pytorch_model-00001-of-00004.bin",
|
| 360 |
+
"vision_tower.pixel_encoder.down.0.block.0.conv2.weight": "pytorch_model-00001-of-00004.bin",
|
| 361 |
+
"vision_tower.pixel_encoder.down.0.block.0.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 362 |
+
"vision_tower.pixel_encoder.down.0.block.0.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 363 |
+
"vision_tower.pixel_encoder.down.0.block.0.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 364 |
+
"vision_tower.pixel_encoder.down.0.block.0.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 365 |
+
"vision_tower.pixel_encoder.down.0.block.1.conv1.bias": "pytorch_model-00001-of-00004.bin",
|
| 366 |
+
"vision_tower.pixel_encoder.down.0.block.1.conv1.weight": "pytorch_model-00001-of-00004.bin",
|
| 367 |
+
"vision_tower.pixel_encoder.down.0.block.1.conv2.bias": "pytorch_model-00001-of-00004.bin",
|
| 368 |
+
"vision_tower.pixel_encoder.down.0.block.1.conv2.weight": "pytorch_model-00001-of-00004.bin",
|
| 369 |
+
"vision_tower.pixel_encoder.down.0.block.1.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 370 |
+
"vision_tower.pixel_encoder.down.0.block.1.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 371 |
+
"vision_tower.pixel_encoder.down.0.block.1.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 372 |
+
"vision_tower.pixel_encoder.down.0.block.1.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 373 |
+
"vision_tower.pixel_encoder.down.0.downsample.conv.bias": "pytorch_model-00001-of-00004.bin",
|
| 374 |
+
"vision_tower.pixel_encoder.down.0.downsample.conv.weight": "pytorch_model-00001-of-00004.bin",
|
| 375 |
+
"vision_tower.pixel_encoder.down.1.block.0.conv1.bias": "pytorch_model-00001-of-00004.bin",
|
| 376 |
+
"vision_tower.pixel_encoder.down.1.block.0.conv1.weight": "pytorch_model-00001-of-00004.bin",
|
| 377 |
+
"vision_tower.pixel_encoder.down.1.block.0.conv2.bias": "pytorch_model-00001-of-00004.bin",
|
| 378 |
+
"vision_tower.pixel_encoder.down.1.block.0.conv2.weight": "pytorch_model-00001-of-00004.bin",
|
| 379 |
+
"vision_tower.pixel_encoder.down.1.block.0.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 380 |
+
"vision_tower.pixel_encoder.down.1.block.0.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 381 |
+
"vision_tower.pixel_encoder.down.1.block.0.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 382 |
+
"vision_tower.pixel_encoder.down.1.block.0.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 383 |
+
"vision_tower.pixel_encoder.down.1.block.1.conv1.bias": "pytorch_model-00001-of-00004.bin",
|
| 384 |
+
"vision_tower.pixel_encoder.down.1.block.1.conv1.weight": "pytorch_model-00001-of-00004.bin",
|
| 385 |
+
"vision_tower.pixel_encoder.down.1.block.1.conv2.bias": "pytorch_model-00001-of-00004.bin",
|
| 386 |
+
"vision_tower.pixel_encoder.down.1.block.1.conv2.weight": "pytorch_model-00001-of-00004.bin",
|
| 387 |
+
"vision_tower.pixel_encoder.down.1.block.1.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 388 |
+
"vision_tower.pixel_encoder.down.1.block.1.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 389 |
+
"vision_tower.pixel_encoder.down.1.block.1.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 390 |
+
"vision_tower.pixel_encoder.down.1.block.1.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 391 |
+
"vision_tower.pixel_encoder.down.1.downsample.conv.bias": "pytorch_model-00001-of-00004.bin",
|
| 392 |
+
"vision_tower.pixel_encoder.down.1.downsample.conv.weight": "pytorch_model-00001-of-00004.bin",
|
| 393 |
+
"vision_tower.pixel_encoder.down.2.block.0.conv1.bias": "pytorch_model-00001-of-00004.bin",
|
| 394 |
+
"vision_tower.pixel_encoder.down.2.block.0.conv1.weight": "pytorch_model-00001-of-00004.bin",
|
| 395 |
+
"vision_tower.pixel_encoder.down.2.block.0.conv2.bias": "pytorch_model-00001-of-00004.bin",
|
| 396 |
+
"vision_tower.pixel_encoder.down.2.block.0.conv2.weight": "pytorch_model-00001-of-00004.bin",
|
| 397 |
+
"vision_tower.pixel_encoder.down.2.block.0.nin_shortcut.bias": "pytorch_model-00001-of-00004.bin",
|
| 398 |
+
"vision_tower.pixel_encoder.down.2.block.0.nin_shortcut.weight": "pytorch_model-00001-of-00004.bin",
|
| 399 |
+
"vision_tower.pixel_encoder.down.2.block.0.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 400 |
+
"vision_tower.pixel_encoder.down.2.block.0.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 401 |
+
"vision_tower.pixel_encoder.down.2.block.0.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 402 |
+
"vision_tower.pixel_encoder.down.2.block.0.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 403 |
+
"vision_tower.pixel_encoder.down.2.block.1.conv1.bias": "pytorch_model-00001-of-00004.bin",
|
| 404 |
+
"vision_tower.pixel_encoder.down.2.block.1.conv1.weight": "pytorch_model-00001-of-00004.bin",
|
| 405 |
+
"vision_tower.pixel_encoder.down.2.block.1.conv2.bias": "pytorch_model-00001-of-00004.bin",
|
| 406 |
+
"vision_tower.pixel_encoder.down.2.block.1.conv2.weight": "pytorch_model-00001-of-00004.bin",
|
| 407 |
+
"vision_tower.pixel_encoder.down.2.block.1.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 408 |
+
"vision_tower.pixel_encoder.down.2.block.1.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 409 |
+
"vision_tower.pixel_encoder.down.2.block.1.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 410 |
+
"vision_tower.pixel_encoder.down.2.block.1.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 411 |
+
"vision_tower.pixel_encoder.down.2.downsample.conv.bias": "pytorch_model-00001-of-00004.bin",
|
| 412 |
+
"vision_tower.pixel_encoder.down.2.downsample.conv.weight": "pytorch_model-00001-of-00004.bin",
|
| 413 |
+
"vision_tower.pixel_encoder.down.3.block.0.conv1.bias": "pytorch_model-00001-of-00004.bin",
|
| 414 |
+
"vision_tower.pixel_encoder.down.3.block.0.conv1.weight": "pytorch_model-00001-of-00004.bin",
|
| 415 |
+
"vision_tower.pixel_encoder.down.3.block.0.conv2.bias": "pytorch_model-00001-of-00004.bin",
|
| 416 |
+
"vision_tower.pixel_encoder.down.3.block.0.conv2.weight": "pytorch_model-00001-of-00004.bin",
|
| 417 |
+
"vision_tower.pixel_encoder.down.3.block.0.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 418 |
+
"vision_tower.pixel_encoder.down.3.block.0.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 419 |
+
"vision_tower.pixel_encoder.down.3.block.0.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 420 |
+
"vision_tower.pixel_encoder.down.3.block.0.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 421 |
+
"vision_tower.pixel_encoder.down.3.block.1.conv1.bias": "pytorch_model-00001-of-00004.bin",
|
| 422 |
+
"vision_tower.pixel_encoder.down.3.block.1.conv1.weight": "pytorch_model-00001-of-00004.bin",
|
| 423 |
+
"vision_tower.pixel_encoder.down.3.block.1.conv2.bias": "pytorch_model-00001-of-00004.bin",
|
| 424 |
+
"vision_tower.pixel_encoder.down.3.block.1.conv2.weight": "pytorch_model-00001-of-00004.bin",
|
| 425 |
+
"vision_tower.pixel_encoder.down.3.block.1.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 426 |
+
"vision_tower.pixel_encoder.down.3.block.1.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 427 |
+
"vision_tower.pixel_encoder.down.3.block.1.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 428 |
+
"vision_tower.pixel_encoder.down.3.block.1.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 429 |
+
"vision_tower.pixel_encoder.down.3.downsample.conv.bias": "pytorch_model-00001-of-00004.bin",
|
| 430 |
+
"vision_tower.pixel_encoder.down.3.downsample.conv.weight": "pytorch_model-00001-of-00004.bin",
|
| 431 |
+
"vision_tower.pixel_encoder.down.4.attn.0.k.bias": "pytorch_model-00001-of-00004.bin",
|
| 432 |
+
"vision_tower.pixel_encoder.down.4.attn.0.k.weight": "pytorch_model-00001-of-00004.bin",
|
| 433 |
+
"vision_tower.pixel_encoder.down.4.attn.0.norm.bias": "pytorch_model-00001-of-00004.bin",
|
| 434 |
+
"vision_tower.pixel_encoder.down.4.attn.0.norm.weight": "pytorch_model-00001-of-00004.bin",
|
| 435 |
+
"vision_tower.pixel_encoder.down.4.attn.0.proj_out.bias": "pytorch_model-00001-of-00004.bin",
|
| 436 |
+
"vision_tower.pixel_encoder.down.4.attn.0.proj_out.weight": "pytorch_model-00001-of-00004.bin",
|
| 437 |
+
"vision_tower.pixel_encoder.down.4.attn.0.q.bias": "pytorch_model-00001-of-00004.bin",
|
| 438 |
+
"vision_tower.pixel_encoder.down.4.attn.0.q.weight": "pytorch_model-00001-of-00004.bin",
|
| 439 |
+
"vision_tower.pixel_encoder.down.4.attn.0.v.bias": "pytorch_model-00001-of-00004.bin",
|
| 440 |
+
"vision_tower.pixel_encoder.down.4.attn.0.v.weight": "pytorch_model-00001-of-00004.bin",
|
| 441 |
+
"vision_tower.pixel_encoder.down.4.attn.1.k.bias": "pytorch_model-00001-of-00004.bin",
|
| 442 |
+
"vision_tower.pixel_encoder.down.4.attn.1.k.weight": "pytorch_model-00001-of-00004.bin",
|
| 443 |
+
"vision_tower.pixel_encoder.down.4.attn.1.norm.bias": "pytorch_model-00001-of-00004.bin",
|
| 444 |
+
"vision_tower.pixel_encoder.down.4.attn.1.norm.weight": "pytorch_model-00001-of-00004.bin",
|
| 445 |
+
"vision_tower.pixel_encoder.down.4.attn.1.proj_out.bias": "pytorch_model-00001-of-00004.bin",
|
| 446 |
+
"vision_tower.pixel_encoder.down.4.attn.1.proj_out.weight": "pytorch_model-00001-of-00004.bin",
|
| 447 |
+
"vision_tower.pixel_encoder.down.4.attn.1.q.bias": "pytorch_model-00001-of-00004.bin",
|
| 448 |
+
"vision_tower.pixel_encoder.down.4.attn.1.q.weight": "pytorch_model-00001-of-00004.bin",
|
| 449 |
+
"vision_tower.pixel_encoder.down.4.attn.1.v.bias": "pytorch_model-00001-of-00004.bin",
|
| 450 |
+
"vision_tower.pixel_encoder.down.4.attn.1.v.weight": "pytorch_model-00001-of-00004.bin",
|
| 451 |
+
"vision_tower.pixel_encoder.down.4.block.0.conv1.bias": "pytorch_model-00001-of-00004.bin",
|
| 452 |
+
"vision_tower.pixel_encoder.down.4.block.0.conv1.weight": "pytorch_model-00001-of-00004.bin",
|
| 453 |
+
"vision_tower.pixel_encoder.down.4.block.0.conv2.bias": "pytorch_model-00001-of-00004.bin",
|
| 454 |
+
"vision_tower.pixel_encoder.down.4.block.0.conv2.weight": "pytorch_model-00001-of-00004.bin",
|
| 455 |
+
"vision_tower.pixel_encoder.down.4.block.0.nin_shortcut.bias": "pytorch_model-00001-of-00004.bin",
|
| 456 |
+
"vision_tower.pixel_encoder.down.4.block.0.nin_shortcut.weight": "pytorch_model-00001-of-00004.bin",
|
| 457 |
+
"vision_tower.pixel_encoder.down.4.block.0.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 458 |
+
"vision_tower.pixel_encoder.down.4.block.0.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 459 |
+
"vision_tower.pixel_encoder.down.4.block.0.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 460 |
+
"vision_tower.pixel_encoder.down.4.block.0.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 461 |
+
"vision_tower.pixel_encoder.down.4.block.1.conv1.bias": "pytorch_model-00001-of-00004.bin",
|
| 462 |
+
"vision_tower.pixel_encoder.down.4.block.1.conv1.weight": "pytorch_model-00001-of-00004.bin",
|
| 463 |
+
"vision_tower.pixel_encoder.down.4.block.1.conv2.bias": "pytorch_model-00001-of-00004.bin",
|
| 464 |
+
"vision_tower.pixel_encoder.down.4.block.1.conv2.weight": "pytorch_model-00001-of-00004.bin",
|
| 465 |
+
"vision_tower.pixel_encoder.down.4.block.1.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 466 |
+
"vision_tower.pixel_encoder.down.4.block.1.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 467 |
+
"vision_tower.pixel_encoder.down.4.block.1.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 468 |
+
"vision_tower.pixel_encoder.down.4.block.1.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 469 |
+
"vision_tower.pixel_encoder.mid.attn_1.k.bias": "pytorch_model-00001-of-00004.bin",
|
| 470 |
+
"vision_tower.pixel_encoder.mid.attn_1.k.weight": "pytorch_model-00001-of-00004.bin",
|
| 471 |
+
"vision_tower.pixel_encoder.mid.attn_1.norm.bias": "pytorch_model-00001-of-00004.bin",
|
| 472 |
+
"vision_tower.pixel_encoder.mid.attn_1.norm.weight": "pytorch_model-00001-of-00004.bin",
|
| 473 |
+
"vision_tower.pixel_encoder.mid.attn_1.proj_out.bias": "pytorch_model-00001-of-00004.bin",
|
| 474 |
+
"vision_tower.pixel_encoder.mid.attn_1.proj_out.weight": "pytorch_model-00001-of-00004.bin",
|
| 475 |
+
"vision_tower.pixel_encoder.mid.attn_1.q.bias": "pytorch_model-00001-of-00004.bin",
|
| 476 |
+
"vision_tower.pixel_encoder.mid.attn_1.q.weight": "pytorch_model-00001-of-00004.bin",
|
| 477 |
+
"vision_tower.pixel_encoder.mid.attn_1.v.bias": "pytorch_model-00001-of-00004.bin",
|
| 478 |
+
"vision_tower.pixel_encoder.mid.attn_1.v.weight": "pytorch_model-00001-of-00004.bin",
|
| 479 |
+
"vision_tower.pixel_encoder.mid.block_1.conv1.bias": "pytorch_model-00001-of-00004.bin",
|
| 480 |
+
"vision_tower.pixel_encoder.mid.block_1.conv1.weight": "pytorch_model-00001-of-00004.bin",
|
| 481 |
+
"vision_tower.pixel_encoder.mid.block_1.conv2.bias": "pytorch_model-00001-of-00004.bin",
|
| 482 |
+
"vision_tower.pixel_encoder.mid.block_1.conv2.weight": "pytorch_model-00001-of-00004.bin",
|
| 483 |
+
"vision_tower.pixel_encoder.mid.block_1.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 484 |
+
"vision_tower.pixel_encoder.mid.block_1.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 485 |
+
"vision_tower.pixel_encoder.mid.block_1.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 486 |
+
"vision_tower.pixel_encoder.mid.block_1.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 487 |
+
"vision_tower.pixel_encoder.mid.block_2.conv1.bias": "pytorch_model-00001-of-00004.bin",
|
| 488 |
+
"vision_tower.pixel_encoder.mid.block_2.conv1.weight": "pytorch_model-00001-of-00004.bin",
|
| 489 |
+
"vision_tower.pixel_encoder.mid.block_2.conv2.bias": "pytorch_model-00001-of-00004.bin",
|
| 490 |
+
"vision_tower.pixel_encoder.mid.block_2.conv2.weight": "pytorch_model-00001-of-00004.bin",
|
| 491 |
+
"vision_tower.pixel_encoder.mid.block_2.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 492 |
+
"vision_tower.pixel_encoder.mid.block_2.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 493 |
+
"vision_tower.pixel_encoder.mid.block_2.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 494 |
+
"vision_tower.pixel_encoder.mid.block_2.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 495 |
+
"vision_tower.pixel_encoder.norm_out.bias": "pytorch_model-00001-of-00004.bin",
|
| 496 |
+
"vision_tower.pixel_encoder.norm_out.weight": "pytorch_model-00001-of-00004.bin",
|
| 497 |
+
"vision_tower.semantic_encoder.blocks.0.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 498 |
+
"vision_tower.semantic_encoder.blocks.0.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 499 |
+
"vision_tower.semantic_encoder.blocks.0.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 500 |
+
"vision_tower.semantic_encoder.blocks.0.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 501 |
+
"vision_tower.semantic_encoder.blocks.0.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 502 |
+
"vision_tower.semantic_encoder.blocks.0.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 503 |
+
"vision_tower.semantic_encoder.blocks.0.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 504 |
+
"vision_tower.semantic_encoder.blocks.0.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 505 |
+
"vision_tower.semantic_encoder.blocks.0.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 506 |
+
"vision_tower.semantic_encoder.blocks.0.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 507 |
+
"vision_tower.semantic_encoder.blocks.0.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 508 |
+
"vision_tower.semantic_encoder.blocks.0.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 509 |
+
"vision_tower.semantic_encoder.blocks.1.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 510 |
+
"vision_tower.semantic_encoder.blocks.1.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 511 |
+
"vision_tower.semantic_encoder.blocks.1.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 512 |
+
"vision_tower.semantic_encoder.blocks.1.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 513 |
+
"vision_tower.semantic_encoder.blocks.1.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 514 |
+
"vision_tower.semantic_encoder.blocks.1.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 515 |
+
"vision_tower.semantic_encoder.blocks.1.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 516 |
+
"vision_tower.semantic_encoder.blocks.1.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 517 |
+
"vision_tower.semantic_encoder.blocks.1.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 518 |
+
"vision_tower.semantic_encoder.blocks.1.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 519 |
+
"vision_tower.semantic_encoder.blocks.1.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 520 |
+
"vision_tower.semantic_encoder.blocks.1.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 521 |
+
"vision_tower.semantic_encoder.blocks.10.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 522 |
+
"vision_tower.semantic_encoder.blocks.10.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 523 |
+
"vision_tower.semantic_encoder.blocks.10.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 524 |
+
"vision_tower.semantic_encoder.blocks.10.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 525 |
+
"vision_tower.semantic_encoder.blocks.10.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 526 |
+
"vision_tower.semantic_encoder.blocks.10.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 527 |
+
"vision_tower.semantic_encoder.blocks.10.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 528 |
+
"vision_tower.semantic_encoder.blocks.10.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 529 |
+
"vision_tower.semantic_encoder.blocks.10.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 530 |
+
"vision_tower.semantic_encoder.blocks.10.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 531 |
+
"vision_tower.semantic_encoder.blocks.10.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 532 |
+
"vision_tower.semantic_encoder.blocks.10.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 533 |
+
"vision_tower.semantic_encoder.blocks.11.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 534 |
+
"vision_tower.semantic_encoder.blocks.11.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 535 |
+
"vision_tower.semantic_encoder.blocks.11.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 536 |
+
"vision_tower.semantic_encoder.blocks.11.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 537 |
+
"vision_tower.semantic_encoder.blocks.11.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 538 |
+
"vision_tower.semantic_encoder.blocks.11.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 539 |
+
"vision_tower.semantic_encoder.blocks.11.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 540 |
+
"vision_tower.semantic_encoder.blocks.11.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 541 |
+
"vision_tower.semantic_encoder.blocks.11.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 542 |
+
"vision_tower.semantic_encoder.blocks.11.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 543 |
+
"vision_tower.semantic_encoder.blocks.11.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 544 |
+
"vision_tower.semantic_encoder.blocks.11.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 545 |
+
"vision_tower.semantic_encoder.blocks.12.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 546 |
+
"vision_tower.semantic_encoder.blocks.12.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 547 |
+
"vision_tower.semantic_encoder.blocks.12.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 548 |
+
"vision_tower.semantic_encoder.blocks.12.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 549 |
+
"vision_tower.semantic_encoder.blocks.12.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 550 |
+
"vision_tower.semantic_encoder.blocks.12.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 551 |
+
"vision_tower.semantic_encoder.blocks.12.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 552 |
+
"vision_tower.semantic_encoder.blocks.12.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 553 |
+
"vision_tower.semantic_encoder.blocks.12.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 554 |
+
"vision_tower.semantic_encoder.blocks.12.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 555 |
+
"vision_tower.semantic_encoder.blocks.12.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 556 |
+
"vision_tower.semantic_encoder.blocks.12.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 557 |
+
"vision_tower.semantic_encoder.blocks.13.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 558 |
+
"vision_tower.semantic_encoder.blocks.13.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 559 |
+
"vision_tower.semantic_encoder.blocks.13.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 560 |
+
"vision_tower.semantic_encoder.blocks.13.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 561 |
+
"vision_tower.semantic_encoder.blocks.13.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 562 |
+
"vision_tower.semantic_encoder.blocks.13.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 563 |
+
"vision_tower.semantic_encoder.blocks.13.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 564 |
+
"vision_tower.semantic_encoder.blocks.13.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 565 |
+
"vision_tower.semantic_encoder.blocks.13.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 566 |
+
"vision_tower.semantic_encoder.blocks.13.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 567 |
+
"vision_tower.semantic_encoder.blocks.13.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 568 |
+
"vision_tower.semantic_encoder.blocks.13.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 569 |
+
"vision_tower.semantic_encoder.blocks.14.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 570 |
+
"vision_tower.semantic_encoder.blocks.14.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 571 |
+
"vision_tower.semantic_encoder.blocks.14.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 572 |
+
"vision_tower.semantic_encoder.blocks.14.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 573 |
+
"vision_tower.semantic_encoder.blocks.14.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 574 |
+
"vision_tower.semantic_encoder.blocks.14.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 575 |
+
"vision_tower.semantic_encoder.blocks.14.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 576 |
+
"vision_tower.semantic_encoder.blocks.14.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 577 |
+
"vision_tower.semantic_encoder.blocks.14.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 578 |
+
"vision_tower.semantic_encoder.blocks.14.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 579 |
+
"vision_tower.semantic_encoder.blocks.14.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 580 |
+
"vision_tower.semantic_encoder.blocks.14.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 581 |
+
"vision_tower.semantic_encoder.blocks.15.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 582 |
+
"vision_tower.semantic_encoder.blocks.15.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 583 |
+
"vision_tower.semantic_encoder.blocks.15.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 584 |
+
"vision_tower.semantic_encoder.blocks.15.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 585 |
+
"vision_tower.semantic_encoder.blocks.15.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 586 |
+
"vision_tower.semantic_encoder.blocks.15.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 587 |
+
"vision_tower.semantic_encoder.blocks.15.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 588 |
+
"vision_tower.semantic_encoder.blocks.15.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 589 |
+
"vision_tower.semantic_encoder.blocks.15.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 590 |
+
"vision_tower.semantic_encoder.blocks.15.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 591 |
+
"vision_tower.semantic_encoder.blocks.15.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 592 |
+
"vision_tower.semantic_encoder.blocks.15.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 593 |
+
"vision_tower.semantic_encoder.blocks.16.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 594 |
+
"vision_tower.semantic_encoder.blocks.16.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 595 |
+
"vision_tower.semantic_encoder.blocks.16.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 596 |
+
"vision_tower.semantic_encoder.blocks.16.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 597 |
+
"vision_tower.semantic_encoder.blocks.16.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 598 |
+
"vision_tower.semantic_encoder.blocks.16.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 599 |
+
"vision_tower.semantic_encoder.blocks.16.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 600 |
+
"vision_tower.semantic_encoder.blocks.16.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 601 |
+
"vision_tower.semantic_encoder.blocks.16.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 602 |
+
"vision_tower.semantic_encoder.blocks.16.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 603 |
+
"vision_tower.semantic_encoder.blocks.16.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 604 |
+
"vision_tower.semantic_encoder.blocks.16.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 605 |
+
"vision_tower.semantic_encoder.blocks.17.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 606 |
+
"vision_tower.semantic_encoder.blocks.17.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 607 |
+
"vision_tower.semantic_encoder.blocks.17.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 608 |
+
"vision_tower.semantic_encoder.blocks.17.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 609 |
+
"vision_tower.semantic_encoder.blocks.17.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 610 |
+
"vision_tower.semantic_encoder.blocks.17.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 611 |
+
"vision_tower.semantic_encoder.blocks.17.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 612 |
+
"vision_tower.semantic_encoder.blocks.17.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 613 |
+
"vision_tower.semantic_encoder.blocks.17.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 614 |
+
"vision_tower.semantic_encoder.blocks.17.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 615 |
+
"vision_tower.semantic_encoder.blocks.17.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 616 |
+
"vision_tower.semantic_encoder.blocks.17.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 617 |
+
"vision_tower.semantic_encoder.blocks.18.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 618 |
+
"vision_tower.semantic_encoder.blocks.18.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 619 |
+
"vision_tower.semantic_encoder.blocks.18.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 620 |
+
"vision_tower.semantic_encoder.blocks.18.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 621 |
+
"vision_tower.semantic_encoder.blocks.18.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 622 |
+
"vision_tower.semantic_encoder.blocks.18.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 623 |
+
"vision_tower.semantic_encoder.blocks.18.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 624 |
+
"vision_tower.semantic_encoder.blocks.18.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 625 |
+
"vision_tower.semantic_encoder.blocks.18.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 626 |
+
"vision_tower.semantic_encoder.blocks.18.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 627 |
+
"vision_tower.semantic_encoder.blocks.18.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 628 |
+
"vision_tower.semantic_encoder.blocks.18.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 629 |
+
"vision_tower.semantic_encoder.blocks.19.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 630 |
+
"vision_tower.semantic_encoder.blocks.19.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 631 |
+
"vision_tower.semantic_encoder.blocks.19.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 632 |
+
"vision_tower.semantic_encoder.blocks.19.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 633 |
+
"vision_tower.semantic_encoder.blocks.19.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 634 |
+
"vision_tower.semantic_encoder.blocks.19.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 635 |
+
"vision_tower.semantic_encoder.blocks.19.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 636 |
+
"vision_tower.semantic_encoder.blocks.19.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 637 |
+
"vision_tower.semantic_encoder.blocks.19.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 638 |
+
"vision_tower.semantic_encoder.blocks.19.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 639 |
+
"vision_tower.semantic_encoder.blocks.19.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 640 |
+
"vision_tower.semantic_encoder.blocks.19.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 641 |
+
"vision_tower.semantic_encoder.blocks.2.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 642 |
+
"vision_tower.semantic_encoder.blocks.2.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 643 |
+
"vision_tower.semantic_encoder.blocks.2.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 644 |
+
"vision_tower.semantic_encoder.blocks.2.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 645 |
+
"vision_tower.semantic_encoder.blocks.2.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 646 |
+
"vision_tower.semantic_encoder.blocks.2.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 647 |
+
"vision_tower.semantic_encoder.blocks.2.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 648 |
+
"vision_tower.semantic_encoder.blocks.2.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 649 |
+
"vision_tower.semantic_encoder.blocks.2.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 650 |
+
"vision_tower.semantic_encoder.blocks.2.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 651 |
+
"vision_tower.semantic_encoder.blocks.2.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 652 |
+
"vision_tower.semantic_encoder.blocks.2.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 653 |
+
"vision_tower.semantic_encoder.blocks.20.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 654 |
+
"vision_tower.semantic_encoder.blocks.20.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 655 |
+
"vision_tower.semantic_encoder.blocks.20.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 656 |
+
"vision_tower.semantic_encoder.blocks.20.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 657 |
+
"vision_tower.semantic_encoder.blocks.20.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 658 |
+
"vision_tower.semantic_encoder.blocks.20.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 659 |
+
"vision_tower.semantic_encoder.blocks.20.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 660 |
+
"vision_tower.semantic_encoder.blocks.20.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 661 |
+
"vision_tower.semantic_encoder.blocks.20.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 662 |
+
"vision_tower.semantic_encoder.blocks.20.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 663 |
+
"vision_tower.semantic_encoder.blocks.20.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 664 |
+
"vision_tower.semantic_encoder.blocks.20.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 665 |
+
"vision_tower.semantic_encoder.blocks.21.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 666 |
+
"vision_tower.semantic_encoder.blocks.21.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 667 |
+
"vision_tower.semantic_encoder.blocks.21.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 668 |
+
"vision_tower.semantic_encoder.blocks.21.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 669 |
+
"vision_tower.semantic_encoder.blocks.21.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 670 |
+
"vision_tower.semantic_encoder.blocks.21.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 671 |
+
"vision_tower.semantic_encoder.blocks.21.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 672 |
+
"vision_tower.semantic_encoder.blocks.21.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 673 |
+
"vision_tower.semantic_encoder.blocks.21.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 674 |
+
"vision_tower.semantic_encoder.blocks.21.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 675 |
+
"vision_tower.semantic_encoder.blocks.21.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 676 |
+
"vision_tower.semantic_encoder.blocks.21.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 677 |
+
"vision_tower.semantic_encoder.blocks.22.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 678 |
+
"vision_tower.semantic_encoder.blocks.22.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 679 |
+
"vision_tower.semantic_encoder.blocks.22.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 680 |
+
"vision_tower.semantic_encoder.blocks.22.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 681 |
+
"vision_tower.semantic_encoder.blocks.22.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 682 |
+
"vision_tower.semantic_encoder.blocks.22.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 683 |
+
"vision_tower.semantic_encoder.blocks.22.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 684 |
+
"vision_tower.semantic_encoder.blocks.22.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 685 |
+
"vision_tower.semantic_encoder.blocks.22.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 686 |
+
"vision_tower.semantic_encoder.blocks.22.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 687 |
+
"vision_tower.semantic_encoder.blocks.22.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 688 |
+
"vision_tower.semantic_encoder.blocks.22.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 689 |
+
"vision_tower.semantic_encoder.blocks.23.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 690 |
+
"vision_tower.semantic_encoder.blocks.23.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 691 |
+
"vision_tower.semantic_encoder.blocks.23.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 692 |
+
"vision_tower.semantic_encoder.blocks.23.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 693 |
+
"vision_tower.semantic_encoder.blocks.23.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 694 |
+
"vision_tower.semantic_encoder.blocks.23.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 695 |
+
"vision_tower.semantic_encoder.blocks.23.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 696 |
+
"vision_tower.semantic_encoder.blocks.23.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 697 |
+
"vision_tower.semantic_encoder.blocks.23.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 698 |
+
"vision_tower.semantic_encoder.blocks.23.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 699 |
+
"vision_tower.semantic_encoder.blocks.23.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 700 |
+
"vision_tower.semantic_encoder.blocks.23.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 701 |
+
"vision_tower.semantic_encoder.blocks.24.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 702 |
+
"vision_tower.semantic_encoder.blocks.24.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 703 |
+
"vision_tower.semantic_encoder.blocks.24.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 704 |
+
"vision_tower.semantic_encoder.blocks.24.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 705 |
+
"vision_tower.semantic_encoder.blocks.24.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 706 |
+
"vision_tower.semantic_encoder.blocks.24.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 707 |
+
"vision_tower.semantic_encoder.blocks.24.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 708 |
+
"vision_tower.semantic_encoder.blocks.24.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 709 |
+
"vision_tower.semantic_encoder.blocks.24.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 710 |
+
"vision_tower.semantic_encoder.blocks.24.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 711 |
+
"vision_tower.semantic_encoder.blocks.24.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 712 |
+
"vision_tower.semantic_encoder.blocks.24.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 713 |
+
"vision_tower.semantic_encoder.blocks.25.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 714 |
+
"vision_tower.semantic_encoder.blocks.25.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 715 |
+
"vision_tower.semantic_encoder.blocks.25.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 716 |
+
"vision_tower.semantic_encoder.blocks.25.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 717 |
+
"vision_tower.semantic_encoder.blocks.25.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 718 |
+
"vision_tower.semantic_encoder.blocks.25.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 719 |
+
"vision_tower.semantic_encoder.blocks.25.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 720 |
+
"vision_tower.semantic_encoder.blocks.25.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 721 |
+
"vision_tower.semantic_encoder.blocks.25.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 722 |
+
"vision_tower.semantic_encoder.blocks.25.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 723 |
+
"vision_tower.semantic_encoder.blocks.25.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 724 |
+
"vision_tower.semantic_encoder.blocks.25.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 725 |
+
"vision_tower.semantic_encoder.blocks.26.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 726 |
+
"vision_tower.semantic_encoder.blocks.26.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 727 |
+
"vision_tower.semantic_encoder.blocks.26.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 728 |
+
"vision_tower.semantic_encoder.blocks.26.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 729 |
+
"vision_tower.semantic_encoder.blocks.26.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 730 |
+
"vision_tower.semantic_encoder.blocks.26.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 731 |
+
"vision_tower.semantic_encoder.blocks.26.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 732 |
+
"vision_tower.semantic_encoder.blocks.26.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 733 |
+
"vision_tower.semantic_encoder.blocks.26.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 734 |
+
"vision_tower.semantic_encoder.blocks.26.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 735 |
+
"vision_tower.semantic_encoder.blocks.26.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 736 |
+
"vision_tower.semantic_encoder.blocks.26.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 737 |
+
"vision_tower.semantic_encoder.blocks.27.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 738 |
+
"vision_tower.semantic_encoder.blocks.27.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 739 |
+
"vision_tower.semantic_encoder.blocks.27.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 740 |
+
"vision_tower.semantic_encoder.blocks.27.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 741 |
+
"vision_tower.semantic_encoder.blocks.27.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 742 |
+
"vision_tower.semantic_encoder.blocks.27.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 743 |
+
"vision_tower.semantic_encoder.blocks.27.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 744 |
+
"vision_tower.semantic_encoder.blocks.27.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 745 |
+
"vision_tower.semantic_encoder.blocks.27.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 746 |
+
"vision_tower.semantic_encoder.blocks.27.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 747 |
+
"vision_tower.semantic_encoder.blocks.27.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 748 |
+
"vision_tower.semantic_encoder.blocks.27.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 749 |
+
"vision_tower.semantic_encoder.blocks.28.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 750 |
+
"vision_tower.semantic_encoder.blocks.28.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 751 |
+
"vision_tower.semantic_encoder.blocks.28.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 752 |
+
"vision_tower.semantic_encoder.blocks.28.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 753 |
+
"vision_tower.semantic_encoder.blocks.28.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 754 |
+
"vision_tower.semantic_encoder.blocks.28.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 755 |
+
"vision_tower.semantic_encoder.blocks.28.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 756 |
+
"vision_tower.semantic_encoder.blocks.28.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 757 |
+
"vision_tower.semantic_encoder.blocks.28.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 758 |
+
"vision_tower.semantic_encoder.blocks.28.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 759 |
+
"vision_tower.semantic_encoder.blocks.28.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 760 |
+
"vision_tower.semantic_encoder.blocks.28.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 761 |
+
"vision_tower.semantic_encoder.blocks.29.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 762 |
+
"vision_tower.semantic_encoder.blocks.29.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 763 |
+
"vision_tower.semantic_encoder.blocks.29.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 764 |
+
"vision_tower.semantic_encoder.blocks.29.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 765 |
+
"vision_tower.semantic_encoder.blocks.29.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 766 |
+
"vision_tower.semantic_encoder.blocks.29.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 767 |
+
"vision_tower.semantic_encoder.blocks.29.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 768 |
+
"vision_tower.semantic_encoder.blocks.29.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 769 |
+
"vision_tower.semantic_encoder.blocks.29.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 770 |
+
"vision_tower.semantic_encoder.blocks.29.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 771 |
+
"vision_tower.semantic_encoder.blocks.29.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 772 |
+
"vision_tower.semantic_encoder.blocks.29.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 773 |
+
"vision_tower.semantic_encoder.blocks.3.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 774 |
+
"vision_tower.semantic_encoder.blocks.3.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 775 |
+
"vision_tower.semantic_encoder.blocks.3.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 776 |
+
"vision_tower.semantic_encoder.blocks.3.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 777 |
+
"vision_tower.semantic_encoder.blocks.3.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 778 |
+
"vision_tower.semantic_encoder.blocks.3.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 779 |
+
"vision_tower.semantic_encoder.blocks.3.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 780 |
+
"vision_tower.semantic_encoder.blocks.3.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 781 |
+
"vision_tower.semantic_encoder.blocks.3.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 782 |
+
"vision_tower.semantic_encoder.blocks.3.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 783 |
+
"vision_tower.semantic_encoder.blocks.3.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 784 |
+
"vision_tower.semantic_encoder.blocks.3.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 785 |
+
"vision_tower.semantic_encoder.blocks.30.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 786 |
+
"vision_tower.semantic_encoder.blocks.30.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 787 |
+
"vision_tower.semantic_encoder.blocks.30.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 788 |
+
"vision_tower.semantic_encoder.blocks.30.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 789 |
+
"vision_tower.semantic_encoder.blocks.30.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 790 |
+
"vision_tower.semantic_encoder.blocks.30.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 791 |
+
"vision_tower.semantic_encoder.blocks.30.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 792 |
+
"vision_tower.semantic_encoder.blocks.30.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 793 |
+
"vision_tower.semantic_encoder.blocks.30.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 794 |
+
"vision_tower.semantic_encoder.blocks.30.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 795 |
+
"vision_tower.semantic_encoder.blocks.30.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 796 |
+
"vision_tower.semantic_encoder.blocks.30.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 797 |
+
"vision_tower.semantic_encoder.blocks.31.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 798 |
+
"vision_tower.semantic_encoder.blocks.31.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 799 |
+
"vision_tower.semantic_encoder.blocks.31.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 800 |
+
"vision_tower.semantic_encoder.blocks.31.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 801 |
+
"vision_tower.semantic_encoder.blocks.31.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 802 |
+
"vision_tower.semantic_encoder.blocks.31.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 803 |
+
"vision_tower.semantic_encoder.blocks.31.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 804 |
+
"vision_tower.semantic_encoder.blocks.31.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 805 |
+
"vision_tower.semantic_encoder.blocks.31.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 806 |
+
"vision_tower.semantic_encoder.blocks.31.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 807 |
+
"vision_tower.semantic_encoder.blocks.31.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 808 |
+
"vision_tower.semantic_encoder.blocks.31.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 809 |
+
"vision_tower.semantic_encoder.blocks.4.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 810 |
+
"vision_tower.semantic_encoder.blocks.4.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 811 |
+
"vision_tower.semantic_encoder.blocks.4.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 812 |
+
"vision_tower.semantic_encoder.blocks.4.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 813 |
+
"vision_tower.semantic_encoder.blocks.4.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 814 |
+
"vision_tower.semantic_encoder.blocks.4.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 815 |
+
"vision_tower.semantic_encoder.blocks.4.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 816 |
+
"vision_tower.semantic_encoder.blocks.4.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 817 |
+
"vision_tower.semantic_encoder.blocks.4.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 818 |
+
"vision_tower.semantic_encoder.blocks.4.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 819 |
+
"vision_tower.semantic_encoder.blocks.4.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 820 |
+
"vision_tower.semantic_encoder.blocks.4.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 821 |
+
"vision_tower.semantic_encoder.blocks.5.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 822 |
+
"vision_tower.semantic_encoder.blocks.5.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 823 |
+
"vision_tower.semantic_encoder.blocks.5.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 824 |
+
"vision_tower.semantic_encoder.blocks.5.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 825 |
+
"vision_tower.semantic_encoder.blocks.5.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 826 |
+
"vision_tower.semantic_encoder.blocks.5.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 827 |
+
"vision_tower.semantic_encoder.blocks.5.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 828 |
+
"vision_tower.semantic_encoder.blocks.5.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 829 |
+
"vision_tower.semantic_encoder.blocks.5.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 830 |
+
"vision_tower.semantic_encoder.blocks.5.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 831 |
+
"vision_tower.semantic_encoder.blocks.5.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 832 |
+
"vision_tower.semantic_encoder.blocks.5.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 833 |
+
"vision_tower.semantic_encoder.blocks.6.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 834 |
+
"vision_tower.semantic_encoder.blocks.6.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 835 |
+
"vision_tower.semantic_encoder.blocks.6.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 836 |
+
"vision_tower.semantic_encoder.blocks.6.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 837 |
+
"vision_tower.semantic_encoder.blocks.6.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 838 |
+
"vision_tower.semantic_encoder.blocks.6.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 839 |
+
"vision_tower.semantic_encoder.blocks.6.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 840 |
+
"vision_tower.semantic_encoder.blocks.6.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 841 |
+
"vision_tower.semantic_encoder.blocks.6.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 842 |
+
"vision_tower.semantic_encoder.blocks.6.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 843 |
+
"vision_tower.semantic_encoder.blocks.6.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 844 |
+
"vision_tower.semantic_encoder.blocks.6.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 845 |
+
"vision_tower.semantic_encoder.blocks.7.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 846 |
+
"vision_tower.semantic_encoder.blocks.7.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 847 |
+
"vision_tower.semantic_encoder.blocks.7.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 848 |
+
"vision_tower.semantic_encoder.blocks.7.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 849 |
+
"vision_tower.semantic_encoder.blocks.7.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 850 |
+
"vision_tower.semantic_encoder.blocks.7.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 851 |
+
"vision_tower.semantic_encoder.blocks.7.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 852 |
+
"vision_tower.semantic_encoder.blocks.7.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 853 |
+
"vision_tower.semantic_encoder.blocks.7.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 854 |
+
"vision_tower.semantic_encoder.blocks.7.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 855 |
+
"vision_tower.semantic_encoder.blocks.7.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 856 |
+
"vision_tower.semantic_encoder.blocks.7.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 857 |
+
"vision_tower.semantic_encoder.blocks.8.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 858 |
+
"vision_tower.semantic_encoder.blocks.8.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 859 |
+
"vision_tower.semantic_encoder.blocks.8.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 860 |
+
"vision_tower.semantic_encoder.blocks.8.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 861 |
+
"vision_tower.semantic_encoder.blocks.8.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 862 |
+
"vision_tower.semantic_encoder.blocks.8.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 863 |
+
"vision_tower.semantic_encoder.blocks.8.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 864 |
+
"vision_tower.semantic_encoder.blocks.8.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 865 |
+
"vision_tower.semantic_encoder.blocks.8.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 866 |
+
"vision_tower.semantic_encoder.blocks.8.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 867 |
+
"vision_tower.semantic_encoder.blocks.8.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 868 |
+
"vision_tower.semantic_encoder.blocks.8.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 869 |
+
"vision_tower.semantic_encoder.blocks.9.attn.proj.bias": "pytorch_model-00001-of-00004.bin",
|
| 870 |
+
"vision_tower.semantic_encoder.blocks.9.attn.proj.weight": "pytorch_model-00001-of-00004.bin",
|
| 871 |
+
"vision_tower.semantic_encoder.blocks.9.attn.qkv.bias": "pytorch_model-00001-of-00004.bin",
|
| 872 |
+
"vision_tower.semantic_encoder.blocks.9.attn.qkv.weight": "pytorch_model-00001-of-00004.bin",
|
| 873 |
+
"vision_tower.semantic_encoder.blocks.9.mlp.fc1.bias": "pytorch_model-00001-of-00004.bin",
|
| 874 |
+
"vision_tower.semantic_encoder.blocks.9.mlp.fc1.weight": "pytorch_model-00001-of-00004.bin",
|
| 875 |
+
"vision_tower.semantic_encoder.blocks.9.mlp.fc2.bias": "pytorch_model-00001-of-00004.bin",
|
| 876 |
+
"vision_tower.semantic_encoder.blocks.9.mlp.fc2.weight": "pytorch_model-00001-of-00004.bin",
|
| 877 |
+
"vision_tower.semantic_encoder.blocks.9.norm1.bias": "pytorch_model-00001-of-00004.bin",
|
| 878 |
+
"vision_tower.semantic_encoder.blocks.9.norm1.weight": "pytorch_model-00001-of-00004.bin",
|
| 879 |
+
"vision_tower.semantic_encoder.blocks.9.norm2.bias": "pytorch_model-00001-of-00004.bin",
|
| 880 |
+
"vision_tower.semantic_encoder.blocks.9.norm2.weight": "pytorch_model-00001-of-00004.bin",
|
| 881 |
+
"vision_tower.semantic_encoder.merger.ln_q.bias": "pytorch_model-00001-of-00004.bin",
|
| 882 |
+
"vision_tower.semantic_encoder.merger.ln_q.weight": "pytorch_model-00001-of-00004.bin",
|
| 883 |
+
"vision_tower.semantic_encoder.merger.mlp.0.bias": "pytorch_model-00001-of-00004.bin",
|
| 884 |
+
"vision_tower.semantic_encoder.merger.mlp.0.weight": "pytorch_model-00001-of-00004.bin",
|
| 885 |
+
"vision_tower.semantic_encoder.merger.mlp.2.bias": "pytorch_model-00001-of-00004.bin",
|
| 886 |
+
"vision_tower.semantic_encoder.merger.mlp.2.weight": "pytorch_model-00001-of-00004.bin",
|
| 887 |
+
"vision_tower.semantic_encoder.patch_embed.proj.weight": "pytorch_model-00001-of-00004.bin"
|
| 888 |
+
}
|
| 889 |
+
}
|
sdxl_decoder_pipe.py
ADDED
|
@@ -0,0 +1,901 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modify from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
|
| 2 |
+
import inspect
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from einops import repeat, rearrange
|
| 12 |
+
|
| 13 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 14 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 15 |
+
|
| 16 |
+
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
| 17 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
| 18 |
+
|
| 19 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 20 |
+
import PIL.Image
|
| 21 |
+
|
| 22 |
+
from diffusers.models.attention_processor import (
|
| 23 |
+
AttnProcessor2_0,
|
| 24 |
+
FusedAttnProcessor2_0,
|
| 25 |
+
XFormersAttnProcessor,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
from diffusers.utils import (
|
| 29 |
+
USE_PEFT_BACKEND,
|
| 30 |
+
deprecate,
|
| 31 |
+
is_invisible_watermark_available,
|
| 32 |
+
is_torch_xla_available,
|
| 33 |
+
logging,
|
| 34 |
+
replace_example_docstring,
|
| 35 |
+
scale_lora_layers,
|
| 36 |
+
unscale_lora_layers,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
| 40 |
+
from diffusers.loaders import (
|
| 41 |
+
FromSingleFileMixin,
|
| 42 |
+
IPAdapterMixin,
|
| 43 |
+
StableDiffusionXLLoraLoaderMixin,
|
| 44 |
+
TextualInversionLoaderMixin,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
if is_invisible_watermark_available():
|
| 48 |
+
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
| 49 |
+
|
| 50 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
| 51 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline, \
|
| 52 |
+
retrieve_timesteps, rescale_noise_cfg
|
| 53 |
+
|
| 54 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, Normalize, InterpolationMode
|
| 55 |
+
|
| 56 |
+
if is_torch_xla_available():
|
| 57 |
+
import torch_xla.core.xla_model as xm
|
| 58 |
+
|
| 59 |
+
XLA_AVAILABLE = True
|
| 60 |
+
else:
|
| 61 |
+
XLA_AVAILABLE = False
|
| 62 |
+
|
| 63 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class StableDiffusionXLDecoderPipelineOutput(StableDiffusionXLPipelineOutput):
|
| 68 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
| 69 |
+
indices_semantic: Optional[torch.Tensor] = None
|
| 70 |
+
indices_pixel: Optional[torch.Tensor] = None
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def expand_dims_like(x, y):
|
| 74 |
+
while x.dim() != y.dim():
|
| 75 |
+
x = x.unsqueeze(-1)
|
| 76 |
+
return x
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class AbstractEmbModel(nn.Module):
|
| 80 |
+
def __init__(self):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self._is_trainable = None
|
| 83 |
+
self._ucg_rate = None
|
| 84 |
+
self._input_key = None
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def is_trainable(self) -> bool:
|
| 88 |
+
return self._is_trainable
|
| 89 |
+
|
| 90 |
+
@property
|
| 91 |
+
def ucg_rate(self) -> Union[float, torch.Tensor]:
|
| 92 |
+
return self._ucg_rate
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def input_key(self) -> str:
|
| 96 |
+
return self._input_key
|
| 97 |
+
|
| 98 |
+
@is_trainable.setter
|
| 99 |
+
def is_trainable(self, value: bool):
|
| 100 |
+
self._is_trainable = value
|
| 101 |
+
|
| 102 |
+
@ucg_rate.setter
|
| 103 |
+
def ucg_rate(self, value: Union[float, torch.Tensor]):
|
| 104 |
+
self._ucg_rate = value
|
| 105 |
+
|
| 106 |
+
@input_key.setter
|
| 107 |
+
def input_key(self, value: str):
|
| 108 |
+
self._input_key = value
|
| 109 |
+
|
| 110 |
+
@is_trainable.deleter
|
| 111 |
+
def is_trainable(self):
|
| 112 |
+
del self._is_trainable
|
| 113 |
+
|
| 114 |
+
@ucg_rate.deleter
|
| 115 |
+
def ucg_rate(self):
|
| 116 |
+
del self._ucg_rate
|
| 117 |
+
|
| 118 |
+
@input_key.deleter
|
| 119 |
+
def input_key(self):
|
| 120 |
+
del self._input_key
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class DualViTok2ImageEmbedder(AbstractEmbModel):
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
image_processor=None,
|
| 127 |
+
vq_model=None,
|
| 128 |
+
device="cuda",
|
| 129 |
+
dtype=torch.float32,
|
| 130 |
+
freeze=True,
|
| 131 |
+
image_size=0,
|
| 132 |
+
resize_factor=1,
|
| 133 |
+
not_bicubic=True,
|
| 134 |
+
return_sequence=False,
|
| 135 |
+
grid_feature_scale=1,
|
| 136 |
+
texture_drop_prob=0,
|
| 137 |
+
semantic_drop_prob=0,
|
| 138 |
+
pixel_channel=32,
|
| 139 |
+
semantic_channel=32,
|
| 140 |
+
):
|
| 141 |
+
super().__init__()
|
| 142 |
+
vq_model.to(device=device, dtype=dtype)
|
| 143 |
+
vq_model.eval()
|
| 144 |
+
|
| 145 |
+
self.processor = image_processor
|
| 146 |
+
|
| 147 |
+
self.model = vq_model
|
| 148 |
+
self.device = device
|
| 149 |
+
if freeze:
|
| 150 |
+
self.freeze()
|
| 151 |
+
|
| 152 |
+
if image_size > 0:
|
| 153 |
+
preprocessor = [
|
| 154 |
+
Resize(image_size) if not_bicubic else Resize(image_size, interpolation=InterpolationMode.BICUBIC)]
|
| 155 |
+
preprocessor += [
|
| 156 |
+
CenterCrop(image_size),
|
| 157 |
+
]
|
| 158 |
+
self.preprocessor = Compose(preprocessor)
|
| 159 |
+
self.image_size = image_size
|
| 160 |
+
self.resize_factor = resize_factor
|
| 161 |
+
self.not_bicubic = not_bicubic
|
| 162 |
+
self.return_sequence = return_sequence
|
| 163 |
+
self.grid_feature_scale = grid_feature_scale
|
| 164 |
+
self.texture_drop_prob = texture_drop_prob
|
| 165 |
+
self.semantic_drop_prob = semantic_drop_prob
|
| 166 |
+
self.pixel_channel = pixel_channel
|
| 167 |
+
self.semantic_channel = semantic_channel
|
| 168 |
+
|
| 169 |
+
def freeze(self):
|
| 170 |
+
self.model = self.model.eval()
|
| 171 |
+
for param in self.parameters():
|
| 172 |
+
param.requires_grad = False
|
| 173 |
+
|
| 174 |
+
def vq_encode(self, image):
|
| 175 |
+
if image.ndim == 5:
|
| 176 |
+
assert image.size(1) == 1
|
| 177 |
+
image = image.squeeze(1)
|
| 178 |
+
bs, _, h, w = image.shape
|
| 179 |
+
|
| 180 |
+
if self.image_size > 0:
|
| 181 |
+
image = self.preprocessor(image)
|
| 182 |
+
else:
|
| 183 |
+
assert self.resize_factor > 0
|
| 184 |
+
preprocessor = Resize((int(h * self.resize_factor), int(w * self.resize_factor))) if self.not_bicubic else \
|
| 185 |
+
Resize((int(h * self.resize_factor), int(w * self.resize_factor)),
|
| 186 |
+
interpolation=InterpolationMode.BICUBIC)
|
| 187 |
+
image = preprocessor(image)
|
| 188 |
+
|
| 189 |
+
inputs = dict(image=image)
|
| 190 |
+
inputs = self.model.get_input(inputs)
|
| 191 |
+
|
| 192 |
+
(quant_semantic, diff_semantic, indices_semantic, target_semantic), \
|
| 193 |
+
(quant_pixel, diff_pixel, indices_pixel) = self.model.encode(**inputs)
|
| 194 |
+
return indices_semantic, indices_pixel
|
| 195 |
+
|
| 196 |
+
def vq_encode_code(self, image):
|
| 197 |
+
(quant_semantic, diff_semantic, indices_semantic, target_semantic), \
|
| 198 |
+
(quant_pixel, diff_pixel, indices_pixel) = self.vq_encode(image)
|
| 199 |
+
return indices_semantic, indices_pixel
|
| 200 |
+
|
| 201 |
+
def vq_decode_code(self, indices_semantic, indices_pixel):
|
| 202 |
+
return self.model.decode_code(indices_semantic, indices_pixel)
|
| 203 |
+
|
| 204 |
+
def forward(self, image, return_indices=False):
|
| 205 |
+
if image.ndim == 5:
|
| 206 |
+
assert image.size(1) == 1
|
| 207 |
+
image = image.squeeze(1)
|
| 208 |
+
bs, _, h, w = image.shape
|
| 209 |
+
|
| 210 |
+
if self.image_size > 0:
|
| 211 |
+
image = self.preprocessor(image)
|
| 212 |
+
else:
|
| 213 |
+
assert self.resize_factor > 0
|
| 214 |
+
preprocessor = Resize((int(h * self.resize_factor), int(w * self.resize_factor))) if self.not_bicubic else \
|
| 215 |
+
Resize((int(h * self.resize_factor), int(w * self.resize_factor)),
|
| 216 |
+
interpolation=InterpolationMode.BICUBIC)
|
| 217 |
+
image = preprocessor(image)
|
| 218 |
+
|
| 219 |
+
inputs = dict(image=image)
|
| 220 |
+
inputs = self.model.get_input(inputs)
|
| 221 |
+
|
| 222 |
+
(quant_semantic, diff_semantic, indices_semantic, target_semantic), \
|
| 223 |
+
(quant_pixel, diff_pixel, indices_pixel) = self.model.encode(**inputs)
|
| 224 |
+
|
| 225 |
+
feature = self.model.merge_quants(quant_semantic, quant_pixel)
|
| 226 |
+
|
| 227 |
+
if self.return_sequence:
|
| 228 |
+
feature = rearrange(feature, 'b c h w -> b h w c')
|
| 229 |
+
_, this_h, this_w, _ = feature.shape
|
| 230 |
+
feature = feature.view(bs, this_w * this_w, -1)
|
| 231 |
+
else:
|
| 232 |
+
feature = feature * self.grid_feature_scale
|
| 233 |
+
|
| 234 |
+
if return_indices:
|
| 235 |
+
return feature, indices_semantic, indices_pixel
|
| 236 |
+
|
| 237 |
+
return feature
|
| 238 |
+
|
| 239 |
+
def encode(self, img):
|
| 240 |
+
return self(img)
|
| 241 |
+
|
| 242 |
+
def indices_to_codes(self, semantic_indices, texture_indices):
|
| 243 |
+
quant_semantic, quant_texture = self.model.indices_to_codes(semantic_indices, texture_indices)
|
| 244 |
+
return self.model.merge_quants(quant_semantic, quant_texture)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class StableDiffusionXLDecoderPipeline(
|
| 248 |
+
DiffusionPipeline,
|
| 249 |
+
StableDiffusionMixin,
|
| 250 |
+
FromSingleFileMixin,
|
| 251 |
+
StableDiffusionXLLoraLoaderMixin,
|
| 252 |
+
TextualInversionLoaderMixin,
|
| 253 |
+
):
|
| 254 |
+
model_cpu_offload_seq = "vq_model_embedder->unet->vae"
|
| 255 |
+
_optional_components = [
|
| 256 |
+
"vq_model_embedder",
|
| 257 |
+
]
|
| 258 |
+
_callback_tensor_inputs = [
|
| 259 |
+
"latents",
|
| 260 |
+
"prompt_embeds",
|
| 261 |
+
"negative_prompt_embeds",
|
| 262 |
+
"add_text_embeds",
|
| 263 |
+
"add_time_ids",
|
| 264 |
+
"negative_pooled_prompt_embeds",
|
| 265 |
+
"negative_add_time_ids",
|
| 266 |
+
]
|
| 267 |
+
|
| 268 |
+
def __init__(
|
| 269 |
+
self,
|
| 270 |
+
vae: AutoencoderKL,
|
| 271 |
+
unet: UNet2DConditionModel,
|
| 272 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 273 |
+
force_zeros_for_empty_prompt: bool = True,
|
| 274 |
+
add_watermarker: Optional[bool] = None,
|
| 275 |
+
vq_image_processor=None,
|
| 276 |
+
vq_model=None,
|
| 277 |
+
):
|
| 278 |
+
super().__init__()
|
| 279 |
+
|
| 280 |
+
self.register_modules(
|
| 281 |
+
vae=vae,
|
| 282 |
+
unet=unet,
|
| 283 |
+
scheduler=scheduler,
|
| 284 |
+
)
|
| 285 |
+
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
| 286 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 287 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 288 |
+
|
| 289 |
+
self.default_sample_size = self.unet.config.sample_size
|
| 290 |
+
|
| 291 |
+
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
| 292 |
+
|
| 293 |
+
if add_watermarker:
|
| 294 |
+
self.watermark = StableDiffusionXLWatermarker()
|
| 295 |
+
else:
|
| 296 |
+
self.watermark = None
|
| 297 |
+
|
| 298 |
+
self.empty_prompt_embeds = torch.zeros([1, 77, 2048]).to(device=unet.device, dtype=unet.dtype)
|
| 299 |
+
self.empty_pooled_prompt_embeds = torch.zeros([1, 1280]).to(device=unet.device, dtype=unet.dtype)
|
| 300 |
+
self.dualvitok_channels = vq_model.pixel_channel + vq_model.semantic_channel
|
| 301 |
+
|
| 302 |
+
self.resolution_group = ['(1024, 1024)', '(768, 1024)', '(1024, 768)', '(512, 2048)', '(2048, 512)',
|
| 303 |
+
'(640, 1920)', '(1920, 640)', '(768, 1536)', '(1536, 768)', '(768, 1152)',
|
| 304 |
+
'(1152, 768)', '(512, 512)']
|
| 305 |
+
|
| 306 |
+
embedder_kwargs = dict(image_size=0,
|
| 307 |
+
resize_factor=1,
|
| 308 |
+
return_sequence=False,
|
| 309 |
+
grid_feature_scale=1)
|
| 310 |
+
if isinstance(vq_model, DualViTok2ImageEmbedder):
|
| 311 |
+
self.vq_model_embedder = vq_model
|
| 312 |
+
else:
|
| 313 |
+
self.vq_model_embedder = DualViTok2ImageEmbedder(vq_image_processor, vq_model, **embedder_kwargs)
|
| 314 |
+
|
| 315 |
+
def vq_encode(self, image):
|
| 316 |
+
return self.vq_model_embedder.encode(image)
|
| 317 |
+
|
| 318 |
+
def vq_encode_code(self, image):
|
| 319 |
+
return self.vq_model_embedder.vq_encode_code(image)
|
| 320 |
+
|
| 321 |
+
def vq_decode_code(self, *args, **kwargs):
|
| 322 |
+
return self.vq_model_embedder.vq_decode_code(*args, **kwargs)
|
| 323 |
+
|
| 324 |
+
def indices_to_codes(self, *args, **kwargs):
|
| 325 |
+
return self.vq_model_embedder.indices_to_codes(*args, **kwargs)
|
| 326 |
+
|
| 327 |
+
def _get_add_time_ids(
|
| 328 |
+
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None,
|
| 329 |
+
resolution_index=None,
|
| 330 |
+
):
|
| 331 |
+
add_time_ids = [resolution_index] * 6
|
| 332 |
+
|
| 333 |
+
passed_add_embed_dim = (
|
| 334 |
+
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
| 335 |
+
)
|
| 336 |
+
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
| 337 |
+
|
| 338 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
| 339 |
+
raise ValueError(
|
| 340 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
| 344 |
+
return add_time_ids
|
| 345 |
+
|
| 346 |
+
def check_inputs(
|
| 347 |
+
self,
|
| 348 |
+
height,
|
| 349 |
+
width,
|
| 350 |
+
callback_steps,
|
| 351 |
+
callback_on_step_end_tensor_inputs=None,
|
| 352 |
+
):
|
| 353 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 354 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 355 |
+
|
| 356 |
+
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
| 357 |
+
raise ValueError(
|
| 358 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 359 |
+
f" {type(callback_steps)}."
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 363 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 364 |
+
):
|
| 365 |
+
raise ValueError(
|
| 366 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
| 370 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
| 371 |
+
shape = (
|
| 372 |
+
batch_size,
|
| 373 |
+
num_channels_latents,
|
| 374 |
+
int(height) // self.vae_scale_factor,
|
| 375 |
+
int(width) // self.vae_scale_factor,
|
| 376 |
+
)
|
| 377 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 378 |
+
raise ValueError(
|
| 379 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 380 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
if latents is None:
|
| 384 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 385 |
+
else:
|
| 386 |
+
latents = latents.to(device)
|
| 387 |
+
|
| 388 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 389 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 390 |
+
return latents
|
| 391 |
+
|
| 392 |
+
def upcast_vae(self):
|
| 393 |
+
dtype = self.vae.dtype
|
| 394 |
+
self.vae.to(dtype=torch.float32)
|
| 395 |
+
use_torch_2_0_or_xformers = isinstance(
|
| 396 |
+
self.vae.decoder.mid_block.attentions[0].processor,
|
| 397 |
+
(
|
| 398 |
+
AttnProcessor2_0,
|
| 399 |
+
XFormersAttnProcessor,
|
| 400 |
+
FusedAttnProcessor2_0,
|
| 401 |
+
),
|
| 402 |
+
)
|
| 403 |
+
# if xformers or torch_2_0 is used attention block does not need
|
| 404 |
+
# to be in float32 which can save lots of memory
|
| 405 |
+
if use_torch_2_0_or_xformers:
|
| 406 |
+
self.vae.post_quant_conv.to(dtype)
|
| 407 |
+
self.vae.decoder.conv_in.to(dtype)
|
| 408 |
+
self.vae.decoder.mid_block.to(dtype)
|
| 409 |
+
|
| 410 |
+
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
| 411 |
+
def get_guidance_scale_embedding(
|
| 412 |
+
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
|
| 413 |
+
) -> torch.Tensor:
|
| 414 |
+
"""
|
| 415 |
+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
| 416 |
+
|
| 417 |
+
Args:
|
| 418 |
+
w (`torch.Tensor`):
|
| 419 |
+
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
|
| 420 |
+
embedding_dim (`int`, *optional*, defaults to 512):
|
| 421 |
+
Dimension of the embeddings to generate.
|
| 422 |
+
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
| 423 |
+
Data type of the generated embeddings.
|
| 424 |
+
|
| 425 |
+
Returns:
|
| 426 |
+
`torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
|
| 427 |
+
"""
|
| 428 |
+
assert len(w.shape) == 1
|
| 429 |
+
w = w * 1000.0
|
| 430 |
+
|
| 431 |
+
half_dim = embedding_dim // 2
|
| 432 |
+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
| 433 |
+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
| 434 |
+
emb = w.to(dtype)[:, None] * emb[None, :]
|
| 435 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
| 436 |
+
if embedding_dim % 2 == 1: # zero pad
|
| 437 |
+
emb = torch.nn.functional.pad(emb, (0, 1))
|
| 438 |
+
assert emb.shape == (w.shape[0], embedding_dim)
|
| 439 |
+
return emb
|
| 440 |
+
|
| 441 |
+
@property
|
| 442 |
+
def guidance_scale(self):
|
| 443 |
+
return self._guidance_scale
|
| 444 |
+
|
| 445 |
+
@property
|
| 446 |
+
def guidance_rescale(self):
|
| 447 |
+
return self._guidance_rescale
|
| 448 |
+
|
| 449 |
+
@property
|
| 450 |
+
def clip_skip(self):
|
| 451 |
+
return self._clip_skip
|
| 452 |
+
|
| 453 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 454 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 455 |
+
# corresponds to doing no classifier free guidance.
|
| 456 |
+
@property
|
| 457 |
+
def do_classifier_free_guidance(self):
|
| 458 |
+
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
| 459 |
+
|
| 460 |
+
@property
|
| 461 |
+
def cross_attention_kwargs(self):
|
| 462 |
+
return self._cross_attention_kwargs
|
| 463 |
+
|
| 464 |
+
@property
|
| 465 |
+
def denoising_end(self):
|
| 466 |
+
return self._denoising_end
|
| 467 |
+
|
| 468 |
+
@property
|
| 469 |
+
def num_timesteps(self):
|
| 470 |
+
return self._num_timesteps
|
| 471 |
+
|
| 472 |
+
@property
|
| 473 |
+
def interrupt(self):
|
| 474 |
+
return self._interrupt
|
| 475 |
+
|
| 476 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 477 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 478 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 479 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 480 |
+
# and should be between [0, 1]
|
| 481 |
+
|
| 482 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 483 |
+
extra_step_kwargs = {}
|
| 484 |
+
if accepts_eta:
|
| 485 |
+
extra_step_kwargs["eta"] = eta
|
| 486 |
+
|
| 487 |
+
# check if the scheduler accepts generator
|
| 488 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 489 |
+
if accepts_generator:
|
| 490 |
+
extra_step_kwargs["generator"] = generator
|
| 491 |
+
return extra_step_kwargs
|
| 492 |
+
|
| 493 |
+
@torch.no_grad()
|
| 494 |
+
def __call__(
|
| 495 |
+
self,
|
| 496 |
+
vq_indices: Optional[List] = None,
|
| 497 |
+
vq_embeds: Optional[torch.Tensor] = None,
|
| 498 |
+
images: Optional[PipelineImageInput] = None,
|
| 499 |
+
height: Optional[int] = None,
|
| 500 |
+
width: Optional[int] = None,
|
| 501 |
+
num_inference_steps: int = 50,
|
| 502 |
+
timesteps: List[int] = None,
|
| 503 |
+
sigmas: List[float] = None,
|
| 504 |
+
denoising_end: Optional[float] = None,
|
| 505 |
+
guidance_scale: float = 2.0,
|
| 506 |
+
eta: float = 0.0,
|
| 507 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 508 |
+
latents: Optional[torch.Tensor] = None,
|
| 509 |
+
output_type: Optional[str] = "pil",
|
| 510 |
+
return_dict: bool = True,
|
| 511 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 512 |
+
guidance_rescale: float = 0.0,
|
| 513 |
+
original_size: Optional[Tuple[int, int]] = None,
|
| 514 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 515 |
+
target_size: Optional[Tuple[int, int]] = None,
|
| 516 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
| 517 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 518 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
| 519 |
+
clip_skip: Optional[int] = None,
|
| 520 |
+
callback_on_step_end: Optional[
|
| 521 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 522 |
+
] = None,
|
| 523 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 524 |
+
**kwargs,
|
| 525 |
+
):
|
| 526 |
+
r"""
|
| 527 |
+
Function invoked when calling the pipeline for generation.
|
| 528 |
+
|
| 529 |
+
Args:
|
| 530 |
+
vq_indices (`Optional[PipelineImageInput]`, *optional*):
|
| 531 |
+
The VQ indices for semantic and pixel tokens. Should be a tuple of (semantic_indices, pixel_indices).
|
| 532 |
+
images (`Optional[PipelineImageInput]`, *optional*):
|
| 533 |
+
Input images in range [-1, 1] as torch.Tensor with shape (batch_size, channels, height, width).
|
| 534 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 535 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 536 |
+
Anything below 512 pixels won't work well for
|
| 537 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
| 538 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
| 539 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 540 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 541 |
+
Anything below 512 pixels won't work well for
|
| 542 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
| 543 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
| 544 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 545 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 546 |
+
expense of slower inference.
|
| 547 |
+
timesteps (`List[int]`, *optional*):
|
| 548 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 549 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 550 |
+
passed will be used. Must be in descending order.
|
| 551 |
+
sigmas (`List[float]`, *optional*):
|
| 552 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 553 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 554 |
+
will be used.
|
| 555 |
+
denoising_end (`float`, *optional*):
|
| 556 |
+
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
| 557 |
+
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
| 558 |
+
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
|
| 559 |
+
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
|
| 560 |
+
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
| 561 |
+
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
|
| 562 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
| 563 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 564 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 565 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 566 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 567 |
+
usually at the expense of lower image quality.
|
| 568 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 569 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 570 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 571 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 572 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 573 |
+
to make generation deterministic.
|
| 574 |
+
latents (`torch.Tensor`, *optional*):
|
| 575 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 576 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 577 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 578 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 579 |
+
The output format of the generate image. Choose between
|
| 580 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 581 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 582 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
| 583 |
+
of a plain tuple.
|
| 584 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 585 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 586 |
+
`self.processor` in
|
| 587 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 588 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
| 589 |
+
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
| 590 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
| 591 |
+
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
| 592 |
+
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
| 593 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 594 |
+
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
| 595 |
+
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
| 596 |
+
explained in section 2.2 of
|
| 597 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
| 598 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
| 599 |
+
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
| 600 |
+
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
| 601 |
+
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
| 602 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
| 603 |
+
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 604 |
+
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
| 605 |
+
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
| 606 |
+
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
| 607 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
| 608 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
| 609 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
| 610 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
| 611 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
| 612 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 613 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 614 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 615 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 616 |
+
|
| 617 |
+
Examples:
|
| 618 |
+
|
| 619 |
+
Returns:
|
| 620 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
| 621 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
| 622 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 623 |
+
"""
|
| 624 |
+
|
| 625 |
+
callback = kwargs.pop("callback", None)
|
| 626 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
| 627 |
+
|
| 628 |
+
if callback is not None:
|
| 629 |
+
deprecate(
|
| 630 |
+
"callback",
|
| 631 |
+
"1.0.0",
|
| 632 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
| 633 |
+
)
|
| 634 |
+
if callback_steps is not None:
|
| 635 |
+
deprecate(
|
| 636 |
+
"callback_steps",
|
| 637 |
+
"1.0.0",
|
| 638 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 642 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 643 |
+
|
| 644 |
+
# 0. Default height and width to unet
|
| 645 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 646 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 647 |
+
|
| 648 |
+
original_size = original_size or (height, width)
|
| 649 |
+
target_size = target_size or (height, width)
|
| 650 |
+
|
| 651 |
+
# 1. Check inputs. Raise error if not correct
|
| 652 |
+
self.check_inputs(
|
| 653 |
+
height,
|
| 654 |
+
width,
|
| 655 |
+
callback_steps,
|
| 656 |
+
callback_on_step_end_tensor_inputs,
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
self._guidance_scale = guidance_scale
|
| 660 |
+
self._guidance_rescale = guidance_rescale
|
| 661 |
+
self._clip_skip = clip_skip
|
| 662 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 663 |
+
self._denoising_end = denoising_end
|
| 664 |
+
self._interrupt = False
|
| 665 |
+
|
| 666 |
+
# 2. encode vq_embeds
|
| 667 |
+
assert images is not None or vq_indices is not None or vq_embeds is not None
|
| 668 |
+
batch_size = len(images) if images is not None else len(vq_indices[0])
|
| 669 |
+
|
| 670 |
+
if images:
|
| 671 |
+
vq_embeds, indices_semantic, indices_pixel = self.vq_model_embedder(images, return_indices=True)
|
| 672 |
+
elif vq_indices:
|
| 673 |
+
indices_semantic, indices_pixel = vq_indices[0], vq_indices[1]
|
| 674 |
+
vq_embeds = self.vq_model_embedder.indices_to_codes(vq_indices[0], vq_indices[1])
|
| 675 |
+
elif vq_embeds:
|
| 676 |
+
if isinstance(vq_embeds, list):
|
| 677 |
+
vq_embeds = self.vq_model_embedder.merge_quants(vq_embeds)
|
| 678 |
+
indices_semantic, indices_pixel = None, None
|
| 679 |
+
else:
|
| 680 |
+
raise ValueError("No valid input provided")
|
| 681 |
+
|
| 682 |
+
device = self._execution_device
|
| 683 |
+
|
| 684 |
+
# 3. Encode input prompt
|
| 685 |
+
lora_scale = (
|
| 686 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
prompt_embeds = repeat(self.empty_prompt_embeds, '1 l c -> b l c', b=batch_size)
|
| 690 |
+
pooled_prompt_embeds = repeat(self.empty_pooled_prompt_embeds, '1 c -> b c', b=batch_size)
|
| 691 |
+
|
| 692 |
+
negative_prompt_embeds = prompt_embeds
|
| 693 |
+
negative_pooled_prompt_embeds = pooled_prompt_embeds
|
| 694 |
+
|
| 695 |
+
# 4. Prepare timesteps
|
| 696 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 697 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
# 5. Prepare latent variables
|
| 701 |
+
# num_channels_latents = self.unet.config.in_channels
|
| 702 |
+
num_channels_latents = 4
|
| 703 |
+
latents = self.prepare_latents(
|
| 704 |
+
batch_size,
|
| 705 |
+
num_channels_latents,
|
| 706 |
+
height,
|
| 707 |
+
width,
|
| 708 |
+
prompt_embeds.dtype,
|
| 709 |
+
device,
|
| 710 |
+
generator,
|
| 711 |
+
latents,
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 715 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 716 |
+
|
| 717 |
+
# 7. Prepare added time ids & embeddings
|
| 718 |
+
add_text_embeds = pooled_prompt_embeds
|
| 719 |
+
text_encoder_projection_dim = 1280
|
| 720 |
+
|
| 721 |
+
resolution = f'({width}, {height})'
|
| 722 |
+
assert resolution in self.resolution_group, f"resolution are not in resolution group. Got {resolution}. Candidates:{self.resolution_group}"
|
| 723 |
+
resolution_index = self.resolution_group.index(resolution)
|
| 724 |
+
# resolution_index = None
|
| 725 |
+
|
| 726 |
+
add_time_ids = self._get_add_time_ids(
|
| 727 |
+
original_size,
|
| 728 |
+
crops_coords_top_left,
|
| 729 |
+
target_size,
|
| 730 |
+
dtype=prompt_embeds.dtype,
|
| 731 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
| 732 |
+
resolution_index=resolution_index,
|
| 733 |
+
)
|
| 734 |
+
if negative_original_size is not None and negative_target_size is not None:
|
| 735 |
+
negative_add_time_ids = self._get_add_time_ids(
|
| 736 |
+
negative_original_size,
|
| 737 |
+
negative_crops_coords_top_left,
|
| 738 |
+
negative_target_size,
|
| 739 |
+
dtype=prompt_embeds.dtype,
|
| 740 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
| 741 |
+
)
|
| 742 |
+
else:
|
| 743 |
+
negative_add_time_ids = add_time_ids
|
| 744 |
+
|
| 745 |
+
if self.do_classifier_free_guidance:
|
| 746 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 747 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
| 748 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
| 749 |
+
|
| 750 |
+
prompt_embeds = prompt_embeds.to(device)
|
| 751 |
+
add_text_embeds = add_text_embeds.to(device)
|
| 752 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size, 1)
|
| 753 |
+
|
| 754 |
+
# 8. Denoising loop
|
| 755 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 756 |
+
|
| 757 |
+
# 8.1 Apply denoising_end
|
| 758 |
+
if (
|
| 759 |
+
self.denoising_end is not None
|
| 760 |
+
and isinstance(self.denoising_end, float)
|
| 761 |
+
and self.denoising_end > 0
|
| 762 |
+
and self.denoising_end < 1
|
| 763 |
+
):
|
| 764 |
+
discrete_timestep_cutoff = int(
|
| 765 |
+
round(
|
| 766 |
+
self.scheduler.config.num_train_timesteps
|
| 767 |
+
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
|
| 768 |
+
)
|
| 769 |
+
)
|
| 770 |
+
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
| 771 |
+
timesteps = timesteps[:num_inference_steps]
|
| 772 |
+
|
| 773 |
+
# 9. Optionally get Guidance Scale Embedding
|
| 774 |
+
timestep_cond = None
|
| 775 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
| 776 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size)
|
| 777 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
| 778 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
| 779 |
+
).to(device=device, dtype=latents.dtype)
|
| 780 |
+
|
| 781 |
+
self._num_timesteps = len(timesteps)
|
| 782 |
+
# with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 783 |
+
for i, t in enumerate(timesteps):
|
| 784 |
+
if self.interrupt:
|
| 785 |
+
continue
|
| 786 |
+
|
| 787 |
+
# expand the latents if we are doing classifier free guidance
|
| 788 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 789 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 790 |
+
|
| 791 |
+
vq_embeds = vq_embeds.to(latent_model_input) if vq_embeds.size(
|
| 792 |
+
-1) == latent_model_input.size(
|
| 793 |
+
-1) else \
|
| 794 |
+
torch.nn.functional.interpolate(vq_embeds.to(latent_model_input),
|
| 795 |
+
size=latent_model_input.shape[-2:])
|
| 796 |
+
vq_embeds_input = torch.cat([torch.zeros_like(vq_embeds),
|
| 797 |
+
vq_embeds]) if self.do_classifier_free_guidance else vq_embeds
|
| 798 |
+
|
| 799 |
+
# predict the noise residual
|
| 800 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
| 801 |
+
|
| 802 |
+
latent_model_input = torch.cat([latent_model_input, vq_embeds_input], dim=1)
|
| 803 |
+
noise_pred = self.unet(
|
| 804 |
+
latent_model_input,
|
| 805 |
+
t,
|
| 806 |
+
encoder_hidden_states=prompt_embeds,
|
| 807 |
+
timestep_cond=timestep_cond,
|
| 808 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 809 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 810 |
+
return_dict=False,
|
| 811 |
+
)[0]
|
| 812 |
+
|
| 813 |
+
# perform guidance
|
| 814 |
+
if self.do_classifier_free_guidance:
|
| 815 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
| 816 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 817 |
+
|
| 818 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
| 819 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
| 820 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_cond, guidance_rescale=self.guidance_rescale)
|
| 821 |
+
|
| 822 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 823 |
+
latents_dtype = latents.dtype
|
| 824 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 825 |
+
if latents.dtype != latents_dtype:
|
| 826 |
+
if torch.backends.mps.is_available():
|
| 827 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 828 |
+
latents = latents.to(latents_dtype)
|
| 829 |
+
|
| 830 |
+
if callback_on_step_end is not None:
|
| 831 |
+
callback_kwargs = {}
|
| 832 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 833 |
+
callback_kwargs[k] = locals()[k]
|
| 834 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 835 |
+
|
| 836 |
+
latents = callback_outputs.pop("latents", latents)
|
| 837 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 838 |
+
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
| 839 |
+
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
| 840 |
+
|
| 841 |
+
# call the callback, if provided
|
| 842 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 843 |
+
# progress_bar.update()
|
| 844 |
+
if callback is not None and i % callback_steps == 0:
|
| 845 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 846 |
+
callback(step_idx, t, latents)
|
| 847 |
+
|
| 848 |
+
if XLA_AVAILABLE:
|
| 849 |
+
xm.mark_step()
|
| 850 |
+
|
| 851 |
+
if not output_type == "latent":
|
| 852 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
| 853 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
| 854 |
+
|
| 855 |
+
if needs_upcasting:
|
| 856 |
+
self.upcast_vae()
|
| 857 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
| 858 |
+
elif latents.dtype != self.vae.dtype:
|
| 859 |
+
if torch.backends.mps.is_available():
|
| 860 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 861 |
+
self.vae = self.vae.to(latents.dtype)
|
| 862 |
+
|
| 863 |
+
# unscale/denormalize the latents
|
| 864 |
+
# denormalize with the mean and std if available and not None
|
| 865 |
+
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
| 866 |
+
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
| 867 |
+
if has_latents_mean and has_latents_std:
|
| 868 |
+
latents_mean = (
|
| 869 |
+
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
| 870 |
+
)
|
| 871 |
+
latents_std = (
|
| 872 |
+
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
| 873 |
+
)
|
| 874 |
+
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
| 875 |
+
else:
|
| 876 |
+
latents = latents / self.vae.config.scaling_factor
|
| 877 |
+
|
| 878 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 879 |
+
|
| 880 |
+
# cast back to fp16 if needed
|
| 881 |
+
if needs_upcasting:
|
| 882 |
+
self.vae.to(dtype=torch.float16)
|
| 883 |
+
else:
|
| 884 |
+
image = latents
|
| 885 |
+
|
| 886 |
+
if not output_type == "latent":
|
| 887 |
+
# apply watermark if available
|
| 888 |
+
if self.watermark is not None:
|
| 889 |
+
image = self.watermark.apply_watermark(image)
|
| 890 |
+
|
| 891 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 892 |
+
|
| 893 |
+
# Offload all models
|
| 894 |
+
self.maybe_free_model_hooks()
|
| 895 |
+
|
| 896 |
+
if not return_dict:
|
| 897 |
+
return (image,)
|
| 898 |
+
|
| 899 |
+
return StableDiffusionXLDecoderPipelineOutput(images=image,
|
| 900 |
+
indices_semantic=indices_semantic,
|
| 901 |
+
indices_pixel=indices_pixel)
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<|im_start|>",
|
| 4 |
+
"<|im_end|>",
|
| 5 |
+
"<|object_ref_start|>",
|
| 6 |
+
"<|object_ref_end|>",
|
| 7 |
+
"<|box_start|>",
|
| 8 |
+
"<|box_end|>",
|
| 9 |
+
"<|quad_start|>",
|
| 10 |
+
"<|quad_end|>",
|
| 11 |
+
"<|vision_start|>",
|
| 12 |
+
"<|vision_end|>",
|
| 13 |
+
"<|vision_pad|>",
|
| 14 |
+
"<|image_pad|>",
|
| 15 |
+
"<|video_pad|>"
|
| 16 |
+
],
|
| 17 |
+
"eos_token": {
|
| 18 |
+
"content": "<|im_end|>",
|
| 19 |
+
"lstrip": false,
|
| 20 |
+
"normalized": false,
|
| 21 |
+
"rstrip": false,
|
| 22 |
+
"single_word": false
|
| 23 |
+
},
|
| 24 |
+
"pad_token": {
|
| 25 |
+
"content": "<|endoftext|>",
|
| 26 |
+
"lstrip": false,
|
| 27 |
+
"normalized": false,
|
| 28 |
+
"rstrip": false,
|
| 29 |
+
"single_word": false
|
| 30 |
+
}
|
| 31 |
+
}
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:efb35e4dde47fc87a13f21a10fc3a0ac50340bc53ad9d88999616403d8498216
|
| 3 |
+
size 33100531
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:db9da21d556f5a362c7d22bfe8bf4d1bd734a303db356521eeb4c96a11881970
|
| 3 |
+
size 25814172
|
vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|