Spaces:
Runtime error
Runtime error
Init commit (#1)
Browse files- Init commit (b5d45ec402ddb0c2646a16f1372b9717172cf9be)
Co-authored-by: Kehan Li <[email protected]>
This view is limited to 50 files because it contains too many changes.
See raw diff
- RynnEC/app.py +387 -0
- RynnEC/requirements.txt +44 -0
- RynnEC/rynnec/__init__.py +269 -0
- RynnEC/rynnec/constants.py +47 -0
- RynnEC/rynnec/mm_utils.py +733 -0
- RynnEC/rynnec/model/__init__.py +196 -0
- RynnEC/rynnec/model/encoder.py +282 -0
- RynnEC/rynnec/model/extension/__init__.py +1 -0
- RynnEC/rynnec/model/extension/sam2_base.py +298 -0
- RynnEC/rynnec/model/loss.py +597 -0
- RynnEC/rynnec/model/predictor/__init__.py +1 -0
- RynnEC/rynnec/model/predictor/sam2_predictor.py +724 -0
- RynnEC/rynnec/model/processor.py +401 -0
- RynnEC/rynnec/model/projector.py +161 -0
- RynnEC/rynnec/model/region_encoder.py +77 -0
- RynnEC/rynnec/model/rynnec_arch.py +271 -0
- RynnEC/rynnec/model/rynnec_qwen2.py +638 -0
- RynnEC/rynnec/model/sam2.py +133 -0
- RynnEC/rynnec/model/sam2_train.py +134 -0
- RynnEC/rynnec/model/utils.py +61 -0
- RynnEC/rynnec/model/videollama3_encoder/__init__.py +3 -0
- RynnEC/rynnec/model/videollama3_encoder/configuration_videollama3_encoder.py +71 -0
- RynnEC/rynnec/model/videollama3_encoder/image_processing_videollama3.py +473 -0
- RynnEC/rynnec/model/videollama3_encoder/modeling_videollama3_encoder.py +555 -0
- RynnEC/rynnec/rynnec_trainer.py +496 -0
- RynnEC/rynnec/train.py +832 -0
- RynnEC/third_parts/sam2/__init__.py +9 -0
- RynnEC/third_parts/sam2/automatic_mask_generator.py +434 -0
- RynnEC/third_parts/sam2/build_sam.py +89 -0
- RynnEC/third_parts/sam2/csrc/connected_components.cu +289 -0
- RynnEC/third_parts/sam2/modeling/__init__.py +5 -0
- RynnEC/third_parts/sam2/modeling/backbones/__init__.py +5 -0
- RynnEC/third_parts/sam2/modeling/backbones/hieradet.py +295 -0
- RynnEC/third_parts/sam2/modeling/backbones/image_encoder.py +133 -0
- RynnEC/third_parts/sam2/modeling/backbones/utils.py +95 -0
- RynnEC/third_parts/sam2/modeling/memory_attention.py +169 -0
- RynnEC/third_parts/sam2/modeling/memory_encoder.py +181 -0
- RynnEC/third_parts/sam2/modeling/position_encoding.py +221 -0
- RynnEC/third_parts/sam2/modeling/sam/__init__.py +5 -0
- RynnEC/third_parts/sam2/modeling/sam/mask_decoder.py +299 -0
- RynnEC/third_parts/sam2/modeling/sam/prompt_encoder.py +182 -0
- RynnEC/third_parts/sam2/modeling/sam/transformer.py +328 -0
- RynnEC/third_parts/sam2/modeling/sam2_base.py +830 -0
- RynnEC/third_parts/sam2/modeling/sam2_utils.py +149 -0
- RynnEC/third_parts/sam2/sam2_configs/__init__.py +5 -0
- RynnEC/third_parts/sam2/sam2_configs/sam2_hiera_b+.yaml +113 -0
- RynnEC/third_parts/sam2/sam2_configs/sam2_hiera_l.yaml +117 -0
- RynnEC/third_parts/sam2/sam2_configs/sam2_hiera_s.yaml +116 -0
- RynnEC/third_parts/sam2/sam2_configs/sam2_hiera_t.yaml +118 -0
- RynnEC/third_parts/sam2/sam2_image_predictor.py +446 -0
RynnEC/app.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import cv2
|
| 3 |
+
import cv2
|
| 4 |
+
import torch
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from transformers import SamModel, SamProcessor
|
| 7 |
+
|
| 8 |
+
import spaces
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from torchvision.transforms import v2
|
| 13 |
+
|
| 14 |
+
from rynnec import disable_torch_init, model_init, mm_infer, mm_infer_segmentation
|
| 15 |
+
from rynnec.mm_utils import annToMask, load_video, load_images
|
| 16 |
+
|
| 17 |
+
from PIL import Image
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
import numpy as np
|
| 20 |
+
import colorsys
|
| 21 |
+
import argparse
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_hsv_palette(n_colors):
|
| 25 |
+
hues = np.linspace(0, 1, int(n_colors) + 1)[1:-1]
|
| 26 |
+
s = 0.8
|
| 27 |
+
v = 0.9
|
| 28 |
+
palette = [(0.0, 0.0, 0.0)] + [
|
| 29 |
+
colorsys.hsv_to_rgb(h_i, s, v) for h_i in hues
|
| 30 |
+
]
|
| 31 |
+
return (255 * np.asarray(palette)).astype("uint8")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def colorize_masks(images, index_masks, fac: float = 0.8, draw_contour=True, edge_thickness=20):
|
| 35 |
+
max_idx = max([m.max() for m in index_masks])
|
| 36 |
+
palette = get_hsv_palette(max_idx + 1)
|
| 37 |
+
color_masks = []
|
| 38 |
+
out_frames = []
|
| 39 |
+
for img, mask in tqdm(zip(images, index_masks), desc='Visualize masks ...'):
|
| 40 |
+
clr_mask = palette[mask.astype("int")]
|
| 41 |
+
blended_img = img
|
| 42 |
+
|
| 43 |
+
blended_img = compose_img_mask(blended_img, clr_mask, fac)
|
| 44 |
+
|
| 45 |
+
if draw_contour:
|
| 46 |
+
blended_img = draw_contours_on_image(blended_img, mask, clr_mask,
|
| 47 |
+
brightness_factor=1.8,
|
| 48 |
+
alpha=0.6,
|
| 49 |
+
thickness=edge_thickness)
|
| 50 |
+
out_frames.append(blended_img)
|
| 51 |
+
|
| 52 |
+
return out_frames, color_masks
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def compose_img_mask(img, color_mask, fac: float = 0.5):
|
| 56 |
+
mask_region = (color_mask.sum(axis=-1) > 0)[..., None]
|
| 57 |
+
out_f = img.copy() / 255
|
| 58 |
+
out_f[mask_region[:, :, 0]] = fac * img[mask_region[:, :, 0]] / 255 + (1 - fac) * color_mask[mask_region[:, :, 0]] / 255
|
| 59 |
+
out_u = (255 * out_f).astype("uint8")
|
| 60 |
+
return out_u
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def draw_contours_on_image(img, index_mask, color_mask, brightness_factor=1.6, alpha=0.5, thickness=2, ignore_index=0):
|
| 64 |
+
img = img.astype("float32")
|
| 65 |
+
overlay = img.copy()
|
| 66 |
+
|
| 67 |
+
unique_indices = np.unique(index_mask)
|
| 68 |
+
if ignore_index is not None:
|
| 69 |
+
unique_indices = [idx for idx in unique_indices if idx != ignore_index]
|
| 70 |
+
|
| 71 |
+
for i in unique_indices:
|
| 72 |
+
bin_mask = (index_mask == i).astype("uint8") * 255
|
| 73 |
+
if bin_mask.sum() == 0:
|
| 74 |
+
continue
|
| 75 |
+
|
| 76 |
+
contours, _ = cv2.findContours(bin_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 77 |
+
|
| 78 |
+
color = color_mask[index_mask == i][0].astype("float32")
|
| 79 |
+
bright_color = np.clip(color * brightness_factor, 0, 255).tolist()
|
| 80 |
+
|
| 81 |
+
cv2.drawContours(overlay, contours, -1, bright_color, thickness)
|
| 82 |
+
|
| 83 |
+
blended = (1 - alpha) * img + alpha * overlay
|
| 84 |
+
return np.clip(blended, 0, 255).astype("uint8")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def extract_first_frame_from_video(video):
|
| 88 |
+
cap = cv2.VideoCapture(video)
|
| 89 |
+
success, frame = cap.read()
|
| 90 |
+
cap.release()
|
| 91 |
+
if success:
|
| 92 |
+
return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def extract_points_from_mask(mask_pil):
|
| 97 |
+
mask = np.asarray(mask_pil)[..., 0]
|
| 98 |
+
coords = np.nonzero(mask)
|
| 99 |
+
coords = np.stack((coords[1], coords[0]), axis=1)
|
| 100 |
+
|
| 101 |
+
return coords
|
| 102 |
+
|
| 103 |
+
def add_contour(img, mask, color=(1., 1., 1.)):
|
| 104 |
+
img = img.copy()
|
| 105 |
+
|
| 106 |
+
mask = mask.astype(np.uint8) * 255
|
| 107 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 108 |
+
cv2.drawContours(img, contours, -1, color, thickness=8)
|
| 109 |
+
|
| 110 |
+
return img
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def load_first_frame(video_path):
|
| 114 |
+
cap = cv2.VideoCapture(video_path)
|
| 115 |
+
ret, frame = cap.read()
|
| 116 |
+
cap.release()
|
| 117 |
+
if not ret:
|
| 118 |
+
raise gr.Error("Could not read the video file.")
|
| 119 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 120 |
+
image = Image.fromarray(frame)
|
| 121 |
+
return image
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def clear_masks():
|
| 125 |
+
return [], [], [], []
|
| 126 |
+
|
| 127 |
+
def clear_all():
|
| 128 |
+
return [], [], [], [], None, "", ""
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@spaces.GPU(duration=120)
|
| 132 |
+
def apply_sam(image, input_points):
|
| 133 |
+
inputs = sam_processor(image, input_points=input_points, return_tensors="pt").to(device)
|
| 134 |
+
|
| 135 |
+
with torch.no_grad():
|
| 136 |
+
outputs = sam_model(**inputs)
|
| 137 |
+
|
| 138 |
+
masks = sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())[0][0]
|
| 139 |
+
scores = outputs.iou_scores[0, 0]
|
| 140 |
+
|
| 141 |
+
mask_selection_index = scores.argmax()
|
| 142 |
+
|
| 143 |
+
mask_np = masks[mask_selection_index].numpy()
|
| 144 |
+
|
| 145 |
+
return mask_np
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
@spaces.GPU(duration=120)
|
| 149 |
+
def run(mode, images, timestamps, masks, mask_ids, instruction, mask_output_video):
|
| 150 |
+
if mode == "QA":
|
| 151 |
+
response = run_text_inference(images, timestamps, masks, mask_ids, instruction)
|
| 152 |
+
else:
|
| 153 |
+
response, mask_output_video = run_seg_inference(images, timestamps, instruction)
|
| 154 |
+
return response, mask_output_video
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def run_text_inference(images, timestamps, masks, mask_ids, instruction):
|
| 158 |
+
masks = torch.from_numpy(np.stack(masks, axis=0))
|
| 159 |
+
|
| 160 |
+
if "<video>" not in instruction:
|
| 161 |
+
instruction = "<video>\n" + instruction
|
| 162 |
+
|
| 163 |
+
if len(masks) >= 2:
|
| 164 |
+
obj_str = f"<video>\nThere are {len(masks)} objects in the video: " + ", ".join([f"<object{i}> [<REGION>]" for i in range(len(masks))])
|
| 165 |
+
instruction = instruction.replace("<video>\n", obj_str)
|
| 166 |
+
else:
|
| 167 |
+
instruction = instruction.replace("<object0>", '[<REGION>]')
|
| 168 |
+
|
| 169 |
+
output = mm_infer(
|
| 170 |
+
(images, timestamps),
|
| 171 |
+
processor,
|
| 172 |
+
instruction,
|
| 173 |
+
model=model,
|
| 174 |
+
tokenizer=processor.tokenizer,
|
| 175 |
+
do_sample=False,
|
| 176 |
+
modal='video',
|
| 177 |
+
masks=masks.cuda() if masks is not None else None,
|
| 178 |
+
mask_ids=mask_ids
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
return output
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def run_seg_inference(images, timestamps, instruction):
|
| 185 |
+
output, masks = mm_infer_segmentation(
|
| 186 |
+
(images, timestamps),
|
| 187 |
+
processor,
|
| 188 |
+
instruction,
|
| 189 |
+
model=model,
|
| 190 |
+
tokenizer=processor.tokenizer,
|
| 191 |
+
do_sample=False,
|
| 192 |
+
modal='video',
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
w, h = images[0].size
|
| 196 |
+
masks = v2.Resize([h, w])(masks).cpu().numpy()
|
| 197 |
+
|
| 198 |
+
mask_list_video = []
|
| 199 |
+
|
| 200 |
+
images = [np.array(image) for image in images]
|
| 201 |
+
masks = [mask[0] for mask in masks]
|
| 202 |
+
show_images, _ = colorize_masks(images, masks)
|
| 203 |
+
for i, image in enumerate(show_images):
|
| 204 |
+
if masks[i].sum() > 1000:
|
| 205 |
+
mask_list_video.append((Image.fromarray(image), f"Frame {i}"))
|
| 206 |
+
|
| 207 |
+
return output, mask_list_video
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def generate_masks_video(image, mask_list_video, mask_raw_list_video, mask_ids, frame_idx):
|
| 211 |
+
image['image'] = image['background'].convert('RGB')
|
| 212 |
+
# del image['background'], image['composite']
|
| 213 |
+
assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
|
| 214 |
+
|
| 215 |
+
mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB')
|
| 216 |
+
points = extract_points_from_mask(mask)
|
| 217 |
+
np.random.seed(0)
|
| 218 |
+
if points.shape[0] == 0:
|
| 219 |
+
raise gr.Error("No points selected")
|
| 220 |
+
|
| 221 |
+
points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False)
|
| 222 |
+
points = points[points_selected_indices]
|
| 223 |
+
coords = [points.tolist()]
|
| 224 |
+
mask_np = apply_sam(image['image'], coords)
|
| 225 |
+
|
| 226 |
+
mask_raw_list_video.append(mask_np)
|
| 227 |
+
mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(image['image'])).astype(np.uint8))
|
| 228 |
+
|
| 229 |
+
mask_list_video.append((mask_image, f"<object{len(mask_list_video)}>"))
|
| 230 |
+
# Return a list containing the mask image.
|
| 231 |
+
image['layers'] = []
|
| 232 |
+
image['composite'] = image['background']
|
| 233 |
+
mask_ids.append(frame_idx)
|
| 234 |
+
return mask_list_video, image, mask_list_video, mask_raw_list_video, mask_ids
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
if __name__ == "__main__":
|
| 238 |
+
parser = argparse.ArgumentParser(description="VideoRefer gradio demo")
|
| 239 |
+
parser.add_argument("--model-path", type=str, default="Alibaba-DAMO-Academy/RynnEC-2B", help="Path to the model checkpoint")
|
| 240 |
+
parser.add_argument("--port", type=int, default=4001)
|
| 241 |
+
|
| 242 |
+
args_cli = parser.parse_args()
|
| 243 |
+
|
| 244 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="amber")) as demo:
|
| 245 |
+
|
| 246 |
+
mask_list = gr.State([])
|
| 247 |
+
mask_raw_list = gr.State([])
|
| 248 |
+
mask_list_video = gr.State([])
|
| 249 |
+
mask_raw_list_video = gr.State([])
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
HEADER = ("""
|
| 253 |
+
<div>
|
| 254 |
+
<h1>RynnEC Demo</h1>
|
| 255 |
+
<h5 style="margin: 0;">Feel free to click on anything that grabs your interest!</h5>
|
| 256 |
+
<h5 style="margin: 0;">If this demo please you, please give us a star ⭐ on Github or 💖 on this space.</h5>
|
| 257 |
+
</div>
|
| 258 |
+
</div>
|
| 259 |
+
<div style="display: flex; justify-content: left; margin-top: 10px;">
|
| 260 |
+
<a href="https://arxiv.org/pdf/2501.00599"><img src="https://img.shields.io/badge/Arxiv-2501.00599-ECA8A7" style="margin-right: 5px;"></a>
|
| 261 |
+
<a href="https://github.com/DAMO-NLP-SG/VideoRefer"><img src='https://img.shields.io/badge/Github-VideoRefer-F7C97E' style="margin-right: 5px;"></a>
|
| 262 |
+
<a href="https://github.com/DAMO-NLP-SG/VideoLLaMA3"><img src='https://img.shields.io/badge/Github-VideoLLaMA3-9DC3E6' style="margin-right: 5px;"></a>
|
| 263 |
+
</div>
|
| 264 |
+
""")
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
image_tips = """
|
| 268 |
+
### 💡 Tips:
|
| 269 |
+
|
| 270 |
+
🧸 Upload an image, and you can use the drawing tool✍️ to highlight the areas you're interested in.
|
| 271 |
+
|
| 272 |
+
🔖 For single-object caption mode, simply select the area and click the 'Generate Caption' button to receive a caption for the object.
|
| 273 |
+
|
| 274 |
+
🔔 In QA mode, you can generate multiple masks by clicking the 'Generate Mask' button multiple times. Afterward, use the corresponding object id to ask questions.
|
| 275 |
+
|
| 276 |
+
📌 Click the button 'Clear Masks' to clear the current generated masks.
|
| 277 |
+
|
| 278 |
+
"""
|
| 279 |
+
|
| 280 |
+
video_tips = """
|
| 281 |
+
### 💡 Tips:
|
| 282 |
+
🧸 Upload an video, and you can use the drawing tool✍️ to highlight the areas you're interested in the first frame.
|
| 283 |
+
|
| 284 |
+
🔔 In QA mode, you can generate multiple masks by clicking the 'Generate Mask' button multiple times. Afterward, use the corresponding object id to ask questions.
|
| 285 |
+
|
| 286 |
+
📌 Click the button 'Clear Masks' to clear the current generated masks.
|
| 287 |
+
|
| 288 |
+
"""
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
with gr.TabItem("Video"):
|
| 292 |
+
with gr.Row():
|
| 293 |
+
with gr.Column():
|
| 294 |
+
video_input = gr.Video(label="Video", interactive=True)
|
| 295 |
+
frame_idx = gr.Slider(minimum=0, maximum=0, value=0, step=1, label="Select Frame", interactive=False)
|
| 296 |
+
selected_frame = gr.ImageEditor(
|
| 297 |
+
label="Annotate Frame",
|
| 298 |
+
type="pil",
|
| 299 |
+
sources=[],
|
| 300 |
+
interactive=True,
|
| 301 |
+
)
|
| 302 |
+
generate_mask_btn_video = gr.Button("1️⃣ Generate Mask", visible=True, variant="primary")
|
| 303 |
+
gr.Examples([f"./demo/videos/{i+1}.mp4" for i in range(4)], inputs=video_input, label="Examples")
|
| 304 |
+
|
| 305 |
+
with gr.Column():
|
| 306 |
+
mode_video = gr.Radio(label="Mode", choices=["QA", "Seg"], value="QA")
|
| 307 |
+
mask_output_video = gr.Gallery(label="Referred Masks", object_fit='scale-down')
|
| 308 |
+
|
| 309 |
+
query_video = gr.Textbox(label="Question", value="Please describe <object0>.", interactive=True, visible=True)
|
| 310 |
+
response_video = gr.Textbox(label="Answer", interactive=False)
|
| 311 |
+
|
| 312 |
+
submit_btn_video = gr.Button("Generate Caption", variant="primary", visible=False)
|
| 313 |
+
submit_btn_video1 = gr.Button("2️⃣ Generate Answer", variant="primary", visible=True)
|
| 314 |
+
description_video = gr.Textbox(label="Output", visible=False)
|
| 315 |
+
|
| 316 |
+
clear_masks_btn_video = gr.Button("Clear Masks", variant="secondary")
|
| 317 |
+
|
| 318 |
+
gr.Markdown(video_tips)
|
| 319 |
+
|
| 320 |
+
frames = gr.State(value=[])
|
| 321 |
+
timestamps = gr.State(value=[])
|
| 322 |
+
mask_ids = gr.State(value=[])
|
| 323 |
+
|
| 324 |
+
def on_video_upload(video_path):
|
| 325 |
+
frames, timestamps = load_video(video_path, fps=1, max_frames=128)
|
| 326 |
+
frames = [Image.fromarray(x.transpose(1, 2, 0)) for x in frames]
|
| 327 |
+
return frames, timestamps, frames[0], gr.update(value=0, maximum=len(frames) - 1, interactive=True)
|
| 328 |
+
|
| 329 |
+
def on_frame_idx_change(frame_idx, frames):
|
| 330 |
+
return frames[frame_idx]
|
| 331 |
+
|
| 332 |
+
def to_seg_mode():
|
| 333 |
+
return (
|
| 334 |
+
*[gr.update(visible=False) for _ in range(4)],
|
| 335 |
+
[]
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
def to_qa_mode():
|
| 339 |
+
return (
|
| 340 |
+
*[gr.update(visible=True) for _ in range(4)],
|
| 341 |
+
[]
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
def on_mode_change(mode):
|
| 345 |
+
if mode == "QA":
|
| 346 |
+
return to_qa_mode()
|
| 347 |
+
return to_seg_mode()
|
| 348 |
+
|
| 349 |
+
mode_video.change(on_mode_change, inputs=[mode_video], outputs=[frame_idx, selected_frame, generate_mask_btn_video, response_video, mask_output_video])
|
| 350 |
+
video_input.change(on_video_upload, inputs=[video_input], outputs=[frames, timestamps, selected_frame, frame_idx])
|
| 351 |
+
frame_idx.change(on_frame_idx_change, inputs=[frame_idx, frames], outputs=[selected_frame])
|
| 352 |
+
|
| 353 |
+
generate_mask_btn_video.click(
|
| 354 |
+
fn=generate_masks_video,
|
| 355 |
+
inputs=[selected_frame, mask_list_video, mask_raw_list_video, mask_ids, frame_idx],
|
| 356 |
+
outputs=[mask_output_video, selected_frame, mask_list_video, mask_raw_list_video, mask_ids]
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
submit_btn_video1.click(
|
| 360 |
+
fn=run,
|
| 361 |
+
inputs=[mode_video, frames, timestamps, mask_raw_list_video, mask_ids, query_video, mask_output_video],
|
| 362 |
+
outputs=[response_video, mask_output_video],
|
| 363 |
+
api_name="describe_video"
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
video_input.clear(
|
| 367 |
+
fn=clear_all,
|
| 368 |
+
outputs=[mask_output_video, mask_list_video, mask_raw_list_video, mask_ids, selected_frame, query_video, response_video]
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
clear_masks_btn_video.click(
|
| 372 |
+
fn=clear_masks,
|
| 373 |
+
outputs=[mask_output_video, mask_list_video, mask_raw_list_video, mask_ids]
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 377 |
+
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
|
| 378 |
+
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
| 379 |
+
# sam_model = sam_processor = None
|
| 380 |
+
disable_torch_init()
|
| 381 |
+
model, processor = model_init(args_cli.model_path)
|
| 382 |
+
# model = processor = None
|
| 383 |
+
|
| 384 |
+
# demo.launch()
|
| 385 |
+
demo.launch(
|
| 386 |
+
share=False,
|
| 387 |
+
)
|
RynnEC/requirements.txt
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu124
|
| 2 |
+
# basic dependencies
|
| 3 |
+
torch==2.4.0
|
| 4 |
+
torchvision==0.19.0
|
| 5 |
+
datasets==2.21.0
|
| 6 |
+
transformers==4.46.3
|
| 7 |
+
tokenizers==0.20.3
|
| 8 |
+
deepspeed==0.15.4
|
| 9 |
+
accelerate==1.0.1
|
| 10 |
+
peft==0.4.0
|
| 11 |
+
timm==1.0.3
|
| 12 |
+
numpy==1.24.4
|
| 13 |
+
# data processing
|
| 14 |
+
decord==0.6.0
|
| 15 |
+
imageio==2.34.0
|
| 16 |
+
imageio-ffmpeg==0.4.9
|
| 17 |
+
moviepy==1.0.3
|
| 18 |
+
opencv-python==4.6.0.66
|
| 19 |
+
pyarrow
|
| 20 |
+
pysubs2
|
| 21 |
+
ffmpeg-python
|
| 22 |
+
# misc
|
| 23 |
+
scikit-learn==1.2.2
|
| 24 |
+
huggingface_hub==0.23.4
|
| 25 |
+
sentencepiece==0.1.99
|
| 26 |
+
shortuuid
|
| 27 |
+
einops==0.6.1
|
| 28 |
+
einops-exts==0.0.4
|
| 29 |
+
bitsandbytes==0.43.3 # for cuda 124
|
| 30 |
+
pydantic>=2.0
|
| 31 |
+
markdown2[all]
|
| 32 |
+
gradio==3.50.0
|
| 33 |
+
gradio_client==0.6.1
|
| 34 |
+
httpx==0.24.1
|
| 35 |
+
requests
|
| 36 |
+
openai
|
| 37 |
+
uvicorn
|
| 38 |
+
fastapi
|
| 39 |
+
tensorboard
|
| 40 |
+
wandb
|
| 41 |
+
tabulate
|
| 42 |
+
hydra-core
|
| 43 |
+
pycocotools==2.0.10
|
| 44 |
+
https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
RynnEC/rynnec/__init__.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import copy
|
| 3 |
+
import math
|
| 4 |
+
import warnings
|
| 5 |
+
import shutil
|
| 6 |
+
from functools import partial
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
from .model import load_pretrained_model
|
| 11 |
+
from .mm_utils import load_images, process_images, load_video, process_video, tokenizer_multimodal_token, get_model_name_from_path, KeywordsStoppingCriteria, DirectResize, sam_preprocess_batch
|
| 12 |
+
from .constants import NUM_FRAMES, DEFAULT_IMAGE_TOKEN, DEFAULT_VIDEO_TOKEN, MODAL_INDEX_MAP, STREAM_START_TOKEN, STREAM_END_TOKEN
|
| 13 |
+
from .model.rynnec_qwen2 import Videollama3Qwen2Processor
|
| 14 |
+
|
| 15 |
+
def disable_torch_init():
|
| 16 |
+
"""
|
| 17 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
| 18 |
+
"""
|
| 19 |
+
import torch
|
| 20 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
| 21 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def model_init(model_path=None, min_visual_tokens=None, max_visual_tokens=None, **kwargs):
|
| 25 |
+
model_path = "Alibaba-DAMO-Academy/RynnEC-2B" if model_path is None else model_path
|
| 26 |
+
model_name = get_model_name_from_path(model_path)
|
| 27 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, **kwargs)
|
| 28 |
+
|
| 29 |
+
if max_visual_tokens is not None:
|
| 30 |
+
image_processor.max_tokens = max_visual_tokens
|
| 31 |
+
if min_visual_tokens is not None:
|
| 32 |
+
image_processor.min_tokens = min_visual_tokens
|
| 33 |
+
|
| 34 |
+
if tokenizer.pad_token is None and tokenizer.unk_token is not None:
|
| 35 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 36 |
+
|
| 37 |
+
processor = Videollama3Qwen2Processor(image_processor, tokenizer)
|
| 38 |
+
|
| 39 |
+
return model, processor
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def mm_infer(images_or_videos, vlprocessor, instruct, model, tokenizer, modal='video', **kwargs):
|
| 43 |
+
|
| 44 |
+
mask_ids = kwargs.pop('mask_ids', None)
|
| 45 |
+
masks = kwargs.pop('masks', None)
|
| 46 |
+
if modal == 'image':
|
| 47 |
+
modal_token = DEFAULT_IMAGE_TOKEN
|
| 48 |
+
images = images_or_videos
|
| 49 |
+
timestamps = None
|
| 50 |
+
elif modal == 'video':
|
| 51 |
+
modal_token = DEFAULT_VIDEO_TOKEN
|
| 52 |
+
images, timestamps = images_or_videos
|
| 53 |
+
elif modal == 'text':
|
| 54 |
+
modal_token = ''
|
| 55 |
+
else:
|
| 56 |
+
raise ValueError(f"Unsupported modal: {modal}")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# 1. text preprocess (tag process & generate prompt).
|
| 60 |
+
if isinstance(instruct, str):
|
| 61 |
+
messages = [{'role': 'user', 'content': instruct}]
|
| 62 |
+
elif isinstance(instruct, list):
|
| 63 |
+
messages = copy.deepcopy(instruct)
|
| 64 |
+
else:
|
| 65 |
+
raise ValueError(f"Unsupported type of instruct: {type(instruct)}")
|
| 66 |
+
|
| 67 |
+
if all(not modal_token in message["content"] for message in messages):
|
| 68 |
+
warnings.warn(f"Image tag not found in the conversation, add it automatically at the beginning!")
|
| 69 |
+
messages[0]["content"] = modal_token + messages[0]["content"]
|
| 70 |
+
|
| 71 |
+
converted_messages = []
|
| 72 |
+
for message in messages:
|
| 73 |
+
chunks = message["content"].split(modal_token)
|
| 74 |
+
converted_messages.append({
|
| 75 |
+
"role": "user",
|
| 76 |
+
"content": []
|
| 77 |
+
})
|
| 78 |
+
|
| 79 |
+
for chunk_idx in range(1, 2 * len(chunks)):
|
| 80 |
+
if chunk_idx % 2 == 1:
|
| 81 |
+
chunk = chunks[chunk_idx // 2].strip()
|
| 82 |
+
converted_messages[-1]["content"].append({"type": "text", "text": chunk}) if chunk else None
|
| 83 |
+
else:
|
| 84 |
+
if modal == 'image':
|
| 85 |
+
converted_messages[-1]["content"].append({"type": "image"})
|
| 86 |
+
elif modal == 'video':
|
| 87 |
+
converted_messages[-1]["content"].append({"type": "video", "num_frames": len(images), "time": timestamps})
|
| 88 |
+
|
| 89 |
+
messages = converted_messages
|
| 90 |
+
|
| 91 |
+
system_message = []
|
| 92 |
+
|
| 93 |
+
image_downsampling = kwargs.get('image_downsampling', model.config.spatial_merge_size)
|
| 94 |
+
# TODO: attention mask?
|
| 95 |
+
messages = system_message + messages
|
| 96 |
+
data_dict = vlprocessor(
|
| 97 |
+
images=images,
|
| 98 |
+
text=messages,
|
| 99 |
+
merge_size=image_downsampling,
|
| 100 |
+
return_labels=True,
|
| 101 |
+
return_tensors="pt",
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
torch_dtype = model.config.torch_dtype if hasattr(model.config, "torch_dtype") else torch.float16
|
| 105 |
+
|
| 106 |
+
# images = [x.to(torch_dtype).cuda(non_blocking=True) for x in data_dict["images"]]
|
| 107 |
+
# grid_thws = [x.cuda(non_blocking=True) for x in data_dict["grid_thws"]]
|
| 108 |
+
|
| 109 |
+
# 3. generate response according to visual signals and prompts.
|
| 110 |
+
keywords = [tokenizer.eos_token]
|
| 111 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, data_dict["input_ids"].unsqueeze(0))
|
| 112 |
+
|
| 113 |
+
do_sample = kwargs.get('do_sample', False)
|
| 114 |
+
temperature = kwargs.get('temperature', 0.2 if do_sample else 1.0)
|
| 115 |
+
top_p = kwargs.get('top_p', 0.9 if do_sample else 1.0)
|
| 116 |
+
top_k = kwargs.get('top_k', 20 if do_sample else 50)
|
| 117 |
+
max_new_tokens = kwargs.get('max_new_tokens', 2048)
|
| 118 |
+
|
| 119 |
+
data_dict["modals"] = [modal]
|
| 120 |
+
data_dict = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data_dict.items()}
|
| 121 |
+
if "pixel_values" in data_dict:
|
| 122 |
+
data_dict["modals"] = data_dict["modals"] * len(data_dict["grid_sizes"])
|
| 123 |
+
data_dict["pixel_values"] = data_dict["pixel_values"].to(torch.bfloat16)
|
| 124 |
+
|
| 125 |
+
with torch.inference_mode():
|
| 126 |
+
output_ids = model.generate(
|
| 127 |
+
input_ids=data_dict["input_ids"].unsqueeze(0).cuda(),
|
| 128 |
+
pixel_values=data_dict["pixel_values"],
|
| 129 |
+
grid_sizes=data_dict["grid_sizes"],
|
| 130 |
+
merge_sizes=data_dict["merge_sizes"],
|
| 131 |
+
modals=data_dict["modals"],
|
| 132 |
+
do_sample=do_sample,
|
| 133 |
+
temperature=temperature,
|
| 134 |
+
max_new_tokens=max_new_tokens,
|
| 135 |
+
top_p=top_p,
|
| 136 |
+
top_k=top_k,
|
| 137 |
+
use_cache=True,
|
| 138 |
+
stopping_criteria=[stopping_criteria],
|
| 139 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 140 |
+
masks=[masks],
|
| 141 |
+
mask_ids=mask_ids
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
outputs = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 145 |
+
|
| 146 |
+
return outputs
|
| 147 |
+
|
| 148 |
+
def mm_infer_segmentation(images_or_videos, vlprocessor, instruct, model, tokenizer, modal='video', seg_start_idx=0, **kwargs):
|
| 149 |
+
|
| 150 |
+
image2maskids = kwargs.get('image2maskids', [])
|
| 151 |
+
img_size=1024
|
| 152 |
+
sam_transform = DirectResize(img_size)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
if modal == 'image':
|
| 156 |
+
modal_token = DEFAULT_IMAGE_TOKEN
|
| 157 |
+
images = images_or_videos
|
| 158 |
+
timestamps = None
|
| 159 |
+
elif modal == 'video':
|
| 160 |
+
modal_token = DEFAULT_VIDEO_TOKEN
|
| 161 |
+
images, timestamps = images_or_videos
|
| 162 |
+
elif modal == 'text':
|
| 163 |
+
modal_token = ''
|
| 164 |
+
else:
|
| 165 |
+
raise ValueError(f"Unsupported modal: {modal}")
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
sam_images = []
|
| 169 |
+
sam_size = None
|
| 170 |
+
if len(images)>0:
|
| 171 |
+
for image in images:
|
| 172 |
+
sam_image = sam_transform.apply_image(np.array(image))
|
| 173 |
+
sam_images.append(sam_image)
|
| 174 |
+
if sam_size is None:
|
| 175 |
+
sam_size = sam_image.shape[:2]
|
| 176 |
+
sam_images = np.array(sam_images)
|
| 177 |
+
sam_images = torch.from_numpy(sam_images).permute(0, 3, 1, 2).contiguous()
|
| 178 |
+
sam_images = sam_preprocess_batch(sam_images)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# 1. text preprocess (tag process & generate prompt).
|
| 182 |
+
if isinstance(instruct, str):
|
| 183 |
+
messages = [{'role': 'user', 'content': instruct}]
|
| 184 |
+
elif isinstance(instruct, list):
|
| 185 |
+
messages = copy.deepcopy(instruct)
|
| 186 |
+
else:
|
| 187 |
+
raise ValueError(f"Unsupported type of instruct: {type(instruct)}")
|
| 188 |
+
|
| 189 |
+
if all(not modal_token in message["content"] for message in messages):
|
| 190 |
+
warnings.warn(f"Image tag not found in the conversation, add it automatically at the beginning!")
|
| 191 |
+
messages[0]["content"] = modal_token + messages[0]["content"]
|
| 192 |
+
|
| 193 |
+
converted_messages = []
|
| 194 |
+
for message in messages:
|
| 195 |
+
chunks = message["content"].split(modal_token)
|
| 196 |
+
converted_messages.append({
|
| 197 |
+
"role": "user",
|
| 198 |
+
"content": []
|
| 199 |
+
})
|
| 200 |
+
|
| 201 |
+
for chunk_idx in range(1, 2 * len(chunks)):
|
| 202 |
+
if chunk_idx % 2 == 1:
|
| 203 |
+
chunk = chunks[chunk_idx // 2].strip()
|
| 204 |
+
converted_messages[-1]["content"].append({"type": "text", "text": chunk}) if chunk else None
|
| 205 |
+
else:
|
| 206 |
+
if modal == 'image':
|
| 207 |
+
converted_messages[-1]["content"].append({"type": "image"})
|
| 208 |
+
elif modal == 'video':
|
| 209 |
+
converted_messages[-1]["content"].append({"type": "video", "num_frames": len(images), "time": timestamps})
|
| 210 |
+
|
| 211 |
+
messages = converted_messages
|
| 212 |
+
|
| 213 |
+
system_message = []
|
| 214 |
+
|
| 215 |
+
image_downsampling = kwargs.get('image_downsampling', model.config.spatial_merge_size)
|
| 216 |
+
# TODO: attention mask?
|
| 217 |
+
messages = system_message + messages
|
| 218 |
+
data_dict = vlprocessor(
|
| 219 |
+
images=images,
|
| 220 |
+
text=messages,
|
| 221 |
+
merge_size=image_downsampling,
|
| 222 |
+
return_labels=True,
|
| 223 |
+
return_tensors="pt",
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
torch_dtype = model.config.torch_dtype if hasattr(model.config, "torch_dtype") else torch.float16
|
| 227 |
+
|
| 228 |
+
keywords = [tokenizer.eos_token]
|
| 229 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, data_dict["input_ids"].unsqueeze(0))
|
| 230 |
+
|
| 231 |
+
do_sample = kwargs.get('do_sample', False)
|
| 232 |
+
temperature = kwargs.get('temperature', 0.2 if do_sample else 1.0)
|
| 233 |
+
top_p = kwargs.get('top_p', 0.9 if do_sample else 1.0)
|
| 234 |
+
top_k = kwargs.get('top_k', 20 if do_sample else 50)
|
| 235 |
+
max_new_tokens = kwargs.get('max_new_tokens', 2048)
|
| 236 |
+
|
| 237 |
+
torch_dtype = model.config.torch_dtype if hasattr(model.config, "torch_dtype") else torch.float16
|
| 238 |
+
|
| 239 |
+
data_dict["modals"] = [modal]
|
| 240 |
+
data_dict = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data_dict.items()}
|
| 241 |
+
if "pixel_values" in data_dict:
|
| 242 |
+
data_dict["modals"] = data_dict["modals"] * len(data_dict["grid_sizes"])
|
| 243 |
+
data_dict["pixel_values"] = data_dict["pixel_values"].to(torch.bfloat16)
|
| 244 |
+
|
| 245 |
+
with torch.inference_mode():
|
| 246 |
+
output_ids, pred_masks = model.inference(
|
| 247 |
+
input_ids=data_dict["input_ids"].unsqueeze(0).cuda(),
|
| 248 |
+
pixel_values=data_dict["pixel_values"],
|
| 249 |
+
grid_sizes=data_dict["grid_sizes"],
|
| 250 |
+
merge_sizes=data_dict["merge_sizes"],
|
| 251 |
+
modals=data_dict["modals"],
|
| 252 |
+
sam_images=[sam_images],
|
| 253 |
+
sam_size=[sam_size],
|
| 254 |
+
image2maskids=[image2maskids],
|
| 255 |
+
do_sample=do_sample,
|
| 256 |
+
temperature=temperature,
|
| 257 |
+
max_new_tokens=max_new_tokens,
|
| 258 |
+
top_p=top_p,
|
| 259 |
+
top_k=top_k,
|
| 260 |
+
use_cache=True,
|
| 261 |
+
stopping_criteria=[stopping_criteria],
|
| 262 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 263 |
+
seg_start_idx=seg_start_idx
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
| 267 |
+
pred_masks_sigmoid = pred_masks.sigmoid()>0.5
|
| 268 |
+
|
| 269 |
+
return outputs, pred_masks_sigmoid
|
RynnEC/rynnec/constants.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
| 2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
| 3 |
+
|
| 4 |
+
LOGDIR = "."
|
| 5 |
+
|
| 6 |
+
# Model Constants
|
| 7 |
+
IGNORE_INDEX = -100
|
| 8 |
+
|
| 9 |
+
# Image arguments
|
| 10 |
+
IMAGE_TOKEN_INDEX = -200
|
| 11 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
| 12 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
| 13 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
| 14 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
| 15 |
+
IMAGE_PLACEHOLDER = "<image-placeholder>"
|
| 16 |
+
|
| 17 |
+
# Video arguments
|
| 18 |
+
VIDEO_TOKEN_INDEX = -201
|
| 19 |
+
DEFAULT_VIDEO_TOKEN = "<video>"
|
| 20 |
+
NUM_FRAMES = 128
|
| 21 |
+
MAX_FRAMES = 768
|
| 22 |
+
NUM_FRAMES_PER_SECOND = 1
|
| 23 |
+
|
| 24 |
+
# Region arguments
|
| 25 |
+
REGION_TOKEN = "<REGION>"
|
| 26 |
+
REGION_TOKEN_REPLACE = "<region>"
|
| 27 |
+
SEG_TOKEN = "[SEG]"
|
| 28 |
+
|
| 29 |
+
# Audio arguments
|
| 30 |
+
AUDIO_TOKEN_INDEX = -202
|
| 31 |
+
DEFAULT_AUDIO_TOKEN = "<audio>"
|
| 32 |
+
|
| 33 |
+
# Stream arguments
|
| 34 |
+
STREAM_START_TOKEN = "<|stream_start|>"
|
| 35 |
+
STREAM_END_TOKEN = "<|stream_end|>"
|
| 36 |
+
STREAM_MAX_FRAMES = 400
|
| 37 |
+
STREAM_FPS = 2
|
| 38 |
+
STREAM_IMAGE_SIZE = 224
|
| 39 |
+
STREAM_DOWNSAMPLING = 4
|
| 40 |
+
|
| 41 |
+
MODAL_INDEX_MAP = {
|
| 42 |
+
"<image>": -200,
|
| 43 |
+
"<video>": -201,
|
| 44 |
+
"<audio>": -202,
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
subimage_token_num=196
|
RynnEC/rynnec/mm_utils.py
ADDED
|
@@ -0,0 +1,733 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adopted from: https://github.com/DAMO-NLP-SG/VideoLLaMA3.
|
| 2 |
+
# Below is the original copyright:
|
| 3 |
+
# Copyright 2025 The VideoLLaMA3 team, Alibaba Group
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import ast
|
| 17 |
+
import os
|
| 18 |
+
import re
|
| 19 |
+
import math
|
| 20 |
+
import base64
|
| 21 |
+
import traceback
|
| 22 |
+
from io import BytesIO
|
| 23 |
+
from typing import Optional
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torchvision.transforms.functional as VF
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
import numpy as np
|
| 29 |
+
from transformers import StoppingCriteria
|
| 30 |
+
|
| 31 |
+
import cv2
|
| 32 |
+
import imageio
|
| 33 |
+
import ffmpeg
|
| 34 |
+
from PIL import Image
|
| 35 |
+
from decord import VideoReader, cpu
|
| 36 |
+
|
| 37 |
+
from .constants import NUM_FRAMES, MAX_FRAMES, NUM_FRAMES_PER_SECOND, MODAL_INDEX_MAP, DEFAULT_IMAGE_TOKEN
|
| 38 |
+
from pycocotools import mask as maskUtils
|
| 39 |
+
|
| 40 |
+
from torchvision.transforms.functional import resize, to_pil_image # type: ignore
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class DirectResize:
|
| 44 |
+
def __init__(self, target_length: int) -> None:
|
| 45 |
+
self.target_length = target_length
|
| 46 |
+
|
| 47 |
+
def apply_image(self, image: np.ndarray) -> np.ndarray:
|
| 48 |
+
"""
|
| 49 |
+
Expects a numpy array with shape HxWxC in uint8 format.
|
| 50 |
+
"""
|
| 51 |
+
img = to_pil_image(image, mode='RGB')
|
| 52 |
+
return np.array(img.resize((self.target_length, self.target_length)))
|
| 53 |
+
|
| 54 |
+
def sam_preprocess_batch(x: torch.Tensor) -> torch.Tensor:
|
| 55 |
+
"""
|
| 56 |
+
Normalize pixel values and pad to square input for a batch of images.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
images (torch.Tensor): A batch tensor of shape [N, C, H, W].
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
torch.Tensor: A batch tensor with normalized and padded images
|
| 63 |
+
(shape: [N, C, 1024, 1024]).
|
| 64 |
+
"""
|
| 65 |
+
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1)
|
| 66 |
+
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1)
|
| 67 |
+
img_size = 1024
|
| 68 |
+
|
| 69 |
+
# Normalize colors
|
| 70 |
+
x = (x - pixel_mean) / pixel_std
|
| 71 |
+
|
| 72 |
+
# Pad
|
| 73 |
+
h, w = x.shape[-2:]
|
| 74 |
+
padh = img_size - h
|
| 75 |
+
padw = img_size - w
|
| 76 |
+
x = F.pad(x, (0, padw, 0, padh))
|
| 77 |
+
return x
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def sam_preprocess(x: torch.Tensor) -> torch.Tensor:
|
| 81 |
+
"""Normalize pixel values and pad to a square input."""
|
| 82 |
+
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
|
| 83 |
+
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
|
| 84 |
+
img_size = 1024
|
| 85 |
+
|
| 86 |
+
# Normalize colors
|
| 87 |
+
x = (x - pixel_mean) / pixel_std
|
| 88 |
+
|
| 89 |
+
# Pad
|
| 90 |
+
h, w = x.shape[-2:]
|
| 91 |
+
padh = img_size - h
|
| 92 |
+
padw = img_size - w
|
| 93 |
+
x = F.pad(x, (0, padw, 0, padh))
|
| 94 |
+
return x
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def reshape_images_to_raw_grid(mm_features_raw, grid_thws):
|
| 98 |
+
start_idx=0
|
| 99 |
+
reshaped_features = []
|
| 100 |
+
# for thw_group in grid_thws:
|
| 101 |
+
for tensor_thw in grid_thws:
|
| 102 |
+
# for tensor_thw in thw_group:
|
| 103 |
+
t, H, W = tensor_thw.squeeze().tolist()
|
| 104 |
+
num_elements = H * W
|
| 105 |
+
for i in range(t):
|
| 106 |
+
split_tensor = mm_features_raw[start_idx:start_idx + num_elements].view(H, W, -1)
|
| 107 |
+
reshaped_features.append(split_tensor)
|
| 108 |
+
|
| 109 |
+
start_idx += num_elements
|
| 110 |
+
|
| 111 |
+
assert len(mm_features_raw)==start_idx
|
| 112 |
+
return reshaped_features
|
| 113 |
+
|
| 114 |
+
def annToMask(mask_ann, h=None, w=None):
|
| 115 |
+
if isinstance(mask_ann, list):
|
| 116 |
+
rles = maskUtils.frPyObjects(mask_ann, h, w)
|
| 117 |
+
rle = maskUtils.merge(rles)
|
| 118 |
+
elif isinstance(mask_ann['counts'], list):
|
| 119 |
+
# uncompressed RLE
|
| 120 |
+
rle = maskUtils.frPyObjects(mask_ann, h, w)
|
| 121 |
+
else:
|
| 122 |
+
# rle
|
| 123 |
+
rle = mask_ann
|
| 124 |
+
mask = maskUtils.decode(rle)
|
| 125 |
+
return mask
|
| 126 |
+
|
| 127 |
+
def chunk_list(input_list, chunk_size):
|
| 128 |
+
return [input_list[i:i + chunk_size] for i in range(0, len(input_list), chunk_size)]
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def load_image_from_base64(image):
|
| 132 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def expand2square(pil_img, background_color):
|
| 136 |
+
width, height = pil_img.size
|
| 137 |
+
if width == height:
|
| 138 |
+
return pil_img
|
| 139 |
+
elif width > height:
|
| 140 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
| 141 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
| 142 |
+
return result
|
| 143 |
+
else:
|
| 144 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
| 145 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
| 146 |
+
return result
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def grid_divide(image, cell_size):
|
| 150 |
+
"""
|
| 151 |
+
Divides an image into grid of a specified size.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
image (PIL.Image.Image): The input image.
|
| 155 |
+
cell_size (int): The size of each cell.
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
list: A list of PIL.Image.Image objects representing the patches.
|
| 159 |
+
"""
|
| 160 |
+
grid = []
|
| 161 |
+
width, height = image.size
|
| 162 |
+
for i in range(0, height, cell_size):
|
| 163 |
+
row = []
|
| 164 |
+
for j in range(0, width, cell_size):
|
| 165 |
+
box = (j, i, j + cell_size, i + cell_size)
|
| 166 |
+
row.append(image.crop(box))
|
| 167 |
+
grid.append(row)
|
| 168 |
+
|
| 169 |
+
return grid
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def load_images(image_path):
|
| 173 |
+
if isinstance(image_path, str) and os.path.isfile(image_path):
|
| 174 |
+
# images = [cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)]
|
| 175 |
+
images = [Image.open(image_path).convert('RGB')]
|
| 176 |
+
elif isinstance(image_path, str) and os.path.isdir(image_path):
|
| 177 |
+
# images = [cv2.cvtColor(cv2.imread(os.path.join(image_path, f)), cv2.COLOR_BGR2RGB) for f in sorted(os.listdir(image_path))]
|
| 178 |
+
images = [Image.open(os.path.join(image_path, f)).convert('RGB') for f in sorted(os.listdir(image_path))]
|
| 179 |
+
elif isinstance(image_path, list) and isinstance(image_path[0], str):
|
| 180 |
+
# images = [cv2.cvtColor(cv2.imread(f), cv2.COLOR_BGR2RGB) for f in image_path]
|
| 181 |
+
images = [Image.open(f).convert('RGB') for f in image_path]
|
| 182 |
+
elif isinstance(image_path, list) and isinstance(image_path[0], Image.Image):
|
| 183 |
+
images = image_path
|
| 184 |
+
elif isinstance(image_path, Image.Image):
|
| 185 |
+
images = [image_path]
|
| 186 |
+
else:
|
| 187 |
+
raise ValueError(f"Unsupported image path type: {image_path}")
|
| 188 |
+
|
| 189 |
+
return images
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def process_pad_image(image, padding_value=(0, 0, 0)):
|
| 193 |
+
image = expand2square(image, padding_value)
|
| 194 |
+
|
| 195 |
+
return [image]
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def find_closest_aspect_ratio(src_ratio, tgt_ratios, ori_size, tgt_size):
|
| 199 |
+
best_ratio_diff = float('inf')
|
| 200 |
+
best_ratio = (1, 1)
|
| 201 |
+
area = ori_size[0] * ori_size[1]
|
| 202 |
+
for ratio in tgt_ratios:
|
| 203 |
+
tgt_ratio = ratio[0] / ratio[1]
|
| 204 |
+
ratio_diff = abs(src_ratio - tgt_ratio)
|
| 205 |
+
if ratio_diff < best_ratio_diff:
|
| 206 |
+
best_ratio_diff = ratio_diff
|
| 207 |
+
best_ratio = ratio
|
| 208 |
+
elif ratio_diff == best_ratio_diff:
|
| 209 |
+
if area > 0.5 * tgt_size[0] * tgt_size[1] * ratio[0] * ratio[1]:
|
| 210 |
+
best_ratio = ratio
|
| 211 |
+
|
| 212 |
+
return best_ratio
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def process_dynamic_image(image, image_size=384, use_thumbnail=True):
|
| 216 |
+
# Grid Params:
|
| 217 |
+
min_num = 1
|
| 218 |
+
max_num = 12
|
| 219 |
+
|
| 220 |
+
if isinstance(image_size, int):
|
| 221 |
+
image_size = (image_size, image_size)
|
| 222 |
+
|
| 223 |
+
ori_size = image.size
|
| 224 |
+
aspect_ratio = ori_size[0] / ori_size[1]
|
| 225 |
+
|
| 226 |
+
# calculate the existing image aspect ratio
|
| 227 |
+
tgt_ratios = []
|
| 228 |
+
for n in range(min_num, max_num + 1):
|
| 229 |
+
tgt_ratios.extend([(i, j) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num])
|
| 230 |
+
tgt_ratios = set(tgt_ratios)
|
| 231 |
+
tgt_ratios = sorted(tgt_ratios, key=lambda x: x[0] * x[1])
|
| 232 |
+
|
| 233 |
+
# find the closest aspect ratio to the target
|
| 234 |
+
tgt_ratio = find_closest_aspect_ratio(aspect_ratio, tgt_ratios, ori_size, image_size)
|
| 235 |
+
|
| 236 |
+
# resize the image to the target size
|
| 237 |
+
tgt_width = image_size[0] * tgt_ratio[0]
|
| 238 |
+
tgt_height = image_size[1] * tgt_ratio[1]
|
| 239 |
+
resized_img = image.resize((tgt_width, tgt_height))
|
| 240 |
+
|
| 241 |
+
# NOTE: internvl2 style split the image into one column grids
|
| 242 |
+
# num_grids = tgt_ratio[0] * tgt_ratio[1]
|
| 243 |
+
# grid_images = []
|
| 244 |
+
# for i in range(num_grids):
|
| 245 |
+
# box = (
|
| 246 |
+
# (i % tgt_ratio[0]) * image_size[0],
|
| 247 |
+
# (i // tgt_ratio[0]) * image_size[1],
|
| 248 |
+
# (i % tgt_ratio[0] + 1) * image_size[0],
|
| 249 |
+
# (i // tgt_ratio[0] + 1) * image_size[1],
|
| 250 |
+
# )
|
| 251 |
+
# # crop out the grid image
|
| 252 |
+
# grid_images.append(resized_img.crop(box))
|
| 253 |
+
# assert len(grid_images) == num_grids
|
| 254 |
+
# grid_images = [grid_images]
|
| 255 |
+
|
| 256 |
+
# NOTE: eager implementation
|
| 257 |
+
# num_grids = tgt_ratio[0] * tgt_ratio[1]
|
| 258 |
+
# sub_grid_images = []
|
| 259 |
+
# tmp_grid_images = []
|
| 260 |
+
# for i in range(num_grids):
|
| 261 |
+
# box = (
|
| 262 |
+
# (i % tgt_ratio[0]) * image_size[0],
|
| 263 |
+
# (i // tgt_ratio[0]) * image_size[1],
|
| 264 |
+
# (i % tgt_ratio[0] + 1) * image_size[0],
|
| 265 |
+
# (i // tgt_ratio[0] + 1) * image_size[1],
|
| 266 |
+
# )
|
| 267 |
+
# tmp_grid_images.append(resized_img.crop(box))
|
| 268 |
+
|
| 269 |
+
# if (i + 1) % tgt_ratio[0] == 0:
|
| 270 |
+
# sub_grid_images.append(tmp_grid_images)
|
| 271 |
+
# tmp_grid_images = []
|
| 272 |
+
|
| 273 |
+
image_grid = grid_divide(resized_img, image_size[0])
|
| 274 |
+
|
| 275 |
+
if use_thumbnail:
|
| 276 |
+
thumbnail_img = image.resize((image_size[0], image_size[1]))
|
| 277 |
+
image_grid = [[thumbnail_img]] + image_grid
|
| 278 |
+
|
| 279 |
+
return image_grid
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def process_highres_image(image_path, image_size=384, use_thumbnail=True, padding_value=(0, 0, 0)):
|
| 283 |
+
# Grid Params:
|
| 284 |
+
grid_width = [1, 2, 3]
|
| 285 |
+
grid_width_real = [x * image_size for x in grid_width]
|
| 286 |
+
|
| 287 |
+
longest_side = max(image.size)
|
| 288 |
+
fit_grid_width_real = [x for x in grid_width_real if x >= longest_side]
|
| 289 |
+
if len(fit_grid_width_real) == 0:
|
| 290 |
+
select_size = max(grid_width_real)
|
| 291 |
+
else:
|
| 292 |
+
select_size = min(fit_grid_width_real)
|
| 293 |
+
|
| 294 |
+
image_padded = expand2square(image, padding_value)
|
| 295 |
+
image_padded = image_padded.resize((select_size, select_size))
|
| 296 |
+
image_grid = grid_divide(image_padded, image_size)
|
| 297 |
+
|
| 298 |
+
if use_thumbnail:
|
| 299 |
+
thumbnail_img = image.resize((image_size, image_size))
|
| 300 |
+
image_grid = [[thumbnail_img]] + image_grid
|
| 301 |
+
|
| 302 |
+
return image_grid
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def select_best_resolution(original_size, possible_resolutions):
|
| 306 |
+
"""
|
| 307 |
+
Selects the best resolution from a list of possible resolutions based on the original size.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
original_size (tuple): The original size of the image in the format (width, height).
|
| 311 |
+
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
tuple: The best fit resolution in the format (width, height).
|
| 315 |
+
"""
|
| 316 |
+
original_width, original_height = original_size
|
| 317 |
+
best_fit = None
|
| 318 |
+
max_effective_resolution = 0
|
| 319 |
+
min_wasted_resolution = float('inf')
|
| 320 |
+
|
| 321 |
+
for width, height in possible_resolutions:
|
| 322 |
+
scale = min(width / original_width, height / original_height)
|
| 323 |
+
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
|
| 324 |
+
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
|
| 325 |
+
wasted_resolution = (width * height) - effective_resolution
|
| 326 |
+
|
| 327 |
+
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
|
| 328 |
+
max_effective_resolution = effective_resolution
|
| 329 |
+
min_wasted_resolution = wasted_resolution
|
| 330 |
+
best_fit = (width, height)
|
| 331 |
+
|
| 332 |
+
return best_fit
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def process_anyres_image(image, image_size=384, use_thumbnail=True, padding_value=(0, 0, 0)):
|
| 336 |
+
"""
|
| 337 |
+
Process an image with variable resolutions.
|
| 338 |
+
|
| 339 |
+
Args:
|
| 340 |
+
image (PIL.Image.Image): The input image to be processed.
|
| 341 |
+
processor: The image processor object.
|
| 342 |
+
|
| 343 |
+
Returns:
|
| 344 |
+
torch.Tensor: A tensor containing the processed image patches.
|
| 345 |
+
"""
|
| 346 |
+
# Grid Params:
|
| 347 |
+
possible_grids = [(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3)]
|
| 348 |
+
possible_resolutions = [(x * image_size, y * image_size) for x, y in possible_grids]
|
| 349 |
+
|
| 350 |
+
best_resolution = select_best_resolution(image.size, possible_resolutions)
|
| 351 |
+
|
| 352 |
+
# resize and padding image
|
| 353 |
+
nw, nh = best_resolution
|
| 354 |
+
ow, oh = image.size
|
| 355 |
+
|
| 356 |
+
scale_factor = min(nw / ow, nh / oh)
|
| 357 |
+
new_size = (int(ow * scale_factor), int(oh * scale_factor))
|
| 358 |
+
|
| 359 |
+
image_padded = Image.new("RGB", (nw, nh), padding_value)
|
| 360 |
+
image_padded.paste(image.resize(new_size), ((nw - new_size[0]) // 2, (nh - new_size[1]) // 2))
|
| 361 |
+
|
| 362 |
+
image_grid = grid_divide(image_padded, image_size)
|
| 363 |
+
|
| 364 |
+
if use_thumbnail:
|
| 365 |
+
thumbnail_img = image.resize((image_size, image_size))
|
| 366 |
+
image_grid = [[thumbnail_img]] + image_grid
|
| 367 |
+
|
| 368 |
+
return image_grid
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def process_adares_image(image_path, image_size=384, use_thumbnail=True):
|
| 372 |
+
# Grid Params:
|
| 373 |
+
min_num = 1
|
| 374 |
+
max_num = 12
|
| 375 |
+
|
| 376 |
+
if isinstance(image_size, int):
|
| 377 |
+
image_size = (image_size, image_size)
|
| 378 |
+
|
| 379 |
+
ori_size = image.size
|
| 380 |
+
aspect_ratio = ori_size[0] / ori_size[1]
|
| 381 |
+
|
| 382 |
+
# calculate the existing image aspect ratio
|
| 383 |
+
tgt_ratios = []
|
| 384 |
+
for n in range(min_num, max_num + 1):
|
| 385 |
+
tgt_ratios.extend([(i, j) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num])
|
| 386 |
+
tgt_ratios = set(tgt_ratios)
|
| 387 |
+
possible_resolutions = [(x * image_size[0], y * image_size[1]) for x, y in tgt_ratios]
|
| 388 |
+
|
| 389 |
+
# find the most possible resolution
|
| 390 |
+
best_resolution = select_best_resolution(ori_size, possible_resolutions)
|
| 391 |
+
|
| 392 |
+
# resize the image to the target size
|
| 393 |
+
resized_img = image.resize((best_resolution[0], best_resolution[1]))
|
| 394 |
+
|
| 395 |
+
image_grid = grid_divide(resized_img, image_size[0])
|
| 396 |
+
|
| 397 |
+
if use_thumbnail:
|
| 398 |
+
thumbnail_img = image.resize((image_size[0], image_size[1]))
|
| 399 |
+
image_grid = [[thumbnail_img]] + image_grid
|
| 400 |
+
|
| 401 |
+
return image_grid
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def process_images(image_path, processor, aspect_ratio='pad', image_size=384, use_thumbnail=True):
|
| 405 |
+
images = load_images(image_path)
|
| 406 |
+
|
| 407 |
+
padding_value = tuple(int(x*255) for x in processor.image_mean)
|
| 408 |
+
|
| 409 |
+
image_grids = []
|
| 410 |
+
for image in images:
|
| 411 |
+
if aspect_ratio == 'pad':
|
| 412 |
+
image_grid = process_pad_image(image, padding_value=padding_value)
|
| 413 |
+
elif aspect_ratio == 'dynamic':
|
| 414 |
+
image_grid = process_dynamic_image(image, image_size=image_size, use_thumbnail=use_thumbnail)
|
| 415 |
+
elif aspect_ratio == 'highres':
|
| 416 |
+
image_grid = process_highres_image(image, image_size=image_size, use_thumbnail=use_thumbnail, padding_value=padding_value)
|
| 417 |
+
elif aspect_ratio == 'anyres':
|
| 418 |
+
image_grid = process_anyres_image(image, image_size=image_size, use_thumbnail=use_thumbnail, padding_value=padding_value)
|
| 419 |
+
elif aspect_ratio == 'adares':
|
| 420 |
+
image_grid = process_adares_image(image, image_size=image_size, use_thumbnail=use_thumbnail)
|
| 421 |
+
else:
|
| 422 |
+
image_grid = [image]
|
| 423 |
+
|
| 424 |
+
image_grid = [processor.preprocess(image_row, return_tensors='pt', num_images=len(images)) for image_row in image_grid]
|
| 425 |
+
image_grids.append(image_grid)
|
| 426 |
+
|
| 427 |
+
return image_grids
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def frame_sample(duration, mode='uniform', num_frames=None, vid_fps=None, fps=None, must_sample_frames=None):
|
| 431 |
+
mask_ids = []
|
| 432 |
+
if mode == 'uniform':
|
| 433 |
+
assert num_frames is not None, "Number of frames must be provided for uniform sampling."
|
| 434 |
+
if duration <= num_frames:
|
| 435 |
+
video_ids = np.arange(duration).astype(int)
|
| 436 |
+
video_ids_list = video_ids.tolist()
|
| 437 |
+
for msf in must_sample_frames:
|
| 438 |
+
if msf not in video_ids_list:
|
| 439 |
+
video_ids_list.append(msf)
|
| 440 |
+
video_ids_list.sort()
|
| 441 |
+
for msf in must_sample_frames:
|
| 442 |
+
mask_ids.append(video_ids_list.index(msf))
|
| 443 |
+
return np.array(video_ids_list), mask_ids
|
| 444 |
+
video_ids = np.linspace(0, duration-1, num_frames, dtype=int)
|
| 445 |
+
video_ids_list = video_ids.tolist()
|
| 446 |
+
if must_sample_frames is not None:
|
| 447 |
+
for msf in must_sample_frames:
|
| 448 |
+
if msf not in video_ids_list:
|
| 449 |
+
video_ids_list.append(msf)
|
| 450 |
+
video_ids_list.sort()
|
| 451 |
+
for msf in must_sample_frames:
|
| 452 |
+
mask_ids.append(video_ids_list.index(msf))
|
| 453 |
+
return np.array(video_ids_list), mask_ids
|
| 454 |
+
elif mode == 'fps':
|
| 455 |
+
assert vid_fps is not None, "FPS must be provided for FPS sampling."
|
| 456 |
+
fps = fps if fps is not None else NUM_FRAMES_PER_SECOND
|
| 457 |
+
segment_len = min(vid_fps // fps, duration)
|
| 458 |
+
video_ids = np.arange(segment_len // 2, duration, segment_len, dtype=int)
|
| 459 |
+
video_ids_list = video_ids.tolist()
|
| 460 |
+
if must_sample_frames is not None:
|
| 461 |
+
for msf in must_sample_frames:
|
| 462 |
+
if msf not in video_ids_list:
|
| 463 |
+
video_ids_list.append(msf)
|
| 464 |
+
video_ids_list.sort()
|
| 465 |
+
for msf in must_sample_frames:
|
| 466 |
+
mask_ids.append(video_ids_list.index(msf))
|
| 467 |
+
return np.array(video_ids_list), mask_ids
|
| 468 |
+
|
| 469 |
+
else:
|
| 470 |
+
raise ImportError(f'Unsupported frame sampling mode: {mode}')
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def load_video_from_ids(video_path, s=None, e=None, fps=None, max_frames=None, temporal_factor=1, must_sample_frames=None):
|
| 474 |
+
if s is not None and e is not None:
|
| 475 |
+
s = s if s >= 0. else 0.
|
| 476 |
+
e = e if e >= 0. else 0.
|
| 477 |
+
if s > e:
|
| 478 |
+
s, e = e, s
|
| 479 |
+
elif s == e:
|
| 480 |
+
e = s + 1
|
| 481 |
+
|
| 482 |
+
# 1. Loading Video
|
| 483 |
+
if os.path.isdir(video_path):
|
| 484 |
+
frame_files = sorted(os.listdir(video_path))
|
| 485 |
+
|
| 486 |
+
vid_fps = 3
|
| 487 |
+
num_frames_of_video = len(frame_files)
|
| 488 |
+
elif video_path.endswith('.gif'):
|
| 489 |
+
gif_reader = imageio.get_reader(video_path)
|
| 490 |
+
|
| 491 |
+
vid_fps = 25
|
| 492 |
+
num_frames_of_video = len(gif_reader)
|
| 493 |
+
else:
|
| 494 |
+
vreader = VideoReader(video_path, ctx=cpu(0), num_threads=2)
|
| 495 |
+
# vreader = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
| 496 |
+
|
| 497 |
+
vid_fps = vreader.get_avg_fps()
|
| 498 |
+
num_frames_of_video = len(vreader)
|
| 499 |
+
|
| 500 |
+
# 2. Determine frame range & Calculate frame indices
|
| 501 |
+
f_start = 0 if s is None else max(int(s * vid_fps) - 1, 0)
|
| 502 |
+
f_end = num_frames_of_video - 1 if e is None else min(int(e * vid_fps) - 1, num_frames_of_video - 1)
|
| 503 |
+
frame_indices = list(range(f_start, f_end + 1))
|
| 504 |
+
|
| 505 |
+
duration = len(frame_indices)
|
| 506 |
+
# 3. Sampling frame indices
|
| 507 |
+
max_frames = max_frames if max_frames is not None else MAX_FRAMES
|
| 508 |
+
if fps is not None and duration / vid_fps < max_frames:
|
| 509 |
+
sampled_ids, mask_ids = frame_sample(duration, mode='fps', vid_fps=vid_fps, fps=fps, must_sample_frames=must_sample_frames)
|
| 510 |
+
sampled_frame_indices = [frame_indices[i] for i in sampled_ids]
|
| 511 |
+
else:
|
| 512 |
+
sampled_ids, mask_ids = frame_sample(duration, mode='uniform', num_frames=max_frames, must_sample_frames=must_sample_frames)
|
| 513 |
+
sampled_frame_indices = [frame_indices[i] for i in sampled_ids]
|
| 514 |
+
|
| 515 |
+
# 4. Acquire frame data
|
| 516 |
+
if os.path.isdir(video_path):
|
| 517 |
+
frames = [cv2.cvtColor(cv2.imread(os.path.join(video_path, frame_files[frame_idx])), cv2.COLOR_BGR2RGB) for frame_idx in sampled_frame_indices]
|
| 518 |
+
elif video_path.endswith('.gif'):
|
| 519 |
+
frames = [cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) for idx, frame in enumerate(gif_reader) if idx in sampled_frame_indices]
|
| 520 |
+
else:
|
| 521 |
+
frames = vreader.get_batch(sampled_frame_indices).asnumpy()
|
| 522 |
+
|
| 523 |
+
# frames = frames.transpose(0, 3, 1, 2)
|
| 524 |
+
timestamps = [x / vid_fps for x in sampled_frame_indices]
|
| 525 |
+
|
| 526 |
+
if temporal_factor > 1:
|
| 527 |
+
pad_length = temporal_factor - len(frames) % temporal_factor
|
| 528 |
+
frames = np.concatenate([frames, frames[-1:].repeat(pad_length, axis=0)])
|
| 529 |
+
[timestamps.append(timestamps[-1] + 1 / fps) for _ in range(pad_length)]
|
| 530 |
+
|
| 531 |
+
# NOTE: pad the video with black frames
|
| 532 |
+
# while num_frames is not None and len(video_data) < num_frames:
|
| 533 |
+
# video_data.append(Image.fromarray(np.zeros((*video_data[-1].size, 3), dtype=np.uint8)))
|
| 534 |
+
|
| 535 |
+
return frames, timestamps, mask_ids
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def load_video(
|
| 539 |
+
video_path: str,
|
| 540 |
+
start_time: Optional[float] = None,
|
| 541 |
+
end_time: Optional[float] = None,
|
| 542 |
+
fps: Optional[float] = None,
|
| 543 |
+
max_frames: Optional[float] = None,
|
| 544 |
+
size: Optional[int] = None,
|
| 545 |
+
size_divisible: int = 1,
|
| 546 |
+
precise_time: bool = False,
|
| 547 |
+
verbose: bool = False,
|
| 548 |
+
temporal_factor: int = 1
|
| 549 |
+
):
|
| 550 |
+
"""
|
| 551 |
+
Load and process a video file and return the frames and the timestamps of each frame.
|
| 552 |
+
|
| 553 |
+
Args:
|
| 554 |
+
video_path (str): Path to the video file.
|
| 555 |
+
start_time (float, optional): Start time in seconds. Defaults to None.
|
| 556 |
+
end_time (float, optional): End time in seconds. Defaults to None.
|
| 557 |
+
fps (float, optional): Frames per second. Defaults to None.
|
| 558 |
+
num_frames (float, optional): Number of frames to sample. Defaults to None.
|
| 559 |
+
size (int, optional): Size of the shortest side. Defaults to None.
|
| 560 |
+
size_divisible (int, optional): Size divisible by this number. Defaults to 1.
|
| 561 |
+
precise_time (bool, optional): Whether to use precise time. Defaults to False.
|
| 562 |
+
verbose (bool, optional): Print ffmpeg output. Defaults to False.
|
| 563 |
+
|
| 564 |
+
Returns:
|
| 565 |
+
frames (List[PIL.Image]): List of frames.
|
| 566 |
+
timestamps (List[float]): List of timestamps.
|
| 567 |
+
"""
|
| 568 |
+
if start_time is not None and end_time is not None and end_time - start_time < 1:
|
| 569 |
+
return load_video_from_ids(video_path, start_time, end_time, fps=fps, max_frames=max_frames)
|
| 570 |
+
if os.path.isdir(video_path):
|
| 571 |
+
return load_video_from_ids(video_path, start_time, end_time, fps=fps, max_frames=max_frames)
|
| 572 |
+
if video_path.endswith('.gif'):
|
| 573 |
+
return load_video_from_ids(video_path, start_time, end_time, fps=fps, max_frames=max_frames)
|
| 574 |
+
probe = ffmpeg.probe(video_path)
|
| 575 |
+
duration = float(probe['format']['duration'])
|
| 576 |
+
video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
|
| 577 |
+
w, h = int(video_stream['width']), int(video_stream['height'])
|
| 578 |
+
|
| 579 |
+
kwargs, input_kwargs, output_kwargs = {}, {}, {}
|
| 580 |
+
do_trim = start_time is not None or end_time is not None
|
| 581 |
+
if start_time is not None:
|
| 582 |
+
new_start_time = max(float(video_stream['start_time']), start_time)
|
| 583 |
+
duration -= new_start_time - start_time
|
| 584 |
+
start_time = new_start_time
|
| 585 |
+
else:
|
| 586 |
+
start_time = float(video_stream['start_time'])
|
| 587 |
+
if end_time is not None:
|
| 588 |
+
duration = min(duration, end_time - start_time)
|
| 589 |
+
else:
|
| 590 |
+
duration = duration
|
| 591 |
+
if do_trim:
|
| 592 |
+
kwargs = {'ss': start_time, 't': duration}
|
| 593 |
+
if precise_time:
|
| 594 |
+
output_kwargs.update(kwargs)
|
| 595 |
+
else:
|
| 596 |
+
input_kwargs.update(kwargs)
|
| 597 |
+
|
| 598 |
+
if size is not None:
|
| 599 |
+
scale_factor = size / min(w, h)
|
| 600 |
+
new_w, new_h = round(w * scale_factor), round(h * scale_factor)
|
| 601 |
+
else:
|
| 602 |
+
new_w, new_h = w, h
|
| 603 |
+
new_w = new_w // size_divisible * size_divisible
|
| 604 |
+
new_h = new_h // size_divisible * size_divisible
|
| 605 |
+
|
| 606 |
+
# NOTE: It may result in unexpected number of frames in ffmpeg
|
| 607 |
+
# if calculate the fps directly according to max_frames
|
| 608 |
+
# NOTE: the below lines may hurt the performance
|
| 609 |
+
# if max_frames is not None and (fps is None or duration * fps > 2 * max_frames):
|
| 610 |
+
# fps = max_frames / duration * 2
|
| 611 |
+
|
| 612 |
+
stream = ffmpeg.input(video_path, **input_kwargs)
|
| 613 |
+
if fps is not None:
|
| 614 |
+
stream = ffmpeg.filter(stream, "fps", fps=fps, round="down")
|
| 615 |
+
if new_w != w or new_h != h:
|
| 616 |
+
stream = ffmpeg.filter(stream, 'scale', new_w, new_h)
|
| 617 |
+
stream = ffmpeg.output(stream, "pipe:", format="rawvideo", pix_fmt="rgb24", **output_kwargs)
|
| 618 |
+
out, _ = ffmpeg.run(stream, capture_stdout=True, quiet=not verbose)
|
| 619 |
+
|
| 620 |
+
frames = np.frombuffer(out, np.uint8).reshape([-1, new_h, new_w, 3]).transpose([0, 3, 1, 2])
|
| 621 |
+
|
| 622 |
+
if fps is not None:
|
| 623 |
+
timestamps = np.arange(start_time, start_time + duration + 1 / fps, 1 / fps)[:len(frames)]
|
| 624 |
+
else:
|
| 625 |
+
timestamps = np.linspace(start_time, start_time + duration, len(frames))
|
| 626 |
+
|
| 627 |
+
max_frames = max_frames if max_frames is not None else MAX_FRAMES
|
| 628 |
+
if max_frames is not None and len(frames) > max_frames:
|
| 629 |
+
indices = np.linspace(0, len(frames) - 1, max_frames, dtype=int)
|
| 630 |
+
frames = frames[indices]
|
| 631 |
+
timestamps = [timestamps[i] for i in indices]
|
| 632 |
+
|
| 633 |
+
if temporal_factor > 1:
|
| 634 |
+
pad_length = temporal_factor - len(frames) % temporal_factor
|
| 635 |
+
frames = np.concatenate([frames, frames[-1:].repeat(pad_length, axis=0)])
|
| 636 |
+
[timestamps.append(timestamps[-1] + 1 / fps) for _ in range(pad_length)]
|
| 637 |
+
|
| 638 |
+
frames = [frame for frame in frames]
|
| 639 |
+
|
| 640 |
+
return frames, timestamps
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
def process_video(video_path, processor, s=None, e=None, aspect_ratio='pad', num_frames=None):
|
| 644 |
+
fps = 1 if num_frames is None else None
|
| 645 |
+
# FFmpeg
|
| 646 |
+
frames, timestamps = load_video(video_path, s, e, fps=fps, max_frames=num_frames)
|
| 647 |
+
# Decord
|
| 648 |
+
# frames, timestamps = load_video_from_ids(video_path, s, e, fps=fps, max_frames=num_frames)
|
| 649 |
+
|
| 650 |
+
assert len(frames) == len(timestamps), "Number of frames and timestamps must match."
|
| 651 |
+
|
| 652 |
+
if aspect_ratio == 'pad':
|
| 653 |
+
frames = [expand2square(f, tuple(int(x*255) for x in processor.image_mean)) for f in frames]
|
| 654 |
+
|
| 655 |
+
if aspect_ratio == 'avt':
|
| 656 |
+
frames = [processor.preprocess(frame, return_tensors='pt', image_num=len(frames)) for frame in frames]
|
| 657 |
+
grid_frames = [frames]
|
| 658 |
+
else:
|
| 659 |
+
frames = processor.preprocess(frames, return_tensors='pt', image_num=len(frames))
|
| 660 |
+
grid_frames = [[frames]]
|
| 661 |
+
|
| 662 |
+
return grid_frames, timestamps
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
def tokenizer_multimodal_token(prompt, tokenizer, multimodal_token=DEFAULT_IMAGE_TOKEN, return_tensors=None):
|
| 666 |
+
"""Tokenize text and multimodal tag to input_ids.
|
| 667 |
+
|
| 668 |
+
Args:
|
| 669 |
+
prompt (str): Text prompt (w/ multimodal tag), e.g., '<video>\nDescribe the video.'
|
| 670 |
+
tokenizer (transformers.PreTrainedTokenizer): Tokenizer object.
|
| 671 |
+
multimodal_token (int): Token index corresponding to the multimodal tag.
|
| 672 |
+
"""
|
| 673 |
+
multimodal_token_index = MODAL_INDEX_MAP.get(multimodal_token, None)
|
| 674 |
+
if multimodal_token_index is None:
|
| 675 |
+
input_ids = tokenizer(prompt, add_special_tokens=False).input_ids
|
| 676 |
+
else:
|
| 677 |
+
prompt_chunks = [tokenizer(chunk, add_special_tokens=False).input_ids for idx, chunk in enumerate(prompt.split(multimodal_token))]
|
| 678 |
+
|
| 679 |
+
input_ids = []
|
| 680 |
+
for i in range(1, 2 * len(prompt_chunks)):
|
| 681 |
+
if i % 2 == 1:
|
| 682 |
+
input_ids.extend(prompt_chunks[i // 2])
|
| 683 |
+
else:
|
| 684 |
+
input_ids.append(multimodal_token_index)
|
| 685 |
+
|
| 686 |
+
if return_tensors is not None:
|
| 687 |
+
if return_tensors == 'pt':
|
| 688 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
| 689 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
| 690 |
+
return input_ids
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
def get_model_name_from_path(model_path):
|
| 694 |
+
model_path = model_path.strip("/")
|
| 695 |
+
model_paths = model_path.split("/")
|
| 696 |
+
if model_paths[-1].startswith('checkpoint-'):
|
| 697 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
| 698 |
+
else:
|
| 699 |
+
return model_paths[-1]
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
| 703 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
| 704 |
+
self.keywords = keywords
|
| 705 |
+
self.keyword_ids = []
|
| 706 |
+
self.max_keyword_len = 0
|
| 707 |
+
for keyword in keywords:
|
| 708 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
| 709 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
| 710 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
| 711 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
| 712 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
| 713 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
| 714 |
+
self.tokenizer = tokenizer
|
| 715 |
+
self.start_len = input_ids.shape[1]
|
| 716 |
+
|
| 717 |
+
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
| 718 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
| 719 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
| 720 |
+
for keyword_id in self.keyword_ids:
|
| 721 |
+
if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
|
| 722 |
+
return True
|
| 723 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
| 724 |
+
for keyword in self.keywords:
|
| 725 |
+
if keyword in outputs:
|
| 726 |
+
return True
|
| 727 |
+
return False
|
| 728 |
+
|
| 729 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
| 730 |
+
outputs = []
|
| 731 |
+
for i in range(output_ids.shape[0]):
|
| 732 |
+
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
|
| 733 |
+
return all(outputs)
|
RynnEC/rynnec/model/__init__.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
|
| 2 |
+
# Copyright 2023 Haotian Liu
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import warnings
|
| 19 |
+
import shutil
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from transformers import PretrainedConfig, AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig, AutoProcessor
|
| 23 |
+
|
| 24 |
+
from .projector import load_mm_projector
|
| 25 |
+
from .videollama3_encoder import Videollama3VisionEncoderModel, Videollama3VisionEncoderConfig
|
| 26 |
+
from .rynnec_qwen2 import RynnecQwen2ForCausalLM, RynnecQwen2Config, Videollama3Qwen2Processor
|
| 27 |
+
|
| 28 |
+
def apply_liger_kernel_to_rynnec():
|
| 29 |
+
from liger_kernel.transformers import (
|
| 30 |
+
apply_liger_kernel_to_mistral,
|
| 31 |
+
apply_liger_kernel_to_qwen2,
|
| 32 |
+
apply_liger_kernel_to_qwen3,
|
| 33 |
+
apply_liger_kernel_to_qwen3_moe,
|
| 34 |
+
)
|
| 35 |
+
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
| 36 |
+
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
| 37 |
+
from .videollama3_encoder import modeling_videollama3_encoder
|
| 38 |
+
|
| 39 |
+
apply_liger_kernel_to_mistral()
|
| 40 |
+
apply_liger_kernel_to_qwen2()
|
| 41 |
+
|
| 42 |
+
modeling_videollama3_encoder.apply_rotary_pos_emb_vision = liger_rotary_pos_emb
|
| 43 |
+
modeling_videollama3_encoder.LayerNorm = LigerLayerNorm
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", **kwargs):
|
| 47 |
+
if 'token' in kwargs:
|
| 48 |
+
token = kwargs['token']
|
| 49 |
+
else:
|
| 50 |
+
token = None
|
| 51 |
+
|
| 52 |
+
# NOTE: auto device_map by default
|
| 53 |
+
# if want to put model into a single device, you can set device_map={"": "cuda:0"}
|
| 54 |
+
kwargs = {"device_map": device_map, **kwargs}
|
| 55 |
+
|
| 56 |
+
config = AutoConfig.from_pretrained(model_path)
|
| 57 |
+
config._attn_implementation = kwargs.pop('attn_implementation', "flash_attention_2") # default to flash_attention_2
|
| 58 |
+
|
| 59 |
+
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else kwargs.pop('torch_dtype', torch.float16)
|
| 60 |
+
|
| 61 |
+
if load_8bit:
|
| 62 |
+
kwargs['load_in_8bit'] = True
|
| 63 |
+
elif load_4bit:
|
| 64 |
+
# NOTE: High-version Transformers will report: """ValueError: You can't pass `load_in_4bit`or `load_in_8bit` as a kwarg when passing `quantization_config` argument at the same time."""
|
| 65 |
+
# kwargs['load_in_4bit'] = True
|
| 66 |
+
kwargs['quantization_config'] = BitsAndBytesConfig(
|
| 67 |
+
load_in_4bit=True,
|
| 68 |
+
bnb_4bit_compute_dtype=torch_dtype,
|
| 69 |
+
bnb_4bit_use_double_quant=True,
|
| 70 |
+
bnb_4bit_quant_type='nf4'
|
| 71 |
+
)
|
| 72 |
+
else:
|
| 73 |
+
kwargs['torch_dtype'] = torch_dtype
|
| 74 |
+
|
| 75 |
+
# judge model type
|
| 76 |
+
model_type = config.model_type if hasattr(config, "model_type") else kwargs.pop('model_type', "rynnec_qwen2")
|
| 77 |
+
|
| 78 |
+
# judge pretrain/finetune
|
| 79 |
+
is_alignment = getattr(config, "tune_mm_mlp_adapter", False) or getattr(config, "is_alignment", False)
|
| 80 |
+
|
| 81 |
+
# NOTE: lora/qlora model loading
|
| 82 |
+
if 'lora' in model_name.lower() or 'qlora' in model_name.lower():
|
| 83 |
+
# if True:
|
| 84 |
+
cfg_pretrained = PretrainedConfig.from_pretrained(model_path, token=token)
|
| 85 |
+
# NOTE: AutoConfig will modify `_name_or_path` property to `model_path` if `model_path` is not None.
|
| 86 |
+
# cfg_pretrained = AutoConfig.from_pretrained(model_path, token=token)
|
| 87 |
+
model_base = model_base if model_base is not None else cfg_pretrained._name_or_path
|
| 88 |
+
|
| 89 |
+
# NOTE: remove qlora training quantization config
|
| 90 |
+
if hasattr(cfg_pretrained, 'quantization_config'):
|
| 91 |
+
del cfg_pretrained.quantization_config
|
| 92 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, token=token)
|
| 93 |
+
print('Loading RynnEC from base model...')
|
| 94 |
+
|
| 95 |
+
config_raw = AutoConfig.from_pretrained(model_base)
|
| 96 |
+
new_vocab_size = config.vocab_size
|
| 97 |
+
if config.vocab_size!=config_raw.vocab_size:
|
| 98 |
+
config.vocab_size = config_raw.vocab_size
|
| 99 |
+
config.training = False
|
| 100 |
+
|
| 101 |
+
if 'qwen2' in model_base.lower():
|
| 102 |
+
model = RynnecQwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs)
|
| 103 |
+
else:
|
| 104 |
+
model = RynnecQwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs)
|
| 105 |
+
|
| 106 |
+
model.config.mask_decoder_model = "./checkpoints/sam2_hiera_large.pt"
|
| 107 |
+
|
| 108 |
+
token_num, tokem_dim = new_vocab_size, model.lm_head.in_features
|
| 109 |
+
|
| 110 |
+
if model.lm_head.weight.shape[0] != token_num:
|
| 111 |
+
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
| 112 |
+
model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
| 113 |
+
|
| 114 |
+
print('Loading additional RynnEC weights...')
|
| 115 |
+
if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
|
| 116 |
+
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
|
| 117 |
+
else:
|
| 118 |
+
# this is probably from HF Hub
|
| 119 |
+
from huggingface_hub import hf_hub_download
|
| 120 |
+
def load_from_hf(repo_id, filename, subfolder=None):
|
| 121 |
+
cache_file = hf_hub_download(
|
| 122 |
+
repo_id=repo_id,
|
| 123 |
+
filename=filename,
|
| 124 |
+
subfolder=subfolder)
|
| 125 |
+
return torch.load(cache_file, map_location='cpu')
|
| 126 |
+
non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
|
| 127 |
+
|
| 128 |
+
# add
|
| 129 |
+
sam2_model = torch.load(model.config.mask_decoder_model, map_location='cpu')['model']
|
| 130 |
+
prefix = "base_model.model.grounding_encoder.sam2_model."
|
| 131 |
+
for param_name in sam2_model.keys():
|
| 132 |
+
new_param_name = prefix + param_name
|
| 133 |
+
if new_param_name not in non_lora_trainables.keys():
|
| 134 |
+
non_lora_trainables[new_param_name] = sam2_model[param_name]
|
| 135 |
+
|
| 136 |
+
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
|
| 137 |
+
if any(k.startswith('model.model.') for k in non_lora_trainables):
|
| 138 |
+
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
|
| 139 |
+
model.load_state_dict(non_lora_trainables, strict=False)
|
| 140 |
+
|
| 141 |
+
from peft import PeftModel
|
| 142 |
+
print('Loading LoRA weights...')
|
| 143 |
+
model = PeftModel.from_pretrained(model, model_path)
|
| 144 |
+
print('Merging LoRA weights...')
|
| 145 |
+
model = model.merge_and_unload()
|
| 146 |
+
print('Model is loaded...')
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
elif model_base is not None or '-base' in model_name.lower() or is_alignment:
|
| 150 |
+
# NOTE: Base/Pretrain model loading
|
| 151 |
+
print('Loading RynnEC from base model...')
|
| 152 |
+
cfg_pretrained = PretrainedConfig.from_pretrained(model_path, token=token)
|
| 153 |
+
# NOTE: AutoConfig will modify `_name_or_path` property to `model_path` if `model_path` is not None.
|
| 154 |
+
# cfg_pretrained = AutoConfig.from_pretrained(model_path, token=token)
|
| 155 |
+
model_base = model_base if model_base is not None else cfg_pretrained._name_or_path
|
| 156 |
+
|
| 157 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, token=token)
|
| 158 |
+
|
| 159 |
+
if model_type in ['rynnec', 'rynnec_qwen2']:
|
| 160 |
+
model = RynnecQwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs)
|
| 161 |
+
else:
|
| 162 |
+
model = RynnecQwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs)
|
| 163 |
+
|
| 164 |
+
# NOTE; loading vision-language projector
|
| 165 |
+
# * old codes for loading local mm_projector.bin
|
| 166 |
+
# mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
|
| 167 |
+
# mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
|
| 168 |
+
# model.load_state_dict(mm_projector_weights, strict=False)
|
| 169 |
+
# * new codes which supports loading mm_projector.bin both offline and online
|
| 170 |
+
mm_projector_weights = load_mm_projector(model_path, token=token)
|
| 171 |
+
model.load_state_dict(mm_projector_weights, strict=False)
|
| 172 |
+
elif 'rynnec' in model_type:
|
| 173 |
+
# NOTE: SFT model loading
|
| 174 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, token=token)
|
| 175 |
+
|
| 176 |
+
if model_type in ['rynnec_qwen2']:
|
| 177 |
+
model = RynnecQwen2ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=config, **kwargs)
|
| 178 |
+
else:
|
| 179 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=config, **kwargs)
|
| 180 |
+
else:
|
| 181 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, token=token)
|
| 182 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, config=config, **kwargs)
|
| 183 |
+
|
| 184 |
+
processor = None
|
| 185 |
+
|
| 186 |
+
# if "videollama" in model_type:
|
| 187 |
+
if True:
|
| 188 |
+
vision_encoder = model.get_vision_encoder()
|
| 189 |
+
processor = vision_encoder.image_processor
|
| 190 |
+
|
| 191 |
+
if hasattr(model.config, "max_sequence_length"):
|
| 192 |
+
context_len = model.config.max_sequence_length
|
| 193 |
+
else:
|
| 194 |
+
context_len = 2048
|
| 195 |
+
|
| 196 |
+
return tokenizer, model, processor, context_len
|
RynnEC/rynnec/model/encoder.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from transformers import (CLIPImageProcessor, CLIPVisionConfig,
|
| 6 |
+
CLIPVisionModel, SiglipImageProcessor,
|
| 7 |
+
SiglipVisionConfig, SiglipVisionModel)
|
| 8 |
+
|
| 9 |
+
from .videollama3_encoder import (Videollama3VisionEncoderConfig,
|
| 10 |
+
Videollama3VisionEncoderModel, Videollama3ImageProcessor)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class CLIPVisionEncoder(nn.Module):
|
| 14 |
+
|
| 15 |
+
def __init__(self, vision_encoder, args, delay_load=False):
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
self.is_loaded = False
|
| 19 |
+
|
| 20 |
+
self.vision_encoder_name = vision_encoder
|
| 21 |
+
self.select_layer = args.mm_vision_select_layer
|
| 22 |
+
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
| 23 |
+
|
| 24 |
+
if not delay_load:
|
| 25 |
+
self.attn_implementation = getattr(args, 'mm_attn_implementation', 'flash_attention_2')
|
| 26 |
+
self.load_model()
|
| 27 |
+
else:
|
| 28 |
+
# uncertain whether flash-attention-2 is supported during inference phase.
|
| 29 |
+
self.attn_implementation = 'sdpa' # 'eager'
|
| 30 |
+
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_encoder_name)
|
| 31 |
+
|
| 32 |
+
def load_model(self):
|
| 33 |
+
if self.is_loaded:
|
| 34 |
+
print('Vision tower is already loaded, `load model` call again, skipping.')
|
| 35 |
+
return
|
| 36 |
+
|
| 37 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_encoder_name)
|
| 38 |
+
|
| 39 |
+
self.vision_encoder = CLIPVisionModel.from_pretrained(self.vision_encoder_name,
|
| 40 |
+
attn_implementation=self.attn_implementation)
|
| 41 |
+
|
| 42 |
+
self.is_loaded = True
|
| 43 |
+
|
| 44 |
+
def feature_select(self, image_forward_outs):
|
| 45 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
| 46 |
+
if self.select_feature == 'patch':
|
| 47 |
+
image_features = image_features[:, 1:]
|
| 48 |
+
elif self.select_feature == 'cls_patch':
|
| 49 |
+
image_features = image_features
|
| 50 |
+
else:
|
| 51 |
+
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
| 52 |
+
return image_features
|
| 53 |
+
|
| 54 |
+
def forward(self, images, **kwargs):
|
| 55 |
+
images = torch.cat(images)
|
| 56 |
+
if type(images) is list:
|
| 57 |
+
image_features = []
|
| 58 |
+
for image in images:
|
| 59 |
+
image_forward_out = self.vision_encoder(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
| 60 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
| 61 |
+
image_features.append(image_feature)
|
| 62 |
+
else:
|
| 63 |
+
image_forward_outs = self.vision_encoder(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
| 64 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
| 65 |
+
|
| 66 |
+
return image_features
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
def dummy_feature(self):
|
| 70 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def dtype(self):
|
| 74 |
+
return self.vision_encoder.dtype
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
def device(self):
|
| 78 |
+
return self.vision_encoder.device
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
def config(self):
|
| 82 |
+
if self.is_loaded:
|
| 83 |
+
return self.vision_encoder.config
|
| 84 |
+
else:
|
| 85 |
+
return self.cfg_only
|
| 86 |
+
|
| 87 |
+
@property
|
| 88 |
+
def hidden_size(self):
|
| 89 |
+
return self.config.hidden_size
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def num_patches(self):
|
| 93 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
| 94 |
+
|
| 95 |
+
@property
|
| 96 |
+
def num_patches_per_side(self):
|
| 97 |
+
return self.config.image_size // self.config.patch_size
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def image_size(self):
|
| 101 |
+
return self.config.image_size
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class SiglipVisionEncoder(nn.Module):
|
| 105 |
+
|
| 106 |
+
def __init__(self, vision_encoder, args, delay_load=False):
|
| 107 |
+
super().__init__()
|
| 108 |
+
|
| 109 |
+
self.is_loaded = False
|
| 110 |
+
|
| 111 |
+
self.vision_encoder_name = vision_encoder
|
| 112 |
+
self.select_layer = args.mm_vision_select_layer
|
| 113 |
+
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
| 114 |
+
|
| 115 |
+
if not delay_load:
|
| 116 |
+
self.attn_implementation = getattr(args, 'mm_attn_implementation', 'flash_attention_2')
|
| 117 |
+
self.load_model()
|
| 118 |
+
else:
|
| 119 |
+
# uncertain whether flash-attention-2 is supported during inference phase.
|
| 120 |
+
self.attn_implementation = 'sdpa' # 'eager'
|
| 121 |
+
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_encoder_name)
|
| 122 |
+
|
| 123 |
+
def load_model(self):
|
| 124 |
+
if self.is_loaded:
|
| 125 |
+
print('Vision tower is already loaded, `load model` call again, skipping.')
|
| 126 |
+
return
|
| 127 |
+
|
| 128 |
+
self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_encoder_name)
|
| 129 |
+
|
| 130 |
+
self.vision_encoder = SiglipVisionModel.from_pretrained(self.vision_encoder_name,
|
| 131 |
+
attn_implementation=self.attn_implementation)
|
| 132 |
+
|
| 133 |
+
self.is_loaded = True
|
| 134 |
+
|
| 135 |
+
def feature_select(self, image_forward_outs):
|
| 136 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
| 137 |
+
if self.select_feature == 'patch':
|
| 138 |
+
image_features = image_features
|
| 139 |
+
else:
|
| 140 |
+
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
| 141 |
+
return image_features
|
| 142 |
+
|
| 143 |
+
def forward(self, images, **kwargs):
|
| 144 |
+
images = torch.cat(images)
|
| 145 |
+
if type(images) is list:
|
| 146 |
+
image_features = []
|
| 147 |
+
for image in images:
|
| 148 |
+
image_forward_out = self.vision_encoder(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
| 149 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
| 150 |
+
image_features.append(image_feature)
|
| 151 |
+
else:
|
| 152 |
+
image_forward_outs = self.vision_encoder(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
| 153 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
| 154 |
+
|
| 155 |
+
return image_features
|
| 156 |
+
|
| 157 |
+
@property
|
| 158 |
+
def dummy_feature(self):
|
| 159 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
| 160 |
+
|
| 161 |
+
@property
|
| 162 |
+
def dtype(self):
|
| 163 |
+
return self.vision_encoder.dtype
|
| 164 |
+
|
| 165 |
+
@property
|
| 166 |
+
def device(self):
|
| 167 |
+
return self.vision_encoder.device
|
| 168 |
+
|
| 169 |
+
@property
|
| 170 |
+
def config(self):
|
| 171 |
+
if self.is_loaded:
|
| 172 |
+
return self.vision_encoder.config
|
| 173 |
+
else:
|
| 174 |
+
return self.cfg_only
|
| 175 |
+
|
| 176 |
+
@property
|
| 177 |
+
def hidden_size(self):
|
| 178 |
+
return self.config.hidden_size
|
| 179 |
+
|
| 180 |
+
@property
|
| 181 |
+
def num_patches(self):
|
| 182 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
| 183 |
+
|
| 184 |
+
@property
|
| 185 |
+
def num_patches_per_side(self):
|
| 186 |
+
return self.config.image_size // self.config.patch_size
|
| 187 |
+
|
| 188 |
+
@property
|
| 189 |
+
def image_size(self):
|
| 190 |
+
return self.config.image_size
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class Videollama3VisionEncoder(nn.Module):
|
| 194 |
+
|
| 195 |
+
def __init__(self, vision_encoder, args, delay_load=False):
|
| 196 |
+
super().__init__()
|
| 197 |
+
|
| 198 |
+
self.is_loaded = False
|
| 199 |
+
|
| 200 |
+
self.vision_encoder_name = vision_encoder
|
| 201 |
+
self.args = args
|
| 202 |
+
|
| 203 |
+
if not delay_load:
|
| 204 |
+
self.attn_implementation = getattr(args, 'mm_attn_implementation', 'flash_attention_2')
|
| 205 |
+
self.load_model(self.args)
|
| 206 |
+
else:
|
| 207 |
+
# uncertain whether flash-attention-2 is supported during inference phase.
|
| 208 |
+
self.attn_implementation = 'sdpa' # 'eager'
|
| 209 |
+
self.cfg_only = Videollama3VisionEncoderConfig.from_pretrained(self.vision_encoder_name)
|
| 210 |
+
|
| 211 |
+
def load_model(self, args):
|
| 212 |
+
if self.is_loaded:
|
| 213 |
+
print('Vision tower is already loaded, `load model` call again, skipping.')
|
| 214 |
+
return
|
| 215 |
+
|
| 216 |
+
# merge_size is set to 1 by default, because STAGE1, STAGE1.5, STAGE2 are trained with merge_size=1
|
| 217 |
+
# for stage 3, the merge_size is set to 2 by argments.
|
| 218 |
+
self.image_processor = Videollama3ImageProcessor.from_pretrained(self.vision_encoder_name)
|
| 219 |
+
|
| 220 |
+
# merge_size is fixed to 1 for STAGE1, STAGE1.5, STAGE2, STAGE3 in encoder and can be modified in connector.
|
| 221 |
+
self.cfg_only = Videollama3VisionEncoderConfig.from_pretrained(self.vision_encoder_name)
|
| 222 |
+
|
| 223 |
+
self.vision_encoder = Videollama3VisionEncoderModel.from_pretrained(
|
| 224 |
+
self.vision_encoder_name,
|
| 225 |
+
torch_dtype=args.torch_dtype,
|
| 226 |
+
attn_implementation=self.attn_implementation)
|
| 227 |
+
|
| 228 |
+
self.is_loaded = True
|
| 229 |
+
|
| 230 |
+
def forward(self, pixel_values, grid_sizes, merge_sizes, **kwargs):
|
| 231 |
+
image_features = self.vision_encoder(pixel_values, grid_sizes, merge_sizes)
|
| 232 |
+
return image_features
|
| 233 |
+
|
| 234 |
+
@property
|
| 235 |
+
def dummy_feature(self):
|
| 236 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
| 237 |
+
|
| 238 |
+
@property
|
| 239 |
+
def dtype(self):
|
| 240 |
+
return self.vision_encoder.dtype
|
| 241 |
+
|
| 242 |
+
@property
|
| 243 |
+
def device(self):
|
| 244 |
+
return self.vision_encoder.device
|
| 245 |
+
|
| 246 |
+
@property
|
| 247 |
+
def config(self):
|
| 248 |
+
if self.is_loaded:
|
| 249 |
+
return self.vision_encoder.config
|
| 250 |
+
else:
|
| 251 |
+
return self.cfg_only
|
| 252 |
+
|
| 253 |
+
@property
|
| 254 |
+
def hidden_size(self):
|
| 255 |
+
return self.config.hidden_size
|
| 256 |
+
|
| 257 |
+
@property
|
| 258 |
+
def num_patches(self):
|
| 259 |
+
return -1
|
| 260 |
+
|
| 261 |
+
@property
|
| 262 |
+
def num_patches_per_side(self):
|
| 263 |
+
return -1
|
| 264 |
+
|
| 265 |
+
@property
|
| 266 |
+
def image_size(self):
|
| 267 |
+
return -1
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def build_vision_encoder(vision_encoder_cfg, **kwargs):
|
| 271 |
+
vision_encoder = getattr(vision_encoder_cfg, 'mm_vision_encoder', getattr(vision_encoder_cfg, 'vision_encoder', None))
|
| 272 |
+
|
| 273 |
+
if 'clip' in vision_encoder:
|
| 274 |
+
vision_encoder = CLIPVisionEncoder(vision_encoder, args=vision_encoder_cfg, **kwargs)
|
| 275 |
+
elif 'navit' in vision_encoder.lower() or 'damovl' in vision_encoder:
|
| 276 |
+
vision_encoder = Videollama3VisionEncoder(vision_encoder, args=vision_encoder_cfg, **kwargs)
|
| 277 |
+
elif 'siglip' in vision_encoder:
|
| 278 |
+
vision_encoder = SiglipVisionEncoder(vision_encoder, args=vision_encoder_cfg, **kwargs)
|
| 279 |
+
else:
|
| 280 |
+
raise ValueError(f'Unknown vision encoder: {vision_encoder}')
|
| 281 |
+
|
| 282 |
+
return vision_encoder
|
RynnEC/rynnec/model/extension/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .sam2_base import SAM2Base
|
RynnEC/rynnec/model/extension/sam2_base.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adopted from https://github.com/magic-research/Sa2VA/blob/main/projects/llava_sam2/models/extension/sam2_base.py.
|
| 2 |
+
# Below is the original copyright:
|
| 3 |
+
# coding=utf-8
|
| 4 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
|
| 20 |
+
from third_parts.sam2.modeling.sam2_base import SAM2Base as _SAM2Base
|
| 21 |
+
from third_parts.sam2.modeling.sam2_base import NO_OBJ_SCORE
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SAM2Base(_SAM2Base):
|
| 25 |
+
|
| 26 |
+
def track_step(
|
| 27 |
+
self,
|
| 28 |
+
frame_idx,
|
| 29 |
+
is_init_cond_frame,
|
| 30 |
+
current_vision_feats,
|
| 31 |
+
current_vision_pos_embeds,
|
| 32 |
+
feat_sizes,
|
| 33 |
+
point_inputs,
|
| 34 |
+
mask_inputs,
|
| 35 |
+
output_dict,
|
| 36 |
+
num_frames,
|
| 37 |
+
track_in_reverse=False, # tracking in reverse time order (for demo usage)
|
| 38 |
+
# Whether to run the memory encoder on the predicted masks. Sometimes we might want
|
| 39 |
+
# to skip the memory encoder with `run_mem_encoder=False`. For example,
|
| 40 |
+
# in demo we might call `track_step` multiple times for each user click,
|
| 41 |
+
# and only encode the memory when the user finalizes their clicks. And in ablation
|
| 42 |
+
# settings like SAM training on static images, we don't need the memory encoder.
|
| 43 |
+
run_mem_encoder=True,
|
| 44 |
+
# The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
|
| 45 |
+
prev_sam_mask_logits=None,
|
| 46 |
+
## Extension: LLM prompt
|
| 47 |
+
language_embd=None,
|
| 48 |
+
):
|
| 49 |
+
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
|
| 50 |
+
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
|
| 51 |
+
if len(current_vision_feats) > 1:
|
| 52 |
+
high_res_features = [
|
| 53 |
+
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
|
| 54 |
+
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
|
| 55 |
+
]
|
| 56 |
+
else:
|
| 57 |
+
high_res_features = None
|
| 58 |
+
if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
|
| 59 |
+
# When use_mask_input_as_output_without_sam=True, we directly output the mask input
|
| 60 |
+
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
|
| 61 |
+
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
|
| 62 |
+
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
|
| 63 |
+
sam_outputs = self._use_mask_as_output(
|
| 64 |
+
pix_feat, high_res_features, mask_inputs
|
| 65 |
+
)
|
| 66 |
+
else:
|
| 67 |
+
# fused the visual feature with previous memory features in the memory bank
|
| 68 |
+
pix_feat_with_mem = self._prepare_memory_conditioned_features(
|
| 69 |
+
frame_idx=frame_idx,
|
| 70 |
+
is_init_cond_frame=is_init_cond_frame,
|
| 71 |
+
current_vision_feats=current_vision_feats[-1:],
|
| 72 |
+
current_vision_pos_embeds=current_vision_pos_embeds[-1:],
|
| 73 |
+
feat_sizes=feat_sizes[-1:],
|
| 74 |
+
output_dict=output_dict,
|
| 75 |
+
num_frames=num_frames,
|
| 76 |
+
track_in_reverse=track_in_reverse,
|
| 77 |
+
)
|
| 78 |
+
# apply SAM-style segmentation head
|
| 79 |
+
# here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
|
| 80 |
+
# e.g. in demo where such logits come from earlier interaction instead of correction sampling
|
| 81 |
+
# (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
|
| 82 |
+
if prev_sam_mask_logits is not None:
|
| 83 |
+
assert point_inputs is not None and mask_inputs is None
|
| 84 |
+
mask_inputs = prev_sam_mask_logits
|
| 85 |
+
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
|
| 86 |
+
sam_outputs = self._forward_sam_heads(
|
| 87 |
+
backbone_features=pix_feat_with_mem,
|
| 88 |
+
point_inputs=point_inputs,
|
| 89 |
+
mask_inputs=mask_inputs,
|
| 90 |
+
high_res_features=high_res_features,
|
| 91 |
+
multimask_output=multimask_output,
|
| 92 |
+
# Inject language Embed if possible
|
| 93 |
+
language_embd=language_embd,
|
| 94 |
+
)
|
| 95 |
+
(
|
| 96 |
+
_,
|
| 97 |
+
_,
|
| 98 |
+
_,
|
| 99 |
+
low_res_masks,
|
| 100 |
+
high_res_masks,
|
| 101 |
+
obj_ptr,
|
| 102 |
+
_,
|
| 103 |
+
) = sam_outputs
|
| 104 |
+
|
| 105 |
+
current_out["pred_masks"] = low_res_masks
|
| 106 |
+
current_out["pred_masks_high_res"] = high_res_masks
|
| 107 |
+
current_out["obj_ptr"] = obj_ptr
|
| 108 |
+
|
| 109 |
+
# Finally run the memory encoder on the predicted mask to encode
|
| 110 |
+
# it into a new memory feature (that can be used in future frames)
|
| 111 |
+
if run_mem_encoder and self.num_maskmem > 0:
|
| 112 |
+
high_res_masks_for_mem_enc = high_res_masks
|
| 113 |
+
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
|
| 114 |
+
current_vision_feats=current_vision_feats,
|
| 115 |
+
feat_sizes=feat_sizes,
|
| 116 |
+
pred_masks_high_res=high_res_masks_for_mem_enc,
|
| 117 |
+
is_mask_from_pts=(point_inputs is not None),
|
| 118 |
+
)
|
| 119 |
+
current_out["maskmem_features"] = maskmem_features
|
| 120 |
+
current_out["maskmem_pos_enc"] = maskmem_pos_enc
|
| 121 |
+
else:
|
| 122 |
+
current_out["maskmem_features"] = None
|
| 123 |
+
current_out["maskmem_pos_enc"] = None
|
| 124 |
+
|
| 125 |
+
return current_out
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _forward_sam_heads(
|
| 129 |
+
self,
|
| 130 |
+
backbone_features,
|
| 131 |
+
point_inputs=None,
|
| 132 |
+
mask_inputs=None,
|
| 133 |
+
high_res_features=None,
|
| 134 |
+
multimask_output=False,
|
| 135 |
+
## Extension: LLM prompt
|
| 136 |
+
language_embd=None,
|
| 137 |
+
):
|
| 138 |
+
"""
|
| 139 |
+
Forward SAM prompt encoders and mask heads.
|
| 140 |
+
|
| 141 |
+
Inputs:
|
| 142 |
+
- backbone_features: image features of [B, C, H, W] shape
|
| 143 |
+
- point_inputs: a dictionary with "point_coords" and "point_labels", where
|
| 144 |
+
1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the
|
| 145 |
+
absolute pixel-unit coordinate in (x, y) format of the P input points
|
| 146 |
+
2) "point_labels" has shape [B, P] and int32 dtype, where 1 means
|
| 147 |
+
positive clicks, 0 means negative clicks, and -1 means padding
|
| 148 |
+
- mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the
|
| 149 |
+
same spatial size as the image.
|
| 150 |
+
- high_res_features: either 1) None or 2) or a list of length 2 containing
|
| 151 |
+
two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively,
|
| 152 |
+
which will be used as high-resolution feature maps for SAM decoder.
|
| 153 |
+
- multimask_output: if it's True, we output 3 candidate masks and their 3
|
| 154 |
+
corresponding IoU estimates, and if it's False, we output only 1 mask and
|
| 155 |
+
its corresponding IoU estimate.
|
| 156 |
+
|
| 157 |
+
Outputs:
|
| 158 |
+
- low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if
|
| 159 |
+
`multimask_output=True` and M = 1 if `multimask_output=False`), the SAM
|
| 160 |
+
output mask logits (before sigmoid) for the low-resolution masks, with 4x
|
| 161 |
+
the resolution (1/4 stride) of the input backbone_features.
|
| 162 |
+
- high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3
|
| 163 |
+
if `multimask_output=True` and M = 1 if `multimask_output=False`),
|
| 164 |
+
upsampled from the low-resolution masks, with shape size as the image
|
| 165 |
+
(stride is 1 pixel).
|
| 166 |
+
- ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1
|
| 167 |
+
if `multimask_output=False`), the estimated IoU of each output mask.
|
| 168 |
+
- low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`.
|
| 169 |
+
If `multimask_output=True`, it's the mask with the highest IoU estimate.
|
| 170 |
+
If `multimask_output=False`, it's the same as `low_res_multimasks`.
|
| 171 |
+
- high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`.
|
| 172 |
+
If `multimask_output=True`, it's the mask with the highest IoU estimate.
|
| 173 |
+
If `multimask_output=False`, it's the same as `high_res_multimasks`.
|
| 174 |
+
- obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted
|
| 175 |
+
based on the output token from the SAM mask decoder.
|
| 176 |
+
"""
|
| 177 |
+
B = backbone_features.size(0)
|
| 178 |
+
device = backbone_features.device
|
| 179 |
+
assert backbone_features.size(1) == self.sam_prompt_embed_dim
|
| 180 |
+
assert backbone_features.size(2) == self.sam_image_embedding_size
|
| 181 |
+
assert backbone_features.size(3) == self.sam_image_embedding_size
|
| 182 |
+
|
| 183 |
+
# a) Handle point prompts
|
| 184 |
+
if point_inputs is not None:
|
| 185 |
+
sam_point_coords = point_inputs["point_coords"]
|
| 186 |
+
sam_point_labels = point_inputs["point_labels"]
|
| 187 |
+
assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
|
| 188 |
+
else:
|
| 189 |
+
# If no points are provide, pad with an empty point (with label -1)
|
| 190 |
+
sam_point_coords = torch.zeros(B, 1, 2, device=device)
|
| 191 |
+
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
|
| 192 |
+
|
| 193 |
+
# b) Handle mask prompts
|
| 194 |
+
if mask_inputs is not None:
|
| 195 |
+
# If mask_inputs is provided, downsize it into low-res mask input if needed
|
| 196 |
+
# and feed it as a dense mask prompt into the SAM mask encoder
|
| 197 |
+
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
|
| 198 |
+
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
|
| 199 |
+
sam_mask_prompt = F.interpolate(
|
| 200 |
+
mask_inputs.float(),
|
| 201 |
+
size=self.sam_prompt_encoder.mask_input_size,
|
| 202 |
+
align_corners=False,
|
| 203 |
+
mode="bilinear",
|
| 204 |
+
antialias=True, # use antialias for downsampling
|
| 205 |
+
)
|
| 206 |
+
else:
|
| 207 |
+
sam_mask_prompt = mask_inputs
|
| 208 |
+
else:
|
| 209 |
+
# Otherwise, simply feed None (and SAM's prompt encoder will add
|
| 210 |
+
# a learned `no_mask_embed` to indicate no mask input in this case).
|
| 211 |
+
sam_mask_prompt = None
|
| 212 |
+
|
| 213 |
+
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
|
| 214 |
+
points=(sam_point_coords, sam_point_labels),
|
| 215 |
+
boxes=None,
|
| 216 |
+
masks=sam_mask_prompt,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
## Extension: LLM prompt
|
| 220 |
+
if language_embd is not None:
|
| 221 |
+
# B N C
|
| 222 |
+
# print('sparse_embeddings ', sparse_embeddings.shape, 'language_embd ', language_embd.shape)
|
| 223 |
+
assert sparse_embeddings.size(0) == language_embd.size(0)
|
| 224 |
+
assert sparse_embeddings.size(2) == language_embd.size(2)
|
| 225 |
+
sparse_embeddings = torch.cat([sparse_embeddings, language_embd], dim=1)
|
| 226 |
+
|
| 227 |
+
(
|
| 228 |
+
low_res_multimasks,
|
| 229 |
+
ious,
|
| 230 |
+
sam_output_tokens,
|
| 231 |
+
object_score_logits,
|
| 232 |
+
) = self.sam_mask_decoder(
|
| 233 |
+
image_embeddings=backbone_features,
|
| 234 |
+
image_pe=self.sam_prompt_encoder.get_dense_pe(),
|
| 235 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
| 236 |
+
dense_prompt_embeddings=dense_embeddings,
|
| 237 |
+
multimask_output=multimask_output,
|
| 238 |
+
repeat_image=False, # the image is already batched
|
| 239 |
+
high_res_features=high_res_features,
|
| 240 |
+
)
|
| 241 |
+
if self.pred_obj_scores:
|
| 242 |
+
is_obj_appearing = object_score_logits > 0
|
| 243 |
+
|
| 244 |
+
# Mask used for spatial memories is always a *hard* choice between obj and no obj,
|
| 245 |
+
# consistent with the actual mask prediction
|
| 246 |
+
# print('Do torch.where !!!')
|
| 247 |
+
# low_res_multimasks = torch.where(
|
| 248 |
+
# is_obj_appearing[:, None, None],
|
| 249 |
+
# low_res_multimasks,
|
| 250 |
+
# NO_OBJ_SCORE,
|
| 251 |
+
# )
|
| 252 |
+
|
| 253 |
+
# convert masks from possibly bfloat16 (or float16) to float32
|
| 254 |
+
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
|
| 255 |
+
low_res_multimasks = low_res_multimasks.float()
|
| 256 |
+
high_res_multimasks = F.interpolate(
|
| 257 |
+
low_res_multimasks,
|
| 258 |
+
size=(self.image_size, self.image_size),
|
| 259 |
+
mode="bilinear",
|
| 260 |
+
align_corners=False,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
sam_output_token = sam_output_tokens[:, 0]
|
| 264 |
+
if multimask_output:
|
| 265 |
+
# take the best mask prediction (with the highest IoU estimation)
|
| 266 |
+
best_iou_inds = torch.argmax(ious, dim=-1)
|
| 267 |
+
batch_inds = torch.arange(B, device=device)
|
| 268 |
+
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
| 269 |
+
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
| 270 |
+
if sam_output_tokens.size(1) > 1:
|
| 271 |
+
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
|
| 272 |
+
else:
|
| 273 |
+
low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
|
| 274 |
+
|
| 275 |
+
# Extract object pointer from the SAM output token (with occlusion handling)
|
| 276 |
+
obj_ptr = self.obj_ptr_proj(sam_output_token)
|
| 277 |
+
if self.pred_obj_scores:
|
| 278 |
+
# Allow *soft* no obj ptr, unlike for masks
|
| 279 |
+
if self.soft_no_obj_ptr:
|
| 280 |
+
# Only hard possible with gt
|
| 281 |
+
assert not self.teacher_force_obj_scores_for_mem
|
| 282 |
+
lambda_is_obj_appearing = object_score_logits.sigmoid()
|
| 283 |
+
else:
|
| 284 |
+
lambda_is_obj_appearing = is_obj_appearing.float()
|
| 285 |
+
|
| 286 |
+
if self.fixed_no_obj_ptr:
|
| 287 |
+
obj_ptr = lambda_is_obj_appearing * obj_ptr
|
| 288 |
+
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
|
| 289 |
+
|
| 290 |
+
return (
|
| 291 |
+
low_res_multimasks,
|
| 292 |
+
high_res_multimasks,
|
| 293 |
+
ious,
|
| 294 |
+
low_res_masks,
|
| 295 |
+
high_res_masks,
|
| 296 |
+
obj_ptr,
|
| 297 |
+
object_score_logits,
|
| 298 |
+
)
|
RynnEC/rynnec/model/loss.py
ADDED
|
@@ -0,0 +1,597 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adopted from https://github.com/magic-research/Sa2VA.
|
| 2 |
+
# Below is the original copyright:
|
| 3 |
+
# coding=utf-8
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
linear_cross_entropy = None
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
from rynnec.constants import IGNORE_INDEX
|
| 22 |
+
from torch import Tensor
|
| 23 |
+
import logging
|
| 24 |
+
from huggingface_hub import hf_hub_download
|
| 25 |
+
import functools
|
| 26 |
+
from typing import Callable, Optional
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def reduce_loss(loss: Tensor, reduction: str) -> Tensor:
|
| 30 |
+
"""Reduce loss as specified.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
loss (Tensor): Elementwise loss tensor.
|
| 34 |
+
reduction (str): Options are "none", "mean" and "sum".
|
| 35 |
+
|
| 36 |
+
Return:
|
| 37 |
+
Tensor: Reduced loss tensor.
|
| 38 |
+
"""
|
| 39 |
+
reduction_enum = F._Reduction.get_enum(reduction)
|
| 40 |
+
# none: 0, elementwise_mean:1, sum: 2
|
| 41 |
+
if reduction_enum == 0:
|
| 42 |
+
return loss
|
| 43 |
+
elif reduction_enum == 1:
|
| 44 |
+
return loss.mean()
|
| 45 |
+
elif reduction_enum == 2:
|
| 46 |
+
return loss.sum()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def weight_reduce_loss(loss: Tensor,
|
| 50 |
+
weight: Optional[Tensor] = None,
|
| 51 |
+
reduction: str = 'mean',
|
| 52 |
+
avg_factor: Optional[float] = None) -> Tensor:
|
| 53 |
+
"""Apply element-wise weight and reduce loss.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
loss (Tensor): Element-wise loss.
|
| 57 |
+
weight (Optional[Tensor], optional): Element-wise weights.
|
| 58 |
+
Defaults to None.
|
| 59 |
+
reduction (str, optional): Same as built-in losses of PyTorch.
|
| 60 |
+
Defaults to 'mean'.
|
| 61 |
+
avg_factor (Optional[float], optional): Average factor when
|
| 62 |
+
computing the mean of losses. Defaults to None.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
Tensor: Processed loss values.
|
| 66 |
+
"""
|
| 67 |
+
# if weight is specified, apply element-wise weight
|
| 68 |
+
if weight is not None:
|
| 69 |
+
loss = loss * weight
|
| 70 |
+
|
| 71 |
+
# if avg_factor is not specified, just reduce the loss
|
| 72 |
+
if avg_factor is None:
|
| 73 |
+
loss = reduce_loss(loss, reduction)
|
| 74 |
+
else:
|
| 75 |
+
# if reduction is mean, then average the loss by avg_factor
|
| 76 |
+
if reduction == 'mean':
|
| 77 |
+
# Avoid causing ZeroDivisionError when avg_factor is 0.0,
|
| 78 |
+
# i.e., all labels of an image belong to ignore index.
|
| 79 |
+
eps = torch.finfo(torch.float32).eps
|
| 80 |
+
loss = loss.sum() / (avg_factor + eps)
|
| 81 |
+
# if reduction is 'none', then do nothing, otherwise raise an error
|
| 82 |
+
elif reduction != 'none':
|
| 83 |
+
raise ValueError('avg_factor can not be used with reduction="sum"')
|
| 84 |
+
return loss
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def dice_loss(pred,
|
| 88 |
+
target,
|
| 89 |
+
weight=None,
|
| 90 |
+
eps=1e-3,
|
| 91 |
+
reduction='mean',
|
| 92 |
+
naive_dice=False,
|
| 93 |
+
avg_factor=None):
|
| 94 |
+
"""Calculate dice loss, there are two forms of dice loss is supported:
|
| 95 |
+
|
| 96 |
+
- the one proposed in `V-Net: Fully Convolutional Neural
|
| 97 |
+
Networks for Volumetric Medical Image Segmentation
|
| 98 |
+
<https://arxiv.org/abs/1606.04797>`_.
|
| 99 |
+
- the dice loss in which the power of the number in the
|
| 100 |
+
denominator is the first power instead of the second
|
| 101 |
+
power.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
pred (torch.Tensor): The prediction, has a shape (n, *)
|
| 105 |
+
target (torch.Tensor): The learning label of the prediction,
|
| 106 |
+
shape (n, *), same shape of pred.
|
| 107 |
+
weight (torch.Tensor, optional): The weight of loss for each
|
| 108 |
+
prediction, has a shape (n,). Defaults to None.
|
| 109 |
+
eps (float): Avoid dividing by zero. Default: 1e-3.
|
| 110 |
+
reduction (str, optional): The method used to reduce the loss into
|
| 111 |
+
a scalar. Defaults to 'mean'.
|
| 112 |
+
Options are "none", "mean" and "sum".
|
| 113 |
+
naive_dice (bool, optional): If false, use the dice
|
| 114 |
+
loss defined in the V-Net paper, otherwise, use the
|
| 115 |
+
naive dice loss in which the power of the number in the
|
| 116 |
+
denominator is the first power instead of the second
|
| 117 |
+
power.Defaults to False.
|
| 118 |
+
avg_factor (int, optional): Average factor that is used to average
|
| 119 |
+
the loss. Defaults to None.
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
input = pred.flatten(1)
|
| 123 |
+
target = target.flatten(1).float()
|
| 124 |
+
|
| 125 |
+
a = torch.sum(input * target, 1)
|
| 126 |
+
if naive_dice:
|
| 127 |
+
b = torch.sum(input, 1)
|
| 128 |
+
c = torch.sum(target, 1)
|
| 129 |
+
d = (2 * a + eps) / (b + c + eps)
|
| 130 |
+
else:
|
| 131 |
+
b = torch.sum(input * input, 1) + eps
|
| 132 |
+
c = torch.sum(target * target, 1) + eps
|
| 133 |
+
d = (2 * a) / (b + c)
|
| 134 |
+
|
| 135 |
+
loss = 1 - d
|
| 136 |
+
if weight is not None:
|
| 137 |
+
assert weight.ndim == loss.ndim
|
| 138 |
+
assert len(weight) == len(pred)
|
| 139 |
+
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
|
| 140 |
+
return loss
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class DiceLoss(nn.Module):
|
| 145 |
+
|
| 146 |
+
def __init__(self,
|
| 147 |
+
use_sigmoid=True,
|
| 148 |
+
activate=True,
|
| 149 |
+
reduction='mean',
|
| 150 |
+
naive_dice=False,
|
| 151 |
+
loss_weight=1.0,
|
| 152 |
+
eps=1e-3):
|
| 153 |
+
"""Compute dice loss.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
use_sigmoid (bool, optional): Whether to the prediction is
|
| 157 |
+
used for sigmoid or softmax. Defaults to True.
|
| 158 |
+
activate (bool): Whether to activate the predictions inside,
|
| 159 |
+
this will disable the inside sigmoid operation.
|
| 160 |
+
Defaults to True.
|
| 161 |
+
reduction (str, optional): The method used
|
| 162 |
+
to reduce the loss. Options are "none",
|
| 163 |
+
"mean" and "sum". Defaults to 'mean'.
|
| 164 |
+
naive_dice (bool, optional): If false, use the dice
|
| 165 |
+
loss defined in the V-Net paper, otherwise, use the
|
| 166 |
+
naive dice loss in which the power of the number in the
|
| 167 |
+
denominator is the first power instead of the second
|
| 168 |
+
power. Defaults to False.
|
| 169 |
+
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
|
| 170 |
+
eps (float): Avoid dividing by zero. Defaults to 1e-3.
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
super(DiceLoss, self).__init__()
|
| 174 |
+
self.use_sigmoid = use_sigmoid
|
| 175 |
+
self.reduction = reduction
|
| 176 |
+
self.naive_dice = naive_dice
|
| 177 |
+
self.loss_weight = loss_weight
|
| 178 |
+
self.eps = eps
|
| 179 |
+
self.activate = activate
|
| 180 |
+
|
| 181 |
+
def forward(self,
|
| 182 |
+
pred,
|
| 183 |
+
target,
|
| 184 |
+
weight=None,
|
| 185 |
+
reduction_override=None,
|
| 186 |
+
avg_factor=None):
|
| 187 |
+
"""Forward function.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
pred (torch.Tensor): The prediction, has a shape (n, *).
|
| 191 |
+
target (torch.Tensor): The label of the prediction,
|
| 192 |
+
shape (n, *), same shape of pred.
|
| 193 |
+
weight (torch.Tensor, optional): The weight of loss for each
|
| 194 |
+
prediction, has a shape (n,). Defaults to None.
|
| 195 |
+
avg_factor (int, optional): Average factor that is used to average
|
| 196 |
+
the loss. Defaults to None.
|
| 197 |
+
reduction_override (str, optional): The reduction method used to
|
| 198 |
+
override the original reduction method of the loss.
|
| 199 |
+
Options are "none", "mean" and "sum".
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
torch.Tensor: The calculated loss
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
assert reduction_override in (None, 'none', 'mean', 'sum')
|
| 206 |
+
reduction = (
|
| 207 |
+
reduction_override if reduction_override else self.reduction)
|
| 208 |
+
|
| 209 |
+
if self.activate:
|
| 210 |
+
if self.use_sigmoid:
|
| 211 |
+
pred = pred.sigmoid()
|
| 212 |
+
else:
|
| 213 |
+
raise NotImplementedError
|
| 214 |
+
|
| 215 |
+
loss = self.loss_weight * dice_loss(
|
| 216 |
+
pred,
|
| 217 |
+
target,
|
| 218 |
+
weight,
|
| 219 |
+
eps=self.eps,
|
| 220 |
+
reduction=reduction,
|
| 221 |
+
naive_dice=self.naive_dice,
|
| 222 |
+
avg_factor=avg_factor)
|
| 223 |
+
|
| 224 |
+
return loss
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def cross_entropy_loss(
|
| 228 |
+
hidden_states,
|
| 229 |
+
lm_head,
|
| 230 |
+
position_ids,
|
| 231 |
+
labels,
|
| 232 |
+
reduction_scope="sequence",
|
| 233 |
+
**loss_kwargs
|
| 234 |
+
):
|
| 235 |
+
batch_size = hidden_states.size(0)
|
| 236 |
+
|
| 237 |
+
shift_hidden_states = hidden_states[..., :-1, :]
|
| 238 |
+
shift_labels = labels[..., 1:]
|
| 239 |
+
mask = shift_labels != IGNORE_INDEX
|
| 240 |
+
shift_hidden_states = shift_hidden_states[mask].contiguous()
|
| 241 |
+
shift_labels = shift_labels[mask].contiguous()
|
| 242 |
+
|
| 243 |
+
if mask.sum() == 0:
|
| 244 |
+
print(f"Get labels={labels}. Found no sample to calculate loss!")
|
| 245 |
+
pseudo_logits = lm_head(hidden_states[:, 0:1])
|
| 246 |
+
loss = 0.0 * pseudo_logits.mean()
|
| 247 |
+
return loss
|
| 248 |
+
|
| 249 |
+
if "num_items_in_batch" not in loss_kwargs:
|
| 250 |
+
reduction = "mean"
|
| 251 |
+
denominator = None
|
| 252 |
+
|
| 253 |
+
elif reduction_scope == "batch":
|
| 254 |
+
reduction = "sum"
|
| 255 |
+
denominator = loss_kwargs["num_items_in_batch"]
|
| 256 |
+
|
| 257 |
+
elif reduction_scope == "sequence":
|
| 258 |
+
reduction = "none"
|
| 259 |
+
|
| 260 |
+
if batch_size == 1:
|
| 261 |
+
# NOTE: packed sequence
|
| 262 |
+
start_indices = torch.nonzero(position_ids[0] == 0)[:, 0]
|
| 263 |
+
end_indices = F.pad(start_indices[1:], (0, 1), value=position_ids.size(1))
|
| 264 |
+
batch_indices = torch.cat(
|
| 265 |
+
[
|
| 266 |
+
torch.full((e - s,), fill_value=i, device=position_ids.device, dtype=torch.long)
|
| 267 |
+
for i, (s, e) in enumerate(zip(start_indices, end_indices))
|
| 268 |
+
],
|
| 269 |
+
).unsqueeze(0)
|
| 270 |
+
else:
|
| 271 |
+
batch_indices = torch.arange(batch_size, device=position_ids.device)
|
| 272 |
+
batch_indices = batch_indices.unsqueeze(1).expand(-1, hidden_states.size(1))
|
| 273 |
+
|
| 274 |
+
shift_batch_indices = batch_indices[..., :-1]
|
| 275 |
+
shift_batch_indices = shift_batch_indices[mask].contiguous()
|
| 276 |
+
num_tokens = F.one_hot(shift_batch_indices).sum(dim=0)
|
| 277 |
+
denominator = num_tokens[shift_batch_indices] * loss_kwargs["num_items_in_batch"]
|
| 278 |
+
|
| 279 |
+
else:
|
| 280 |
+
raise ValueError(f"Unknown reduction scope: {reduction_scope}")
|
| 281 |
+
|
| 282 |
+
if linear_cross_entropy is None:
|
| 283 |
+
shift_logits = lm_head(shift_hidden_states)
|
| 284 |
+
loss = torch.nn.functional.cross_entropy(
|
| 285 |
+
shift_logits,
|
| 286 |
+
shift_labels,
|
| 287 |
+
reduction=reduction,
|
| 288 |
+
)
|
| 289 |
+
else:
|
| 290 |
+
loss = linear_cross_entropy(
|
| 291 |
+
shift_hidden_states,
|
| 292 |
+
lm_head.weight,
|
| 293 |
+
shift_labels,
|
| 294 |
+
bias=lm_head.bias,
|
| 295 |
+
reduction=reduction,
|
| 296 |
+
accum_e_fp32=True,
|
| 297 |
+
accum_c_fp32=True,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
if denominator is not None:
|
| 301 |
+
loss = loss / denominator
|
| 302 |
+
if loss.ndim > 0:
|
| 303 |
+
loss = loss.sum()
|
| 304 |
+
|
| 305 |
+
return loss
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def cross_entropy(pred,
|
| 310 |
+
label,
|
| 311 |
+
weight=None,
|
| 312 |
+
reduction='mean',
|
| 313 |
+
avg_factor=None,
|
| 314 |
+
class_weight=None,
|
| 315 |
+
ignore_index=-100,
|
| 316 |
+
avg_non_ignore=False):
|
| 317 |
+
"""Calculate the CrossEntropy loss.
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
pred (torch.Tensor): The prediction with shape (N, C), C is the number
|
| 321 |
+
of classes.
|
| 322 |
+
label (torch.Tensor): The learning label of the prediction.
|
| 323 |
+
weight (torch.Tensor, optional): Sample-wise loss weight.
|
| 324 |
+
reduction (str, optional): The method used to reduce the loss.
|
| 325 |
+
avg_factor (int, optional): Average factor that is used to average
|
| 326 |
+
the loss. Defaults to None.
|
| 327 |
+
class_weight (list[float], optional): The weight for each class.
|
| 328 |
+
ignore_index (int | None): The label index to be ignored.
|
| 329 |
+
If None, it will be set to default value. Default: -100.
|
| 330 |
+
avg_non_ignore (bool): The flag decides to whether the loss is
|
| 331 |
+
only averaged over non-ignored targets. Default: False.
|
| 332 |
+
|
| 333 |
+
Returns:
|
| 334 |
+
torch.Tensor: The calculated loss
|
| 335 |
+
"""
|
| 336 |
+
# The default value of ignore_index is the same as F.cross_entropy
|
| 337 |
+
ignore_index = -100 if ignore_index is None else ignore_index
|
| 338 |
+
# element-wise losses
|
| 339 |
+
loss = F.cross_entropy(
|
| 340 |
+
pred,
|
| 341 |
+
label,
|
| 342 |
+
weight=class_weight,
|
| 343 |
+
reduction='none',
|
| 344 |
+
ignore_index=ignore_index)
|
| 345 |
+
|
| 346 |
+
# average loss over non-ignored elements
|
| 347 |
+
# pytorch's official cross_entropy average loss over non-ignored elements
|
| 348 |
+
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
|
| 349 |
+
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
|
| 350 |
+
avg_factor = label.numel() - (label == ignore_index).sum().item()
|
| 351 |
+
|
| 352 |
+
# apply weights and do the reduction
|
| 353 |
+
if weight is not None:
|
| 354 |
+
weight = weight.float()
|
| 355 |
+
loss = weight_reduce_loss(
|
| 356 |
+
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
|
| 357 |
+
|
| 358 |
+
return loss
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def _expand_onehot_labels(labels, label_weights, label_channels, ignore_index):
|
| 362 |
+
"""Expand onehot labels to match the size of prediction."""
|
| 363 |
+
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
|
| 364 |
+
valid_mask = (labels >= 0) & (labels != ignore_index)
|
| 365 |
+
inds = torch.nonzero(
|
| 366 |
+
valid_mask & (labels < label_channels), as_tuple=False)
|
| 367 |
+
|
| 368 |
+
if inds.numel() > 0:
|
| 369 |
+
bin_labels[inds, labels[inds]] = 1
|
| 370 |
+
|
| 371 |
+
valid_mask = valid_mask.view(-1, 1).expand(labels.size(0),
|
| 372 |
+
label_channels).float()
|
| 373 |
+
if label_weights is None:
|
| 374 |
+
bin_label_weights = valid_mask
|
| 375 |
+
else:
|
| 376 |
+
bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels)
|
| 377 |
+
bin_label_weights *= valid_mask
|
| 378 |
+
|
| 379 |
+
return bin_labels, bin_label_weights, valid_mask
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def binary_cross_entropy(pred,
|
| 383 |
+
label,
|
| 384 |
+
weight=None,
|
| 385 |
+
reduction='mean',
|
| 386 |
+
avg_factor=None,
|
| 387 |
+
class_weight=None,
|
| 388 |
+
ignore_index=-100,
|
| 389 |
+
avg_non_ignore=False):
|
| 390 |
+
"""Calculate the binary CrossEntropy loss.
|
| 391 |
+
|
| 392 |
+
Args:
|
| 393 |
+
pred (torch.Tensor): The prediction with shape (N, 1) or (N, ).
|
| 394 |
+
When the shape of pred is (N, 1), label will be expanded to
|
| 395 |
+
one-hot format, and when the shape of pred is (N, ), label
|
| 396 |
+
will not be expanded to one-hot format.
|
| 397 |
+
label (torch.Tensor): The learning label of the prediction,
|
| 398 |
+
with shape (N, ).
|
| 399 |
+
weight (torch.Tensor, optional): Sample-wise loss weight.
|
| 400 |
+
reduction (str, optional): The method used to reduce the loss.
|
| 401 |
+
Options are "none", "mean" and "sum".
|
| 402 |
+
avg_factor (int, optional): Average factor that is used to average
|
| 403 |
+
the loss. Defaults to None.
|
| 404 |
+
class_weight (list[float], optional): The weight for each class.
|
| 405 |
+
ignore_index (int | None): The label index to be ignored.
|
| 406 |
+
If None, it will be set to default value. Default: -100.
|
| 407 |
+
avg_non_ignore (bool): The flag decides to whether the loss is
|
| 408 |
+
only averaged over non-ignored targets. Default: False.
|
| 409 |
+
|
| 410 |
+
Returns:
|
| 411 |
+
torch.Tensor: The calculated loss.
|
| 412 |
+
"""
|
| 413 |
+
# The default value of ignore_index is the same as F.cross_entropy
|
| 414 |
+
ignore_index = -100 if ignore_index is None else ignore_index
|
| 415 |
+
|
| 416 |
+
if pred.dim() != label.dim():
|
| 417 |
+
label, weight, valid_mask = _expand_onehot_labels(
|
| 418 |
+
label, weight, pred.size(-1), ignore_index)
|
| 419 |
+
else:
|
| 420 |
+
# should mask out the ignored elements
|
| 421 |
+
valid_mask = ((label >= 0) & (label != ignore_index)).float()
|
| 422 |
+
if weight is not None:
|
| 423 |
+
# The inplace writing method will have a mismatched broadcast
|
| 424 |
+
# shape error if the weight and valid_mask dimensions
|
| 425 |
+
# are inconsistent such as (B,N,1) and (B,N,C).
|
| 426 |
+
weight = weight * valid_mask
|
| 427 |
+
else:
|
| 428 |
+
weight = valid_mask
|
| 429 |
+
|
| 430 |
+
# average loss over non-ignored elements
|
| 431 |
+
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
|
| 432 |
+
avg_factor = valid_mask.sum().item()
|
| 433 |
+
|
| 434 |
+
# weighted element-wise losses
|
| 435 |
+
weight = weight.float()
|
| 436 |
+
loss = F.binary_cross_entropy_with_logits(
|
| 437 |
+
pred, label.float(), pos_weight=class_weight, reduction='none')
|
| 438 |
+
# do the reduction for the weighted loss
|
| 439 |
+
loss = weight_reduce_loss(
|
| 440 |
+
loss, weight, reduction=reduction, avg_factor=avg_factor)
|
| 441 |
+
|
| 442 |
+
return loss
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def mask_cross_entropy(pred,
|
| 446 |
+
target,
|
| 447 |
+
label,
|
| 448 |
+
reduction='mean',
|
| 449 |
+
avg_factor=None,
|
| 450 |
+
class_weight=None,
|
| 451 |
+
ignore_index=None,
|
| 452 |
+
**kwargs):
|
| 453 |
+
"""Calculate the CrossEntropy loss for masks.
|
| 454 |
+
|
| 455 |
+
Args:
|
| 456 |
+
pred (torch.Tensor): The prediction with shape (N, C, *), C is the
|
| 457 |
+
number of classes. The trailing * indicates arbitrary shape.
|
| 458 |
+
target (torch.Tensor): The learning label of the prediction.
|
| 459 |
+
label (torch.Tensor): ``label`` indicates the class label of the mask
|
| 460 |
+
corresponding object. This will be used to select the mask in the
|
| 461 |
+
of the class which the object belongs to when the mask prediction
|
| 462 |
+
if not class-agnostic.
|
| 463 |
+
reduction (str, optional): The method used to reduce the loss.
|
| 464 |
+
Options are "none", "mean" and "sum".
|
| 465 |
+
avg_factor (int, optional): Average factor that is used to average
|
| 466 |
+
the loss. Defaults to None.
|
| 467 |
+
class_weight (list[float], optional): The weight for each class.
|
| 468 |
+
ignore_index (None): Placeholder, to be consistent with other loss.
|
| 469 |
+
Default: None.
|
| 470 |
+
|
| 471 |
+
Returns:
|
| 472 |
+
torch.Tensor: The calculated loss
|
| 473 |
+
|
| 474 |
+
Example:
|
| 475 |
+
>>> N, C = 3, 11
|
| 476 |
+
>>> H, W = 2, 2
|
| 477 |
+
>>> pred = torch.randn(N, C, H, W) * 1000
|
| 478 |
+
>>> target = torch.rand(N, H, W)
|
| 479 |
+
>>> label = torch.randint(0, C, size=(N,))
|
| 480 |
+
>>> reduction = 'mean'
|
| 481 |
+
>>> avg_factor = None
|
| 482 |
+
>>> class_weights = None
|
| 483 |
+
>>> loss = mask_cross_entropy(pred, target, label, reduction,
|
| 484 |
+
>>> avg_factor, class_weights)
|
| 485 |
+
>>> assert loss.shape == (1,)
|
| 486 |
+
"""
|
| 487 |
+
assert ignore_index is None, 'BCE loss does not support ignore_index'
|
| 488 |
+
# TODO: handle these two reserved arguments
|
| 489 |
+
assert reduction == 'mean' and avg_factor is None
|
| 490 |
+
num_rois = pred.size()[0]
|
| 491 |
+
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
|
| 492 |
+
pred_slice = pred[inds, label].squeeze(1)
|
| 493 |
+
return F.binary_cross_entropy_with_logits(
|
| 494 |
+
pred_slice, target, weight=class_weight, reduction='mean')[None]
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
class CrossEntropyLoss(nn.Module):
|
| 498 |
+
|
| 499 |
+
def __init__(self,
|
| 500 |
+
use_sigmoid=False,
|
| 501 |
+
use_mask=False,
|
| 502 |
+
reduction='mean',
|
| 503 |
+
class_weight=None,
|
| 504 |
+
ignore_index=None,
|
| 505 |
+
loss_weight=1.0,
|
| 506 |
+
avg_non_ignore=False):
|
| 507 |
+
"""CrossEntropyLoss.
|
| 508 |
+
|
| 509 |
+
Args:
|
| 510 |
+
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
|
| 511 |
+
of softmax. Defaults to False.
|
| 512 |
+
use_mask (bool, optional): Whether to use mask cross entropy loss.
|
| 513 |
+
Defaults to False.
|
| 514 |
+
reduction (str, optional): . Defaults to 'mean'.
|
| 515 |
+
Options are "none", "mean" and "sum".
|
| 516 |
+
class_weight (list[float], optional): Weight of each class.
|
| 517 |
+
Defaults to None.
|
| 518 |
+
ignore_index (int | None): The label index to be ignored.
|
| 519 |
+
Defaults to None.
|
| 520 |
+
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
| 521 |
+
avg_non_ignore (bool): The flag decides to whether the loss is
|
| 522 |
+
only averaged over non-ignored targets. Default: False.
|
| 523 |
+
"""
|
| 524 |
+
super(CrossEntropyLoss, self).__init__()
|
| 525 |
+
assert (use_sigmoid is False) or (use_mask is False)
|
| 526 |
+
self.use_sigmoid = use_sigmoid
|
| 527 |
+
self.use_mask = use_mask
|
| 528 |
+
self.reduction = reduction
|
| 529 |
+
self.loss_weight = loss_weight
|
| 530 |
+
self.class_weight = class_weight
|
| 531 |
+
self.ignore_index = ignore_index
|
| 532 |
+
self.avg_non_ignore = avg_non_ignore
|
| 533 |
+
if ((ignore_index is not None) and not self.avg_non_ignore
|
| 534 |
+
and self.reduction == 'mean'):
|
| 535 |
+
warnings.warn(
|
| 536 |
+
'Default ``avg_non_ignore`` is False, if you would like to '
|
| 537 |
+
'ignore the certain label and average loss over non-ignore '
|
| 538 |
+
'labels, which is the same with PyTorch official '
|
| 539 |
+
'cross_entropy, set ``avg_non_ignore=True``.')
|
| 540 |
+
|
| 541 |
+
if self.use_sigmoid:
|
| 542 |
+
self.cls_criterion = binary_cross_entropy
|
| 543 |
+
elif self.use_mask:
|
| 544 |
+
self.cls_criterion = mask_cross_entropy
|
| 545 |
+
else:
|
| 546 |
+
self.cls_criterion = cross_entropy
|
| 547 |
+
|
| 548 |
+
def extra_repr(self):
|
| 549 |
+
"""Extra repr."""
|
| 550 |
+
s = f'avg_non_ignore={self.avg_non_ignore}'
|
| 551 |
+
return s
|
| 552 |
+
|
| 553 |
+
def forward(self,
|
| 554 |
+
cls_score,
|
| 555 |
+
label,
|
| 556 |
+
weight=None,
|
| 557 |
+
avg_factor=None,
|
| 558 |
+
reduction_override=None,
|
| 559 |
+
ignore_index=None,
|
| 560 |
+
**kwargs):
|
| 561 |
+
"""Forward function.
|
| 562 |
+
|
| 563 |
+
Args:
|
| 564 |
+
cls_score (torch.Tensor): The prediction.
|
| 565 |
+
label (torch.Tensor): The learning label of the prediction.
|
| 566 |
+
weight (torch.Tensor, optional): Sample-wise loss weight.
|
| 567 |
+
avg_factor (int, optional): Average factor that is used to average
|
| 568 |
+
the loss. Defaults to None.
|
| 569 |
+
reduction_override (str, optional): The method used to reduce the
|
| 570 |
+
loss. Options are "none", "mean" and "sum".
|
| 571 |
+
ignore_index (int | None): The label index to be ignored.
|
| 572 |
+
If not None, it will override the default value. Default: None.
|
| 573 |
+
Returns:
|
| 574 |
+
torch.Tensor: The calculated loss.
|
| 575 |
+
"""
|
| 576 |
+
assert reduction_override in (None, 'none', 'mean', 'sum')
|
| 577 |
+
reduction = (
|
| 578 |
+
reduction_override if reduction_override else self.reduction)
|
| 579 |
+
if ignore_index is None:
|
| 580 |
+
ignore_index = self.ignore_index
|
| 581 |
+
|
| 582 |
+
if self.class_weight is not None:
|
| 583 |
+
class_weight = cls_score.new_tensor(
|
| 584 |
+
self.class_weight, device=cls_score.device)
|
| 585 |
+
else:
|
| 586 |
+
class_weight = None
|
| 587 |
+
loss_cls = self.loss_weight * self.cls_criterion(
|
| 588 |
+
cls_score,
|
| 589 |
+
label,
|
| 590 |
+
weight,
|
| 591 |
+
class_weight=class_weight,
|
| 592 |
+
reduction=reduction,
|
| 593 |
+
avg_factor=avg_factor,
|
| 594 |
+
ignore_index=ignore_index,
|
| 595 |
+
avg_non_ignore=self.avg_non_ignore,
|
| 596 |
+
**kwargs)
|
| 597 |
+
return loss_cls
|
RynnEC/rynnec/model/predictor/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .sam2_predictor import SAM2VideoPredictor
|
RynnEC/rynnec/model/predictor/sam2_predictor.py
ADDED
|
@@ -0,0 +1,724 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adopted from https://github.com/magic-research/Sa2VA/blob/main/projects/llava_sam2/models/predictor/sam2_predictor.py.
|
| 2 |
+
# Below is the original copyright:
|
| 3 |
+
# coding=utf-8
|
| 4 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
from collections import OrderedDict
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
|
| 22 |
+
from ..extension import SAM2Base
|
| 23 |
+
from third_parts.sam2.modeling.sam2_base import NO_OBJ_SCORE
|
| 24 |
+
from third_parts.sam2.utils.misc import fill_holes_in_mask_scores
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _obj_id_to_idx(inference_state, obj_id):
|
| 28 |
+
"""Map client-side object id to model-side object index."""
|
| 29 |
+
obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
|
| 30 |
+
if obj_idx is not None:
|
| 31 |
+
return obj_idx
|
| 32 |
+
|
| 33 |
+
# This is a new object id not sent to the server before. We only allow adding
|
| 34 |
+
# new objects *before* the tracking starts.
|
| 35 |
+
allow_new_object = not inference_state["tracking_has_started"]
|
| 36 |
+
if allow_new_object:
|
| 37 |
+
# get the next object slot
|
| 38 |
+
obj_idx = len(inference_state["obj_id_to_idx"])
|
| 39 |
+
inference_state["obj_id_to_idx"][obj_id] = obj_idx
|
| 40 |
+
inference_state["obj_idx_to_id"][obj_idx] = obj_id
|
| 41 |
+
inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
|
| 42 |
+
# set up input and output structures for this object
|
| 43 |
+
inference_state["point_inputs_per_obj"][obj_idx] = {}
|
| 44 |
+
inference_state["mask_inputs_per_obj"][obj_idx] = {}
|
| 45 |
+
inference_state["output_dict_per_obj"][obj_idx] = {
|
| 46 |
+
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
| 47 |
+
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
| 48 |
+
}
|
| 49 |
+
inference_state["temp_output_dict_per_obj"][obj_idx] = {
|
| 50 |
+
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
| 51 |
+
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
| 52 |
+
}
|
| 53 |
+
return obj_idx
|
| 54 |
+
else:
|
| 55 |
+
raise RuntimeError(
|
| 56 |
+
f"Cannot add new object id {obj_id} after tracking starts. "
|
| 57 |
+
f"All existing object ids: {inference_state['obj_ids']}. "
|
| 58 |
+
f"Please call 'reset_state' to restart from scratch."
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _get_maskmem_pos_enc(inference_state, current_out):
|
| 63 |
+
"""
|
| 64 |
+
`maskmem_pos_enc` is the same across frames and objects, so we cache it as
|
| 65 |
+
a constant in the inference session to reduce session storage size.
|
| 66 |
+
"""
|
| 67 |
+
model_constants = inference_state["constants"]
|
| 68 |
+
# "out_maskmem_pos_enc" should be either a list of tensors or None
|
| 69 |
+
out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
| 70 |
+
if out_maskmem_pos_enc is not None:
|
| 71 |
+
if "maskmem_pos_enc" not in model_constants:
|
| 72 |
+
assert isinstance(out_maskmem_pos_enc, list)
|
| 73 |
+
# only take the slice for one object, since it's same across objects
|
| 74 |
+
maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
|
| 75 |
+
model_constants["maskmem_pos_enc"] = maskmem_pos_enc
|
| 76 |
+
else:
|
| 77 |
+
maskmem_pos_enc = model_constants["maskmem_pos_enc"]
|
| 78 |
+
# expand the cached maskmem_pos_enc to the actual batch size
|
| 79 |
+
batch_size = out_maskmem_pos_enc[0].size(0)
|
| 80 |
+
expanded_maskmem_pos_enc = [
|
| 81 |
+
x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
|
| 82 |
+
]
|
| 83 |
+
else:
|
| 84 |
+
expanded_maskmem_pos_enc = None
|
| 85 |
+
return expanded_maskmem_pos_enc
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _obj_idx_to_id(inference_state, obj_idx):
|
| 89 |
+
"""Map model-side object index to client-side object id."""
|
| 90 |
+
return inference_state["obj_idx_to_id"][obj_idx]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _get_obj_num(inference_state):
|
| 94 |
+
"""Get the total number of unique object ids received so far in this session."""
|
| 95 |
+
return len(inference_state["obj_idx_to_id"])
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class SAM2VideoPredictor(SAM2Base):
|
| 99 |
+
"""The predictor class to handle user interactions and manage inference states."""
|
| 100 |
+
|
| 101 |
+
def __init__(
|
| 102 |
+
self,
|
| 103 |
+
fill_hole_area=0,
|
| 104 |
+
# whether to apply non-overlapping constraints on the output object masks
|
| 105 |
+
non_overlap_masks=False,
|
| 106 |
+
# whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
|
| 107 |
+
# note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
|
| 108 |
+
clear_non_cond_mem_around_input=False,
|
| 109 |
+
# whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
|
| 110 |
+
clear_non_cond_mem_for_multi_obj=False,
|
| 111 |
+
**kwargs,
|
| 112 |
+
):
|
| 113 |
+
super().__init__(**kwargs)
|
| 114 |
+
self.fill_hole_area = fill_hole_area
|
| 115 |
+
self.non_overlap_masks = non_overlap_masks
|
| 116 |
+
self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
|
| 117 |
+
self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
|
| 118 |
+
|
| 119 |
+
def _get_image_feature(self, inference_state, frame_idx, batch_size):
|
| 120 |
+
"""Compute the image features on a given frame."""
|
| 121 |
+
# Look up in the cache first
|
| 122 |
+
image, backbone_out = inference_state["cached_features"].get(
|
| 123 |
+
frame_idx, (None, None)
|
| 124 |
+
)
|
| 125 |
+
if backbone_out is None:
|
| 126 |
+
# Cache miss -- we will run inference on a single image
|
| 127 |
+
image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
|
| 128 |
+
backbone_out = self.forward_image(image)
|
| 129 |
+
# Cache the most recent frame's feature (for repeated interactions with
|
| 130 |
+
# a frame; we can use an LRU cache for more frames in the future).
|
| 131 |
+
inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
|
| 132 |
+
|
| 133 |
+
# expand the features to have the same dimension as the number of objects
|
| 134 |
+
expanded_image = image.expand(batch_size, -1, -1, -1)
|
| 135 |
+
expanded_backbone_out = {
|
| 136 |
+
"backbone_fpn": backbone_out["backbone_fpn"].copy(),
|
| 137 |
+
"vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
|
| 138 |
+
}
|
| 139 |
+
for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
|
| 140 |
+
expanded_backbone_out["backbone_fpn"][i] = feat.expand(
|
| 141 |
+
batch_size, -1, -1, -1
|
| 142 |
+
)
|
| 143 |
+
for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
|
| 144 |
+
pos = pos.expand(batch_size, -1, -1, -1)
|
| 145 |
+
expanded_backbone_out["vision_pos_enc"][i] = pos
|
| 146 |
+
|
| 147 |
+
features = self._prepare_backbone_features(expanded_backbone_out)
|
| 148 |
+
features = (expanded_image,) + features
|
| 149 |
+
return features
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _run_single_frame_inference(
|
| 153 |
+
self,
|
| 154 |
+
inference_state,
|
| 155 |
+
output_dict,
|
| 156 |
+
frame_idx,
|
| 157 |
+
batch_size,
|
| 158 |
+
is_init_cond_frame,
|
| 159 |
+
point_inputs,
|
| 160 |
+
mask_inputs,
|
| 161 |
+
reverse,
|
| 162 |
+
run_mem_encoder,
|
| 163 |
+
prev_sam_mask_logits=None,
|
| 164 |
+
## Extension: LLM prompt
|
| 165 |
+
language_embd=None,
|
| 166 |
+
):
|
| 167 |
+
"""Run tracking on a single frame based on current inputs and previous memory."""
|
| 168 |
+
# Retrieve correct image features
|
| 169 |
+
(
|
| 170 |
+
_,
|
| 171 |
+
_,
|
| 172 |
+
current_vision_feats,
|
| 173 |
+
current_vision_pos_embeds,
|
| 174 |
+
feat_sizes,
|
| 175 |
+
) = self._get_image_feature(inference_state, frame_idx, batch_size)
|
| 176 |
+
|
| 177 |
+
# point and mask should not appear as input simultaneously on the same frame
|
| 178 |
+
assert point_inputs is None or mask_inputs is None
|
| 179 |
+
current_out = self.track_step(
|
| 180 |
+
frame_idx=frame_idx,
|
| 181 |
+
is_init_cond_frame=is_init_cond_frame,
|
| 182 |
+
current_vision_feats=current_vision_feats,
|
| 183 |
+
current_vision_pos_embeds=current_vision_pos_embeds,
|
| 184 |
+
feat_sizes=feat_sizes,
|
| 185 |
+
point_inputs=point_inputs,
|
| 186 |
+
mask_inputs=mask_inputs,
|
| 187 |
+
output_dict=output_dict,
|
| 188 |
+
num_frames=inference_state["num_frames"],
|
| 189 |
+
track_in_reverse=reverse,
|
| 190 |
+
run_mem_encoder=run_mem_encoder,
|
| 191 |
+
prev_sam_mask_logits=prev_sam_mask_logits,
|
| 192 |
+
language_embd=language_embd,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# optionally offload the output to CPU memory to save GPU space
|
| 196 |
+
storage_device = inference_state["storage_device"]
|
| 197 |
+
maskmem_features = current_out["maskmem_features"]
|
| 198 |
+
if maskmem_features is not None:
|
| 199 |
+
maskmem_features = maskmem_features.to(torch.bfloat16)
|
| 200 |
+
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
|
| 201 |
+
pred_masks_gpu = current_out["pred_masks"]
|
| 202 |
+
# potentially fill holes in the predicted masks
|
| 203 |
+
if self.fill_hole_area > 0:
|
| 204 |
+
pred_masks_gpu = fill_holes_in_mask_scores(
|
| 205 |
+
pred_masks_gpu, self.fill_hole_area
|
| 206 |
+
)
|
| 207 |
+
pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
|
| 208 |
+
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
| 209 |
+
maskmem_pos_enc = _get_maskmem_pos_enc(inference_state, current_out)
|
| 210 |
+
# object pointer is a small tensor, so we always keep it on GPU memory for fast access
|
| 211 |
+
obj_ptr = current_out["obj_ptr"]
|
| 212 |
+
# make a compact version of this frame's output to reduce the state size
|
| 213 |
+
compact_current_out = {
|
| 214 |
+
"maskmem_features": maskmem_features,
|
| 215 |
+
"maskmem_pos_enc": maskmem_pos_enc,
|
| 216 |
+
"pred_masks": pred_masks,
|
| 217 |
+
"obj_ptr": obj_ptr,
|
| 218 |
+
}
|
| 219 |
+
return compact_current_out, pred_masks_gpu
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def _consolidate_temp_output_across_obj(
|
| 223 |
+
self,
|
| 224 |
+
inference_state,
|
| 225 |
+
frame_idx,
|
| 226 |
+
is_cond,
|
| 227 |
+
run_mem_encoder,
|
| 228 |
+
consolidate_at_video_res=False,
|
| 229 |
+
):
|
| 230 |
+
"""
|
| 231 |
+
Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
|
| 232 |
+
a frame into a single output for all objects, including
|
| 233 |
+
1) fill any missing objects either from `output_dict_per_obj` (if they exist in
|
| 234 |
+
`output_dict_per_obj` for this frame) or leave them as placeholder values
|
| 235 |
+
(if they don't exist in `output_dict_per_obj` for this frame);
|
| 236 |
+
2) if specified, rerun memory encoder after apply non-overlapping constraints
|
| 237 |
+
on the object scores.
|
| 238 |
+
"""
|
| 239 |
+
batch_size = _get_obj_num(inference_state)
|
| 240 |
+
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
| 241 |
+
# Optionally, we allow consolidating the temporary outputs at the original
|
| 242 |
+
# video resolution (to provide a better editing experience for mask prompts).
|
| 243 |
+
if consolidate_at_video_res:
|
| 244 |
+
assert not run_mem_encoder, "memory encoder cannot run at video resolution"
|
| 245 |
+
consolidated_H = inference_state["video_height"]
|
| 246 |
+
consolidated_W = inference_state["video_width"]
|
| 247 |
+
consolidated_mask_key = "pred_masks_video_res"
|
| 248 |
+
else:
|
| 249 |
+
consolidated_H = consolidated_W = self.image_size // 4
|
| 250 |
+
consolidated_mask_key = "pred_masks"
|
| 251 |
+
|
| 252 |
+
# Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
|
| 253 |
+
# will be added when rerunning the memory encoder after applying non-overlapping
|
| 254 |
+
# constraints to object scores. Its "pred_masks" are prefilled with a large
|
| 255 |
+
# negative value (NO_OBJ_SCORE) to represent missing objects.
|
| 256 |
+
consolidated_out = {
|
| 257 |
+
"maskmem_features": None,
|
| 258 |
+
"maskmem_pos_enc": None,
|
| 259 |
+
consolidated_mask_key: torch.full(
|
| 260 |
+
size=(batch_size, 1, consolidated_H, consolidated_W),
|
| 261 |
+
fill_value=NO_OBJ_SCORE,
|
| 262 |
+
dtype=torch.float32,
|
| 263 |
+
device=inference_state["storage_device"],
|
| 264 |
+
),
|
| 265 |
+
"obj_ptr": torch.full(
|
| 266 |
+
size=(batch_size, self.hidden_dim),
|
| 267 |
+
fill_value=NO_OBJ_SCORE,
|
| 268 |
+
dtype=torch.float32,
|
| 269 |
+
device=inference_state["device"],
|
| 270 |
+
),
|
| 271 |
+
}
|
| 272 |
+
empty_mask_ptr = None
|
| 273 |
+
for obj_idx in range(batch_size):
|
| 274 |
+
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
| 275 |
+
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
| 276 |
+
out = obj_temp_output_dict[storage_key].get(frame_idx, None)
|
| 277 |
+
# If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
|
| 278 |
+
# we fall back and look up its previous output in "output_dict_per_obj".
|
| 279 |
+
# We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
|
| 280 |
+
# "output_dict_per_obj" to find a previous output for this object.
|
| 281 |
+
if out is None:
|
| 282 |
+
out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
|
| 283 |
+
if out is None:
|
| 284 |
+
out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
|
| 285 |
+
# If the object doesn't appear in "output_dict_per_obj" either, we skip it
|
| 286 |
+
# and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
|
| 287 |
+
# placeholder above) and set its object pointer to be a dummy pointer.
|
| 288 |
+
if out is None:
|
| 289 |
+
# Fill in dummy object pointers for those objects without any inputs or
|
| 290 |
+
# tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
|
| 291 |
+
# i.e. when we need to build the memory for tracking).
|
| 292 |
+
if run_mem_encoder:
|
| 293 |
+
if empty_mask_ptr is None:
|
| 294 |
+
empty_mask_ptr = self._get_empty_mask_ptr(
|
| 295 |
+
inference_state, frame_idx
|
| 296 |
+
)
|
| 297 |
+
# fill object pointer with a dummy pointer (based on an empty mask)
|
| 298 |
+
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
|
| 299 |
+
continue
|
| 300 |
+
# Add the temporary object output mask to consolidated output mask
|
| 301 |
+
obj_mask = out["pred_masks"]
|
| 302 |
+
consolidated_pred_masks = consolidated_out[consolidated_mask_key]
|
| 303 |
+
if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
|
| 304 |
+
consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
|
| 305 |
+
else:
|
| 306 |
+
# Resize first if temporary object mask has a different resolution
|
| 307 |
+
resized_obj_mask = torch.nn.functional.interpolate(
|
| 308 |
+
obj_mask,
|
| 309 |
+
size=consolidated_pred_masks.shape[-2:],
|
| 310 |
+
mode="bilinear",
|
| 311 |
+
align_corners=False,
|
| 312 |
+
)
|
| 313 |
+
consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
|
| 314 |
+
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
|
| 315 |
+
|
| 316 |
+
# Optionally, apply non-overlapping constraints on the consolidated scores
|
| 317 |
+
# and rerun the memory encoder
|
| 318 |
+
if run_mem_encoder:
|
| 319 |
+
device = inference_state["device"]
|
| 320 |
+
high_res_masks = torch.nn.functional.interpolate(
|
| 321 |
+
consolidated_out["pred_masks"].to(device, non_blocking=True),
|
| 322 |
+
size=(self.image_size, self.image_size),
|
| 323 |
+
mode="bilinear",
|
| 324 |
+
align_corners=False,
|
| 325 |
+
)
|
| 326 |
+
if self.non_overlap_masks_for_mem_enc:
|
| 327 |
+
high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
|
| 328 |
+
maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
|
| 329 |
+
inference_state=inference_state,
|
| 330 |
+
frame_idx=frame_idx,
|
| 331 |
+
batch_size=batch_size,
|
| 332 |
+
high_res_masks=high_res_masks,
|
| 333 |
+
is_mask_from_pts=True, # these frames are what the user interacted with
|
| 334 |
+
)
|
| 335 |
+
consolidated_out["maskmem_features"] = maskmem_features
|
| 336 |
+
consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
|
| 337 |
+
|
| 338 |
+
return consolidated_out
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def _get_orig_video_res_output(self, inference_state, any_res_masks):
|
| 342 |
+
"""
|
| 343 |
+
Resize the object scores to the original video resolution (video_res_masks)
|
| 344 |
+
and apply non-overlapping constraints for final output.
|
| 345 |
+
"""
|
| 346 |
+
device = inference_state["device"]
|
| 347 |
+
video_H = inference_state["video_height"]
|
| 348 |
+
video_W = inference_state["video_width"]
|
| 349 |
+
any_res_masks = any_res_masks.to(device, non_blocking=True)
|
| 350 |
+
if any_res_masks.shape[-2:] == (video_H, video_W):
|
| 351 |
+
video_res_masks = any_res_masks
|
| 352 |
+
else:
|
| 353 |
+
video_res_masks = torch.nn.functional.interpolate(
|
| 354 |
+
any_res_masks,
|
| 355 |
+
size=(video_H, video_W),
|
| 356 |
+
mode="bilinear",
|
| 357 |
+
align_corners=False,
|
| 358 |
+
)
|
| 359 |
+
if self.non_overlap_masks:
|
| 360 |
+
video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
|
| 361 |
+
return any_res_masks, video_res_masks
|
| 362 |
+
|
| 363 |
+
def init_state(
|
| 364 |
+
self,
|
| 365 |
+
images
|
| 366 |
+
):
|
| 367 |
+
"""Initialize a inference state."""
|
| 368 |
+
inference_state = {}
|
| 369 |
+
inference_state["images"] = images
|
| 370 |
+
inference_state["num_frames"] = len(images)
|
| 371 |
+
# whether to offload the video frames to CPU memory
|
| 372 |
+
# turning on this option saves the GPU memory with only a very small overhead
|
| 373 |
+
inference_state["offload_video_to_cpu"] = False
|
| 374 |
+
# whether to offload the inference state to CPU memory
|
| 375 |
+
# turning on this option saves the GPU memory at the cost of a lower tracking fps
|
| 376 |
+
# (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
|
| 377 |
+
# and from 24 to 21 when tracking two objects)
|
| 378 |
+
inference_state["offload_state_to_cpu"] = False
|
| 379 |
+
# the original video height and width, used for resizing final output scores
|
| 380 |
+
inference_state["video_height"] = self.image_size
|
| 381 |
+
inference_state["video_width"] = self.image_size
|
| 382 |
+
inference_state["device"] = torch.device("cuda")
|
| 383 |
+
inference_state["storage_device"] = torch.device("cuda")
|
| 384 |
+
# inputs on each frame
|
| 385 |
+
inference_state["point_inputs_per_obj"] = {}
|
| 386 |
+
inference_state["mask_inputs_per_obj"] = {}
|
| 387 |
+
# visual features on a small number of recently visited frames for quick interactions
|
| 388 |
+
inference_state["cached_features"] = {}
|
| 389 |
+
# values that don't change across frames (so we only need to hold one copy of them)
|
| 390 |
+
inference_state["constants"] = {}
|
| 391 |
+
# mapping between client-side object id and model-side object index
|
| 392 |
+
inference_state["obj_id_to_idx"] = OrderedDict()
|
| 393 |
+
inference_state["obj_idx_to_id"] = OrderedDict()
|
| 394 |
+
inference_state["obj_ids"] = []
|
| 395 |
+
# A storage to hold the model's tracking results and states on each frame
|
| 396 |
+
inference_state["output_dict"] = {
|
| 397 |
+
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
| 398 |
+
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
| 399 |
+
}
|
| 400 |
+
# Slice (view) of each object tracking results, sharing the same memory with "output_dict"
|
| 401 |
+
inference_state["output_dict_per_obj"] = {}
|
| 402 |
+
# A temporary storage to hold new outputs when user interact with a frame
|
| 403 |
+
# to add clicks or mask (it's merged into "output_dict" before propagation starts)
|
| 404 |
+
inference_state["temp_output_dict_per_obj"] = {}
|
| 405 |
+
# Frames that already holds consolidated outputs from click or mask inputs
|
| 406 |
+
# (we directly use their consolidated outputs during tracking)
|
| 407 |
+
inference_state["consolidated_frame_inds"] = {
|
| 408 |
+
"cond_frame_outputs": set(), # set containing frame indices
|
| 409 |
+
"non_cond_frame_outputs": set(), # set containing frame indices
|
| 410 |
+
}
|
| 411 |
+
# metadata for each tracking frame (e.g. which direction it's tracked)
|
| 412 |
+
inference_state["tracking_has_started"] = False
|
| 413 |
+
inference_state["frames_already_tracked"] = {}
|
| 414 |
+
return inference_state
|
| 415 |
+
|
| 416 |
+
def add_language_embd(
|
| 417 |
+
self,
|
| 418 |
+
inference_state,
|
| 419 |
+
frame_idx,
|
| 420 |
+
obj_id,
|
| 421 |
+
language_embd,
|
| 422 |
+
inference=False,
|
| 423 |
+
):
|
| 424 |
+
obj_idx = _obj_id_to_idx(inference_state, obj_id)
|
| 425 |
+
|
| 426 |
+
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
|
| 427 |
+
# whether to track in reverse time order
|
| 428 |
+
if is_init_cond_frame:
|
| 429 |
+
reverse = False
|
| 430 |
+
else:
|
| 431 |
+
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
|
| 432 |
+
|
| 433 |
+
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
| 434 |
+
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
| 435 |
+
# Add a frame to conditioning output if it's an initial conditioning frame or
|
| 436 |
+
# if the model sees all frames receiving clicks/mask as conditioning frames.
|
| 437 |
+
is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
|
| 438 |
+
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
| 439 |
+
|
| 440 |
+
# Get any previously predicted mask logits on this object and feed it along with
|
| 441 |
+
# the new clicks into the SAM mask decoder.
|
| 442 |
+
prev_sam_mask_logits = None
|
| 443 |
+
# lookup temporary output dict first, which contains the most recent output
|
| 444 |
+
# (if not found, then lookup conditioning and non-conditioning frame output)
|
| 445 |
+
prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
|
| 446 |
+
if prev_out is None:
|
| 447 |
+
prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
|
| 448 |
+
if prev_out is None:
|
| 449 |
+
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
|
| 450 |
+
|
| 451 |
+
if prev_out is not None and prev_out["pred_masks"] is not None:
|
| 452 |
+
prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
|
| 453 |
+
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
|
| 454 |
+
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
|
| 455 |
+
|
| 456 |
+
current_out, pred_mask_gpu = self._run_single_frame_inference(
|
| 457 |
+
inference_state=inference_state,
|
| 458 |
+
output_dict=obj_output_dict, # run on the slice of a single object
|
| 459 |
+
frame_idx=frame_idx,
|
| 460 |
+
batch_size=1, # run on the slice of a single object
|
| 461 |
+
is_init_cond_frame=is_init_cond_frame,
|
| 462 |
+
point_inputs=None,
|
| 463 |
+
mask_inputs=None,
|
| 464 |
+
reverse=reverse,
|
| 465 |
+
# Skip the memory encoder when adding clicks or mask. We execute the memory encoder
|
| 466 |
+
# at the beginning of `propagate_in_video` (after user finalize their clicks). This
|
| 467 |
+
# allows us to enforce non-overlapping constraints on all objects before encoding
|
| 468 |
+
# them into memory.
|
| 469 |
+
run_mem_encoder=False,
|
| 470 |
+
prev_sam_mask_logits=prev_sam_mask_logits,
|
| 471 |
+
## Extension: LLM prompt
|
| 472 |
+
language_embd=language_embd,
|
| 473 |
+
)
|
| 474 |
+
# Add the output to the output dict (to be used as future memory)
|
| 475 |
+
obj_temp_output_dict[storage_key][frame_idx] = current_out
|
| 476 |
+
|
| 477 |
+
# Resize the output mask to the original video resolution
|
| 478 |
+
obj_ids = inference_state["obj_ids"]
|
| 479 |
+
if inference:
|
| 480 |
+
_consolidated_out = self._consolidate_temp_output_across_obj(
|
| 481 |
+
inference_state,
|
| 482 |
+
frame_idx,
|
| 483 |
+
is_cond=is_cond,
|
| 484 |
+
run_mem_encoder=False,
|
| 485 |
+
consolidate_at_video_res=False,
|
| 486 |
+
)
|
| 487 |
+
# _, video_res_masks = self._get_orig_video_res_output(
|
| 488 |
+
# inference_state, consolidated_out["pred_masks_video_res"]
|
| 489 |
+
# )
|
| 490 |
+
return frame_idx, obj_ids, pred_mask_gpu
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
|
| 494 |
+
"""
|
| 495 |
+
Remove the non-conditioning memory around the input frame. When users provide
|
| 496 |
+
correction clicks, the surrounding frames' non-conditioning memories can still
|
| 497 |
+
contain outdated object appearance information and could confuse the model.
|
| 498 |
+
|
| 499 |
+
This method clears those non-conditioning memories surrounding the interacted
|
| 500 |
+
frame to avoid giving the model both old and new information about the object.
|
| 501 |
+
"""
|
| 502 |
+
r = self.memory_temporal_stride_for_eval
|
| 503 |
+
frame_idx_begin = frame_idx - r * self.num_maskmem
|
| 504 |
+
frame_idx_end = frame_idx + r * self.num_maskmem
|
| 505 |
+
output_dict = inference_state["output_dict"]
|
| 506 |
+
non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
|
| 507 |
+
for t in range(frame_idx_begin, frame_idx_end + 1):
|
| 508 |
+
non_cond_frame_outputs.pop(t, None)
|
| 509 |
+
for obj_output_dict in inference_state["output_dict_per_obj"].values():
|
| 510 |
+
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
|
| 511 |
+
|
| 512 |
+
def _run_memory_encoder(
|
| 513 |
+
self, inference_state, frame_idx, batch_size, high_res_masks, is_mask_from_pts
|
| 514 |
+
):
|
| 515 |
+
"""
|
| 516 |
+
Run the memory encoder on `high_res_masks`. This is usually after applying
|
| 517 |
+
non-overlapping constraints to object scores. Since their scores changed, their
|
| 518 |
+
memory also need to be computed again with the memory encoder.
|
| 519 |
+
"""
|
| 520 |
+
# Retrieve correct image features
|
| 521 |
+
_, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
|
| 522 |
+
inference_state, frame_idx, batch_size
|
| 523 |
+
)
|
| 524 |
+
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
|
| 525 |
+
current_vision_feats=current_vision_feats,
|
| 526 |
+
feat_sizes=feat_sizes,
|
| 527 |
+
pred_masks_high_res=high_res_masks,
|
| 528 |
+
is_mask_from_pts=is_mask_from_pts,
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
# optionally offload the output to CPU memory to save GPU space
|
| 532 |
+
storage_device = inference_state["storage_device"]
|
| 533 |
+
maskmem_features = maskmem_features.to(torch.bfloat16)
|
| 534 |
+
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
|
| 535 |
+
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
| 536 |
+
maskmem_pos_enc = _get_maskmem_pos_enc(
|
| 537 |
+
inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
|
| 538 |
+
)
|
| 539 |
+
return maskmem_features, maskmem_pos_enc
|
| 540 |
+
|
| 541 |
+
def _add_output_per_object(
|
| 542 |
+
self, inference_state, frame_idx, current_out, storage_key
|
| 543 |
+
):
|
| 544 |
+
"""
|
| 545 |
+
Split a multi-object output into per-object output slices and add them into
|
| 546 |
+
`output_dict_per_obj`. The resulting slices share the same tensor storage.
|
| 547 |
+
"""
|
| 548 |
+
maskmem_features = current_out["maskmem_features"]
|
| 549 |
+
assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
|
| 550 |
+
|
| 551 |
+
maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
| 552 |
+
assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
|
| 553 |
+
|
| 554 |
+
output_dict_per_obj = inference_state["output_dict_per_obj"]
|
| 555 |
+
for obj_idx, obj_output_dict in output_dict_per_obj.items():
|
| 556 |
+
obj_slice = slice(obj_idx, obj_idx + 1)
|
| 557 |
+
obj_out = {
|
| 558 |
+
"maskmem_features": None,
|
| 559 |
+
"maskmem_pos_enc": None,
|
| 560 |
+
"pred_masks": current_out["pred_masks"][obj_slice],
|
| 561 |
+
"obj_ptr": current_out["obj_ptr"][obj_slice],
|
| 562 |
+
}
|
| 563 |
+
if maskmem_features is not None:
|
| 564 |
+
obj_out["maskmem_features"] = maskmem_features[obj_slice]
|
| 565 |
+
if maskmem_pos_enc is not None:
|
| 566 |
+
obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
|
| 567 |
+
obj_output_dict[storage_key][frame_idx] = obj_out
|
| 568 |
+
|
| 569 |
+
@torch.inference_mode()
|
| 570 |
+
def propagate_in_video_preflight(self, inference_state):
|
| 571 |
+
"""Prepare inference_state and consolidate temporary outputs before tracking."""
|
| 572 |
+
# Tracking has started and we don't allow adding new objects until session is reset.
|
| 573 |
+
inference_state["tracking_has_started"] = True
|
| 574 |
+
batch_size = _get_obj_num(inference_state)
|
| 575 |
+
|
| 576 |
+
# Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
|
| 577 |
+
# add them into "output_dict".
|
| 578 |
+
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
| 579 |
+
output_dict = inference_state["output_dict"]
|
| 580 |
+
# "consolidated_frame_inds" contains indices of those frames where consolidated
|
| 581 |
+
# temporary outputs have been added (either in this call or any previous calls
|
| 582 |
+
# to `propagate_in_video_preflight`).
|
| 583 |
+
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
| 584 |
+
for is_cond in [False, True]:
|
| 585 |
+
# Separately consolidate conditioning and non-conditioning temp outptus
|
| 586 |
+
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
| 587 |
+
# Find all the frames that contain temporary outputs for any objects
|
| 588 |
+
# (these should be the frames that have just received clicks for mask inputs
|
| 589 |
+
# via `add_new_points` or `add_new_mask`)
|
| 590 |
+
temp_frame_inds = set()
|
| 591 |
+
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
| 592 |
+
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
|
| 593 |
+
consolidated_frame_inds[storage_key].update(temp_frame_inds)
|
| 594 |
+
# consolidate the temprary output across all objects on this frame
|
| 595 |
+
for frame_idx in temp_frame_inds:
|
| 596 |
+
consolidated_out = self._consolidate_temp_output_across_obj(
|
| 597 |
+
inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
|
| 598 |
+
)
|
| 599 |
+
# merge them into "output_dict" and also create per-object slices
|
| 600 |
+
output_dict[storage_key][frame_idx] = consolidated_out
|
| 601 |
+
self._add_output_per_object(
|
| 602 |
+
inference_state, frame_idx, consolidated_out, storage_key
|
| 603 |
+
)
|
| 604 |
+
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
|
| 605 |
+
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
|
| 606 |
+
)
|
| 607 |
+
if clear_non_cond_mem:
|
| 608 |
+
# clear non-conditioning memory of the surrounding frames
|
| 609 |
+
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
|
| 610 |
+
|
| 611 |
+
# clear temporary outputs in `temp_output_dict_per_obj`
|
| 612 |
+
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
| 613 |
+
obj_temp_output_dict[storage_key].clear()
|
| 614 |
+
|
| 615 |
+
# edge case: if an output is added to "cond_frame_outputs", we remove any prior
|
| 616 |
+
# output on the same frame in "non_cond_frame_outputs"
|
| 617 |
+
for frame_idx in output_dict["cond_frame_outputs"]:
|
| 618 |
+
output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
| 619 |
+
for obj_output_dict in inference_state["output_dict_per_obj"].values():
|
| 620 |
+
for frame_idx in obj_output_dict["cond_frame_outputs"]:
|
| 621 |
+
obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
| 622 |
+
for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
|
| 623 |
+
assert frame_idx in output_dict["cond_frame_outputs"]
|
| 624 |
+
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
|
| 625 |
+
|
| 626 |
+
# Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
|
| 627 |
+
# with either points or mask inputs (which should be true under a correct workflow).
|
| 628 |
+
all_consolidated_frame_inds = (
|
| 629 |
+
consolidated_frame_inds["cond_frame_outputs"]
|
| 630 |
+
| consolidated_frame_inds["non_cond_frame_outputs"]
|
| 631 |
+
)
|
| 632 |
+
input_frames_inds = set()
|
| 633 |
+
for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
|
| 634 |
+
input_frames_inds.update(point_inputs_per_frame.keys())
|
| 635 |
+
for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
|
| 636 |
+
input_frames_inds.update(mask_inputs_per_frame.keys())
|
| 637 |
+
|
| 638 |
+
# with language embd as input, there may not be point or box
|
| 639 |
+
# assert all_consolidated_frame_inds == input_frames_inds
|
| 640 |
+
|
| 641 |
+
@torch.inference_mode()
|
| 642 |
+
def propagate_in_video(
|
| 643 |
+
self,
|
| 644 |
+
inference_state,
|
| 645 |
+
start_frame_idx=None,
|
| 646 |
+
max_frame_num_to_track=None,
|
| 647 |
+
reverse=False,
|
| 648 |
+
):
|
| 649 |
+
"""Propagate the input points across frames to track in the entire video."""
|
| 650 |
+
self.propagate_in_video_preflight(inference_state)
|
| 651 |
+
|
| 652 |
+
output_dict = inference_state["output_dict"]
|
| 653 |
+
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
| 654 |
+
obj_ids = inference_state["obj_ids"]
|
| 655 |
+
num_frames = inference_state["num_frames"]
|
| 656 |
+
batch_size = _get_obj_num(inference_state)
|
| 657 |
+
if len(output_dict["cond_frame_outputs"]) == 0:
|
| 658 |
+
raise RuntimeError("No points are provided; please add points first")
|
| 659 |
+
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
|
| 660 |
+
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
# set start index, end index, and processing order
|
| 664 |
+
if start_frame_idx is None:
|
| 665 |
+
# default: start from the earliest frame with input points
|
| 666 |
+
start_frame_idx = min(output_dict["cond_frame_outputs"])
|
| 667 |
+
if max_frame_num_to_track is None:
|
| 668 |
+
# default: track all the frames in the video
|
| 669 |
+
max_frame_num_to_track = num_frames
|
| 670 |
+
if reverse:
|
| 671 |
+
end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
|
| 672 |
+
if start_frame_idx > 0:
|
| 673 |
+
processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
|
| 674 |
+
else:
|
| 675 |
+
processing_order = [] # skip reverse tracking if starting from frame 0
|
| 676 |
+
else:
|
| 677 |
+
end_frame_idx = min(
|
| 678 |
+
start_frame_idx + max_frame_num_to_track, num_frames - 1
|
| 679 |
+
)
|
| 680 |
+
processing_order = range(start_frame_idx, end_frame_idx + 1)
|
| 681 |
+
|
| 682 |
+
for frame_idx in tqdm(processing_order, desc="propagate in video"):
|
| 683 |
+
# We skip those frames already in consolidated outputs (these are frames
|
| 684 |
+
# that received input clicks or mask). Note that we cannot directly run
|
| 685 |
+
# batched forward on them via `_run_single_frame_inference` because the
|
| 686 |
+
# number of clicks on each object might be different.
|
| 687 |
+
if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
|
| 688 |
+
storage_key = "cond_frame_outputs"
|
| 689 |
+
current_out = output_dict[storage_key][frame_idx]
|
| 690 |
+
pred_masks = current_out["pred_masks"]
|
| 691 |
+
if clear_non_cond_mem:
|
| 692 |
+
# clear non-conditioning memory of the surrounding frames
|
| 693 |
+
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
|
| 694 |
+
elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
|
| 695 |
+
storage_key = "non_cond_frame_outputs"
|
| 696 |
+
current_out = output_dict[storage_key][frame_idx]
|
| 697 |
+
pred_masks = current_out["pred_masks"]
|
| 698 |
+
else:
|
| 699 |
+
storage_key = "non_cond_frame_outputs"
|
| 700 |
+
current_out, pred_masks = self._run_single_frame_inference(
|
| 701 |
+
inference_state=inference_state,
|
| 702 |
+
output_dict=output_dict,
|
| 703 |
+
frame_idx=frame_idx,
|
| 704 |
+
batch_size=batch_size,
|
| 705 |
+
is_init_cond_frame=False,
|
| 706 |
+
point_inputs=None,
|
| 707 |
+
mask_inputs=None,
|
| 708 |
+
reverse=reverse,
|
| 709 |
+
run_mem_encoder=True,
|
| 710 |
+
)
|
| 711 |
+
output_dict[storage_key][frame_idx] = current_out
|
| 712 |
+
# Create slices of per-object outputs for subsequent interaction with each
|
| 713 |
+
# individual object after tracking.
|
| 714 |
+
self._add_output_per_object(
|
| 715 |
+
inference_state, frame_idx, current_out, storage_key
|
| 716 |
+
)
|
| 717 |
+
inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
|
| 718 |
+
|
| 719 |
+
# Resize the output mask to the original video resolution (we directly use
|
| 720 |
+
# the mask scores on GPU for output to avoid any CPU conversion in between)
|
| 721 |
+
_, video_res_masks = self._get_orig_video_res_output(
|
| 722 |
+
inference_state, pred_masks
|
| 723 |
+
)
|
| 724 |
+
yield frame_idx, obj_ids, video_res_masks
|
RynnEC/rynnec/model/processor.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 5 |
+
# and OPT implementations in this library. It has been modified from its
|
| 6 |
+
# original forms to accommodate minor architectural differences compared
|
| 7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
"""
|
| 21 |
+
Processor class for VideoLLaMA3.
|
| 22 |
+
"""
|
| 23 |
+
from abc import ABCMeta, abstractmethod
|
| 24 |
+
import copy
|
| 25 |
+
import warnings
|
| 26 |
+
from collections import defaultdict
|
| 27 |
+
from typing import List, Union, Dict, Optional, Any
|
| 28 |
+
|
| 29 |
+
import json
|
| 30 |
+
import torch
|
| 31 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 32 |
+
from transformers.image_utils import ImageInput, VideoInput
|
| 33 |
+
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
| 34 |
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 35 |
+
|
| 36 |
+
from rynnec.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX
|
| 37 |
+
from rynnec.mm_utils import load_video, load_images
|
| 38 |
+
from rynnec.model.videollama3_encoder.image_processing_videollama3 import is_valid_image, is_valid_video
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class Videollama3ProcessorKwargs(ProcessingKwargs, total=False):
|
| 42 |
+
_defaults = {
|
| 43 |
+
"text_kwargs": {
|
| 44 |
+
"padding": False,
|
| 45 |
+
},
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class Videollama3BaseProcessor(ProcessorMixin, metaclass=ABCMeta):
|
| 50 |
+
r"""
|
| 51 |
+
Modified from Qwen2VLProcessor
|
| 52 |
+
Args:
|
| 53 |
+
image_processor ([`Qwen2VLImageProcessor`], *optional*):
|
| 54 |
+
The image processor is a required input.
|
| 55 |
+
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
| 56 |
+
The tokenizer is a required input.
|
| 57 |
+
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
| 58 |
+
in a chat into a tokenizable string.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
attributes = ["image_processor", "tokenizer"]
|
| 62 |
+
valid_kwargs = ["chat_template", "image_merge_size", "video_merge_size", "fps", "max_frames"]
|
| 63 |
+
image_processor_class = "AutoImageProcessor"
|
| 64 |
+
tokenizer_class = None
|
| 65 |
+
chat_template = None
|
| 66 |
+
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
image_processor=None,
|
| 70 |
+
tokenizer=None,
|
| 71 |
+
chat_template=None,
|
| 72 |
+
image_merge_size: int = 1,
|
| 73 |
+
video_merge_size: int = 2,
|
| 74 |
+
fps=1,
|
| 75 |
+
max_frames=180,
|
| 76 |
+
**kwargs
|
| 77 |
+
):
|
| 78 |
+
if chat_template is not None:
|
| 79 |
+
self.chat_template = chat_template
|
| 80 |
+
|
| 81 |
+
self.image_processor = image_processor
|
| 82 |
+
self.tokenizer = tokenizer
|
| 83 |
+
self.image_merge_size = image_merge_size
|
| 84 |
+
self.video_merge_size = video_merge_size
|
| 85 |
+
self.fps = fps
|
| 86 |
+
self.max_frames = max_frames
|
| 87 |
+
|
| 88 |
+
if self.chat_template is not None:
|
| 89 |
+
self.tokenizer.chat_template = self.chat_template
|
| 90 |
+
|
| 91 |
+
self.image_token = DEFAULT_IMAGE_TOKEN
|
| 92 |
+
self.think_start_token = "<think>"
|
| 93 |
+
self.think_end_token = "</think>"
|
| 94 |
+
self.tokenizer.add_tokens([self.image_token], special_tokens=True)
|
| 95 |
+
self.tokenizer.add_tokens([self.think_start_token, self.think_end_token], special_tokens=False)
|
| 96 |
+
self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
|
| 97 |
+
self.think_start_token_id = self.tokenizer.convert_tokens_to_ids(self.think_start_token)
|
| 98 |
+
self.think_end_token_id = self.tokenizer.convert_tokens_to_ids(self.think_end_token)
|
| 99 |
+
self.newline_token_id = self.tokenizer.encode("\n")[0]
|
| 100 |
+
|
| 101 |
+
def load_video(self, *args, **kwargs):
|
| 102 |
+
return load_video(*args, **kwargs)
|
| 103 |
+
|
| 104 |
+
def load_images(self, *args, **kwargs):
|
| 105 |
+
return load_images(*args, **kwargs)
|
| 106 |
+
|
| 107 |
+
def _get_downsampled_grid_sizes(self, image_inputs: Dict[str, Any]):
|
| 108 |
+
grid_sizes = []
|
| 109 |
+
for grid_size, merge_size in zip(image_inputs.get("grid_sizes", []), image_inputs.get("merge_sizes", [])):
|
| 110 |
+
if not torch.all(grid_size[1:] % merge_size == 0):
|
| 111 |
+
warnings.warn(f"Grid size {grid_size} is not divisible by merge size. Some undesired errors may occur.")
|
| 112 |
+
if grid_size[0] == 1:
|
| 113 |
+
grid_sizes.append(grid_size[1:] / merge_size)
|
| 114 |
+
elif grid_size[0] > 1:
|
| 115 |
+
grid_sizes.extend([grid_size[1:] / merge_size] * grid_size[0])
|
| 116 |
+
return grid_sizes
|
| 117 |
+
|
| 118 |
+
def _get_visual_seq_len(self, grid_size: torch.Tensor):
|
| 119 |
+
num_tokens = int(grid_size.prod().item())
|
| 120 |
+
return num_tokens
|
| 121 |
+
|
| 122 |
+
@abstractmethod
|
| 123 |
+
def _process_text_with_label(
|
| 124 |
+
self,
|
| 125 |
+
text: List[Dict],
|
| 126 |
+
grid_sizes: torch.Tensor = None,
|
| 127 |
+
**kwargs,
|
| 128 |
+
):
|
| 129 |
+
return {}
|
| 130 |
+
|
| 131 |
+
def _process_text_without_label(
|
| 132 |
+
self,
|
| 133 |
+
text: Union[List[str], List[Dict]],
|
| 134 |
+
grid_sizes: torch.Tensor = None,
|
| 135 |
+
**kwargs,
|
| 136 |
+
):
|
| 137 |
+
if isinstance(text, (list, tuple)) and isinstance(text[0], dict):
|
| 138 |
+
warnings.warn("Input text is a list of messages. Automatically convert it to a string with 'apply_chat_template' with generation prompt.")
|
| 139 |
+
text = self.apply_chat_template(text, tokenize=False, add_generation_prompt=True)
|
| 140 |
+
|
| 141 |
+
if len(grid_sizes) > 0:
|
| 142 |
+
image_idx = 0
|
| 143 |
+
while self.image_token in text:
|
| 144 |
+
thw = grid_sizes[image_idx]
|
| 145 |
+
text = text.replace(self.image_token, "<placeholder>" * thw.prod().long(), 1)
|
| 146 |
+
image_idx += 1
|
| 147 |
+
text = text.replace("<placeholder>", self.image_token)
|
| 148 |
+
assert len(grid_sizes) == image_idx, "Number of images does not match the number of image tokens in the text."
|
| 149 |
+
|
| 150 |
+
text_inputs = self.tokenizer(text, **kwargs)
|
| 151 |
+
return text_inputs
|
| 152 |
+
|
| 153 |
+
def process_text(
|
| 154 |
+
self,
|
| 155 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], List[Dict]],
|
| 156 |
+
image_inputs: Dict[str, torch.Tensor] = {},
|
| 157 |
+
return_labels: bool = False,
|
| 158 |
+
**kwargs,
|
| 159 |
+
):
|
| 160 |
+
kwargs.pop("padding", None)
|
| 161 |
+
kwargs.pop("padding_side", None)
|
| 162 |
+
|
| 163 |
+
grid_sizes = []
|
| 164 |
+
for grid_size, merge_size in zip(image_inputs.get("grid_sizes", []), image_inputs.get("merge_sizes", [])):
|
| 165 |
+
if not torch.all(grid_size[1:] % merge_size == 0):
|
| 166 |
+
warnings.warn(f"Grid size {grid_size} is not divisible by merge size. Some undesired errors may occur.")
|
| 167 |
+
if grid_size[0] == 1:
|
| 168 |
+
grid_sizes.append(grid_size[1:] / merge_size)
|
| 169 |
+
elif grid_size[0] > 1:
|
| 170 |
+
grid_sizes.extend([grid_size[1:] / merge_size] * grid_size[0])
|
| 171 |
+
|
| 172 |
+
if return_labels:
|
| 173 |
+
return self._process_text_with_label(text, grid_sizes, **kwargs)
|
| 174 |
+
return self._process_text_without_label(text, grid_sizes, **kwargs)
|
| 175 |
+
|
| 176 |
+
def process_images(
|
| 177 |
+
self,
|
| 178 |
+
images: ImageInput = None,
|
| 179 |
+
merge_size: Optional[int] = 1,
|
| 180 |
+
**kwargs,
|
| 181 |
+
):
|
| 182 |
+
if images is None:
|
| 183 |
+
return {}
|
| 184 |
+
image_inputs = self.image_processor(images=images, merge_size=merge_size, **kwargs)
|
| 185 |
+
return image_inputs
|
| 186 |
+
|
| 187 |
+
def __call__(
|
| 188 |
+
self,
|
| 189 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], List[Dict]] = None,
|
| 190 |
+
images: ImageInput = None,
|
| 191 |
+
merge_size: Optional[int] = 1,
|
| 192 |
+
return_labels: bool = False,
|
| 193 |
+
**kwargs: Unpack[Videollama3ProcessorKwargs],
|
| 194 |
+
) -> BatchFeature:
|
| 195 |
+
"""
|
| 196 |
+
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
| 197 |
+
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
| 198 |
+
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
|
| 199 |
+
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
| 203 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
| 204 |
+
tensor. Both channels-first and channels-last formats are supported.
|
| 205 |
+
text (`str`, `List[str]`, `List[List[str]]`):
|
| 206 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
| 207 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
| 208 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 209 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
| 210 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
| 211 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
| 212 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
| 213 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
| 214 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
| 218 |
+
|
| 219 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
| 220 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
| 221 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
| 222 |
+
`None`).
|
| 223 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
| 224 |
+
- **grid_sizes** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
|
| 225 |
+
"""
|
| 226 |
+
output_kwargs = self._merge_kwargs(
|
| 227 |
+
Videollama3ProcessorKwargs,
|
| 228 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 229 |
+
**kwargs,
|
| 230 |
+
)
|
| 231 |
+
output_kwargs["text_kwargs"].pop("padding", None)
|
| 232 |
+
output_kwargs["text_kwargs"].pop("padding_side", None)
|
| 233 |
+
|
| 234 |
+
image_inputs = self.process_images(images, merge_size, **output_kwargs["images_kwargs"])
|
| 235 |
+
text_inputs = self.process_text(text, image_inputs, return_labels, **output_kwargs["text_kwargs"])
|
| 236 |
+
|
| 237 |
+
return BatchFeature(data={**text_inputs, **image_inputs})
|
| 238 |
+
|
| 239 |
+
def batch_decode(self, *args, **kwargs):
|
| 240 |
+
"""
|
| 241 |
+
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
| 242 |
+
refer to the docstring of this method for more information.
|
| 243 |
+
"""
|
| 244 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 245 |
+
|
| 246 |
+
def decode(self, *args, **kwargs):
|
| 247 |
+
"""
|
| 248 |
+
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
| 249 |
+
the docstring of this method for more information.
|
| 250 |
+
"""
|
| 251 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 252 |
+
|
| 253 |
+
def _load_multimodal_data(self, conversation: List[Dict[str, Any]]):
|
| 254 |
+
multimodal_info = defaultdict(list)
|
| 255 |
+
new_conversation = []
|
| 256 |
+
for message in conversation:
|
| 257 |
+
new_message = {"role": message["role"]}
|
| 258 |
+
if not isinstance(message["content"], (list, tuple)):
|
| 259 |
+
new_message["content"] = message["content"]
|
| 260 |
+
new_conversation.append(new_message)
|
| 261 |
+
continue
|
| 262 |
+
|
| 263 |
+
new_contents = []
|
| 264 |
+
for content in message["content"]:
|
| 265 |
+
if not isinstance(content, dict):
|
| 266 |
+
new_contents.append(content)
|
| 267 |
+
continue
|
| 268 |
+
assert "type" in content, "Content must have 'type' field."
|
| 269 |
+
if content["type"] in ["image", "video"] and content["type"] in content and isinstance(content[content["type"]], dict):
|
| 270 |
+
# TODO: support other types which are not compatible with json
|
| 271 |
+
load_args = content[content["type"]]
|
| 272 |
+
data_id = json.dumps({k: v for k, v in load_args.items() if not k in ["start_time", "end_time"]})
|
| 273 |
+
new_content = copy.deepcopy(content)
|
| 274 |
+
multimodal_info[data_id].append(new_content)
|
| 275 |
+
new_contents.append(new_content)
|
| 276 |
+
else:
|
| 277 |
+
new_contents.append(content)
|
| 278 |
+
|
| 279 |
+
new_message["content"] = new_contents
|
| 280 |
+
new_conversation.append(new_message)
|
| 281 |
+
|
| 282 |
+
for data_id, contents in multimodal_info.items():
|
| 283 |
+
data_type = contents[0]["type"]
|
| 284 |
+
if data_type == "image":
|
| 285 |
+
image = self.load_images(contents[0][data_type]["image_path"])[0]
|
| 286 |
+
for content in contents:
|
| 287 |
+
content["image"] = image.copy()
|
| 288 |
+
|
| 289 |
+
elif data_type == "video":
|
| 290 |
+
# TODO: start_time is None?
|
| 291 |
+
start_times = [content["video"].get("start_time", 0.) for content in contents]
|
| 292 |
+
end_times = [content["video"].get("end_time", float("inf")) for content in contents]
|
| 293 |
+
|
| 294 |
+
load_args = contents[0][data_type]
|
| 295 |
+
start_time, end_time = min(start_times), max(end_times)
|
| 296 |
+
if start_time > 0:
|
| 297 |
+
load_args["start_time"] = start_time
|
| 298 |
+
if end_time < float("inf"):
|
| 299 |
+
load_args["end_time"] = end_time
|
| 300 |
+
images, timestamps = self.load_video(**load_args)
|
| 301 |
+
|
| 302 |
+
for content, start_time, end_time in zip(contents, start_times, end_times):
|
| 303 |
+
cur_images, cur_timestamps = [], []
|
| 304 |
+
for image, timestamp in zip(images, timestamps):
|
| 305 |
+
if start_time <= timestamp <= end_time:
|
| 306 |
+
cur_images.append(image.copy())
|
| 307 |
+
cur_timestamps.append(timestamp)
|
| 308 |
+
|
| 309 |
+
content[data_type] = cur_images
|
| 310 |
+
content["num_frames"] = len(cur_images)
|
| 311 |
+
content["timestamps"] = cur_timestamps
|
| 312 |
+
|
| 313 |
+
return new_conversation
|
| 314 |
+
|
| 315 |
+
def _gather_multimodal_data(self, conversation: List[Dict[str, Any]]):
|
| 316 |
+
images = []
|
| 317 |
+
for message in conversation:
|
| 318 |
+
if not isinstance(message["content"], (list, tuple)):
|
| 319 |
+
continue
|
| 320 |
+
for content in message["content"]:
|
| 321 |
+
if not isinstance(content, dict):
|
| 322 |
+
continue
|
| 323 |
+
if content["type"] == "video":
|
| 324 |
+
video = content["video"]
|
| 325 |
+
assert is_valid_video(video), f"Invalid video data: {video}."
|
| 326 |
+
images.append(video)
|
| 327 |
+
if content["type"] == "image":
|
| 328 |
+
image = content["image"]
|
| 329 |
+
assert is_valid_image(image), f"Invalid image data: {image}."
|
| 330 |
+
images.append(image)
|
| 331 |
+
images = images if len(images) > 0 else None
|
| 332 |
+
return images
|
| 333 |
+
|
| 334 |
+
def apply_chat_template(
|
| 335 |
+
self,
|
| 336 |
+
conversation: List[Dict[str, Any]],
|
| 337 |
+
chat_template: Optional[str] = None,
|
| 338 |
+
tokenize: bool = False,
|
| 339 |
+
add_system_prompt: bool = False,
|
| 340 |
+
add_generation_prompt: bool = False,
|
| 341 |
+
add_think_prompt: bool = False,
|
| 342 |
+
return_dict: bool = False,
|
| 343 |
+
**kwargs,
|
| 344 |
+
) -> str:
|
| 345 |
+
"""
|
| 346 |
+
Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input
|
| 347 |
+
conversations to turn them into a single tokenizable string.
|
| 348 |
+
Args:
|
| 349 |
+
conversation (`List[Dict, str, str]`):
|
| 350 |
+
The conversation to format.
|
| 351 |
+
chat_template (`Optional[str]`, *optional*):
|
| 352 |
+
The Jinja template to use for formatting the conversation. If not provided, the tokenizer's
|
| 353 |
+
chat template is used.
|
| 354 |
+
tokenize (`bool`, *optional*, defaults to `False`):
|
| 355 |
+
Whether to tokenize the output or not.
|
| 356 |
+
add_system_prompt (`bool`, *optional*, defaults to `False`):
|
| 357 |
+
Whether to add the system prompt to the output or not.
|
| 358 |
+
add_generation_prompt (`bool`, *optional*, defaults to `False`):
|
| 359 |
+
Whether to add the generation prompt to the output or not.
|
| 360 |
+
image_token (`Optional[str]`, *optional*, defaults to `<image>`):
|
| 361 |
+
The token to use for indicating images in the conversation.
|
| 362 |
+
**kwargs:
|
| 363 |
+
Additional keyword arguments
|
| 364 |
+
"""
|
| 365 |
+
|
| 366 |
+
if chat_template is None:
|
| 367 |
+
if self.chat_template is not None:
|
| 368 |
+
chat_template = self.chat_template
|
| 369 |
+
else:
|
| 370 |
+
raise ValueError(
|
| 371 |
+
"No chat template is set for this processor. Please either set the `chat_template` attribute, "
|
| 372 |
+
"or provide a chat template as an argument. See "
|
| 373 |
+
"https://huggingface.co/docs/transformers/main/en/chat_templating for more information."
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
images = None
|
| 377 |
+
if return_dict:
|
| 378 |
+
conversation = self._load_multimodal_data(conversation)
|
| 379 |
+
images = self._gather_multimodal_data(conversation)
|
| 380 |
+
|
| 381 |
+
prompt = self.tokenizer.apply_chat_template(
|
| 382 |
+
conversation,
|
| 383 |
+
chat_template=chat_template,
|
| 384 |
+
tokenize=tokenize,
|
| 385 |
+
add_system_prompt=add_system_prompt,
|
| 386 |
+
add_generation_prompt=add_generation_prompt,
|
| 387 |
+
add_think_prompt=add_think_prompt,
|
| 388 |
+
image_token=self.image_token,
|
| 389 |
+
**kwargs
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
out = {"text": prompt, "images": images}
|
| 393 |
+
if return_dict:
|
| 394 |
+
return out
|
| 395 |
+
return out["text"]
|
| 396 |
+
|
| 397 |
+
@property
|
| 398 |
+
def model_input_names(self):
|
| 399 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
| 400 |
+
image_processor_input_names = self.image_processor.model_input_names
|
| 401 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
RynnEC/rynnec/model/projector.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Alibaba DAMO Academy
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
import os
|
| 17 |
+
import re
|
| 18 |
+
|
| 19 |
+
import einops
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from timm.models.layers import LayerNorm, LayerNorm2d
|
| 24 |
+
from timm.models.regnet import RegStage
|
| 25 |
+
from transformers import TRANSFORMERS_CACHE
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def parse_snapshot_folder(repo_id, cache_dir=None, repo_type="model"):
|
| 29 |
+
revision = "main"
|
| 30 |
+
# 1. parse the downloaded cache folder
|
| 31 |
+
if cache_dir is None:
|
| 32 |
+
cache_dir = TRANSFORMERS_CACHE
|
| 33 |
+
else:
|
| 34 |
+
cache_dir = cache_dir
|
| 35 |
+
object_id = repo_id.replace("/", "--")
|
| 36 |
+
repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}")
|
| 37 |
+
# 2. resolve refs (for instance to convert main to the associated commit sha)
|
| 38 |
+
refs_dir = os.path.join(repo_cache, "refs")
|
| 39 |
+
if os.path.isdir(refs_dir):
|
| 40 |
+
revision_file = os.path.join(refs_dir, revision)
|
| 41 |
+
if os.path.isfile(revision_file):
|
| 42 |
+
with open(revision_file) as f:
|
| 43 |
+
revision = f.read()
|
| 44 |
+
# 3. acquire the snapshot folder
|
| 45 |
+
folder = os.path.join(repo_cache, "snapshots", revision)
|
| 46 |
+
|
| 47 |
+
return folder
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_mm_projector(model_path, cache_dir=None, token=None):
|
| 51 |
+
if os.path.exists(os.path.join(model_path, 'mm_projector.bin')):
|
| 52 |
+
is_local = True
|
| 53 |
+
folder = model_path
|
| 54 |
+
else:
|
| 55 |
+
is_local = False
|
| 56 |
+
folder = parse_snapshot_folder(model_path, cache_dir=cache_dir, repo_type="model")
|
| 57 |
+
if not os.path.exists(os.path.join(folder, 'mm_projector.bin')):
|
| 58 |
+
# downloading from remote repo
|
| 59 |
+
from huggingface_hub import snapshot_download
|
| 60 |
+
snapshot_download(repo_id=model_path, cache_dir=cache_dir, token=token)
|
| 61 |
+
|
| 62 |
+
mm_projector_weights = torch.load(os.path.join(folder, 'mm_projector.bin'), map_location='cpu')
|
| 63 |
+
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
|
| 64 |
+
return mm_projector_weights
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class IdentityMap(nn.Module):
|
| 68 |
+
|
| 69 |
+
def __init__(self):
|
| 70 |
+
super().__init__()
|
| 71 |
+
|
| 72 |
+
def forward(self, x, *args, **kwargs):
|
| 73 |
+
return x
|
| 74 |
+
|
| 75 |
+
@property
|
| 76 |
+
def config(self):
|
| 77 |
+
return {"mm_projector_type": 'identity'}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def build_mlp(depth, hidden_size, output_hidden_size):
|
| 81 |
+
modules = [nn.Linear(hidden_size, output_hidden_size)]
|
| 82 |
+
for _ in range(1, depth):
|
| 83 |
+
modules.append(nn.GELU())
|
| 84 |
+
modules.append(nn.Linear(output_hidden_size, output_hidden_size))
|
| 85 |
+
return nn.Sequential(*modules)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class SimSpatialConv(nn.Module):
|
| 89 |
+
|
| 90 |
+
def __init__(self, mm_hidden_size, hidden_size, downsample=(2, 2), padding=1, depth=1, mlp_depth=2):
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.encoder_hidden_size = encoder_hidden_size = mm_hidden_size
|
| 93 |
+
self.output_hidden_size = output_hidden_size = hidden_size
|
| 94 |
+
self.downsample = downsample
|
| 95 |
+
self.padding = padding
|
| 96 |
+
self.sampler = nn.Sequential(
|
| 97 |
+
nn.Conv2d(
|
| 98 |
+
in_channels=self.encoder_hidden_size,
|
| 99 |
+
out_channels=4 * self.encoder_hidden_size,
|
| 100 |
+
kernel_size=self.downsample,
|
| 101 |
+
stride=self.downsample,
|
| 102 |
+
padding=self.padding,
|
| 103 |
+
bias=True
|
| 104 |
+
),
|
| 105 |
+
nn.SiLU(),
|
| 106 |
+
)
|
| 107 |
+
self.readout = build_mlp(mlp_depth, 4 * self.encoder_hidden_size, self.output_hidden_size)
|
| 108 |
+
|
| 109 |
+
def forward(self, x):
|
| 110 |
+
hw = int(x.size(1) ** 0.5)
|
| 111 |
+
x = einops.rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw)
|
| 112 |
+
x = self.sampler(x)
|
| 113 |
+
x = einops.rearrange(x, "b d h w -> b (h w) d")
|
| 114 |
+
x = self.readout(x)
|
| 115 |
+
return x
|
| 116 |
+
|
| 117 |
+
def cal_proj_size(self, input_size):
|
| 118 |
+
if isinstance(input_size, int):
|
| 119 |
+
input_size = (input_size, input_size)
|
| 120 |
+
height = math.ceil((input_size[0] + self.padding) / self.downsample[0])
|
| 121 |
+
width = math.ceil((input_size[1] + self.padding) / self.downsample[1])
|
| 122 |
+
return height * width
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class MlpGeluProjector(nn.Module):
|
| 126 |
+
def __init__(self, mm_hidden_size, hidden_size, projector_type):
|
| 127 |
+
super().__init__()
|
| 128 |
+
|
| 129 |
+
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
|
| 130 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
| 131 |
+
|
| 132 |
+
self.readout = build_mlp(mlp_depth, mm_hidden_size, hidden_size)
|
| 133 |
+
|
| 134 |
+
def forward(self, x):
|
| 135 |
+
x = self.readout(x)
|
| 136 |
+
return x
|
| 137 |
+
|
| 138 |
+
def cal_proj_size(self, input_size):
|
| 139 |
+
if isinstance(input_size, int):
|
| 140 |
+
input_size = (input_size, input_size)
|
| 141 |
+
height = input_size[0]
|
| 142 |
+
width = input_size[1]
|
| 143 |
+
return height * width
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def build_vision_projector(config, mm_hidden_size, delay_load=False, **kwargs):
|
| 147 |
+
# rynnec projector only support image-wise operation now, i.e., prohibit the temporal aggregation
|
| 148 |
+
projector_type = getattr(config, 'mm_projector_type', 'linear')
|
| 149 |
+
hidden_size = config.hidden_size
|
| 150 |
+
|
| 151 |
+
if projector_type == "linear":
|
| 152 |
+
# NOTE: for both linear and mlp2x_gelu projector type, mean pooling is adopted to aggreate video features
|
| 153 |
+
return nn.Linear(mm_hidden_size, hidden_size)
|
| 154 |
+
elif projector_type == "simp_spatial_conv":
|
| 155 |
+
return SimSpatialConv(mm_hidden_size, hidden_size)
|
| 156 |
+
elif projector_type.startswith("mlp"):
|
| 157 |
+
return MlpGeluProjector(mm_hidden_size, hidden_size, projector_type)
|
| 158 |
+
if projector_type == 'identity':
|
| 159 |
+
return IdentityMap()
|
| 160 |
+
|
| 161 |
+
raise ValueError(f'Unknown projector type: {projector_type}')
|
RynnEC/rynnec/model/region_encoder.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
class MaskExtractor(nn.Module):
|
| 7 |
+
def __init__(self, config, mm_hidden_size, depth=2):
|
| 8 |
+
super(MaskExtractor, self).__init__()
|
| 9 |
+
self.mask_pooling = MaskPooling()
|
| 10 |
+
modules = [nn.Linear(mm_hidden_size, config.hidden_size)]
|
| 11 |
+
for _ in range(1, depth):
|
| 12 |
+
modules.append(nn.GELU())
|
| 13 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
| 14 |
+
self.feat_linear = nn.Sequential(*modules)
|
| 15 |
+
|
| 16 |
+
def forward(self, feats, masks):
|
| 17 |
+
query_feats = []
|
| 18 |
+
|
| 19 |
+
if masks is None: #infer
|
| 20 |
+
return None
|
| 21 |
+
|
| 22 |
+
num_imgs = len(masks)
|
| 23 |
+
region_token_nums = []
|
| 24 |
+
image_idx = 0
|
| 25 |
+
for idx in range(num_imgs):
|
| 26 |
+
if masks[idx]==None:
|
| 27 |
+
continue
|
| 28 |
+
for mask_idx in range(len(masks[idx])):
|
| 29 |
+
mask = masks[idx][mask_idx].unsqueeze(0).unsqueeze(0).float()
|
| 30 |
+
if len(mask[0])==0:
|
| 31 |
+
mask = torch.zeros((1, 1, 336, 336)).to(feats.device).float()
|
| 32 |
+
|
| 33 |
+
feat = feats[image_idx].unsqueeze(0)
|
| 34 |
+
image_idx+=1
|
| 35 |
+
|
| 36 |
+
# h, w = feat.shape[1:3]
|
| 37 |
+
feat = feat.permute(0,3,1,2)
|
| 38 |
+
|
| 39 |
+
raw_dtype = feat.dtype
|
| 40 |
+
feat = feat.to(mask.dtype)
|
| 41 |
+
|
| 42 |
+
mask_feat_raw = self.mask_pooling(feat, mask) # [n, 1024]
|
| 43 |
+
|
| 44 |
+
query_feats.append(mask_feat_raw)
|
| 45 |
+
if len(query_feats)==0:
|
| 46 |
+
return None
|
| 47 |
+
mask_feats = torch.cat(query_feats, dim=0)
|
| 48 |
+
mask_feats = mask_feats.to(feats[0].dtype)
|
| 49 |
+
mask_feats_linear = self.feat_linear(mask_feats)
|
| 50 |
+
return mask_feats_linear
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class MaskPooling(nn.Module):
|
| 54 |
+
def __init__(self):
|
| 55 |
+
super().__init__()
|
| 56 |
+
|
| 57 |
+
def forward(self, x, mask):
|
| 58 |
+
|
| 59 |
+
if not x.shape[-2:] == mask.shape[-2:]:
|
| 60 |
+
# reshape mask to x
|
| 61 |
+
mask = F.interpolate(mask, size=x.shape[-2:], mode='bilinear', align_corners=False)
|
| 62 |
+
|
| 63 |
+
# b, c, h ,w = x.shape
|
| 64 |
+
# b, q, h, w = mask.shape
|
| 65 |
+
mask = (mask > 0).to(mask.dtype)
|
| 66 |
+
mask = mask.permute(1,0,2,3)
|
| 67 |
+
denorm = mask.sum(dim=(-1, -2), keepdim=True) + 1e-8
|
| 68 |
+
|
| 69 |
+
mask_pooled_x = (x * mask/denorm).sum(-1).sum(-1)
|
| 70 |
+
|
| 71 |
+
return mask_pooled_x
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def build_region_encoder(config, mm_hidden_size):
|
| 75 |
+
|
| 76 |
+
return MaskExtractor(config, mm_hidden_size)
|
| 77 |
+
|
RynnEC/rynnec/model/rynnec_arch.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adopted from: https://github.com/DAMO-NLP-SG/VideoLLaMA3.
|
| 2 |
+
# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
|
| 3 |
+
# Copyright 2023 Haotian Liu
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import math
|
| 19 |
+
from abc import ABC, abstractmethod
|
| 20 |
+
from typing import List, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import einops
|
| 23 |
+
import torch
|
| 24 |
+
import torch.distributed as dist
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
import numpy as np
|
| 27 |
+
|
| 28 |
+
from ..constants import IGNORE_INDEX, MODAL_INDEX_MAP, NUM_FRAMES
|
| 29 |
+
from .encoder import build_vision_encoder
|
| 30 |
+
from .projector import build_vision_projector, load_mm_projector
|
| 31 |
+
from .region_encoder import build_region_encoder
|
| 32 |
+
from ..mm_utils import reshape_images_to_raw_grid
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def spatial_downsampling(features, grid_thws, stride=2):
|
| 36 |
+
n, c = features.shape
|
| 37 |
+
|
| 38 |
+
flatten_grid_thws = torch.cat([grid_thw for batch_grid_thws in grid_thws for grid_thw in batch_grid_thws])
|
| 39 |
+
split_sizes = [grid_thw.prod() for grid_thw in flatten_grid_thws]
|
| 40 |
+
features = torch.split(features, split_sizes)
|
| 41 |
+
|
| 42 |
+
new_features = []
|
| 43 |
+
for feature, grid_thw in zip(features, flatten_grid_thws):
|
| 44 |
+
# NOTE: adapted for reshape in image processor
|
| 45 |
+
feature = feature.view(grid_thw[0], grid_thw[1] // stride, grid_thw[2] // stride, stride, stride, c).permute(0, 1, 3, 2, 4, 5)
|
| 46 |
+
feature = feature.reshape(grid_thw[0], grid_thw[1], grid_thw[2], c).permute(0, 3, 1, 2)
|
| 47 |
+
# NOTE: previous version model is align_corners=True
|
| 48 |
+
new_feature = torch.nn.functional.interpolate(feature, (math.ceil(grid_thw[1] / stride), math.ceil(grid_thw[2] / stride)), mode='bilinear')
|
| 49 |
+
# new_feature = nn.functional.avg_pool2d(feature, stride)
|
| 50 |
+
# new_feature = nn.functional.max_pool2d(feature, stride)
|
| 51 |
+
new_features.append(new_feature.permute(0, 2, 3, 1).view(-1, c))
|
| 52 |
+
new_features = torch.cat(new_features)
|
| 53 |
+
|
| 54 |
+
return new_features
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class RynnecMetaModel:
|
| 58 |
+
|
| 59 |
+
def __init__(self, config):
|
| 60 |
+
super(RynnecMetaModel, self).__init__(config)
|
| 61 |
+
|
| 62 |
+
if hasattr(config, "vision_encoder") or hasattr(config, "mm_vision_encoder"):
|
| 63 |
+
self.vision_encoder = build_vision_encoder(config, delay_load=False)
|
| 64 |
+
self.mm_projector = build_vision_projector(config, config.mm_hidden_size)
|
| 65 |
+
self.region_encoder = build_region_encoder(config, config.mm_hidden_size)
|
| 66 |
+
|
| 67 |
+
def get_vision_encoder(self):
|
| 68 |
+
vision_encoder = getattr(self, 'vision_encoder', None)
|
| 69 |
+
if type(vision_encoder) is list:
|
| 70 |
+
vision_encoder = vision_encoder[0]
|
| 71 |
+
return vision_encoder
|
| 72 |
+
|
| 73 |
+
def get_mm_projector(self):
|
| 74 |
+
return self.mm_projector
|
| 75 |
+
|
| 76 |
+
def initialize_vision_modules(self, model_args, fsdp=None):
|
| 77 |
+
vision_encoder = model_args.vision_encoder
|
| 78 |
+
mm_vision_select_layer = model_args.mm_vision_select_layer
|
| 79 |
+
mm_vision_select_feature = model_args.mm_vision_select_feature
|
| 80 |
+
pretrain_mm_projector = model_args.pretrain_mm_projector
|
| 81 |
+
|
| 82 |
+
self.config.mm_vision_encoder = vision_encoder
|
| 83 |
+
|
| 84 |
+
if self.get_vision_encoder() is None:
|
| 85 |
+
vision_encoder = build_vision_encoder(model_args)
|
| 86 |
+
|
| 87 |
+
if fsdp is not None and len(fsdp) > 0:
|
| 88 |
+
self.vision_encoder = [vision_encoder]
|
| 89 |
+
else:
|
| 90 |
+
self.vision_encoder = vision_encoder
|
| 91 |
+
else:
|
| 92 |
+
if fsdp is not None and len(fsdp) > 0:
|
| 93 |
+
vision_encoder = self.vision_encoder[0]
|
| 94 |
+
else:
|
| 95 |
+
vision_encoder = self.vision_encoder
|
| 96 |
+
# NOTE: only compatible with delay_load encoder
|
| 97 |
+
# vision_encoder.load_model(vision_encoder.cfg_only)
|
| 98 |
+
|
| 99 |
+
self.config.use_mm_proj = True
|
| 100 |
+
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
|
| 101 |
+
self.config.mm_hidden_size = vision_encoder.hidden_size
|
| 102 |
+
self.config.mm_vision_select_layer = mm_vision_select_layer
|
| 103 |
+
self.config.mm_vision_select_feature = mm_vision_select_feature
|
| 104 |
+
|
| 105 |
+
if getattr(self, 'mm_projector', None) is None:
|
| 106 |
+
self.mm_projector = build_vision_projector(self.config)
|
| 107 |
+
else:
|
| 108 |
+
# In case it is frozen by LoRA
|
| 109 |
+
for p in self.mm_projector.parameters():
|
| 110 |
+
p.requires_grad = True
|
| 111 |
+
|
| 112 |
+
if pretrain_mm_projector is not None:
|
| 113 |
+
if os.path.exists(pretrain_mm_projector):
|
| 114 |
+
is_local = True
|
| 115 |
+
if os.path.isdir(pretrain_mm_projector):
|
| 116 |
+
mm_projector_weights = load_mm_projector(pretrain_mm_projector)
|
| 117 |
+
else:
|
| 118 |
+
mm_projector_weights = torch.load(pretrain_mm_projector, map_location='cpu')
|
| 119 |
+
else:
|
| 120 |
+
# Support loading projector weights from remote HuggingFace model hub
|
| 121 |
+
is_local = False
|
| 122 |
+
pretrain_mm_projector = pretrain_mm_projector.replace('mm_projector.bin', '')
|
| 123 |
+
pretrain_mm_projector = pretrain_mm_projector.strip('/').strip('\\').strip()
|
| 124 |
+
mm_projector_weights = load_mm_projector(pretrain_mm_projector)
|
| 125 |
+
|
| 126 |
+
def get_w(weights, keyword):
|
| 127 |
+
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
| 128 |
+
|
| 129 |
+
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'), strict=False)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class RynnecMetaForCausalLM(ABC):
|
| 133 |
+
|
| 134 |
+
@abstractmethod
|
| 135 |
+
def get_model(self):
|
| 136 |
+
pass
|
| 137 |
+
|
| 138 |
+
def num_frames(self):
|
| 139 |
+
if hasattr(self.config, 'num_frames'):
|
| 140 |
+
return self.config.num_frames
|
| 141 |
+
else:
|
| 142 |
+
return NUM_FRAMES
|
| 143 |
+
|
| 144 |
+
def spatial_merge_size(self):
|
| 145 |
+
if hasattr(self.config, 'spatial_merge_size'):
|
| 146 |
+
return self.config.spatial_merge_size
|
| 147 |
+
else:
|
| 148 |
+
return 1
|
| 149 |
+
|
| 150 |
+
def get_vision_encoder(self):
|
| 151 |
+
return self.get_model().get_vision_encoder()
|
| 152 |
+
|
| 153 |
+
def get_mm_projector(self):
|
| 154 |
+
return self.get_model().get_mm_projector()
|
| 155 |
+
|
| 156 |
+
def encode_images(
|
| 157 |
+
self,
|
| 158 |
+
pixel_values: torch.FloatTensor,
|
| 159 |
+
grid_sizes: torch.LongTensor,
|
| 160 |
+
merge_sizes: torch.LongTensor,
|
| 161 |
+
):
|
| 162 |
+
mm_features, mm_features_raw = self.get_model().get_vision_encoder()(
|
| 163 |
+
pixel_values=pixel_values,
|
| 164 |
+
grid_sizes=grid_sizes,
|
| 165 |
+
merge_sizes=merge_sizes,
|
| 166 |
+
)
|
| 167 |
+
mm_features = self.get_model().mm_projector(mm_features)
|
| 168 |
+
return mm_features, mm_features_raw
|
| 169 |
+
|
| 170 |
+
def _get_valid_visual_tokens(
|
| 171 |
+
self,
|
| 172 |
+
mm_features: torch.FloatTensor,
|
| 173 |
+
batched_num_patches: torch.LongTensor,
|
| 174 |
+
modals: List[str],
|
| 175 |
+
):
|
| 176 |
+
valid_masks = []
|
| 177 |
+
for num_patches, modal in zip(batched_num_patches, modals):
|
| 178 |
+
valid_mask = torch.full((num_patches, ), modal != "text", dtype=torch.bool, device=mm_features.device)
|
| 179 |
+
valid_masks.append(valid_mask)
|
| 180 |
+
mm_features = mm_features[torch.cat(valid_masks)]
|
| 181 |
+
return mm_features
|
| 182 |
+
|
| 183 |
+
def prepare_inputs_labels_for_multimodal(
|
| 184 |
+
self,
|
| 185 |
+
input_ids: torch.LongTensor = None,
|
| 186 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 187 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 188 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 189 |
+
labels: Optional[torch.LongTensor] = None,
|
| 190 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 191 |
+
grid_sizes: Optional[torch.LongTensor] = None,
|
| 192 |
+
merge_sizes: Optional[torch.LongTensor] = None,
|
| 193 |
+
modals: Optional[List[str]] = None,
|
| 194 |
+
masks=None,
|
| 195 |
+
mask_ids = None
|
| 196 |
+
):
|
| 197 |
+
vision_encoder = self.get_vision_encoder()
|
| 198 |
+
# NOTE: text-only situation
|
| 199 |
+
if vision_encoder is None or pixel_values is None or input_ids.shape[1] == 1:
|
| 200 |
+
return input_ids, attention_mask, position_ids, past_key_values, None, labels
|
| 201 |
+
|
| 202 |
+
# 1. flatten text inputs
|
| 203 |
+
B, N = input_ids.shape
|
| 204 |
+
input_ids = input_ids.view(B * N)
|
| 205 |
+
if attention_mask is not None:
|
| 206 |
+
attention_mask = attention_mask.view(B * N)
|
| 207 |
+
if position_ids is not None:
|
| 208 |
+
position_ids = position_ids.view(B * N)
|
| 209 |
+
if labels is not None:
|
| 210 |
+
labels = labels.view(B * N)
|
| 211 |
+
|
| 212 |
+
# 2. embed visual tokens
|
| 213 |
+
batched_num_patches = grid_sizes.prod(dim=1).div(merge_sizes ** 2).long()
|
| 214 |
+
|
| 215 |
+
mm_features, mm_features_raw = self.encode_images(pixel_values, grid_sizes, merge_sizes)
|
| 216 |
+
mm_features = mm_features.to(input_ids.device)
|
| 217 |
+
mm_features_raw = mm_features_raw.to(input_ids.device)
|
| 218 |
+
mm_features = self._get_valid_visual_tokens(mm_features, batched_num_patches, modals)
|
| 219 |
+
|
| 220 |
+
# 3. embed text tokens
|
| 221 |
+
image_selected = (input_ids == self.config.image_token_index)
|
| 222 |
+
# input_ids[image_selected] = 0
|
| 223 |
+
inputs_embeds = self.get_model().embed_tokens(input_ids).clone()
|
| 224 |
+
|
| 225 |
+
num_vision_tokens = image_selected.sum()
|
| 226 |
+
if mm_features.size(0) > num_vision_tokens:
|
| 227 |
+
print(f"Number of mm_features ({mm_features.size(0)}) exceeds the number of image tokens ({num_vision_tokens}). Automative truncated.")
|
| 228 |
+
mm_features = mm_features[:num_vision_tokens]
|
| 229 |
+
|
| 230 |
+
# 4. replace multimodal tokens with features
|
| 231 |
+
inputs_embeds[image_selected] = inputs_embeds[image_selected] * 0.0 + mm_features
|
| 232 |
+
|
| 233 |
+
# 5. embed region tokens
|
| 234 |
+
try:
|
| 235 |
+
|
| 236 |
+
mask_selected = (input_ids == self.config.region_token_index)
|
| 237 |
+
|
| 238 |
+
if mask_selected.sum() > 0:
|
| 239 |
+
reshaped_features = reshape_images_to_raw_grid(mm_features_raw, grid_sizes)
|
| 240 |
+
mask_additional_image_features = []
|
| 241 |
+
idx = 0
|
| 242 |
+
new_masks = []
|
| 243 |
+
for bs in range(len(masks)):
|
| 244 |
+
flag=True
|
| 245 |
+
for ml in range(len(masks[bs])):
|
| 246 |
+
if mask_ids[idx]>=0:
|
| 247 |
+
mask_additional_image_features.append(reshaped_features[mask_ids[idx]])
|
| 248 |
+
else:
|
| 249 |
+
flag=False
|
| 250 |
+
idx+=1
|
| 251 |
+
if flag:
|
| 252 |
+
new_masks.append(masks[bs])
|
| 253 |
+
|
| 254 |
+
mask_feats = self.get_model().region_encoder(mask_additional_image_features, new_masks)
|
| 255 |
+
inputs_embeds[mask_selected] = inputs_embeds[mask_selected]*0.0 + mask_feats
|
| 256 |
+
|
| 257 |
+
except Exception as e:
|
| 258 |
+
print(e)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# 6. reshape back to batched format
|
| 262 |
+
C = inputs_embeds.shape[-1]
|
| 263 |
+
inputs_embeds = inputs_embeds.reshape(B, -1, C)
|
| 264 |
+
if attention_mask is not None:
|
| 265 |
+
attention_mask = attention_mask.view(B, -1)
|
| 266 |
+
if labels is not None:
|
| 267 |
+
labels = labels.view(B, -1)
|
| 268 |
+
if position_ids is not None:
|
| 269 |
+
position_ids = position_ids.view(B, -1)
|
| 270 |
+
|
| 271 |
+
return None, attention_mask, position_ids, past_key_values, inputs_embeds, labels
|
RynnEC/rynnec/model/rynnec_qwen2.py
ADDED
|
@@ -0,0 +1,638 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adopted from: https://github.com/DAMO-NLP-SG/VideoLLaMA3.
|
| 2 |
+
# Adopted from: https://github.com/haotian-liu/LLaVA.
|
| 3 |
+
# Below is the original copyright:
|
| 4 |
+
# Copyright 2023 Haotian Liu
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
from typing import List, Optional, Tuple, Union, Dict
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
from transformers import (AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoImageProcessor,
|
| 25 |
+
Qwen2Config, Qwen2ForCausalLM, Qwen2Model)
|
| 26 |
+
from transformers.generation.utils import GenerateOutput
|
| 27 |
+
# from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 28 |
+
from dataclasses import dataclass
|
| 29 |
+
from transformers.utils import ModelOutput
|
| 30 |
+
|
| 31 |
+
from .loss import cross_entropy_loss, CrossEntropyLoss, DiceLoss
|
| 32 |
+
from .processor import Videollama3BaseProcessor
|
| 33 |
+
from .rynnec_arch import RynnecMetaForCausalLM, RynnecMetaModel
|
| 34 |
+
from .videollama3_encoder import Videollama3ImageProcessor
|
| 35 |
+
from rynnec.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN
|
| 36 |
+
from .sam2_train import SAM2TrainRunner
|
| 37 |
+
from .sam2 import SAM2
|
| 38 |
+
from .utils import genetate_video_pred_embeddings, process_video_gt_masks
|
| 39 |
+
|
| 40 |
+
CHAT_TEMPLATE = """
|
| 41 |
+
{%- set identifier = 'im' %}
|
| 42 |
+
{% for message in messages %}
|
| 43 |
+
{% if message['role'] == 'stream' %}
|
| 44 |
+
{% set identifier = 'stream' %}
|
| 45 |
+
{% else %}
|
| 46 |
+
{% set identifier = 'im' %}
|
| 47 |
+
{% endif %}
|
| 48 |
+
{% if message['role'] is not none %}
|
| 49 |
+
{{- '<|' + identifier + '_start|>' + message['role'] + '\n' -}}
|
| 50 |
+
{% endif %}
|
| 51 |
+
{% if message['content'] is string %}
|
| 52 |
+
{{- message['content'] + '<|' + identifier + '_end|>\n' -}}
|
| 53 |
+
{% else %}
|
| 54 |
+
{% for content in message['content'] %}
|
| 55 |
+
{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}
|
| 56 |
+
{% if 'time' in content %}
|
| 57 |
+
{{- 'Time ' + content['time'] | round(1) | string + 's: ' -}}
|
| 58 |
+
{% endif %}
|
| 59 |
+
{{- image_token + '\n' -}}
|
| 60 |
+
{% elif content['type'] == 'video' or 'video' in content or 'video_url' in content %}
|
| 61 |
+
{% for i in range(content['num_frames']) %}
|
| 62 |
+
{% if 'timestamps' in content %}
|
| 63 |
+
{{- 'Time ' + content['timestamps'][i] | round(1) | string + 's:' -}}
|
| 64 |
+
{% endif %}
|
| 65 |
+
{% if i < content['num_frames'] - 1 %}
|
| 66 |
+
{{- image_token + ',' -}}
|
| 67 |
+
{% else %}
|
| 68 |
+
{{- image_token + '\n' -}}
|
| 69 |
+
{% endif %}
|
| 70 |
+
{% endfor %}
|
| 71 |
+
{% elif content['type'] == 'text' or 'text' in content %}
|
| 72 |
+
{{- content['text'] -}}
|
| 73 |
+
{% endif %}
|
| 74 |
+
{% endfor %}
|
| 75 |
+
{% if message['role'] is not none %}
|
| 76 |
+
{{- '<|' + identifier + '_end|>\n' -}}
|
| 77 |
+
{% endif %}
|
| 78 |
+
{% endif %}
|
| 79 |
+
{% endfor %}
|
| 80 |
+
{% if add_generation_prompt %}
|
| 81 |
+
{{- '<|im_start|>assistant\n' -}}
|
| 82 |
+
{% if add_think_prompt %}
|
| 83 |
+
{{- '<think>\n' -}}
|
| 84 |
+
{% endif %}
|
| 85 |
+
{% endif %}
|
| 86 |
+
"""
|
| 87 |
+
@dataclass
|
| 88 |
+
class CausalLMOutputWithPast(ModelOutput):
|
| 89 |
+
loss: Optional[torch.FloatTensor] = None
|
| 90 |
+
logits: torch.FloatTensor = None
|
| 91 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None
|
| 92 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 93 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 94 |
+
rope_deltas: Optional[torch.LongTensor] = None
|
| 95 |
+
ce_loss: Optional[torch.FloatTensor] = None
|
| 96 |
+
mask_bce_loss: Optional[torch.FloatTensor] = None
|
| 97 |
+
mask_dice_loss: Optional[torch.FloatTensor] = None
|
| 98 |
+
mask_loss: Optional[torch.FloatTensor] = None
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class Videollama3Qwen2Processor(Videollama3BaseProcessor):
|
| 102 |
+
|
| 103 |
+
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
| 104 |
+
chat_template = CHAT_TEMPLATE
|
| 105 |
+
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
image_processor=None,
|
| 109 |
+
tokenizer=None,
|
| 110 |
+
chat_template=None,
|
| 111 |
+
image_merge_size: int = 1,
|
| 112 |
+
video_merge_size: int = 2,
|
| 113 |
+
fps=1,
|
| 114 |
+
max_frames=180,
|
| 115 |
+
**kwargs
|
| 116 |
+
):
|
| 117 |
+
super().__init__(image_processor, tokenizer, chat_template, **kwargs)
|
| 118 |
+
self.generation_prompt = self._infer_generation_prompt()
|
| 119 |
+
self.generation_prompt_ids = self.tokenizer.encode(self.generation_prompt, return_tensors="pt")
|
| 120 |
+
self.generation_prompt_length = len(self.generation_prompt_ids[0])
|
| 121 |
+
|
| 122 |
+
def _infer_generation_prompt(self):
|
| 123 |
+
pseudo_message = [{"role": "user", "content": ""}]
|
| 124 |
+
instruction = self.apply_chat_template(pseudo_message, tokenize=False, add_generation_prompt=True)
|
| 125 |
+
conversation = self.apply_chat_template(pseudo_message, tokenize=False, add_generation_prompt=False)
|
| 126 |
+
return instruction.replace(conversation, "")
|
| 127 |
+
|
| 128 |
+
def _process_text_with_label(
|
| 129 |
+
self,
|
| 130 |
+
text: List[Dict],
|
| 131 |
+
grid_sizes: torch.Tensor = None,
|
| 132 |
+
**kwargs,
|
| 133 |
+
):
|
| 134 |
+
assert kwargs.pop("return_tensors", "pt") == "pt", "Only PyTorch tensors are supported when return_labels=True."
|
| 135 |
+
assert isinstance(text[0], dict), "When return_labels=True, text must be a list of messages."
|
| 136 |
+
|
| 137 |
+
input_ids_list = []
|
| 138 |
+
targets_list = []
|
| 139 |
+
image_idx = 0
|
| 140 |
+
|
| 141 |
+
for message_idx, message in enumerate(text):
|
| 142 |
+
# 1. set chat template and append image tokens
|
| 143 |
+
prompt = self.apply_chat_template([message], tokenize=False, add_generation_prompt=False)
|
| 144 |
+
prompt_chunks = prompt.split(DEFAULT_IMAGE_TOKEN)
|
| 145 |
+
prompt = []
|
| 146 |
+
for chunk_idx in range(len(prompt_chunks) - 1):
|
| 147 |
+
prompt.append(prompt_chunks[chunk_idx])
|
| 148 |
+
thw = grid_sizes[image_idx]
|
| 149 |
+
prompt.append(DEFAULT_IMAGE_TOKEN * thw.prod().long())
|
| 150 |
+
image_idx += 1
|
| 151 |
+
prompt.append(prompt_chunks[-1])
|
| 152 |
+
prompt = "".join(prompt)
|
| 153 |
+
|
| 154 |
+
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")[0]
|
| 155 |
+
input_ids_list.append(input_ids)
|
| 156 |
+
|
| 157 |
+
targets = torch.full_like(input_ids, IGNORE_INDEX)
|
| 158 |
+
if message["role"] == "assistant" or message["role"] is None:
|
| 159 |
+
targets[self.generation_prompt_length:-1] = input_ids[self.generation_prompt_length:-1].clone()
|
| 160 |
+
|
| 161 |
+
# NOTE: mask out image tokens
|
| 162 |
+
vision_mask = input_ids == self.image_token_id
|
| 163 |
+
targets[vision_mask] = IGNORE_INDEX
|
| 164 |
+
vision_indices = torch.nonzero(vision_mask, as_tuple=True)[0]
|
| 165 |
+
targets[vision_indices + 1] = IGNORE_INDEX
|
| 166 |
+
|
| 167 |
+
# NOTE: mask out <think> or <think>\n
|
| 168 |
+
think_mask = targets == self.think_start_token_id
|
| 169 |
+
targets[think_mask] = IGNORE_INDEX
|
| 170 |
+
think_indices = torch.nonzero(think_mask, as_tuple=True)[0]
|
| 171 |
+
newline_mask = torch.zeros_like(think_mask)
|
| 172 |
+
newline_mask[think_indices + 1] = targets[think_indices + 1] == self.newline_token_id
|
| 173 |
+
targets[newline_mask] = IGNORE_INDEX
|
| 174 |
+
|
| 175 |
+
targets_list.append(targets)
|
| 176 |
+
|
| 177 |
+
assert len(grid_sizes) == image_idx, "Number of images does not match the number of image tokens in the text."
|
| 178 |
+
|
| 179 |
+
text_inputs = {
|
| 180 |
+
"input_ids": torch.cat(input_ids_list),
|
| 181 |
+
"labels": torch.cat(targets_list),
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
return text_inputs
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class RynnecQwen2Config(Qwen2Config):
|
| 188 |
+
model_type = "rynnec_qwen2"
|
| 189 |
+
|
| 190 |
+
def __init__(self, **kwargs):
|
| 191 |
+
super().__init__(**kwargs)
|
| 192 |
+
self.model_type = "rynnec_qwen2"
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class RynnecQwen2Model(RynnecMetaModel, Qwen2Model):
|
| 196 |
+
config_class = RynnecQwen2Config
|
| 197 |
+
|
| 198 |
+
def __init__(self, config: RynnecQwen2Config):
|
| 199 |
+
super(RynnecQwen2Model, self).__init__(config)
|
| 200 |
+
|
| 201 |
+
if hasattr(config, "mm_mask_decoder"): # inference
|
| 202 |
+
self.build_mask_decoder(config)
|
| 203 |
+
else: # training
|
| 204 |
+
if 'out_dim' not in config:
|
| 205 |
+
config.out_dim = 256
|
| 206 |
+
|
| 207 |
+
def build_mask_decoder(self, config):
|
| 208 |
+
|
| 209 |
+
# Projection layer for lisa
|
| 210 |
+
in_dim = config.hidden_size
|
| 211 |
+
out_dim = config.out_dim
|
| 212 |
+
text_fc = [
|
| 213 |
+
nn.Linear(in_dim, in_dim),
|
| 214 |
+
nn.ReLU(inplace=True),
|
| 215 |
+
nn.Linear(in_dim, out_dim),
|
| 216 |
+
nn.Dropout(0.0),
|
| 217 |
+
]
|
| 218 |
+
self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)])
|
| 219 |
+
self.text_hidden_fcs.train()
|
| 220 |
+
for param in self.text_hidden_fcs.parameters():
|
| 221 |
+
param.requires_grad = True
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class RynnecQwen2ForCausalLM(Qwen2ForCausalLM, RynnecMetaForCausalLM):
|
| 225 |
+
config_class = RynnecQwen2Config
|
| 226 |
+
|
| 227 |
+
def __init__(self, config, **kwargs):
|
| 228 |
+
super(Qwen2ForCausalLM, self).__init__(config)
|
| 229 |
+
self.model = RynnecQwen2Model(config)
|
| 230 |
+
self.vocab_size = config.vocab_size
|
| 231 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 232 |
+
|
| 233 |
+
# Initialize weights and apply final processing
|
| 234 |
+
self.post_init()
|
| 235 |
+
|
| 236 |
+
if hasattr(config, "training") and config.training is True:
|
| 237 |
+
self.grounding_encoder = SAM2TrainRunner(ckpt_path=config.mask_decoder_model)
|
| 238 |
+
config.mm_mask_decoder = True
|
| 239 |
+
else:
|
| 240 |
+
self.grounding_encoder = SAM2(ckpt_path=config.mask_decoder_model)
|
| 241 |
+
|
| 242 |
+
self.loss_mask = CrossEntropyLoss(
|
| 243 |
+
use_sigmoid=True,
|
| 244 |
+
reduction='mean',
|
| 245 |
+
loss_weight=2.0
|
| 246 |
+
)
|
| 247 |
+
self.loss_dice = DiceLoss(
|
| 248 |
+
use_sigmoid=True,
|
| 249 |
+
activate=True,
|
| 250 |
+
reduction='mean',
|
| 251 |
+
naive_dice=True,
|
| 252 |
+
eps=1.0,
|
| 253 |
+
loss_weight=0.5
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
def load_sam2_weights(self, model_path):
|
| 257 |
+
sam2_model = torch.load(model_path, map_location='cpu')['model']
|
| 258 |
+
prefix = "sam2_model."
|
| 259 |
+
new_state_dict = {}
|
| 260 |
+
for param_name in sam2_model.keys():
|
| 261 |
+
new_param_name = prefix + param_name
|
| 262 |
+
new_state_dict[new_param_name] = sam2_model[param_name]
|
| 263 |
+
|
| 264 |
+
self.grounding_encoder.load_state_dict(new_state_dict, strict=False)
|
| 265 |
+
|
| 266 |
+
def get_model(self):
|
| 267 |
+
return self.model
|
| 268 |
+
# NOTE: arguments are copied from transformers==4.46.3
|
| 269 |
+
def forward(
|
| 270 |
+
self,
|
| 271 |
+
input_ids: torch.LongTensor = None,
|
| 272 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 273 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 274 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 275 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 276 |
+
labels: Optional[torch.LongTensor] = None,
|
| 277 |
+
use_cache: Optional[bool] = None,
|
| 278 |
+
output_attentions: Optional[bool] = None,
|
| 279 |
+
output_hidden_states: Optional[bool] = None,
|
| 280 |
+
return_dict: Optional[bool] = None,
|
| 281 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 282 |
+
num_logits_to_keep: int = 0,
|
| 283 |
+
# multimodal inputs
|
| 284 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 285 |
+
grid_sizes: Optional[torch.LongTensor] = None,
|
| 286 |
+
merge_sizes: Optional[torch.LongTensor] = None,
|
| 287 |
+
modals: Optional[List[str]] = None,
|
| 288 |
+
masks: Optional[List[torch.LongTensor]] = None,
|
| 289 |
+
mask_ids = None,
|
| 290 |
+
sam_images = None,
|
| 291 |
+
sam_size = None,
|
| 292 |
+
image2maskids = None,
|
| 293 |
+
**loss_kwargs,
|
| 294 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 295 |
+
torch.cuda.empty_cache()
|
| 296 |
+
if inputs_embeds is None:
|
| 297 |
+
input_ids_raw = input_ids.clone()
|
| 298 |
+
(
|
| 299 |
+
input_ids,
|
| 300 |
+
attention_mask,
|
| 301 |
+
position_ids,
|
| 302 |
+
past_key_values,
|
| 303 |
+
inputs_embeds,
|
| 304 |
+
labels,
|
| 305 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
| 306 |
+
input_ids=input_ids,
|
| 307 |
+
attention_mask=attention_mask,
|
| 308 |
+
position_ids=position_ids,
|
| 309 |
+
past_key_values=past_key_values,
|
| 310 |
+
labels=labels,
|
| 311 |
+
pixel_values=pixel_values,
|
| 312 |
+
grid_sizes=grid_sizes,
|
| 313 |
+
merge_sizes=merge_sizes,
|
| 314 |
+
modals=modals,
|
| 315 |
+
masks=masks,
|
| 316 |
+
mask_ids=mask_ids
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 320 |
+
output_hidden_states = (
|
| 321 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 322 |
+
)
|
| 323 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 324 |
+
|
| 325 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 326 |
+
outputs = self.model(
|
| 327 |
+
input_ids=input_ids,
|
| 328 |
+
attention_mask=attention_mask,
|
| 329 |
+
position_ids=position_ids,
|
| 330 |
+
past_key_values=past_key_values,
|
| 331 |
+
inputs_embeds=inputs_embeds,
|
| 332 |
+
use_cache=use_cache,
|
| 333 |
+
output_attentions=output_attentions,
|
| 334 |
+
output_hidden_states=output_hidden_states,
|
| 335 |
+
return_dict=return_dict,
|
| 336 |
+
cache_position=cache_position,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
hidden_states = outputs[0]
|
| 340 |
+
loss, logits = None, None
|
| 341 |
+
_valid = True
|
| 342 |
+
seg_valid = True
|
| 343 |
+
|
| 344 |
+
if labels is not None: #training
|
| 345 |
+
|
| 346 |
+
ce_loss = cross_entropy_loss(
|
| 347 |
+
hidden_states=hidden_states,
|
| 348 |
+
lm_head=self.lm_head,
|
| 349 |
+
position_ids=position_ids,
|
| 350 |
+
labels=labels,
|
| 351 |
+
reduction_scope=self.config.loss_reduction_scope,
|
| 352 |
+
**loss_kwargs,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
if self.config.has_mask:
|
| 356 |
+
|
| 357 |
+
hidden_states_sam = []
|
| 358 |
+
hidden_states_sam.append(self.model.text_hidden_fcs[0](hidden_states))
|
| 359 |
+
hidden_states_sam = torch.stack(hidden_states_sam, dim=-1).sum(dim=-1)
|
| 360 |
+
|
| 361 |
+
bs = input_ids_raw.shape[0]
|
| 362 |
+
gt_masks_list = []
|
| 363 |
+
pred_masks_list = []
|
| 364 |
+
mask_bce_loss = 0
|
| 365 |
+
mask_dice_loss = 0
|
| 366 |
+
num_masks = 0
|
| 367 |
+
for i in range(bs):
|
| 368 |
+
pred_masks = []
|
| 369 |
+
pred_embeddings = []
|
| 370 |
+
input_id = input_ids_raw[i]
|
| 371 |
+
seg_token_mask = input_id[1:]==self.config.seg_token_index
|
| 372 |
+
seg_token_mask = torch.cat(
|
| 373 |
+
[
|
| 374 |
+
seg_token_mask,
|
| 375 |
+
torch.zeros((1)).bool().cuda(),
|
| 376 |
+
],
|
| 377 |
+
dim=0,
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
pred_embedding = hidden_states_sam[i][seg_token_mask]
|
| 381 |
+
if len(pred_embedding)>0:
|
| 382 |
+
pred_embeddings.append(pred_embedding)
|
| 383 |
+
else:
|
| 384 |
+
pred_embeddings.append(hidden_states_sam[i, :1])
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
gt_masks_video = [] # FIXME: Only support one segmentation now
|
| 388 |
+
gt_mask = masks[i]
|
| 389 |
+
mask_valid = True
|
| 390 |
+
|
| 391 |
+
if len(image2maskids[i])==0:
|
| 392 |
+
sam_images[i] = sam_images[i][:1]
|
| 393 |
+
gt_masks_video.append(torch.zeros((len(sam_images[i]), 224, 224)).to(sam_images[0].device))
|
| 394 |
+
mask_valid = False
|
| 395 |
+
|
| 396 |
+
else:
|
| 397 |
+
for mids in image2maskids[i]:
|
| 398 |
+
for mid in mids:
|
| 399 |
+
if mid is None:
|
| 400 |
+
gt_masks_video.append(torch.zeros((224, 224)).unsqueeze(0).to(gt_mask[0].device))
|
| 401 |
+
else:
|
| 402 |
+
gt_masks_video.append(gt_mask[mid].unsqueeze(0))
|
| 403 |
+
frames_per_batch = [len(sam_images[i])]
|
| 404 |
+
try:
|
| 405 |
+
pred_embeddings_list_video = genetate_video_pred_embeddings(pred_embeddings, frames_per_batch)
|
| 406 |
+
|
| 407 |
+
# pred_embeddings_list_video, gt_masks_video = check_obj_number(pred_embeddings_list_video, gt_masks_video)
|
| 408 |
+
|
| 409 |
+
g_pixel_values = sam_images[i]
|
| 410 |
+
num_objs = len(pred_embeddings_list_video[0])
|
| 411 |
+
|
| 412 |
+
# with torch.no_grad():
|
| 413 |
+
|
| 414 |
+
sam_states = self.grounding_encoder.get_sam2_embeddings(g_pixel_values, expand_size=num_objs)
|
| 415 |
+
language_embeddings = torch.cat(pred_embeddings_list_video, dim=0)[:, None]#.contiguous()
|
| 416 |
+
|
| 417 |
+
num_frames = len(pred_embeddings_list_video)
|
| 418 |
+
gt_masks_video = process_video_gt_masks(gt_masks_video, num_frames, num_objs)
|
| 419 |
+
pred_masks = self.grounding_encoder.inject_language_embd(sam_states, language_embeddings, nf_nobj=(num_frames, num_objs))
|
| 420 |
+
|
| 421 |
+
gt_masks = [F.interpolate(gt_mask.unsqueeze(0), size=pred_masks[0].shape[-2:], mode='nearest').squeeze(0) for gt_mask in gt_masks_video]
|
| 422 |
+
gt_masks = torch.cat(gt_masks, dim=0)
|
| 423 |
+
pred_masks = pred_masks.flatten(0, 1)
|
| 424 |
+
|
| 425 |
+
if not mask_valid:
|
| 426 |
+
pred_masks = pred_masks*0.0
|
| 427 |
+
|
| 428 |
+
if len(pred_masks) != len(gt_masks):
|
| 429 |
+
# drop this data
|
| 430 |
+
print(f"Pred mask shape {pred_masks.shape} is not equal to gt_mask shape {gt_masks.shape} !!!")
|
| 431 |
+
min_num = min(len(pred_masks), len(gt_masks))
|
| 432 |
+
pred_masks = pred_masks[:min_num]
|
| 433 |
+
gt_masks = gt_masks[:min_num]
|
| 434 |
+
seg_valid = False
|
| 435 |
+
|
| 436 |
+
if not seg_valid or not mask_valid:
|
| 437 |
+
_scale = 0.0
|
| 438 |
+
else:
|
| 439 |
+
_scale = 1.0
|
| 440 |
+
|
| 441 |
+
mask_bce_loss_ = self.loss_mask(pred_masks, gt_masks) * len(pred_masks) * _scale
|
| 442 |
+
mask_dice_loss_ = self.loss_dice(pred_masks, gt_masks) * len(gt_masks) * _scale
|
| 443 |
+
mask_bce_loss += mask_bce_loss_
|
| 444 |
+
mask_dice_loss += mask_dice_loss_
|
| 445 |
+
num_masks += len(pred_masks)
|
| 446 |
+
except Exception as exp:
|
| 447 |
+
print(exp)
|
| 448 |
+
_valid = False
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
if num_masks>0:
|
| 452 |
+
mask_bce_loss = mask_bce_loss / num_masks
|
| 453 |
+
mask_dice_loss = mask_dice_loss / num_masks
|
| 454 |
+
|
| 455 |
+
mask_bce_loss = self.config.bce_loss_weight * mask_bce_loss
|
| 456 |
+
mask_dice_loss = self.config.dice_loss_weight * mask_dice_loss
|
| 457 |
+
if _valid==False:
|
| 458 |
+
mask_bce_loss = mask_bce_loss * 0.0
|
| 459 |
+
mask_dice_loss = mask_dice_loss* 0.0
|
| 460 |
+
|
| 461 |
+
mask_loss = mask_bce_loss + mask_dice_loss
|
| 462 |
+
loss = mask_loss + ce_loss
|
| 463 |
+
else:
|
| 464 |
+
loss = ce_loss
|
| 465 |
+
|
| 466 |
+
else:
|
| 467 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 468 |
+
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
| 469 |
+
|
| 470 |
+
if not return_dict:
|
| 471 |
+
output = (logits,) + outputs[1:]
|
| 472 |
+
return (loss,) + output if loss is not None else output
|
| 473 |
+
|
| 474 |
+
if loss is not None:
|
| 475 |
+
if self.config.has_mask:
|
| 476 |
+
return CausalLMOutputWithPast(
|
| 477 |
+
loss=loss,
|
| 478 |
+
ce_loss=ce_loss.detach(),
|
| 479 |
+
mask_bce_loss=mask_bce_loss.detach(),
|
| 480 |
+
mask_dice_loss=mask_dice_loss.detach(),
|
| 481 |
+
mask_loss=mask_loss.detach(),
|
| 482 |
+
logits=logits,
|
| 483 |
+
past_key_values=outputs.past_key_values,
|
| 484 |
+
hidden_states=outputs.hidden_states,
|
| 485 |
+
attentions=outputs.attentions,
|
| 486 |
+
)
|
| 487 |
+
else:
|
| 488 |
+
return CausalLMOutputWithPast(
|
| 489 |
+
loss=loss,
|
| 490 |
+
logits=logits,
|
| 491 |
+
past_key_values=outputs.past_key_values,
|
| 492 |
+
hidden_states=outputs.hidden_states,
|
| 493 |
+
attentions=outputs.attentions,
|
| 494 |
+
)
|
| 495 |
+
else: #infer
|
| 496 |
+
return CausalLMOutputWithPast(
|
| 497 |
+
loss=loss,
|
| 498 |
+
logits=logits,
|
| 499 |
+
past_key_values=outputs.past_key_values,
|
| 500 |
+
hidden_states=outputs.hidden_states,
|
| 501 |
+
attentions=outputs.attentions,
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
@torch.no_grad()
|
| 505 |
+
def inference(
|
| 506 |
+
self,
|
| 507 |
+
# multimodal inputs
|
| 508 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 509 |
+
grid_sizes: Optional[torch.LongTensor] = None,
|
| 510 |
+
merge_sizes: Optional[torch.LongTensor] = None,
|
| 511 |
+
modals: Optional[List[str]] = None,
|
| 512 |
+
masks: Optional[List[torch.LongTensor]] = None,
|
| 513 |
+
mask_ids = None,
|
| 514 |
+
sam_images = None,
|
| 515 |
+
sam_size = None,
|
| 516 |
+
image2maskids = None,
|
| 517 |
+
seg_start_idx = 0,
|
| 518 |
+
**kwargs,
|
| 519 |
+
):
|
| 520 |
+
outputs = self.generate(
|
| 521 |
+
pixel_values=pixel_values,
|
| 522 |
+
grid_sizes=grid_sizes,
|
| 523 |
+
merge_sizes=merge_sizes,
|
| 524 |
+
modals=modals,
|
| 525 |
+
masks=masks,
|
| 526 |
+
mask_ids=mask_ids,
|
| 527 |
+
output_hidden_states=True,
|
| 528 |
+
return_dict_in_generate=True,
|
| 529 |
+
**kwargs
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
input_ids = kwargs.pop('input_ids')
|
| 533 |
+
last_hidden_state = []
|
| 534 |
+
for hs in outputs.hidden_states: # round
|
| 535 |
+
last_hidden_state.append(hs[-1])
|
| 536 |
+
last_hidden_state = torch.cat(last_hidden_state, dim=1)
|
| 537 |
+
|
| 538 |
+
output_ids = outputs.sequences
|
| 539 |
+
|
| 540 |
+
concat_ids = torch.cat((input_ids, output_ids), dim=1)
|
| 541 |
+
seg_token_mask = concat_ids[:, 1:] == self.config.seg_token_index
|
| 542 |
+
|
| 543 |
+
last_hidden_state_sam = self.model.text_hidden_fcs[0](last_hidden_state)
|
| 544 |
+
|
| 545 |
+
pred_embeddings = last_hidden_state_sam[seg_token_mask]
|
| 546 |
+
seg_token_counts = seg_token_mask.int().sum()
|
| 547 |
+
|
| 548 |
+
if seg_token_counts>0:
|
| 549 |
+
|
| 550 |
+
g_pixel_values = torch.cat(sam_images, dim=0).contiguous()
|
| 551 |
+
num_objs = 1 #FIXME: Only support one segmentation now
|
| 552 |
+
if seg_start_idx>0:
|
| 553 |
+
# before start idx
|
| 554 |
+
g_pixel_values_beg = g_pixel_values[:seg_start_idx+1].flip(0)
|
| 555 |
+
num_frames = len(g_pixel_values_beg)
|
| 556 |
+
sam_states_beg = self.grounding_encoder.get_sam2_embeddings(g_pixel_values_beg)
|
| 557 |
+
pred_masks_beg = self.grounding_encoder.language_embd_inference(sam_states_beg, [pred_embeddings]*num_frames)
|
| 558 |
+
else:
|
| 559 |
+
pred_masks_beg = torch.zeros((1, 1, 1024, 1024)).to(pixel_values.device)
|
| 560 |
+
|
| 561 |
+
if seg_start_idx<=len(g_pixel_values)-1:
|
| 562 |
+
g_pixel_values_end = g_pixel_values[seg_start_idx:]
|
| 563 |
+
num_frames = len(g_pixel_values_end)
|
| 564 |
+
sam_states_end = self.grounding_encoder.get_sam2_embeddings(g_pixel_values_end)
|
| 565 |
+
pred_masks_end = self.grounding_encoder.language_embd_inference(sam_states_end, [pred_embeddings]*num_frames)
|
| 566 |
+
else:
|
| 567 |
+
pred_masks_end = torch.zeros((0, 1, 1024, 1024)).to(pixel_values.device)
|
| 568 |
+
|
| 569 |
+
pred_masks = torch.cat([pred_masks_beg[1:].flip(0), pred_masks_end], dim=0)
|
| 570 |
+
|
| 571 |
+
return output_ids, pred_masks
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
@torch.no_grad()
|
| 575 |
+
def generate(
|
| 576 |
+
self,
|
| 577 |
+
# multimodal inputs
|
| 578 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 579 |
+
grid_sizes: Optional[torch.LongTensor] = None,
|
| 580 |
+
merge_sizes: Optional[torch.LongTensor] = None,
|
| 581 |
+
modals: Optional[List[str]] = None,
|
| 582 |
+
masks: Optional[List[torch.LongTensor]] = None,
|
| 583 |
+
mask_ids = None,
|
| 584 |
+
**kwargs,
|
| 585 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
| 586 |
+
input_ids = kwargs.pop("input_ids", None)
|
| 587 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
| 588 |
+
position_ids = kwargs.pop("position_ids", None)
|
| 589 |
+
past_key_values = kwargs.pop("past_key_values", None)
|
| 590 |
+
|
| 591 |
+
if "inputs_embeds" in kwargs:
|
| 592 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
| 593 |
+
|
| 594 |
+
if pixel_values is not None:
|
| 595 |
+
(
|
| 596 |
+
input_ids,
|
| 597 |
+
attention_mask,
|
| 598 |
+
position_ids,
|
| 599 |
+
past_key_values,
|
| 600 |
+
inputs_embeds,
|
| 601 |
+
labels,
|
| 602 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
| 603 |
+
input_ids=input_ids,
|
| 604 |
+
attention_mask=attention_mask,
|
| 605 |
+
position_ids=position_ids,
|
| 606 |
+
past_key_values=past_key_values,
|
| 607 |
+
labels=None,
|
| 608 |
+
pixel_values=pixel_values,
|
| 609 |
+
grid_sizes=grid_sizes,
|
| 610 |
+
merge_sizes=merge_sizes,
|
| 611 |
+
modals=modals,
|
| 612 |
+
masks=masks,
|
| 613 |
+
mask_ids=mask_ids
|
| 614 |
+
)
|
| 615 |
+
else:
|
| 616 |
+
inputs_embeds = self.get_model().embed_tokens(input_ids)
|
| 617 |
+
|
| 618 |
+
return super().generate(
|
| 619 |
+
position_ids=position_ids,
|
| 620 |
+
attention_mask=attention_mask,
|
| 621 |
+
inputs_embeds=inputs_embeds,
|
| 622 |
+
**kwargs
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
| 626 |
+
images = kwargs.pop("images", None)
|
| 627 |
+
_inputs = super().prepare_inputs_for_generation(
|
| 628 |
+
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
| 629 |
+
)
|
| 630 |
+
if images is not None:
|
| 631 |
+
_inputs['images'] = images
|
| 632 |
+
return _inputs
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
AutoConfig.register("rynnec_qwen2", RynnecQwen2Config)
|
| 636 |
+
AutoModelForCausalLM.register(RynnecQwen2Config, RynnecQwen2ForCausalLM)
|
| 637 |
+
AutoProcessor.register(RynnecQwen2Config, Videollama3Qwen2Processor)
|
| 638 |
+
AutoImageProcessor.register(RynnecQwen2Config, Videollama3ImageProcessor)
|
RynnEC/rynnec/model/sam2.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adopted from https://github.com/magic-research/Sa2VA/blob/main/projects/llava_sam2/models/sam2.py.
|
| 2 |
+
# Below is the original copyright:
|
| 3 |
+
# coding=utf-8
|
| 4 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
import os.path
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
from hydra import compose
|
| 23 |
+
from hydra.utils import instantiate
|
| 24 |
+
from omegaconf import OmegaConf
|
| 25 |
+
|
| 26 |
+
from .utils import load_checkpoint_with_prefix, load_state_dict_to_model
|
| 27 |
+
|
| 28 |
+
class SAM2(nn.Module):
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
cfg_path: str = "sam2_hiera_l.yaml",
|
| 32 |
+
ckpt_path: str = "sam2_hiera_large.pt",
|
| 33 |
+
hydra_overrides_extra=None,
|
| 34 |
+
apply_postprocessing=True,
|
| 35 |
+
):
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
import third_parts.sam2 # noqa: F401
|
| 39 |
+
|
| 40 |
+
if hydra_overrides_extra is None:
|
| 41 |
+
hydra_overrides_extra = []
|
| 42 |
+
hydra_overrides = [
|
| 43 |
+
## Extension: LLM prompt
|
| 44 |
+
"++model._target_=rynnec.model.predictor.SAM2VideoPredictor",
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
if apply_postprocessing:
|
| 48 |
+
hydra_overrides_extra = hydra_overrides_extra.copy()
|
| 49 |
+
hydra_overrides_extra += [
|
| 50 |
+
# dynamically fall back to multi-mask if the single mask is not stable
|
| 51 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
|
| 52 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
|
| 53 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
|
| 54 |
+
# the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
|
| 55 |
+
# "++model.binarize_mask_from_pts_for_mem_enc=true",
|
| 56 |
+
# fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
|
| 57 |
+
# "++model.fill_hole_area=8",
|
| 58 |
+
]
|
| 59 |
+
hydra_overrides.extend(hydra_overrides_extra)
|
| 60 |
+
|
| 61 |
+
# Read config and init model
|
| 62 |
+
cfg = compose(config_name=cfg_path, overrides=hydra_overrides)
|
| 63 |
+
OmegaConf.resolve(cfg)
|
| 64 |
+
sam2_model = instantiate(cfg.model, _recursive_=True)
|
| 65 |
+
state_dict = load_checkpoint_with_prefix(ckpt_path)
|
| 66 |
+
load_state_dict_to_model(sam2_model, state_dict)
|
| 67 |
+
|
| 68 |
+
self.sam2_model = sam2_model
|
| 69 |
+
|
| 70 |
+
self.hidden_dim = self.sam2_model.hidden_dim
|
| 71 |
+
|
| 72 |
+
self.img_mean = (0.485, 0.456, 0.406)
|
| 73 |
+
self.img_std = (0.229, 0.224, 0.225)
|
| 74 |
+
|
| 75 |
+
def inject_language_embd(self, inference_state, language_embd):
|
| 76 |
+
num_frame = len(language_embd)
|
| 77 |
+
num_obj = len(language_embd[0])
|
| 78 |
+
mask_out = []
|
| 79 |
+
for frame_idx in range(num_frame):
|
| 80 |
+
frame_mask_out = []
|
| 81 |
+
for obj_idx in range(num_obj):
|
| 82 |
+
_language_embd = language_embd[frame_idx][obj_idx][None][None]
|
| 83 |
+
_, _, out_mask_logits = self.sam2_model.add_language_embd(inference_state, frame_idx, obj_idx + 100, _language_embd)
|
| 84 |
+
frame_mask_out.append(out_mask_logits)
|
| 85 |
+
frame_mask_out = torch.cat(frame_mask_out, dim=1)
|
| 86 |
+
mask_out.append(frame_mask_out)
|
| 87 |
+
mask_out = torch.cat(mask_out, dim=0)
|
| 88 |
+
return mask_out
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def language_embd_inference(self, inference_state, language_embd):
|
| 92 |
+
num_frame = len(language_embd)
|
| 93 |
+
num_obj = len(language_embd[0])
|
| 94 |
+
mask_out = []
|
| 95 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 96 |
+
for frame_idx in range(num_frame):
|
| 97 |
+
frame_mask_out = []
|
| 98 |
+
|
| 99 |
+
for obj_idx in range(num_obj):
|
| 100 |
+
_language_embd = language_embd[frame_idx][obj_idx][None][None]
|
| 101 |
+
_, _, out_mask_logits = self.sam2_model.add_language_embd(
|
| 102 |
+
inference_state,
|
| 103 |
+
frame_idx,
|
| 104 |
+
obj_idx + 100,
|
| 105 |
+
_language_embd,
|
| 106 |
+
inference=True,
|
| 107 |
+
)
|
| 108 |
+
frame_mask_out.append(out_mask_logits)
|
| 109 |
+
frame_mask_out = torch.cat(frame_mask_out, dim=1)
|
| 110 |
+
mask_out.append(frame_mask_out)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
mask_out = []
|
| 114 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in self.sam2_model.propagate_in_video(inference_state):
|
| 115 |
+
mask_out.append(out_mask_logits)
|
| 116 |
+
mask_out = torch.cat(mask_out, dim=0)
|
| 117 |
+
return mask_out
|
| 118 |
+
|
| 119 |
+
def get_sam2_embeddings(self, images):
|
| 120 |
+
return self.sam2_model.init_state(images)
|
| 121 |
+
|
| 122 |
+
def forward(self, batch):
|
| 123 |
+
raise NotImplementedError
|
| 124 |
+
|
| 125 |
+
def preprocess_image(self, image: torch.Tensor, dtype=torch.float32) -> torch.Tensor:
|
| 126 |
+
image = image / 255.
|
| 127 |
+
|
| 128 |
+
img_mean = torch.tensor(self.img_mean, dtype=dtype, device=image.device)[:, None, None]
|
| 129 |
+
img_std = torch.tensor(self.img_std, dtype=dtype, device=image.device)[:, None, None]
|
| 130 |
+
image -= img_mean
|
| 131 |
+
image /= img_std
|
| 132 |
+
|
| 133 |
+
return image
|
RynnEC/rynnec/model/sam2_train.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adopted from https://github.com/magic-research/Sa2VA/blob/main/projects/llava_sam2/models/sam2_train.py.
|
| 2 |
+
# Below is the original copyright:
|
| 3 |
+
# coding=utf-8
|
| 4 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
import os.path
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
|
| 23 |
+
from hydra import compose
|
| 24 |
+
from hydra.utils import instantiate
|
| 25 |
+
from omegaconf import OmegaConf
|
| 26 |
+
|
| 27 |
+
from .utils import load_checkpoint_with_prefix, load_state_dict_to_model
|
| 28 |
+
|
| 29 |
+
BASE_DIR = 'pretrained/'
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class SAM2TrainRunner(nn.Module):
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
cfg_path: str = "sam2_hiera_l.yaml",
|
| 36 |
+
ckpt_path: str = "sam2_hiera_large.pt",
|
| 37 |
+
hydra_overrides_extra=None,
|
| 38 |
+
apply_postprocessing=True,
|
| 39 |
+
):
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
import third_parts.sam2 # noqa: F401
|
| 43 |
+
|
| 44 |
+
if hydra_overrides_extra is None:
|
| 45 |
+
hydra_overrides_extra = []
|
| 46 |
+
hydra_overrides = [
|
| 47 |
+
## Extension: LLM prompt
|
| 48 |
+
"++model._target_=rynnec.model.extension.SAM2Base",
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
if apply_postprocessing:
|
| 52 |
+
hydra_overrides_extra = hydra_overrides_extra.copy()
|
| 53 |
+
|
| 54 |
+
hydra_overrides.extend(hydra_overrides_extra)
|
| 55 |
+
|
| 56 |
+
# Read config and init model
|
| 57 |
+
cfg = compose(config_name=cfg_path, overrides=hydra_overrides)
|
| 58 |
+
OmegaConf.resolve(cfg)
|
| 59 |
+
sam2_model = instantiate(cfg.model, _recursive_=True)
|
| 60 |
+
state_dict = load_checkpoint_with_prefix(ckpt_path)
|
| 61 |
+
load_state_dict_to_model(sam2_model, state_dict)
|
| 62 |
+
|
| 63 |
+
self.sam2_model = sam2_model
|
| 64 |
+
|
| 65 |
+
self.hidden_dim = self.sam2_model.hidden_dim
|
| 66 |
+
self.img_mean = (0.485, 0.456, 0.406)
|
| 67 |
+
self.img_std = (0.229, 0.224, 0.225)
|
| 68 |
+
|
| 69 |
+
def preprocess_image(self, image: torch.Tensor) -> torch.Tensor:
|
| 70 |
+
image = image / 255.
|
| 71 |
+
img_mean = torch.tensor(self.img_mean, dtype=image.dtype, device=image.device)[:, None, None]
|
| 72 |
+
img_std = torch.tensor(self.img_std, dtype=image.dtype, device=image.device)[:, None, None]
|
| 73 |
+
image -= img_mean
|
| 74 |
+
image /= img_std
|
| 75 |
+
return image
|
| 76 |
+
|
| 77 |
+
def inject_language_embd(self, sam_states, language_embd, nf_nobj=None):
|
| 78 |
+
high_res_features = [
|
| 79 |
+
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
|
| 80 |
+
for x, s in zip(sam_states['current_vision_feats'][:-1], sam_states['feat_sizes'][:-1])
|
| 81 |
+
]
|
| 82 |
+
|
| 83 |
+
B = sam_states['current_vision_feats'][-1].size(1) # batch size on this frame
|
| 84 |
+
C = self.hidden_dim
|
| 85 |
+
H, W = sam_states['feat_sizes'][-1]
|
| 86 |
+
|
| 87 |
+
if self.sam2_model.directly_add_no_mem_embed:
|
| 88 |
+
# directly add no-mem embedding (instead of using the transformer encoder)
|
| 89 |
+
pix_feat_with_mem = sam_states['current_vision_feats'][-1] + self.sam2_model.no_mem_embed
|
| 90 |
+
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
|
| 91 |
+
else:
|
| 92 |
+
raise NotImplementedError("directly add no memory embedding is not implemented")
|
| 93 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 94 |
+
_, _, _, low_res_masks, high_res_masks, obj_ptr, _, = self.sam2_model._forward_sam_heads(
|
| 95 |
+
backbone_features=pix_feat_with_mem,
|
| 96 |
+
point_inputs=None,
|
| 97 |
+
mask_inputs=None,
|
| 98 |
+
high_res_features=high_res_features,
|
| 99 |
+
multimask_output=self.sam2_model._use_multimask(is_init_cond_frame=True, point_inputs=None),
|
| 100 |
+
# Inject language Embed if possible
|
| 101 |
+
language_embd=language_embd,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
if nf_nobj is not None:
|
| 105 |
+
pred_masks = low_res_masks.squeeze(1)
|
| 106 |
+
pred_masks = pred_masks.unflatten(0, nf_nobj)
|
| 107 |
+
else:
|
| 108 |
+
pred_masks = low_res_masks
|
| 109 |
+
return pred_masks
|
| 110 |
+
|
| 111 |
+
def get_sam2_embeddings(self, images, expand_size=1):
|
| 112 |
+
# Step 1: inference the backbone with the images
|
| 113 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 114 |
+
feats = self.sam2_model.forward_image(images)
|
| 115 |
+
|
| 116 |
+
if expand_size > 1:
|
| 117 |
+
# feats['vision_features'] = feats['vision_features'][:, None].expand(-1, expand_size, -1, -1, -1).flatten(0, 1)
|
| 118 |
+
for i, feat in enumerate(feats["backbone_fpn"]):
|
| 119 |
+
feats["backbone_fpn"][i] = feat[:, None].expand(-1, expand_size, -1, -1, -1).flatten(0, 1)
|
| 120 |
+
for i, pos in enumerate(feats["vision_pos_enc"]):
|
| 121 |
+
pos = pos[:, None].expand(-1, expand_size, -1, -1, -1).flatten(0, 1)
|
| 122 |
+
feats["vision_pos_enc"][i] = pos
|
| 123 |
+
|
| 124 |
+
# Step 2: Process the features to output
|
| 125 |
+
_, current_vision_feats, current_vision_pos_embeds, feat_sizes = self.sam2_model._prepare_backbone_features(feats)
|
| 126 |
+
|
| 127 |
+
return {
|
| 128 |
+
"current_vision_feats": current_vision_feats,
|
| 129 |
+
"current_vision_pos_embeds": current_vision_pos_embeds,
|
| 130 |
+
"feat_sizes": feat_sizes,
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
def forward(self, batch):
|
| 134 |
+
raise NotImplementedError
|
RynnEC/rynnec/model/utils.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
import logging
|
| 6 |
+
from huggingface_hub import hf_hub_download
|
| 7 |
+
import functools
|
| 8 |
+
from typing import Callable, Optional
|
| 9 |
+
|
| 10 |
+
def process_video_gt_masks(gt_masks, num_frames, num_objs):
|
| 11 |
+
gt_masks_processed = []
|
| 12 |
+
for i in range(num_frames):
|
| 13 |
+
for j in range(num_objs):
|
| 14 |
+
gt_masks_processed.append(gt_masks[j*num_frames+i])
|
| 15 |
+
return gt_masks_processed
|
| 16 |
+
|
| 17 |
+
def load_checkpoint_with_prefix(filename, prefix=None, map_location='cpu', logger='current'):
|
| 18 |
+
HF_HUB_PREFIX = 'hf-hub:'
|
| 19 |
+
if filename.startswith(HF_HUB_PREFIX):
|
| 20 |
+
model_id = filename[len(HF_HUB_PREFIX):]
|
| 21 |
+
filename = hf_hub_download(model_id, 'pytorch_model.bin')
|
| 22 |
+
|
| 23 |
+
checkpoint = torch.load(filename, map_location=map_location)
|
| 24 |
+
|
| 25 |
+
if 'state_dict' in checkpoint:
|
| 26 |
+
state_dict = checkpoint['state_dict']
|
| 27 |
+
elif 'model' in checkpoint:
|
| 28 |
+
state_dict = checkpoint['model']
|
| 29 |
+
else:
|
| 30 |
+
state_dict = checkpoint
|
| 31 |
+
if not prefix:
|
| 32 |
+
return state_dict
|
| 33 |
+
if not prefix.endswith('.'):
|
| 34 |
+
prefix += '.'
|
| 35 |
+
prefix_len = len(prefix)
|
| 36 |
+
|
| 37 |
+
state_dict = {
|
| 38 |
+
k[prefix_len:]: v
|
| 39 |
+
for k, v in state_dict.items() if k.startswith(prefix)
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
assert state_dict, f'{prefix} is not in the pretrained model'
|
| 43 |
+
return state_dict
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def load_state_dict_to_model(model, state_dict, logger='current'):
|
| 47 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict)
|
| 48 |
+
if missing_keys:
|
| 49 |
+
raise RuntimeError()
|
| 50 |
+
if unexpected_keys:
|
| 51 |
+
raise RuntimeError()
|
| 52 |
+
|
| 53 |
+
def genetate_video_pred_embeddings(pred_embeddings_list, frames_per_batch):
|
| 54 |
+
assert len(pred_embeddings_list) == len(frames_per_batch), \
|
| 55 |
+
f"Lengths do not match: len(pred_embeddings_list)={len(pred_embeddings_list)}, len(frames_per_batch)={len(frames_per_batch)}"
|
| 56 |
+
|
| 57 |
+
pred_embeddings_list_video = []
|
| 58 |
+
for pred_embedding_batch, frame_nums in zip(pred_embeddings_list, frames_per_batch):
|
| 59 |
+
pred_embeddings_list_video += [pred_embedding_batch] * frame_nums
|
| 60 |
+
return pred_embeddings_list_video
|
| 61 |
+
|
RynnEC/rynnec/model/videollama3_encoder/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .configuration_videollama3_encoder import Videollama3VisionEncoderConfig
|
| 2 |
+
from .image_processing_videollama3 import Videollama3ImageProcessor
|
| 3 |
+
from .modeling_videollama3_encoder import Videollama3VisionEncoderModel
|
RynnEC/rynnec/model/videollama3_encoder/configuration_videollama3_encoder.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/siglip/configuration_siglip.py.
|
| 2 |
+
# Below is the original copyright:
|
| 3 |
+
# coding=utf-8
|
| 4 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
"""VideoLLaMA3 vision encoder model configuration."""
|
| 18 |
+
import os
|
| 19 |
+
from typing import Union
|
| 20 |
+
|
| 21 |
+
from transformers import PretrainedConfig
|
| 22 |
+
from transformers.utils import logging
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
class Videollama3VisionEncoderConfig(PretrainedConfig):
|
| 28 |
+
|
| 29 |
+
model_type = "videollama3_vision_encoder"
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
hidden_size=768,
|
| 34 |
+
intermediate_size=3072,
|
| 35 |
+
num_hidden_layers=12,
|
| 36 |
+
num_attention_heads=12,
|
| 37 |
+
num_channels=3,
|
| 38 |
+
patch_size=16,
|
| 39 |
+
hidden_act="gelu_pytorch_tanh",
|
| 40 |
+
layer_norm_eps=1e-6,
|
| 41 |
+
attention_dropout=0.0,
|
| 42 |
+
**kwargs,
|
| 43 |
+
):
|
| 44 |
+
super().__init__(**kwargs)
|
| 45 |
+
|
| 46 |
+
self.hidden_size = hidden_size
|
| 47 |
+
self.intermediate_size = intermediate_size
|
| 48 |
+
self.num_hidden_layers = num_hidden_layers
|
| 49 |
+
self.num_attention_heads = num_attention_heads
|
| 50 |
+
self.num_channels = num_channels
|
| 51 |
+
self.patch_size = patch_size
|
| 52 |
+
self.attention_dropout = attention_dropout
|
| 53 |
+
self.layer_norm_eps = layer_norm_eps
|
| 54 |
+
self.hidden_act = hidden_act
|
| 55 |
+
|
| 56 |
+
# @classmethod
|
| 57 |
+
# def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
| 58 |
+
# cls._set_token_in_kwargs(kwargs)
|
| 59 |
+
|
| 60 |
+
# config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 61 |
+
|
| 62 |
+
# p
|
| 63 |
+
# config_dict = config_dict["vision_config"]
|
| 64 |
+
|
| 65 |
+
# if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
| 66 |
+
# logger.warning(
|
| 67 |
+
# f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 68 |
+
# f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
| 69 |
+
# )
|
| 70 |
+
|
| 71 |
+
# return cls.from_dict(config_dict, **kwargs)
|
RynnEC/rynnec/model/videollama3_encoder/image_processing_videollama3.py
ADDED
|
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py.
|
| 2 |
+
# Below is the original copyright:
|
| 3 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 6 |
+
# and OPT implementations in this library. It has been modified from its
|
| 7 |
+
# original forms to accommodate minor architectural differences compared
|
| 8 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
"""Image processor class for VideoLLaMA3."""
|
| 22 |
+
|
| 23 |
+
import math
|
| 24 |
+
from typing import Dict, List, Optional, Union
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
| 30 |
+
from transformers.image_utils import ImageInput
|
| 31 |
+
from transformers.image_transforms import (
|
| 32 |
+
convert_to_rgb,
|
| 33 |
+
resize,
|
| 34 |
+
to_channel_dimension_format,
|
| 35 |
+
)
|
| 36 |
+
from transformers.image_utils import (
|
| 37 |
+
OPENAI_CLIP_MEAN,
|
| 38 |
+
OPENAI_CLIP_STD,
|
| 39 |
+
ChannelDimension,
|
| 40 |
+
ImageInput,
|
| 41 |
+
PILImageResampling,
|
| 42 |
+
VideoInput,
|
| 43 |
+
get_image_size,
|
| 44 |
+
infer_channel_dimension_format,
|
| 45 |
+
is_scaled_image,
|
| 46 |
+
is_valid_image,
|
| 47 |
+
make_list_of_images,
|
| 48 |
+
to_numpy_array,
|
| 49 |
+
)
|
| 50 |
+
from transformers.utils import TensorType, is_vision_available, logging
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
logger = logging.get_logger(__name__)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
if is_vision_available():
|
| 57 |
+
from PIL import Image
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def is_valid_video(video) -> bool:
|
| 61 |
+
if isinstance(video, (list, tuple)):
|
| 62 |
+
return all(is_valid_image(frame) for frame in video)
|
| 63 |
+
elif isinstance(video, np.ndarray):
|
| 64 |
+
return video.ndim == 4
|
| 65 |
+
elif isinstance(video, torch.Tensor):
|
| 66 |
+
return video.ndim == 4
|
| 67 |
+
return False
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def make_batched_images(images) -> List[List[ImageInput]]:
|
| 71 |
+
"""
|
| 72 |
+
Accepts images in list or nested list format, and makes a list of images for preprocessing.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
|
| 76 |
+
The input image.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
list: A list of images.
|
| 80 |
+
"""
|
| 81 |
+
if isinstance(images, (list, tuple)):
|
| 82 |
+
# list of images/videos
|
| 83 |
+
if not all(is_valid_video(image) or is_valid_image(image) for image in images):
|
| 84 |
+
raise ValueError(f"Could not make batched images from {images}")
|
| 85 |
+
return images
|
| 86 |
+
elif is_valid_video(images) or is_valid_image(images):
|
| 87 |
+
# single image/video
|
| 88 |
+
return [images]
|
| 89 |
+
|
| 90 |
+
raise ValueError(f"Could not make batched images from {images}")
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def simple_batched_resize(
|
| 94 |
+
images, factor: int = 28, min_tokens: int = 4 * 4, max_tokens: int = 16384, input_data_format: str = None
|
| 95 |
+
):
|
| 96 |
+
min_pixels = min_tokens * factor * factor
|
| 97 |
+
max_pixels = max_tokens * factor * factor
|
| 98 |
+
|
| 99 |
+
num_images = 0
|
| 100 |
+
for image in images:
|
| 101 |
+
if is_valid_video(image):
|
| 102 |
+
num_images += len(image)
|
| 103 |
+
else:
|
| 104 |
+
num_images += 1
|
| 105 |
+
|
| 106 |
+
image_sizes = []
|
| 107 |
+
for image in images:
|
| 108 |
+
if is_valid_video(image):
|
| 109 |
+
image = image[0]
|
| 110 |
+
if isinstance(image, Image.Image):
|
| 111 |
+
width, height = image.size
|
| 112 |
+
else:
|
| 113 |
+
height, width = get_image_size(image, channel_dim=input_data_format)
|
| 114 |
+
image_sizes.append([height, width])
|
| 115 |
+
|
| 116 |
+
tmp_image_sizes = []
|
| 117 |
+
for height, width in image_sizes:
|
| 118 |
+
h_bar = round(height / factor) * factor
|
| 119 |
+
w_bar = round(width / factor) * factor
|
| 120 |
+
if h_bar * w_bar > (max_pixels // num_images):
|
| 121 |
+
beta = math.sqrt((height * width) / (max_pixels // num_images))
|
| 122 |
+
h_bar = math.floor(height / beta / factor) * factor
|
| 123 |
+
w_bar = math.floor(width / beta / factor) * factor
|
| 124 |
+
# per image min_pixels
|
| 125 |
+
if h_bar * w_bar < min_pixels:
|
| 126 |
+
beta = math.sqrt(min_pixels / (height * width))
|
| 127 |
+
h_bar = math.ceil(height * beta / factor) * factor
|
| 128 |
+
w_bar = math.ceil(width * beta / factor) * factor
|
| 129 |
+
tmp_image_sizes.append((h_bar, w_bar))
|
| 130 |
+
image_sizes = tmp_image_sizes
|
| 131 |
+
return image_sizes
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def batched_resize(
|
| 135 |
+
images, factors: List[int], min_tokens: int = 4 * 4, max_tokens: int = 16384, input_data_format: str = None
|
| 136 |
+
):
|
| 137 |
+
image_sizes = []
|
| 138 |
+
for image in images:
|
| 139 |
+
if is_valid_video(image):
|
| 140 |
+
num_frame = len(image)
|
| 141 |
+
image = image[0]
|
| 142 |
+
else:
|
| 143 |
+
num_frame = 1
|
| 144 |
+
if isinstance(image, Image.Image):
|
| 145 |
+
width, height = image.size
|
| 146 |
+
else:
|
| 147 |
+
height, width = get_image_size(image, channel_dim=input_data_format)
|
| 148 |
+
image_sizes.append([num_frame, height, width])
|
| 149 |
+
|
| 150 |
+
# global max_pixels
|
| 151 |
+
smart_scale_factors = 1.0
|
| 152 |
+
total_tokens = 0
|
| 153 |
+
for (num_frame, height, width), factor in zip(image_sizes, factors):
|
| 154 |
+
total_tokens += num_frame * math.ceil(height / factor) * math.ceil(width / factor)
|
| 155 |
+
|
| 156 |
+
# TODO: add min_pixels
|
| 157 |
+
if total_tokens > max_tokens:
|
| 158 |
+
beta = math.sqrt(total_tokens / max_tokens)
|
| 159 |
+
tmp_image_sizes = []
|
| 160 |
+
for (_, height, width), factor in zip(image_sizes, factors):
|
| 161 |
+
h_bar = math.floor(height / beta / factor) * factor
|
| 162 |
+
w_bar = math.floor(width / beta / factor) * factor
|
| 163 |
+
tmp_image_sizes.append((h_bar, w_bar))
|
| 164 |
+
image_sizes = tmp_image_sizes
|
| 165 |
+
else:
|
| 166 |
+
tmp_image_sizes = []
|
| 167 |
+
for (_, height, width), factor in zip(image_sizes, factors):
|
| 168 |
+
height = round(height / factor) * factor
|
| 169 |
+
width = round(width / factor) * factor
|
| 170 |
+
tmp_image_sizes.append((height, width))
|
| 171 |
+
image_sizes = tmp_image_sizes
|
| 172 |
+
|
| 173 |
+
return image_sizes
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class Videollama3ImageProcessor(BaseImageProcessor):
|
| 177 |
+
r"""
|
| 178 |
+
Constructs a DAMOVL image processor that dynamically resizes images based on the original images.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 182 |
+
Whether to resize the image's (height, width) dimensions.
|
| 183 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
| 184 |
+
Resampling filter to use when resizing the image.
|
| 185 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 186 |
+
Whether to rescale the image by the specified scale `rescale_factor`.
|
| 187 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 188 |
+
Scale factor to use if rescaling the image.
|
| 189 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 190 |
+
Whether to normalize the image.
|
| 191 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
|
| 192 |
+
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
| 193 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
|
| 194 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
| 195 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
| 196 |
+
Whether to convert the image to RGB.
|
| 197 |
+
min_pixels (`int`, *optional*, defaults to `56 * 56`):
|
| 198 |
+
The min pixels of the image to resize the image.
|
| 199 |
+
max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
|
| 200 |
+
The max pixels of the image to resize the image.
|
| 201 |
+
patch_size (`int`, *optional*, defaults to 14):
|
| 202 |
+
The spacial patch size of the vision encoder.
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
model_input_names = ["pixel_values", "grid_sizes", "merge_sizes"]
|
| 206 |
+
|
| 207 |
+
def __init__(
|
| 208 |
+
self,
|
| 209 |
+
do_resize: bool = True,
|
| 210 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 211 |
+
do_rescale: bool = True,
|
| 212 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 213 |
+
do_normalize: bool = True,
|
| 214 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 215 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 216 |
+
do_convert_rgb: bool = True,
|
| 217 |
+
min_tokens: int = 4 * 4,
|
| 218 |
+
max_tokens: int = 16384,
|
| 219 |
+
patch_size: int = 14,
|
| 220 |
+
**kwargs,
|
| 221 |
+
) -> None:
|
| 222 |
+
super().__init__(**kwargs)
|
| 223 |
+
self.do_resize = do_resize
|
| 224 |
+
self.resample = resample
|
| 225 |
+
self.do_rescale = do_rescale
|
| 226 |
+
self.rescale_factor = rescale_factor
|
| 227 |
+
self.do_normalize = do_normalize
|
| 228 |
+
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
| 229 |
+
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
| 230 |
+
self.min_tokens = min_tokens
|
| 231 |
+
self.max_tokens = max_tokens
|
| 232 |
+
self.patch_size = patch_size
|
| 233 |
+
self.do_convert_rgb = do_convert_rgb
|
| 234 |
+
|
| 235 |
+
def _preprocess(
|
| 236 |
+
self,
|
| 237 |
+
images: Union[ImageInput, VideoInput],
|
| 238 |
+
target_size: List[int],
|
| 239 |
+
merge_size: int = 1,
|
| 240 |
+
do_resize: bool = None,
|
| 241 |
+
resample: PILImageResampling = None,
|
| 242 |
+
do_rescale: bool = None,
|
| 243 |
+
rescale_factor: float = None,
|
| 244 |
+
do_normalize: bool = None,
|
| 245 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 246 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 247 |
+
do_convert_rgb: bool = None,
|
| 248 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
| 249 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 250 |
+
):
|
| 251 |
+
"""
|
| 252 |
+
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
images (`ImageInput`):
|
| 256 |
+
Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
|
| 257 |
+
target_size (`List[int]`):
|
| 258 |
+
The target size to resize the image to. Should be a list of two integers: [target_height, target_width].
|
| 259 |
+
merge_size (`int`, *optional*, defaults to `1`):
|
| 260 |
+
The merge size after the vision encoder.
|
| 261 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 262 |
+
Whether to resize the image.
|
| 263 |
+
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
| 264 |
+
Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
|
| 265 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 266 |
+
Whether to rescale the image.
|
| 267 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 268 |
+
Scale factor to use if rescaling the image.
|
| 269 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 270 |
+
Whether to normalize the image.
|
| 271 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 272 |
+
Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
| 273 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 274 |
+
Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
| 275 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 276 |
+
Whether to convert the image to RGB.
|
| 277 |
+
data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 278 |
+
The channel dimension format for the output image. Can be one of:
|
| 279 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 280 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 281 |
+
- Unset: Use the channel dimension format of the input image.
|
| 282 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 283 |
+
The channel dimension format for the input image. Can be one of:
|
| 284 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 285 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 286 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 287 |
+
"""
|
| 288 |
+
images = make_list_of_images(images)
|
| 289 |
+
|
| 290 |
+
if do_convert_rgb:
|
| 291 |
+
images = [convert_to_rgb(image) for image in images]
|
| 292 |
+
|
| 293 |
+
# All transformations expect numpy arrays.
|
| 294 |
+
images = [to_numpy_array(image) for image in images]
|
| 295 |
+
|
| 296 |
+
if is_scaled_image(images[0]) and do_rescale:
|
| 297 |
+
logger.warning_once(
|
| 298 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 299 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 300 |
+
)
|
| 301 |
+
if input_data_format is None:
|
| 302 |
+
# We assume that all images have the same channel dimension format.
|
| 303 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 304 |
+
|
| 305 |
+
height, width = get_image_size(images[0], channel_dim=input_data_format)
|
| 306 |
+
resized_height, resized_width = height, width
|
| 307 |
+
processed_images = []
|
| 308 |
+
for image in images:
|
| 309 |
+
if do_resize:
|
| 310 |
+
resized_height, resized_width = target_size
|
| 311 |
+
image = resize(
|
| 312 |
+
image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
if do_rescale:
|
| 316 |
+
image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
|
| 317 |
+
|
| 318 |
+
if do_normalize:
|
| 319 |
+
image = self.normalize(
|
| 320 |
+
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
| 324 |
+
processed_images.append(image)
|
| 325 |
+
|
| 326 |
+
patches = np.array(processed_images)
|
| 327 |
+
if data_format == ChannelDimension.LAST:
|
| 328 |
+
patches = patches.transpose(0, 3, 1, 2)
|
| 329 |
+
t = patches.shape[0]
|
| 330 |
+
channel = patches.shape[1]
|
| 331 |
+
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
|
| 332 |
+
patches = patches.reshape(
|
| 333 |
+
t,
|
| 334 |
+
channel,
|
| 335 |
+
grid_h // merge_size,
|
| 336 |
+
merge_size,
|
| 337 |
+
self.patch_size,
|
| 338 |
+
grid_w // merge_size,
|
| 339 |
+
merge_size,
|
| 340 |
+
self.patch_size,
|
| 341 |
+
)
|
| 342 |
+
patches = patches.transpose(0, 2, 5, 3, 6, 1, 4, 7)
|
| 343 |
+
flatten_patches = patches.reshape(
|
| 344 |
+
t * grid_h * grid_w, channel * self.patch_size * self.patch_size
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
return flatten_patches, (t, grid_h, grid_w)
|
| 348 |
+
|
| 349 |
+
def preprocess(
|
| 350 |
+
self,
|
| 351 |
+
images: ImageInput,
|
| 352 |
+
do_resize: bool = None,
|
| 353 |
+
resample: PILImageResampling = None,
|
| 354 |
+
do_rescale: bool = None,
|
| 355 |
+
rescale_factor: float = None,
|
| 356 |
+
do_normalize: bool = None,
|
| 357 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 358 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 359 |
+
do_convert_rgb: bool = None,
|
| 360 |
+
merge_size: Optional[Union[int, List[int]]] = None,
|
| 361 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 362 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
| 363 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 364 |
+
):
|
| 365 |
+
"""
|
| 366 |
+
Args:
|
| 367 |
+
images (`ImageInput`):
|
| 368 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 369 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 370 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 371 |
+
Whether to resize the image.
|
| 372 |
+
resample (`int`, *optional*, defaults to `self.resample`):
|
| 373 |
+
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
| 374 |
+
has an effect if `do_resize` is set to `True`.
|
| 375 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 376 |
+
Whether to rescale the image.
|
| 377 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 378 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 379 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 380 |
+
Whether to normalize the image.
|
| 381 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 382 |
+
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
| 383 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 384 |
+
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
| 385 |
+
`True`.
|
| 386 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 387 |
+
Whether to convert the image to RGB.
|
| 388 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 389 |
+
The type of tensors to return. Can be one of:
|
| 390 |
+
- Unset: Return a list of `np.ndarray`.
|
| 391 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 392 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 393 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 394 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 395 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 396 |
+
The channel dimension format for the output image. Can be one of:
|
| 397 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 398 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 399 |
+
- Unset: Use the channel dimension format of the input image.
|
| 400 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 401 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 402 |
+
from the input image. Can be one of:
|
| 403 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 404 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 405 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 406 |
+
|
| 407 |
+
"""
|
| 408 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 409 |
+
resample = resample if resample is not None else self.resample
|
| 410 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 411 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 412 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 413 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 414 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 415 |
+
merge_size = merge_size if merge_size is not None else self.merge_size
|
| 416 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
| 417 |
+
|
| 418 |
+
images = make_batched_images(images)
|
| 419 |
+
|
| 420 |
+
if isinstance(merge_size, (list, tuple)):
|
| 421 |
+
assert len(merge_size) == len(images), "Merge size must be the same length as images."
|
| 422 |
+
merge_sizes = merge_size
|
| 423 |
+
else:
|
| 424 |
+
merge_sizes = [merge_size for _ in images]
|
| 425 |
+
|
| 426 |
+
if all(merge_size == merge_sizes[0] for merge_size in merge_sizes):
|
| 427 |
+
target_sizes = simple_batched_resize(
|
| 428 |
+
images,
|
| 429 |
+
factor=self.patch_size * merge_sizes[0],
|
| 430 |
+
min_tokens=self.min_tokens,
|
| 431 |
+
max_tokens=self.max_tokens,
|
| 432 |
+
input_data_format=input_data_format,
|
| 433 |
+
)
|
| 434 |
+
else:
|
| 435 |
+
target_sizes = batched_resize(
|
| 436 |
+
images,
|
| 437 |
+
factors=[self.patch_size * merge_size for merge_size in merge_sizes],
|
| 438 |
+
min_tokens=self.min_tokens,
|
| 439 |
+
max_tokens=self.max_tokens,
|
| 440 |
+
input_data_format=input_data_format,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
pixel_values, grid_sizes = [], []
|
| 444 |
+
for image, merge_size, target_size in zip(images, merge_sizes, target_sizes):
|
| 445 |
+
patches, grid_size = self._preprocess(
|
| 446 |
+
image,
|
| 447 |
+
target_size=target_size,
|
| 448 |
+
merge_size=merge_size,
|
| 449 |
+
do_resize=do_resize,
|
| 450 |
+
resample=resample,
|
| 451 |
+
do_rescale=do_rescale,
|
| 452 |
+
rescale_factor=rescale_factor,
|
| 453 |
+
do_normalize=do_normalize,
|
| 454 |
+
image_mean=image_mean,
|
| 455 |
+
image_std=image_std,
|
| 456 |
+
data_format=data_format,
|
| 457 |
+
do_convert_rgb=do_convert_rgb,
|
| 458 |
+
input_data_format=input_data_format,
|
| 459 |
+
)
|
| 460 |
+
pixel_values.append(patches)
|
| 461 |
+
grid_sizes.append(grid_size)
|
| 462 |
+
|
| 463 |
+
pixel_values = np.concatenate(pixel_values, axis=0)
|
| 464 |
+
grid_sizes = np.array(grid_sizes)
|
| 465 |
+
merge_sizes = np.array(merge_sizes)
|
| 466 |
+
|
| 467 |
+
data = {
|
| 468 |
+
"pixel_values": pixel_values,
|
| 469 |
+
"grid_sizes": grid_sizes,
|
| 470 |
+
"merge_sizes": merge_sizes,
|
| 471 |
+
}
|
| 472 |
+
|
| 473 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
RynnEC/rynnec/model/videollama3_encoder/modeling_videollama3_encoder.py
ADDED
|
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py.
|
| 2 |
+
# Below is the original copyright:
|
| 3 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 6 |
+
# and OPT implementations in this library. It has been modified from its
|
| 7 |
+
# original forms to accommodate minor architectural differences compared
|
| 8 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
"""PyTorch VideoLLaMA3 vision encoder model."""
|
| 22 |
+
|
| 23 |
+
import importlib.util
|
| 24 |
+
import os.path as osp
|
| 25 |
+
import math
|
| 26 |
+
import warnings
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
import torch.nn as nn
|
| 30 |
+
import torch.nn.functional as F
|
| 31 |
+
import torch.utils.checkpoint
|
| 32 |
+
from torch.nn.init import _calculate_fan_in_and_fan_out
|
| 33 |
+
|
| 34 |
+
from transformers.activations import ACT2FN
|
| 35 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 36 |
+
from transformers.utils import is_flash_attn_2_available
|
| 37 |
+
|
| 38 |
+
if is_flash_attn_2_available():
|
| 39 |
+
from flash_attn import flash_attn_varlen_func
|
| 40 |
+
else:
|
| 41 |
+
flash_attn_varlen_func = None
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
from .configuration_videollama3_encoder import Videollama3VisionEncoderConfig
|
| 45 |
+
except ImportError:
|
| 46 |
+
spec = importlib.util.spec_from_file_location(
|
| 47 |
+
"configuration_videollama3_encoder",
|
| 48 |
+
osp.join(osp.dirname(__file__), "configuration_videollama3_encoder.py"),
|
| 49 |
+
)
|
| 50 |
+
configuration_videollama3_encoder = importlib.util.module_from_spec(spec)
|
| 51 |
+
spec.loader.exec_module(configuration_videollama3_encoder)
|
| 52 |
+
Videollama3VisionEncoderConfig = getattr(
|
| 53 |
+
configuration_videollama3_encoder,
|
| 54 |
+
"Videollama3VisionEncoderConfig",
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
LayerNorm = nn.LayerNorm
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _trunc_normal_(tensor, mean, std, a, b):
|
| 61 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
| 62 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
| 63 |
+
def norm_cdf(x):
|
| 64 |
+
# Computes standard normal cumulative distribution function
|
| 65 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
| 66 |
+
|
| 67 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
| 68 |
+
warnings.warn(
|
| 69 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
| 70 |
+
"The distribution of values may be incorrect.",
|
| 71 |
+
stacklevel=2,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# Values are generated by using a truncated uniform distribution and
|
| 75 |
+
# then using the inverse CDF for the normal distribution.
|
| 76 |
+
# Get upper and lower cdf values
|
| 77 |
+
l = norm_cdf((a - mean) / std)
|
| 78 |
+
u = norm_cdf((b - mean) / std)
|
| 79 |
+
|
| 80 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
| 81 |
+
# [2l-1, 2u-1].
|
| 82 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
| 83 |
+
|
| 84 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
| 85 |
+
# standard normal
|
| 86 |
+
tensor.erfinv_()
|
| 87 |
+
|
| 88 |
+
# Transform to proper mean, std
|
| 89 |
+
tensor.mul_(std * math.sqrt(2.0))
|
| 90 |
+
tensor.add_(mean)
|
| 91 |
+
|
| 92 |
+
# Clamp to ensure it's in the proper range
|
| 93 |
+
tensor.clamp_(min=a, max=b)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def trunc_normal_tf_(
|
| 97 |
+
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
|
| 98 |
+
) -> torch.Tensor:
|
| 99 |
+
"""Fills the input Tensor with values drawn from a truncated
|
| 100 |
+
normal distribution. The values are effectively drawn from the
|
| 101 |
+
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
|
| 102 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
| 103 |
+
the bounds. The method used for generating the random values works
|
| 104 |
+
best when :math:`a \\leq \text{mean} \\leq b`.
|
| 105 |
+
|
| 106 |
+
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
|
| 107 |
+
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
|
| 108 |
+
and the result is subsequently scaled and shifted by the mean and std args.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
tensor: an n-dimensional `torch.Tensor`
|
| 112 |
+
mean: the mean of the normal distribution
|
| 113 |
+
std: the standard deviation of the normal distribution
|
| 114 |
+
a: the minimum cutoff value
|
| 115 |
+
b: the maximum cutoff value
|
| 116 |
+
"""
|
| 117 |
+
with torch.no_grad():
|
| 118 |
+
_trunc_normal_(tensor, 0, 1.0, a, b)
|
| 119 |
+
tensor.mul_(std).add_(mean)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
|
| 123 |
+
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
| 124 |
+
if mode == "fan_in":
|
| 125 |
+
denom = fan_in
|
| 126 |
+
elif mode == "fan_out":
|
| 127 |
+
denom = fan_out
|
| 128 |
+
elif mode == "fan_avg":
|
| 129 |
+
denom = (fan_in + fan_out) / 2
|
| 130 |
+
|
| 131 |
+
variance = scale / denom
|
| 132 |
+
|
| 133 |
+
if distribution == "truncated_normal":
|
| 134 |
+
# constant is stddev of standard normal truncated to (-2, 2)
|
| 135 |
+
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
|
| 136 |
+
elif distribution == "normal":
|
| 137 |
+
with torch.no_grad():
|
| 138 |
+
tensor.normal_(std=math.sqrt(variance))
|
| 139 |
+
elif distribution == "uniform":
|
| 140 |
+
bound = math.sqrt(3 * variance)
|
| 141 |
+
with torch.no_grad():
|
| 142 |
+
tensor.uniform_(-bound, bound)
|
| 143 |
+
else:
|
| 144 |
+
raise ValueError(f"invalid distribution {distribution}")
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def lecun_normal_(tensor):
|
| 148 |
+
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def default_flax_embed_init(tensor):
|
| 152 |
+
variance_scaling_(tensor, mode="fan_in", distribution="normal")
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
| 156 |
+
def rotate_half(x):
|
| 157 |
+
"""Rotates half the hidden dims of the input."""
|
| 158 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 159 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 160 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
| 164 |
+
# orig_dtype = tensor.dtype
|
| 165 |
+
# tensor = tensor.float()
|
| 166 |
+
# cos = freqs.cos()
|
| 167 |
+
# sin = freqs.sin()
|
| 168 |
+
# cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
| 169 |
+
# sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
| 170 |
+
# output = (tensor * cos) + (rotate_half(tensor) * sin)
|
| 171 |
+
# output = output.to(orig_dtype)
|
| 172 |
+
# return output
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def apply_rotary_pos_emb_vision(q, k, cos, sin) -> torch.Tensor:
|
| 176 |
+
orig_dtype = q.dtype
|
| 177 |
+
q, k = q.float(), k.float()
|
| 178 |
+
cos = cos.unsqueeze(1).float()
|
| 179 |
+
sin = sin.unsqueeze(1).float()
|
| 180 |
+
q = (q * cos) + (rotate_half(q) * sin)
|
| 181 |
+
k = (k * cos) + (rotate_half(k) * sin)
|
| 182 |
+
return q.to(orig_dtype), k.to(orig_dtype)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class VisionRotaryEmbedding(nn.Module):
|
| 186 |
+
|
| 187 |
+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
| 188 |
+
super().__init__()
|
| 189 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
| 190 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 191 |
+
|
| 192 |
+
def forward(self, seqlen: int) -> torch.Tensor:
|
| 193 |
+
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
| 194 |
+
freqs = torch.outer(seq, self.inv_freq)
|
| 195 |
+
return freqs
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class Videollama3VisionEmbeddings(nn.Module):
|
| 199 |
+
|
| 200 |
+
def __init__(self, config: Videollama3VisionEncoderConfig):
|
| 201 |
+
super().__init__()
|
| 202 |
+
self.config = config
|
| 203 |
+
self.embed_dim = config.hidden_size
|
| 204 |
+
self.patch_size = config.patch_size
|
| 205 |
+
|
| 206 |
+
self.patch_embedding = nn.Conv2d(
|
| 207 |
+
in_channels=config.num_channels,
|
| 208 |
+
out_channels=self.embed_dim,
|
| 209 |
+
kernel_size=self.patch_size,
|
| 210 |
+
stride=self.patch_size,
|
| 211 |
+
padding="valid",
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 215 |
+
hidden_states = hidden_states.view(
|
| 216 |
+
-1, self.config.num_channels, self.patch_size, self.patch_size
|
| 217 |
+
)
|
| 218 |
+
patch_embeds = self.patch_embedding(hidden_states) # shape = [*, width, grid, grid]
|
| 219 |
+
# embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
| 220 |
+
embeddings = patch_embeds.view(-1, self.embed_dim)
|
| 221 |
+
|
| 222 |
+
return embeddings
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class VisionAttention(nn.Module):
|
| 226 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 227 |
+
|
| 228 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
|
| 229 |
+
def __init__(self, config):
|
| 230 |
+
super().__init__()
|
| 231 |
+
self.config = config
|
| 232 |
+
self.embed_dim = config.hidden_size
|
| 233 |
+
self.num_heads = config.num_attention_heads
|
| 234 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 235 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 236 |
+
raise ValueError(
|
| 237 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
| 238 |
+
f" {self.num_heads})."
|
| 239 |
+
)
|
| 240 |
+
self.scale = self.head_dim**-0.5
|
| 241 |
+
self.dropout = config.attention_dropout
|
| 242 |
+
|
| 243 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 244 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 245 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 246 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 247 |
+
|
| 248 |
+
def forward(
|
| 249 |
+
self,
|
| 250 |
+
hidden_states: torch.Tensor,
|
| 251 |
+
cu_seqlens: torch.Tensor,
|
| 252 |
+
rotary_pos_emb: torch.Tensor = None,
|
| 253 |
+
) -> torch.Tensor:
|
| 254 |
+
"""Input shape: Time x Channel"""
|
| 255 |
+
|
| 256 |
+
q_len, _ = hidden_states.size()
|
| 257 |
+
|
| 258 |
+
query_states = self.q_proj(hidden_states)
|
| 259 |
+
key_states = self.k_proj(hidden_states)
|
| 260 |
+
value_states = self.v_proj(hidden_states)
|
| 261 |
+
|
| 262 |
+
query_states = query_states.view(q_len, self.num_heads, self.head_dim)
|
| 263 |
+
key_states = key_states.view(q_len, self.num_heads, self.head_dim)
|
| 264 |
+
value_states = value_states.view(q_len, self.num_heads, self.head_dim)
|
| 265 |
+
|
| 266 |
+
query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
| 267 |
+
key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
| 268 |
+
|
| 269 |
+
attention_mask = torch.zeros([1, q_len, q_len], device=query_states.device, dtype=torch.bool)
|
| 270 |
+
for i in range(1, len(cu_seqlens)):
|
| 271 |
+
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
| 272 |
+
|
| 273 |
+
query_states = query_states.transpose(0, 1)
|
| 274 |
+
key_states = key_states.transpose(0, 1)
|
| 275 |
+
value_states = value_states.transpose(0, 1)
|
| 276 |
+
|
| 277 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(1, 2)) / math.sqrt(self.head_dim)
|
| 278 |
+
attn_weights = attn_weights + attention_mask
|
| 279 |
+
|
| 280 |
+
# upcast attention to fp32
|
| 281 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 282 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
| 283 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 284 |
+
|
| 285 |
+
attn_output = attn_output.transpose(0, 1)
|
| 286 |
+
attn_output = attn_output.reshape(q_len, -1)
|
| 287 |
+
attn_output = self.out_proj(attn_output)
|
| 288 |
+
|
| 289 |
+
return attn_output
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class VisionFlashAttention2(VisionAttention):
|
| 293 |
+
|
| 294 |
+
def __init__(self, *args, **kwargs):
|
| 295 |
+
super().__init__(*args, **kwargs)
|
| 296 |
+
|
| 297 |
+
# Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
|
| 298 |
+
def forward(
|
| 299 |
+
self,
|
| 300 |
+
hidden_states: torch.Tensor,
|
| 301 |
+
cu_seqlens: torch.Tensor,
|
| 302 |
+
rotary_pos_emb: torch.Tensor = None,
|
| 303 |
+
) -> torch.Tensor:
|
| 304 |
+
q_len, _ = hidden_states.size()
|
| 305 |
+
|
| 306 |
+
query_states = self.q_proj(hidden_states)
|
| 307 |
+
key_states = self.k_proj(hidden_states)
|
| 308 |
+
value_states = self.v_proj(hidden_states)
|
| 309 |
+
|
| 310 |
+
# Flash attention requires the input to have the shape
|
| 311 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
| 312 |
+
# therefore we just need to keep the original shape
|
| 313 |
+
query_states = query_states.view(1, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 314 |
+
key_states = key_states.view(1, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 315 |
+
value_states = value_states.view(q_len, self.num_heads, self.head_dim)
|
| 316 |
+
# query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
| 317 |
+
# key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
| 318 |
+
query_states, key_states = apply_rotary_pos_emb_vision(
|
| 319 |
+
query_states,
|
| 320 |
+
key_states,
|
| 321 |
+
rotary_pos_emb.cos().unsqueeze(0).repeat(1, 1, 2),
|
| 322 |
+
rotary_pos_emb.sin().unsqueeze(0).repeat(1, 1, 2),
|
| 323 |
+
)
|
| 324 |
+
query_states = query_states.transpose(1, 2).squeeze(0)
|
| 325 |
+
key_states = key_states.transpose(1, 2).squeeze(0)
|
| 326 |
+
|
| 327 |
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
| 328 |
+
attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
| 329 |
+
q_len, -1
|
| 330 |
+
)
|
| 331 |
+
attn_output = self.out_proj(attn_output)
|
| 332 |
+
|
| 333 |
+
return attn_output
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class VisionSdpaAttention(VisionAttention):
|
| 337 |
+
|
| 338 |
+
def forward(
|
| 339 |
+
self,
|
| 340 |
+
hidden_states: torch.Tensor,
|
| 341 |
+
cu_seqlens: torch.Tensor,
|
| 342 |
+
rotary_pos_emb: torch.Tensor = None,
|
| 343 |
+
) -> torch.Tensor:
|
| 344 |
+
seq_length = hidden_states.shape[0]
|
| 345 |
+
query_states = self.q_proj(hidden_states)
|
| 346 |
+
key_states = self.k_proj(hidden_states)
|
| 347 |
+
value_states = self.v_proj(hidden_states)
|
| 348 |
+
|
| 349 |
+
query_states = query_states.view(seq_length, self.num_heads, self.head_dim)
|
| 350 |
+
key_states = key_states.view(seq_length, self.num_heads, self.head_dim)
|
| 351 |
+
value_states = value_states.view(seq_length, self.num_heads, self.head_dim)
|
| 352 |
+
|
| 353 |
+
query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
| 354 |
+
key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
| 355 |
+
|
| 356 |
+
attention_mask = torch.zeros([1, seq_length, seq_length], device=query_states.device, dtype=torch.bool)
|
| 357 |
+
for i in range(1, len(cu_seqlens)):
|
| 358 |
+
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
| 359 |
+
|
| 360 |
+
query_states = query_states.transpose(0, 1)
|
| 361 |
+
key_states = key_states.transpose(0, 1)
|
| 362 |
+
value_states = value_states.transpose(0, 1)
|
| 363 |
+
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attention_mask, dropout_p=0.0)
|
| 364 |
+
attn_output = attn_output.transpose(0, 1)
|
| 365 |
+
attn_output = attn_output.reshape(seq_length, -1)
|
| 366 |
+
attn_output = self.out_proj(attn_output)
|
| 367 |
+
return attn_output
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
VISION_ATTENTION_CLASSES = {
|
| 371 |
+
"eager": VisionAttention,
|
| 372 |
+
"flash_attention_2": VisionFlashAttention2,
|
| 373 |
+
"sdpa": VisionSdpaAttention,
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Videollama3
|
| 378 |
+
class Videollama3VisionMLP(nn.Module):
|
| 379 |
+
|
| 380 |
+
def __init__(self, config):
|
| 381 |
+
super().__init__()
|
| 382 |
+
self.config = config
|
| 383 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
| 384 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 385 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 386 |
+
|
| 387 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 388 |
+
hidden_states = self.fc1(hidden_states)
|
| 389 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 390 |
+
hidden_states = self.fc2(hidden_states)
|
| 391 |
+
return hidden_states
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
class Videollama3VisionEncoderLayer(nn.Module):
|
| 395 |
+
|
| 396 |
+
def __init__(self, config: Videollama3VisionEncoderConfig):
|
| 397 |
+
super().__init__()
|
| 398 |
+
self.embed_dim = config.hidden_size
|
| 399 |
+
self.self_attn = VISION_ATTENTION_CLASSES[config._attn_implementation](config=config)
|
| 400 |
+
self.layer_norm1 = LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 401 |
+
self.mlp = Videollama3VisionMLP(config)
|
| 402 |
+
self.layer_norm2 = LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 403 |
+
|
| 404 |
+
# Ignore copy
|
| 405 |
+
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
|
| 406 |
+
hidden_states = hidden_states + self.self_attn(
|
| 407 |
+
self.layer_norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
| 408 |
+
)
|
| 409 |
+
hidden_states = hidden_states + self.mlp(self.layer_norm2(hidden_states))
|
| 410 |
+
return hidden_states
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
class Videollama3VisionTransformerEncoder(nn.Module):
|
| 414 |
+
|
| 415 |
+
def __init__(self, config: Videollama3VisionEncoderConfig):
|
| 416 |
+
super().__init__()
|
| 417 |
+
self.config = config
|
| 418 |
+
head_dim = config.hidden_size // config.num_attention_heads
|
| 419 |
+
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
| 420 |
+
self.layers = nn.ModuleList([Videollama3VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 421 |
+
self.gradient_checkpointing = False
|
| 422 |
+
|
| 423 |
+
def rot_pos_emb(self, grid_sizes, merge_sizes):
|
| 424 |
+
pos_ids = []
|
| 425 |
+
for (t, h, w), merge_size in zip(grid_sizes, merge_sizes):
|
| 426 |
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
| 427 |
+
hpos_ids = hpos_ids.reshape(
|
| 428 |
+
h // merge_size,
|
| 429 |
+
merge_size,
|
| 430 |
+
w // merge_size,
|
| 431 |
+
merge_size,
|
| 432 |
+
)
|
| 433 |
+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
| 434 |
+
hpos_ids = hpos_ids.flatten()
|
| 435 |
+
|
| 436 |
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
| 437 |
+
wpos_ids = wpos_ids.reshape(
|
| 438 |
+
h // merge_size,
|
| 439 |
+
merge_size,
|
| 440 |
+
w // merge_size,
|
| 441 |
+
merge_size,
|
| 442 |
+
)
|
| 443 |
+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
| 444 |
+
wpos_ids = wpos_ids.flatten()
|
| 445 |
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
| 446 |
+
|
| 447 |
+
pos_ids = torch.cat(pos_ids, dim=0)
|
| 448 |
+
max_grid_size = grid_sizes[:, 1:].max()
|
| 449 |
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
| 450 |
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
| 451 |
+
|
| 452 |
+
return rotary_pos_emb
|
| 453 |
+
|
| 454 |
+
def forward(self, hidden_states, grid_sizes, merge_sizes) -> torch.Tensor:
|
| 455 |
+
rotary_pos_emb = self.rot_pos_emb(grid_sizes, merge_sizes)
|
| 456 |
+
|
| 457 |
+
cu_seqlens = torch.repeat_interleave(grid_sizes[:, 1] * grid_sizes[:, 2], grid_sizes[:, 0]).cumsum(dim=0, dtype=torch.int32)
|
| 458 |
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
| 459 |
+
|
| 460 |
+
for blk in self.layers:
|
| 461 |
+
if self.gradient_checkpointing and self.training:
|
| 462 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 463 |
+
blk.__call__,
|
| 464 |
+
hidden_states,
|
| 465 |
+
cu_seqlens,
|
| 466 |
+
rotary_pos_emb
|
| 467 |
+
)
|
| 468 |
+
else:
|
| 469 |
+
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
|
| 470 |
+
|
| 471 |
+
return hidden_states
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
class Videollama3VisionEncoderModel(PreTrainedModel):
|
| 475 |
+
|
| 476 |
+
config_class = Videollama3VisionEncoderConfig
|
| 477 |
+
base_model_prefix = "videollama3"
|
| 478 |
+
main_input_name = "pixel_values"
|
| 479 |
+
supports_gradient_checkpointing = True
|
| 480 |
+
_no_split_modules = [
|
| 481 |
+
"Videollama3VisionEncoderLayer",
|
| 482 |
+
"Videollama3VisionEmbeddings",
|
| 483 |
+
]
|
| 484 |
+
_supports_flash_attn_2 = True
|
| 485 |
+
_supports_sdpa = True
|
| 486 |
+
|
| 487 |
+
def __init__(self, config: Videollama3VisionEncoderConfig):
|
| 488 |
+
super().__init__(config=config)
|
| 489 |
+
embed_dim = config.hidden_size
|
| 490 |
+
|
| 491 |
+
self.embeddings = Videollama3VisionEmbeddings(config)
|
| 492 |
+
self.encoder = Videollama3VisionTransformerEncoder(config)
|
| 493 |
+
self.post_layernorm = LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
| 494 |
+
|
| 495 |
+
self.post_init()
|
| 496 |
+
|
| 497 |
+
def forward(self, pixel_values, grid_sizes, merge_sizes=None):
|
| 498 |
+
hidden_states = self.embeddings(pixel_values)
|
| 499 |
+
hidden_states = self.encoder(hidden_states, grid_sizes, merge_sizes)
|
| 500 |
+
hidden_states = self.post_layernorm(hidden_states)
|
| 501 |
+
hidden_states_raw = hidden_states.clone()
|
| 502 |
+
|
| 503 |
+
hidden_states_chunks = hidden_states.split(grid_sizes.prod(dim=1).tolist(), dim=0)
|
| 504 |
+
outputs = []
|
| 505 |
+
|
| 506 |
+
for hidden_states, grid_size, merge_size in zip(hidden_states_chunks, grid_sizes, merge_sizes):
|
| 507 |
+
# NOTE: previous implementation, which supports downsampling with any factor
|
| 508 |
+
c = hidden_states.shape[-1]
|
| 509 |
+
hidden_states = hidden_states.view(
|
| 510 |
+
grid_size[0], grid_size[1] // merge_size, grid_size[2] // merge_size, merge_size, merge_size, c
|
| 511 |
+
).permute(0, 1, 3, 2, 4, 5)
|
| 512 |
+
hidden_states = hidden_states.reshape(
|
| 513 |
+
grid_size[0], grid_size[1], grid_size[2], c
|
| 514 |
+
).permute(0, 3, 1, 2)
|
| 515 |
+
hidden_states = torch.nn.functional.interpolate(
|
| 516 |
+
hidden_states,
|
| 517 |
+
size=(grid_size[1] // merge_size, grid_size[2] // merge_size),
|
| 518 |
+
mode='bilinear'
|
| 519 |
+
)
|
| 520 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).view(-1, c)
|
| 521 |
+
|
| 522 |
+
# NOTE: simplified implementation, which only supports downsampling with integer factor
|
| 523 |
+
# NOTE: this implementation is mathematically equivalent to the previous one when merge_size is 1 or 2 but may cause slightly different results
|
| 524 |
+
# hidden_states = hidden_states.view(-1, merge_size * merge_size, hidden_states.size(-1))
|
| 525 |
+
# hidden_states = hidden_states.mean(dim=1)
|
| 526 |
+
|
| 527 |
+
outputs.append(hidden_states)
|
| 528 |
+
|
| 529 |
+
return torch.cat(outputs, dim=0), hidden_states_raw
|
| 530 |
+
|
| 531 |
+
def _init_weights(self, module):
|
| 532 |
+
"""Initialize the weights"""
|
| 533 |
+
if isinstance(module, nn.Embedding):
|
| 534 |
+
default_flax_embed_init(module.weight)
|
| 535 |
+
elif isinstance(module, VisionAttention):
|
| 536 |
+
nn.init.xavier_uniform_(module.q_proj.weight)
|
| 537 |
+
nn.init.xavier_uniform_(module.k_proj.weight)
|
| 538 |
+
nn.init.xavier_uniform_(module.v_proj.weight)
|
| 539 |
+
nn.init.xavier_uniform_(module.out_proj.weight)
|
| 540 |
+
nn.init.zeros_(module.q_proj.bias)
|
| 541 |
+
nn.init.zeros_(module.k_proj.bias)
|
| 542 |
+
nn.init.zeros_(module.v_proj.bias)
|
| 543 |
+
nn.init.zeros_(module.out_proj.bias)
|
| 544 |
+
elif isinstance(module, Videollama3VisionMLP):
|
| 545 |
+
nn.init.xavier_uniform_(module.fc1.weight)
|
| 546 |
+
nn.init.xavier_uniform_(module.fc2.weight)
|
| 547 |
+
nn.init.normal_(module.fc1.bias, std=1e-6)
|
| 548 |
+
nn.init.normal_(module.fc2.bias, std=1e-6)
|
| 549 |
+
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 550 |
+
lecun_normal_(module.weight)
|
| 551 |
+
if module.bias is not None:
|
| 552 |
+
nn.init.zeros_(module.bias)
|
| 553 |
+
elif isinstance(module, LayerNorm):
|
| 554 |
+
module.bias.data.zero_()
|
| 555 |
+
module.weight.data.fill_(1.0)
|
RynnEC/rynnec/rynnec_trainer.py
ADDED
|
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adopted from: https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py
|
| 2 |
+
import os
|
| 3 |
+
import logging
|
| 4 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.utils.data import Sampler
|
| 9 |
+
|
| 10 |
+
from transformers import Trainer
|
| 11 |
+
from transformers.trainer import (
|
| 12 |
+
is_sagemaker_mp_enabled,
|
| 13 |
+
get_parameter_names,
|
| 14 |
+
has_length,
|
| 15 |
+
ALL_LAYERNORM_LAYERS,
|
| 16 |
+
logger,
|
| 17 |
+
TRAINER_STATE_NAME,
|
| 18 |
+
)
|
| 19 |
+
from transformers.utils import (
|
| 20 |
+
is_sagemaker_mp_enabled,
|
| 21 |
+
logging,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def maybe_zero_3(param, ignore_status=False, name=None):
|
| 25 |
+
from deepspeed import zero
|
| 26 |
+
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
| 27 |
+
if hasattr(param, "ds_id"):
|
| 28 |
+
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
|
| 29 |
+
if not ignore_status:
|
| 30 |
+
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
|
| 31 |
+
with zero.GatheredParameters([param]):
|
| 32 |
+
param = param.data.detach().cpu().clone()
|
| 33 |
+
else:
|
| 34 |
+
param = param.detach().cpu().clone()
|
| 35 |
+
return param
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
|
| 39 |
+
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
|
| 40 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
|
| 41 |
+
return to_return
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Borrowed from peft.utils.get_peft_model_state_dict
|
| 45 |
+
def get_peft_state_maybe_zero_3(named_params, bias):
|
| 46 |
+
if bias == "none":
|
| 47 |
+
to_return = {k: t for k, t in named_params if "lora_" in k}
|
| 48 |
+
elif bias == "all":
|
| 49 |
+
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
|
| 50 |
+
elif bias == "lora_only":
|
| 51 |
+
to_return = {}
|
| 52 |
+
maybe_lora_bias = {}
|
| 53 |
+
lora_bias_names = set()
|
| 54 |
+
for k, t in named_params:
|
| 55 |
+
if "lora_" in k:
|
| 56 |
+
to_return[k] = t
|
| 57 |
+
bias_name = k.split("lora_")[0] + "bias"
|
| 58 |
+
lora_bias_names.add(bias_name)
|
| 59 |
+
elif "bias" in k:
|
| 60 |
+
maybe_lora_bias[k] = t
|
| 61 |
+
for k, t in maybe_lora_bias:
|
| 62 |
+
if bias_name in lora_bias_names:
|
| 63 |
+
to_return[bias_name] = t
|
| 64 |
+
else:
|
| 65 |
+
raise NotImplementedError
|
| 66 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
|
| 67 |
+
return to_return
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
|
| 71 |
+
to_return = {k: t for k, t in named_params if "lora_" not in k}
|
| 72 |
+
if require_grad_only:
|
| 73 |
+
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
|
| 74 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
|
| 75 |
+
return to_return
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def find_all_linear_names(model):
|
| 79 |
+
cls = torch.nn.Linear
|
| 80 |
+
lora_module_names = set()
|
| 81 |
+
multimodal_keywords = ['mm_projector', 'vision_encoder', 'vision_resampler', 'text_hidden_fcs', 'region_encoder', 'grounding_encoder']
|
| 82 |
+
for name, module in model.named_modules():
|
| 83 |
+
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
|
| 84 |
+
continue
|
| 85 |
+
if isinstance(module, cls):
|
| 86 |
+
if 'lm_head' in name:
|
| 87 |
+
continue
|
| 88 |
+
lora_module_names.add(name)
|
| 89 |
+
|
| 90 |
+
return list(lora_module_names)
|
| 91 |
+
|
| 92 |
+
def safe_save_model_for_hf_trainer(trainer: Trainer,
|
| 93 |
+
output_dir: str):
|
| 94 |
+
"""Collects the state dict and dump to disk."""
|
| 95 |
+
|
| 96 |
+
if getattr(trainer.args, "is_alignment", False):
|
| 97 |
+
# Only save Adapter
|
| 98 |
+
keys_to_match = ['mm_projector']
|
| 99 |
+
|
| 100 |
+
weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
|
| 101 |
+
trainer.model.config.save_pretrained(output_dir)
|
| 102 |
+
|
| 103 |
+
current_folder = output_dir.split('/')[-1]
|
| 104 |
+
parent_folder = os.path.dirname(output_dir)
|
| 105 |
+
# if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
|
| 106 |
+
if torch.distributed.get_rank() == 0:
|
| 107 |
+
if current_folder.startswith('checkpoint-'):
|
| 108 |
+
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
|
| 109 |
+
os.makedirs(mm_projector_folder, exist_ok=True)
|
| 110 |
+
torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
|
| 111 |
+
else:
|
| 112 |
+
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
|
| 113 |
+
return
|
| 114 |
+
|
| 115 |
+
if trainer.deepspeed:
|
| 116 |
+
torch.cuda.synchronize()
|
| 117 |
+
trainer.save_model(output_dir)
|
| 118 |
+
return
|
| 119 |
+
|
| 120 |
+
state_dict = trainer.model.state_dict()
|
| 121 |
+
if trainer.args.should_save:
|
| 122 |
+
cpu_state_dict = {
|
| 123 |
+
key: value.cpu()
|
| 124 |
+
for key, value in state_dict.items()
|
| 125 |
+
}
|
| 126 |
+
del state_dict
|
| 127 |
+
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def split_to_even_chunks(indices, lengths, num_chunks):
|
| 131 |
+
"""
|
| 132 |
+
Split a list of indices into `chunks` chunks of roughly equal lengths.
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
if len(indices) % num_chunks != 0:
|
| 136 |
+
return [indices[i::num_chunks] for i in range(num_chunks)]
|
| 137 |
+
|
| 138 |
+
num_indices_per_chunk = len(indices) // num_chunks
|
| 139 |
+
|
| 140 |
+
chunks = [[] for _ in range(num_chunks)]
|
| 141 |
+
chunks_lengths = [0 for _ in range(num_chunks)]
|
| 142 |
+
for index in indices:
|
| 143 |
+
shortest_chunk = chunks_lengths.index(min(chunks_lengths))
|
| 144 |
+
chunks[shortest_chunk].append(index)
|
| 145 |
+
chunks_lengths[shortest_chunk] += lengths[index]
|
| 146 |
+
if len(chunks[shortest_chunk]) == num_indices_per_chunk:
|
| 147 |
+
chunks_lengths[shortest_chunk] = float("inf")
|
| 148 |
+
|
| 149 |
+
return chunks
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
|
| 153 |
+
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
| 154 |
+
assert all(l != 0 for l in lengths), "Should not have zero length."
|
| 155 |
+
if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
|
| 156 |
+
# all samples are in the same modality
|
| 157 |
+
return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
|
| 158 |
+
mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
|
| 159 |
+
lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
|
| 160 |
+
|
| 161 |
+
mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
|
| 162 |
+
lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
|
| 163 |
+
megabatch_size = world_size * batch_size
|
| 164 |
+
mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
|
| 165 |
+
lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
|
| 166 |
+
|
| 167 |
+
# last_mm = mm_megabatches[-1]
|
| 168 |
+
# last_lang = lang_megabatches[-1]
|
| 169 |
+
# additional_batch = last_mm + last_lang
|
| 170 |
+
megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
|
| 171 |
+
megabatch_indices = torch.randperm(len(megabatches), generator=generator)
|
| 172 |
+
megabatches = [megabatches[i] for i in megabatch_indices]
|
| 173 |
+
|
| 174 |
+
# if len(additional_batch) > 0:
|
| 175 |
+
# megabatches.append(sorted(additional_batch))
|
| 176 |
+
|
| 177 |
+
return [i for megabatch in megabatches for i in megabatch]
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
|
| 181 |
+
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
| 182 |
+
indices = torch.randperm(len(lengths), generator=generator)
|
| 183 |
+
megabatch_size = world_size * batch_size
|
| 184 |
+
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
|
| 185 |
+
megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
|
| 186 |
+
megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
|
| 187 |
+
|
| 188 |
+
return [i for megabatch in megabatches for batch in megabatch for i in batch]
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class LengthGroupedSampler(Sampler):
|
| 192 |
+
r"""
|
| 193 |
+
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
|
| 194 |
+
keeping a bit of randomness.
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
def __init__(
|
| 198 |
+
self,
|
| 199 |
+
batch_size: int,
|
| 200 |
+
world_size: int,
|
| 201 |
+
lengths: Optional[List[int]] = None,
|
| 202 |
+
generator=None,
|
| 203 |
+
group_by_modality: bool = False,
|
| 204 |
+
):
|
| 205 |
+
if lengths is None:
|
| 206 |
+
raise ValueError("Lengths must be provided.")
|
| 207 |
+
|
| 208 |
+
self.batch_size = batch_size
|
| 209 |
+
self.world_size = world_size
|
| 210 |
+
self.lengths = lengths
|
| 211 |
+
self.generator = generator
|
| 212 |
+
self.group_by_modality = group_by_modality
|
| 213 |
+
|
| 214 |
+
def __len__(self):
|
| 215 |
+
return len(self.lengths)
|
| 216 |
+
|
| 217 |
+
def __iter__(self):
|
| 218 |
+
if self.group_by_modality:
|
| 219 |
+
indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
|
| 220 |
+
else:
|
| 221 |
+
indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
|
| 222 |
+
return iter(indices)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class RynnECTrainer(Trainer):
|
| 226 |
+
|
| 227 |
+
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
| 228 |
+
print('_get_train_sampler')
|
| 229 |
+
print('world size: ', self.args.world_size * self.args.gradient_accumulation_steps)
|
| 230 |
+
if self.train_dataset is None or not has_length(self.train_dataset):
|
| 231 |
+
return None
|
| 232 |
+
print('group_by_modality_length...')
|
| 233 |
+
if self.args.group_by_modality_length:
|
| 234 |
+
lengths = self.train_dataset.modality_lengths
|
| 235 |
+
return LengthGroupedSampler(
|
| 236 |
+
self.args.train_batch_size,
|
| 237 |
+
world_size=self.args.world_size * self.args.gradient_accumulation_steps,
|
| 238 |
+
lengths=lengths,
|
| 239 |
+
group_by_modality=True,
|
| 240 |
+
)
|
| 241 |
+
else:
|
| 242 |
+
return super()._get_train_sampler()
|
| 243 |
+
|
| 244 |
+
def update_history_loss_dict(self,outputs):
|
| 245 |
+
if not hasattr(self,'history_loss_dict'):
|
| 246 |
+
self.history_loss_dict = {}
|
| 247 |
+
for name, value in outputs.items():
|
| 248 |
+
if 'loss' in name and name != 'loss':
|
| 249 |
+
if name not in self.history_loss_dict:
|
| 250 |
+
self.history_loss_dict[name] = value.item()
|
| 251 |
+
else:
|
| 252 |
+
if value != 0:
|
| 253 |
+
self.history_loss_dict[name] = value.item()
|
| 254 |
+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
| 255 |
+
"""
|
| 256 |
+
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
| 257 |
+
|
| 258 |
+
Subclass and override for custom behavior.
|
| 259 |
+
"""
|
| 260 |
+
if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
|
| 261 |
+
labels = inputs.pop("labels")
|
| 262 |
+
else:
|
| 263 |
+
labels = None
|
| 264 |
+
if self.model_accepts_loss_kwargs:
|
| 265 |
+
loss_kwargs = {}
|
| 266 |
+
if num_items_in_batch is not None:
|
| 267 |
+
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
| 268 |
+
inputs = {**inputs, **loss_kwargs}
|
| 269 |
+
outputs = model(**inputs)
|
| 270 |
+
# Save past state if it exists
|
| 271 |
+
# TODO: this needs to be fixed and made cleaner later.
|
| 272 |
+
if self.args.past_index >= 0:
|
| 273 |
+
self._past = outputs[self.args.past_index]
|
| 274 |
+
|
| 275 |
+
if labels is not None:
|
| 276 |
+
unwrapped_model = self.accelerator.unwrap_model(model)
|
| 277 |
+
if _is_peft_model(unwrapped_model):
|
| 278 |
+
model_name = unwrapped_model.base_model.model._get_name()
|
| 279 |
+
else:
|
| 280 |
+
model_name = unwrapped_model._get_name()
|
| 281 |
+
# User-defined compute_loss function
|
| 282 |
+
if self.compute_loss_func is not None:
|
| 283 |
+
loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
|
| 284 |
+
elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
|
| 285 |
+
loss = self.label_smoother(outputs, labels, shift_labels=True)
|
| 286 |
+
else:
|
| 287 |
+
loss = self.label_smoother(outputs, labels)
|
| 288 |
+
else:
|
| 289 |
+
if isinstance(outputs, dict) and "loss" not in outputs:
|
| 290 |
+
raise ValueError(
|
| 291 |
+
"The model did not return a loss from the inputs, only the following keys: "
|
| 292 |
+
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
|
| 293 |
+
)
|
| 294 |
+
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
| 295 |
+
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
| 296 |
+
if isinstance(outputs, dict) and 'mask_bce_loss' in outputs:
|
| 297 |
+
loss_dict = {}
|
| 298 |
+
for name,value in outputs.items():
|
| 299 |
+
if 'loss' in name and name != 'loss':
|
| 300 |
+
loss_value = value.item()
|
| 301 |
+
if loss_value == 0 and hasattr(self,'history_loss_dict'):
|
| 302 |
+
loss_value = self.history_loss_dict[name]
|
| 303 |
+
loss_dict[name] = loss_value
|
| 304 |
+
self.update_history_loss_dict(outputs)
|
| 305 |
+
self.log(loss_dict)
|
| 306 |
+
|
| 307 |
+
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
|
| 308 |
+
loss *= self.accelerator.num_processes
|
| 309 |
+
|
| 310 |
+
return (loss, outputs) if return_outputs else loss
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def create_optimizer(self):
|
| 314 |
+
"""
|
| 315 |
+
Setup the optimizer.
|
| 316 |
+
|
| 317 |
+
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
| 318 |
+
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
|
| 319 |
+
"""
|
| 320 |
+
if is_sagemaker_mp_enabled():
|
| 321 |
+
return super().create_optimizer()
|
| 322 |
+
|
| 323 |
+
opt_model = self.model
|
| 324 |
+
|
| 325 |
+
if self.optimizer is None:
|
| 326 |
+
optimized_parameters = [(n, p) for n, p in opt_model.named_parameters() if p.requires_grad]
|
| 327 |
+
optimizer_grouped_parameters = []
|
| 328 |
+
|
| 329 |
+
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
|
| 330 |
+
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
| 331 |
+
|
| 332 |
+
if self.args.llm_lr is not None:
|
| 333 |
+
lm_parameters = [
|
| 334 |
+
name for name, _ in optimized_parameters if "vision_encoder" not in name and "mm_projector" not in name and "region_encoder" not in name and "grounding_encoder" not in name
|
| 335 |
+
]
|
| 336 |
+
decay_lm_parameters = [name for name in lm_parameters if name in decay_parameters]
|
| 337 |
+
nodecay_lm_parameters = [name for name in lm_parameters if name not in decay_parameters]
|
| 338 |
+
optimizer_grouped_parameters.extend([
|
| 339 |
+
{
|
| 340 |
+
"params": [p for n, p in optimized_parameters if n in decay_lm_parameters],
|
| 341 |
+
"weight_decay": self.args.weight_decay,
|
| 342 |
+
"lr": self.args.llm_lr,
|
| 343 |
+
},
|
| 344 |
+
{
|
| 345 |
+
"params": [p for n, p in optimized_parameters if n in nodecay_lm_parameters],
|
| 346 |
+
"weight_decay": 0.0,
|
| 347 |
+
"lr": self.args.llm_lr,
|
| 348 |
+
}
|
| 349 |
+
])
|
| 350 |
+
|
| 351 |
+
if self.args.mm_projector_lr is not None:
|
| 352 |
+
projector_parameters = [name for name, _ in optimized_parameters if "mm_projector" in name]
|
| 353 |
+
decay_projector_parameters = [name for name in projector_parameters if name in decay_parameters]
|
| 354 |
+
nodecay_projector_parameters = [name for name in projector_parameters if name not in decay_parameters]
|
| 355 |
+
optimizer_grouped_parameters.extend([
|
| 356 |
+
{
|
| 357 |
+
"params": [p for n, p in optimized_parameters if n in decay_projector_parameters],
|
| 358 |
+
"weight_decay": self.args.weight_decay,
|
| 359 |
+
"lr": self.args.mm_projector_lr,
|
| 360 |
+
},
|
| 361 |
+
{
|
| 362 |
+
"params": [p for n, p in optimized_parameters if n in nodecay_projector_parameters],
|
| 363 |
+
"weight_decay": 0.0,
|
| 364 |
+
"lr": self.args.mm_projector_lr,
|
| 365 |
+
}
|
| 366 |
+
])
|
| 367 |
+
|
| 368 |
+
if self.args.vision_encoder_lr is not None:
|
| 369 |
+
vision_encoder_parameters = [name for name, _ in optimized_parameters if "vision_encoder" in name]
|
| 370 |
+
decay_vision_encoder_parameters = [name for name in vision_encoder_parameters if name in decay_parameters]
|
| 371 |
+
nodecay_vision_encoder_parameters = [name for name in vision_encoder_parameters if name not in decay_parameters]
|
| 372 |
+
optimizer_grouped_parameters.extend([
|
| 373 |
+
{
|
| 374 |
+
"params": [p for n, p in optimized_parameters if n in decay_vision_encoder_parameters],
|
| 375 |
+
"weight_decay": self.args.weight_decay,
|
| 376 |
+
"lr": self.args.vision_encoder_lr,
|
| 377 |
+
},
|
| 378 |
+
{
|
| 379 |
+
"params": [p for n, p in optimized_parameters if n in nodecay_vision_encoder_parameters],
|
| 380 |
+
"weight_decay": 0.0,
|
| 381 |
+
"lr": self.args.vision_encoder_lr,
|
| 382 |
+
}
|
| 383 |
+
])
|
| 384 |
+
|
| 385 |
+
if self.args.region_encoder_lr is not None:
|
| 386 |
+
projector_parameters = [name for name, _ in optimized_parameters if "region_encoder" in name]
|
| 387 |
+
decay_projector_parameters = [name for name in projector_parameters if name in decay_parameters]
|
| 388 |
+
nodecay_projector_parameters = [name for name in projector_parameters if name not in decay_parameters]
|
| 389 |
+
optimizer_grouped_parameters.extend([
|
| 390 |
+
{
|
| 391 |
+
"params": [p for n, p in optimized_parameters if n in decay_projector_parameters],
|
| 392 |
+
"weight_decay": self.args.weight_decay,
|
| 393 |
+
"lr": self.args.region_encoder_lr,
|
| 394 |
+
},
|
| 395 |
+
{
|
| 396 |
+
"params": [p for n, p in optimized_parameters if n in nodecay_projector_parameters],
|
| 397 |
+
"weight_decay": 0.0,
|
| 398 |
+
"lr": self.args.region_encoder_lr,
|
| 399 |
+
}
|
| 400 |
+
])
|
| 401 |
+
if self.args.sam_decoder_lr is not None:
|
| 402 |
+
projector_parameters = [name for name, _ in optimized_parameters if "grounding_encoder" in name and "image_encoder" not in name]
|
| 403 |
+
decay_projector_parameters = [name for name in projector_parameters if name in decay_parameters]
|
| 404 |
+
nodecay_projector_parameters = [name for name in projector_parameters if name not in decay_parameters]
|
| 405 |
+
optimizer_grouped_parameters.extend([
|
| 406 |
+
{
|
| 407 |
+
"params": [p for n, p in optimized_parameters if n in decay_projector_parameters],
|
| 408 |
+
"weight_decay": self.args.weight_decay,
|
| 409 |
+
"lr": self.args.sam_decoder_lr,
|
| 410 |
+
},
|
| 411 |
+
{
|
| 412 |
+
"params": [p for n, p in optimized_parameters if n in nodecay_projector_parameters],
|
| 413 |
+
"weight_decay": 0.0,
|
| 414 |
+
"lr": self.args.sam_decoder_lr,
|
| 415 |
+
}
|
| 416 |
+
])
|
| 417 |
+
|
| 418 |
+
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
|
| 419 |
+
|
| 420 |
+
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
| 421 |
+
if optimizer_cls.__name__ == "Adam8bit":
|
| 422 |
+
import bitsandbytes
|
| 423 |
+
|
| 424 |
+
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
| 425 |
+
|
| 426 |
+
skipped = 0
|
| 427 |
+
for module in opt_model.modules():
|
| 428 |
+
if isinstance(module, nn.Embedding):
|
| 429 |
+
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
|
| 430 |
+
logger.info(f"skipped {module}: {skipped/2**20}M params")
|
| 431 |
+
manager.register_module_override(module, "weight", {"optim_bits": 32})
|
| 432 |
+
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
|
| 433 |
+
logger.info(f"skipped: {skipped/2**20}M params")
|
| 434 |
+
|
| 435 |
+
return self.optimizer
|
| 436 |
+
|
| 437 |
+
def _save_checkpoint(self, model, trial, metrics=None):
|
| 438 |
+
if getattr(self.args, 'is_alignment', False):
|
| 439 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
| 440 |
+
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
| 441 |
+
|
| 442 |
+
run_dir = self._get_output_dir(trial=trial)
|
| 443 |
+
output_dir = os.path.join(run_dir, checkpoint_folder)
|
| 444 |
+
|
| 445 |
+
# Only save Adapter
|
| 446 |
+
keys_to_match = ['mm_projector', 'vision_resampler']
|
| 447 |
+
|
| 448 |
+
weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
|
| 449 |
+
|
| 450 |
+
if self.args.local_rank == 0 or self.args.local_rank == -1:
|
| 451 |
+
self.model.config.save_pretrained(output_dir)
|
| 452 |
+
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
|
| 453 |
+
# Save optimizer and scheduler
|
| 454 |
+
self._save_optimizer_and_scheduler(output_dir)
|
| 455 |
+
# Save RNG state
|
| 456 |
+
self._save_rng_state(output_dir)
|
| 457 |
+
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
|
| 458 |
+
self.args.distributed_state.wait_for_everyone()
|
| 459 |
+
else:
|
| 460 |
+
# NOTE: Supporting save complete lora checkpoint during training.
|
| 461 |
+
if self.args.lora_enable:
|
| 462 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
| 463 |
+
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
| 464 |
+
|
| 465 |
+
run_dir = self._get_output_dir(trial=trial)
|
| 466 |
+
output_dir = os.path.join(run_dir, checkpoint_folder)
|
| 467 |
+
|
| 468 |
+
state_dict = get_peft_state_maybe_zero_3(self.model.named_parameters(), self.args.lora_bias)
|
| 469 |
+
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(self.model.named_parameters())
|
| 470 |
+
|
| 471 |
+
# add for qwen2
|
| 472 |
+
if hasattr(self.model, 'base_model') and hasattr(self.model.base_model, 'lm_head'):
|
| 473 |
+
lm_head_weight = self.model.base_model.lm_head.weight.cpu()
|
| 474 |
+
non_lora_state_dict['base_model.lm_head.weight'] = lm_head_weight
|
| 475 |
+
print("add base_model.lm_head.weight")
|
| 476 |
+
else:
|
| 477 |
+
print("The model does not have 'base_model.lm_head.weight' attribute.")
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
if self.args.local_rank == 0 or self.args.local_rank == -1:
|
| 481 |
+
# save for acquring `config.json`
|
| 482 |
+
self.model.config.save_pretrained(output_dir)
|
| 483 |
+
# save for acquring `adapter_config.json`, `adapter_model.bin`
|
| 484 |
+
# self.model.save_pretrained(output_dir, state_dict=state_dict)
|
| 485 |
+
torch.save(non_lora_state_dict, os.path.join(output_dir, 'non_lora_trainables.bin'))
|
| 486 |
+
|
| 487 |
+
# save for acquring lora adapter parameters & trainer states: `adapter_config.json`, `adapter_model.safetensors`
|
| 488 |
+
super(RynnECTrainer, self)._save_checkpoint(model, trial, metrics)
|
| 489 |
+
else:
|
| 490 |
+
super(RynnECTrainer, self)._save_checkpoint(model, trial, metrics)
|
| 491 |
+
|
| 492 |
+
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
| 493 |
+
if getattr(self.args, 'is_alignment', False):
|
| 494 |
+
pass
|
| 495 |
+
else:
|
| 496 |
+
super(RynnECTrainer, self)._save(output_dir, state_dict)
|
RynnEC/rynnec/train.py
ADDED
|
@@ -0,0 +1,832 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adopted from https://github.com/DAMO-NLP-SG/VideoLLaMA3. Below is the original copyright:
|
| 2 |
+
# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
|
| 3 |
+
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
|
| 4 |
+
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
|
| 5 |
+
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
|
| 19 |
+
import math
|
| 20 |
+
import copy
|
| 21 |
+
import json
|
| 22 |
+
import os
|
| 23 |
+
import pathlib
|
| 24 |
+
import random
|
| 25 |
+
import re
|
| 26 |
+
import sys
|
| 27 |
+
import warnings
|
| 28 |
+
import traceback
|
| 29 |
+
from packaging import version
|
| 30 |
+
from dataclasses import dataclass, field
|
| 31 |
+
from typing import Dict, List, Optional, Sequence
|
| 32 |
+
import numpy as np
|
| 33 |
+
import pyarrow as pa
|
| 34 |
+
|
| 35 |
+
# torch-related packages
|
| 36 |
+
# NOTE: torch must be imported before transformers. Otherwise, `Segmentation fault (core dumped)` will occur.
|
| 37 |
+
import torch
|
| 38 |
+
import transformers
|
| 39 |
+
from packaging import version
|
| 40 |
+
import datasets
|
| 41 |
+
from datasets import load_dataset, concatenate_datasets
|
| 42 |
+
from torch.utils.data import Dataset
|
| 43 |
+
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
| 44 |
+
from transformers import logging
|
| 45 |
+
# logging.set_verbosity_error()
|
| 46 |
+
|
| 47 |
+
sys.path.append('./')
|
| 48 |
+
|
| 49 |
+
from rynnec.constants import (IGNORE_INDEX, MODAL_INDEX_MAP,
|
| 50 |
+
NUM_FRAMES, DEFAULT_IMAGE_TOKEN, STREAM_MAX_FRAMES,
|
| 51 |
+
STREAM_DOWNSAMPLING, STREAM_FPS, STREAM_IMAGE_SIZE,
|
| 52 |
+
STREAM_START_TOKEN, STREAM_END_TOKEN, REGION_TOKEN, SEG_TOKEN, REGION_TOKEN_REPLACE)
|
| 53 |
+
from rynnec.mm_utils import (load_images, load_video, DirectResize, load_video_from_ids,
|
| 54 |
+
tokenizer_multimodal_token, annToMask, sam_preprocess_batch)
|
| 55 |
+
from rynnec.model import *
|
| 56 |
+
from rynnec.rynnec_trainer import (
|
| 57 |
+
RynnECTrainer, find_all_linear_names, get_peft_state_maybe_zero_3,
|
| 58 |
+
get_peft_state_non_lora_maybe_zero_3, safe_save_model_for_hf_trainer)
|
| 59 |
+
|
| 60 |
+
# NOTE: fast tokenizer warning issue: https://github.com/huggingface/transformers/issues/5486
|
| 61 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
| 62 |
+
|
| 63 |
+
local_rank = None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def rank0_print(*args):
|
| 67 |
+
if local_rank == 0:
|
| 68 |
+
print(*args)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def set_seed(seed=42):
|
| 72 |
+
"""
|
| 73 |
+
Set the random seed for reproducible results.
|
| 74 |
+
|
| 75 |
+
:param seed: An integer value to be used as the random seed.
|
| 76 |
+
"""
|
| 77 |
+
torch.manual_seed(seed)
|
| 78 |
+
torch.cuda.manual_seed(seed)
|
| 79 |
+
torch.cuda.manual_seed_all(seed) # for multi-GPU setups
|
| 80 |
+
torch.backends.cudnn.deterministic = True
|
| 81 |
+
torch.backends.cudnn.benchmark = False
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def int_with_none(value):
|
| 85 |
+
if value == 'None':
|
| 86 |
+
return None
|
| 87 |
+
return int(value)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@dataclass
|
| 91 |
+
class ModelArguments:
|
| 92 |
+
# LLM Arguments
|
| 93 |
+
model_type: Optional[str] = field(default="rynnec", metadata={"help": "Model type selected in the list: " + ", ".join('rynnec_qwen2')})
|
| 94 |
+
model_path: Optional[str] = field(default="lmsys/vicuna-7b-v1.5")
|
| 95 |
+
version: Optional[str] = field(default="v1", metadata={"help": "Version of the conversation template."})
|
| 96 |
+
freeze_backbone: bool = field(default=False, metadata={"help": "Whether to freeze the LLM backbone."})
|
| 97 |
+
# Connector Arguments
|
| 98 |
+
mm_projector_type: Optional[str] = field(default='linear')
|
| 99 |
+
pretrain_mm_projector: Optional[str] = field(default=None)
|
| 100 |
+
# Vision tower Arguments
|
| 101 |
+
vision_encoder: Optional[str] = field(default=None)
|
| 102 |
+
mm_vision_select_layer: Optional[int] = field(default=-1)
|
| 103 |
+
mm_vision_select_feature: Optional[str] = field(default="patch")
|
| 104 |
+
mm_attn_implementation: Optional[str] = field(default="flash_attention_2")
|
| 105 |
+
# Token downsampling Arguments
|
| 106 |
+
spatial_merge_size: Optional[int] = field(default=1)
|
| 107 |
+
mm_max_length: Optional[int] = field(default=10240)
|
| 108 |
+
use_token_compression: Optional[bool] = field(default=False)
|
| 109 |
+
mask_decoder_model: Optional[str] = field(default="./checkpoints/sam2_hiera_large.pt")
|
| 110 |
+
load_sam2_weight: Optional[bool] = field(default=False)
|
| 111 |
+
training: Optional[bool] = field(default=True)
|
| 112 |
+
has_mask: Optional[bool] = field(default=True)
|
| 113 |
+
|
| 114 |
+
@dataclass
|
| 115 |
+
class DataArguments:
|
| 116 |
+
# Path Arguments
|
| 117 |
+
data_path: List[str] = field(default=None, metadata={"help": "Path to the training data."})
|
| 118 |
+
# image_folder: Optional[str] = field(default=None)
|
| 119 |
+
# video_folder: Optional[str] = field(default=None)
|
| 120 |
+
data_folder: Optional[str] = field(default=None)
|
| 121 |
+
# Loading Arguments
|
| 122 |
+
is_multimodal: bool = False
|
| 123 |
+
fps: Optional[int] = field(default=None)
|
| 124 |
+
max_frames: Optional[int_with_none] = field(default=None)
|
| 125 |
+
# Preprocess Arguments
|
| 126 |
+
image_aspect_ratio: str = 'square'
|
| 127 |
+
use_batch_flattening: bool = field(default=False, metadata={"help": "Whether to flatten the in-batch sequences of variable lengths."})
|
| 128 |
+
dataset_cache_dir: Optional[str] = field(default=None)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@dataclass
|
| 132 |
+
class TrainingArguments(transformers.TrainingArguments):
|
| 133 |
+
# shut auto processing (_remove_unused_columns) of transformers Trainer
|
| 134 |
+
remove_unused_columns: bool = field(default=False)
|
| 135 |
+
|
| 136 |
+
optim: str = field(default="adamw_torch")
|
| 137 |
+
# Training learning rate Arguments
|
| 138 |
+
vision_encoder_lr: Optional[float] = None
|
| 139 |
+
mm_projector_lr: Optional[float] = None
|
| 140 |
+
llm_lr: Optional[float] = None
|
| 141 |
+
region_encoder_lr: Optional[float] = None
|
| 142 |
+
sam_encoder_lr: Optional[float] = None
|
| 143 |
+
sam_decoder_lr: Optional[float] = None
|
| 144 |
+
# Training Data Arguments
|
| 145 |
+
group_by_modality_length: bool = field(default=False)
|
| 146 |
+
model_max_length: int = field(
|
| 147 |
+
default=512,
|
| 148 |
+
metadata={
|
| 149 |
+
"help":
|
| 150 |
+
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
| 151 |
+
},
|
| 152 |
+
)
|
| 153 |
+
# Lora or Quant Arguments
|
| 154 |
+
double_quant: bool = field(
|
| 155 |
+
default=True,
|
| 156 |
+
metadata={"help": "Compress the quantization statistics through double quantization."}
|
| 157 |
+
)
|
| 158 |
+
quant_type: str = field(
|
| 159 |
+
default="nf4",
|
| 160 |
+
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
|
| 161 |
+
)
|
| 162 |
+
bits: int = field(
|
| 163 |
+
default=16,
|
| 164 |
+
metadata={"help": "How many bits to use."}
|
| 165 |
+
)
|
| 166 |
+
lora_enable: bool = False
|
| 167 |
+
lora_r: int = 64
|
| 168 |
+
lora_alpha: int = 16
|
| 169 |
+
lora_dropout: float = 0.05
|
| 170 |
+
lora_weight_path: str = ""
|
| 171 |
+
lora_bias: str = "none"
|
| 172 |
+
|
| 173 |
+
use_workload_balancing: bool = field(default=False, metadata={"help": "Whether to use data balancing."})
|
| 174 |
+
loss_reduction_scope: str = field(default="batch", metadata={"help": "Loss reduction scope."})
|
| 175 |
+
context_parallel_size: int = field(default=1, metadata={"help": "Context parallel size."})
|
| 176 |
+
use_liger_kernel: bool = field(default=False, metadata={"help": "Whether to use Liger Kernel."})
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class LazySupervisedDataset(Dataset):
|
| 181 |
+
"""Dataset for supervised fine-tuning."""
|
| 182 |
+
|
| 183 |
+
def __init__(self, data_path: str, vlprocessor, data_args: DataArguments):
|
| 184 |
+
super(LazySupervisedDataset, self).__init__()
|
| 185 |
+
data_objs = []
|
| 186 |
+
|
| 187 |
+
try:
|
| 188 |
+
for data in data_path:
|
| 189 |
+
# NOTE: load_dataset can process both json or jsonl files
|
| 190 |
+
if data.endswith(".json") or data.endswith(".jsonl"):
|
| 191 |
+
data_objs.append(load_dataset("json", data_files=data, cache_dir=data_args.dataset_cache_dir)["train"])
|
| 192 |
+
else:
|
| 193 |
+
raise Exception(f"Unsupported file format (<{data}>)!")
|
| 194 |
+
list_data_dict = concatenate_datasets(data_objs)
|
| 195 |
+
except:
|
| 196 |
+
traceback.print_exc()
|
| 197 |
+
# NOTE: compatible with the old version
|
| 198 |
+
list_data_dict = []
|
| 199 |
+
for data in data_path:
|
| 200 |
+
if data.endswith(".json"):
|
| 201 |
+
data = json.load(open(data, "r"))
|
| 202 |
+
for i in data:
|
| 203 |
+
i['id'] = len(list_data_dict)
|
| 204 |
+
list_data_dict.append(i)
|
| 205 |
+
elif data.endswith(".jsonl"):
|
| 206 |
+
with open(data, "r", encoding="utf-8") as fp:
|
| 207 |
+
for line in fp:
|
| 208 |
+
line = line.strip()
|
| 209 |
+
obj = json.loads(line)
|
| 210 |
+
obj["id"] = len(list_data_dict)
|
| 211 |
+
list_data_dict.append(obj)
|
| 212 |
+
else:
|
| 213 |
+
raise Exception(f"Unsupported file format (<{data}>)!!!")
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
rank0_print("Formatting inputs...Skip in lazy mode")
|
| 217 |
+
self.vlprocessor = vlprocessor
|
| 218 |
+
self.list_data_dict = list_data_dict
|
| 219 |
+
self.data_args = data_args
|
| 220 |
+
|
| 221 |
+
img_size=1024
|
| 222 |
+
self.img_size = img_size
|
| 223 |
+
self.sam_transform = DirectResize(img_size)
|
| 224 |
+
|
| 225 |
+
def __len__(self):
|
| 226 |
+
return len(self.list_data_dict)
|
| 227 |
+
|
| 228 |
+
@property
|
| 229 |
+
def lengths(self):
|
| 230 |
+
length_list = []
|
| 231 |
+
for sample in self.list_data_dict:
|
| 232 |
+
img_tokens = 576 if 'image' in sample else 0
|
| 233 |
+
length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
|
| 234 |
+
return length_list
|
| 235 |
+
|
| 236 |
+
@property
|
| 237 |
+
def modality_lengths(self):
|
| 238 |
+
length_list = []
|
| 239 |
+
for sample in self.list_data_dict:
|
| 240 |
+
cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
|
| 241 |
+
if cur_len==0:
|
| 242 |
+
cur_len = 1
|
| 243 |
+
cur_len = cur_len if 'masks' in sample and sample['masks'] is not None and ('seg' not in sample or sample['seg'] is None) else -cur_len
|
| 244 |
+
length_list.append(cur_len)
|
| 245 |
+
return length_list
|
| 246 |
+
|
| 247 |
+
def _convert_normal(self, data_dict):
|
| 248 |
+
data_folder = self.data_args.data_folder
|
| 249 |
+
conversation = copy.deepcopy(data_dict["conversations"])
|
| 250 |
+
|
| 251 |
+
# data sanity check and repair
|
| 252 |
+
start_idx = 0
|
| 253 |
+
for sentence in conversation:
|
| 254 |
+
if sentence["from"] == "human" or sentence["from"] == "system":
|
| 255 |
+
break
|
| 256 |
+
start_idx += 1
|
| 257 |
+
if start_idx > 0:
|
| 258 |
+
warnings.warn(f"Find {start_idx} non-user sentences at the beginning of the conversation, remove them automatically!")
|
| 259 |
+
conversation = conversation[start_idx:]
|
| 260 |
+
assert len(conversation) > 1, f"Invalid conversation"
|
| 261 |
+
|
| 262 |
+
mask_ids = []
|
| 263 |
+
|
| 264 |
+
if 'image' in data_dict and data_dict['image'] is not None:
|
| 265 |
+
modal = 'image'
|
| 266 |
+
if all(not "<image>" in sentence["value"] for sentence in conversation):
|
| 267 |
+
warnings.warn(f"Image tag not found in the conversation, add it automatically at the beginning!")
|
| 268 |
+
conversation[0]["value"] = "<image>" + conversation[0]["value"]
|
| 269 |
+
image_file = data_dict['image']
|
| 270 |
+
if isinstance(image_file, list):
|
| 271 |
+
image_file = [os.path.join(data_folder, f) for f in image_file]
|
| 272 |
+
else:
|
| 273 |
+
image_file = os.path.join(data_folder, image_file)
|
| 274 |
+
images = load_images(image_file)
|
| 275 |
+
|
| 276 |
+
masks = []
|
| 277 |
+
if 'masks' in data_dict and data_dict['masks'] is not None:
|
| 278 |
+
if 'height' in data_dict:
|
| 279 |
+
h = data_dict['height']
|
| 280 |
+
w = data_dict['width']
|
| 281 |
+
else:
|
| 282 |
+
h = None
|
| 283 |
+
w = None
|
| 284 |
+
|
| 285 |
+
if isinstance(data_dict['masks'], str):
|
| 286 |
+
masks_ = json.load(open(data_dict['masks']))
|
| 287 |
+
else:
|
| 288 |
+
masks_= data_dict['masks']
|
| 289 |
+
image2maskids = []
|
| 290 |
+
mask_idx = 0
|
| 291 |
+
for ann in masks_:
|
| 292 |
+
image2maskids_ = []
|
| 293 |
+
mask = annToMask(ann, h, w)
|
| 294 |
+
masks.append(mask)
|
| 295 |
+
mask_ids.append(0)
|
| 296 |
+
image2maskids_.append(mask_idx)
|
| 297 |
+
mask_idx += 1
|
| 298 |
+
image2maskids.append(image2maskids_)
|
| 299 |
+
masks = np.stack(masks, axis=0)
|
| 300 |
+
masks = torch.from_numpy(masks)
|
| 301 |
+
|
| 302 |
+
seg_flag = False
|
| 303 |
+
for conv in conversation:
|
| 304 |
+
conv['value'] = conv['value'].replace(REGION_TOKEN_REPLACE, f'[{REGION_TOKEN}]')
|
| 305 |
+
if SEG_TOKEN in conv['value']:
|
| 306 |
+
seg_flag = True
|
| 307 |
+
if seg_flag is False:
|
| 308 |
+
image2maskids = []
|
| 309 |
+
else:
|
| 310 |
+
mask_ids = [-10000 for i in range(len(mask_ids))]
|
| 311 |
+
else:
|
| 312 |
+
image2maskids = []
|
| 313 |
+
masks = torch.zeros((1, 336, 336))
|
| 314 |
+
mask_ids.append(-10000)
|
| 315 |
+
|
| 316 |
+
elif 'video' in data_dict and data_dict['video'] is not None:
|
| 317 |
+
modal = 'video'
|
| 318 |
+
if all(not "<video>" in sentence["value"] for sentence in conversation):
|
| 319 |
+
warnings.warn(f"Video tag not found in the conversation, add it automatically at the beginning!")
|
| 320 |
+
conversation[0]["value"] = "<video>" + conversation[0]["value"]
|
| 321 |
+
if 'video_root' in data_dict and data_dict['video_root'] is not None:
|
| 322 |
+
video_root = data_dict['video_root']
|
| 323 |
+
video_file = [os.path.join(video_root,d) for d in data_dict['video']]
|
| 324 |
+
else:
|
| 325 |
+
video_file = data_dict['video']
|
| 326 |
+
|
| 327 |
+
if not isinstance(video_file, list):
|
| 328 |
+
video_file = [video_file]
|
| 329 |
+
if isinstance(video_file, list) and len(video_file) == 1 and ('timestamps' not in data_dict or data_dict['timestamps'] is None):
|
| 330 |
+
video_file = os.path.join(data_folder, video_file[0])
|
| 331 |
+
must_sample_frames = []
|
| 332 |
+
if 'masks' in data_dict and data_dict['masks'] is not None:
|
| 333 |
+
if isinstance(data_dict['masks'], str):
|
| 334 |
+
masks_ = json.load(open(data_dict['masks']))
|
| 335 |
+
else:
|
| 336 |
+
masks_= data_dict['masks']
|
| 337 |
+
for ann in masks_:
|
| 338 |
+
for k in ann.keys():
|
| 339 |
+
must_sample_frames.append(int(k))
|
| 340 |
+
images, timestamps, mask_ids = load_video_from_ids(video_file, fps=self.data_args.fps, max_frames=self.data_args.max_frames, must_sample_frames=must_sample_frames)
|
| 341 |
+
elif isinstance(video_file, list): #images
|
| 342 |
+
images = []
|
| 343 |
+
for vf in video_file:
|
| 344 |
+
images+=load_images(os.path.join(data_folder, vf))
|
| 345 |
+
timestamps = data_dict['timestamps']
|
| 346 |
+
|
| 347 |
+
else:
|
| 348 |
+
raise ValueError(f"Unsupported video format: {video_file}")
|
| 349 |
+
images = [images]
|
| 350 |
+
masks = []
|
| 351 |
+
mask_nums = []
|
| 352 |
+
image2maskids = []
|
| 353 |
+
maskid = 0
|
| 354 |
+
|
| 355 |
+
if 'masks' in data_dict and data_dict['masks'] is not None:
|
| 356 |
+
if 'mask_ids' in data_dict and data_dict['mask_ids'] is not None:
|
| 357 |
+
mask_ids = data_dict["mask_ids"]
|
| 358 |
+
if 'height' in data_dict:
|
| 359 |
+
h = data_dict['height']
|
| 360 |
+
w = data_dict['width']
|
| 361 |
+
else:
|
| 362 |
+
h = None
|
| 363 |
+
w = None
|
| 364 |
+
|
| 365 |
+
if isinstance(data_dict['masks'], str):
|
| 366 |
+
masks_ = json.load(open(data_dict['masks']))
|
| 367 |
+
else:
|
| 368 |
+
masks_= data_dict['masks']
|
| 369 |
+
for ann in masks_:
|
| 370 |
+
image2maskids_ = [None]*len(video_file)
|
| 371 |
+
for k in ann.keys():
|
| 372 |
+
mask = annToMask(ann[k], h, w)
|
| 373 |
+
masks.append(mask)
|
| 374 |
+
image2maskids_[mask_ids[maskid]] = maskid
|
| 375 |
+
maskid+=1
|
| 376 |
+
image2maskids.append(image2maskids_)
|
| 377 |
+
|
| 378 |
+
mask_nums.append(len(ann.keys()))
|
| 379 |
+
masks = np.stack(masks, axis=0)
|
| 380 |
+
masks = torch.from_numpy(masks)
|
| 381 |
+
|
| 382 |
+
conv_i = 0
|
| 383 |
+
region_num = 0
|
| 384 |
+
seg_flag = False
|
| 385 |
+
for idx in range(len(mask_nums)):
|
| 386 |
+
while '<region>' not in conversation[conv_i]['value'] and conv_i<len(conversation)-1:
|
| 387 |
+
conv_i+=1
|
| 388 |
+
conversation[conv_i]['value'] = conversation[conv_i]['value'].replace('<region>', "["+REGION_TOKEN*mask_nums[idx]+"]", 1)
|
| 389 |
+
region_num += mask_nums[idx]
|
| 390 |
+
if '[SEG]' in conversation[conv_i]['value']:
|
| 391 |
+
seg_flag = True
|
| 392 |
+
|
| 393 |
+
if seg_flag is False:
|
| 394 |
+
image2maskids = []
|
| 395 |
+
else:
|
| 396 |
+
mask_ids = [-10000 for i in range(len(mask_ids))]
|
| 397 |
+
# assert region_num == len(masks), f"error in {conversation}"
|
| 398 |
+
|
| 399 |
+
else:
|
| 400 |
+
image2maskids = []
|
| 401 |
+
masks = torch.zeros((1, 336, 336))
|
| 402 |
+
mask_ids.append(-10000)
|
| 403 |
+
|
| 404 |
+
else:
|
| 405 |
+
modal = 'text'
|
| 406 |
+
image2maskids = []
|
| 407 |
+
images = None
|
| 408 |
+
masks = torch.zeros((1, 336, 336))
|
| 409 |
+
sam_size = (336, 336)
|
| 410 |
+
sam_images = torch.zeros(1, 3, self.img_size, self.img_size)
|
| 411 |
+
mask_ids = [-10000]
|
| 412 |
+
|
| 413 |
+
if images is not None and len(images)>0:
|
| 414 |
+
sam_images = []
|
| 415 |
+
sam_size = None
|
| 416 |
+
if modal=='video':
|
| 417 |
+
for image in images[0]:
|
| 418 |
+
sam_image = self.sam_transform.apply_image(np.array(image))
|
| 419 |
+
sam_images.append(sam_image)
|
| 420 |
+
if sam_size is None:
|
| 421 |
+
sam_size = sam_image.shape[:2]
|
| 422 |
+
else:
|
| 423 |
+
for image in images:
|
| 424 |
+
sam_image = self.sam_transform.apply_image(np.array(image))
|
| 425 |
+
sam_images.append(sam_image)
|
| 426 |
+
if sam_size is None:
|
| 427 |
+
sam_size = sam_image.shape[:2]
|
| 428 |
+
sam_images = np.array(sam_images)
|
| 429 |
+
sam_images = torch.from_numpy(sam_images).permute(0, 3, 1, 2).contiguous()
|
| 430 |
+
sam_images = sam_preprocess_batch(sam_images)
|
| 431 |
+
|
| 432 |
+
messages = []
|
| 433 |
+
for conv in conversation:
|
| 434 |
+
if conv["from"] == "human":
|
| 435 |
+
# replace video tag to image tag for unified processing
|
| 436 |
+
# conv["value"] = conv["value"].replace("<video>", "<image>" * len(images))
|
| 437 |
+
chunks = conv["value"].split("<image>" if modal == 'image' else "<video>")
|
| 438 |
+
messages.append({
|
| 439 |
+
"role": "user",
|
| 440 |
+
"content": []
|
| 441 |
+
})
|
| 442 |
+
|
| 443 |
+
for chunk_idx in range(1, 2 * len(chunks)):
|
| 444 |
+
if chunk_idx % 2 == 1:
|
| 445 |
+
chunk = chunks[chunk_idx // 2].strip()
|
| 446 |
+
messages[-1]["content"].append({"type": "text", "text": chunk}) if chunk else None
|
| 447 |
+
else:
|
| 448 |
+
if modal == 'image':
|
| 449 |
+
messages[-1]["content"].append({"type": "image"})
|
| 450 |
+
elif modal == 'video':
|
| 451 |
+
messages[-1]["content"].append({"type": "video", "num_frames": len(images[0]), "time": timestamps})
|
| 452 |
+
else:
|
| 453 |
+
messages.append({
|
| 454 |
+
"role": "assistant",
|
| 455 |
+
"content": conv['value']
|
| 456 |
+
})
|
| 457 |
+
|
| 458 |
+
# TODO: dynamic downsampling
|
| 459 |
+
# image_downsampling = self.data_args.spatial_merge_size
|
| 460 |
+
image_downsampling = 2 if modal == "video" else 1
|
| 461 |
+
|
| 462 |
+
return modal, images, messages, image_downsampling, masks, mask_ids, sam_images, sam_size, image2maskids
|
| 463 |
+
|
| 464 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
| 465 |
+
data_dict = self.list_data_dict[i]
|
| 466 |
+
|
| 467 |
+
try:
|
| 468 |
+
modal, images, messages, image_downsampling, masks, mask_ids, sam_images, sam_size, image2maskids = self._convert_normal(data_dict)
|
| 469 |
+
|
| 470 |
+
data_dict = self.vlprocessor(
|
| 471 |
+
images=images,
|
| 472 |
+
text=messages,
|
| 473 |
+
merge_size=image_downsampling,
|
| 474 |
+
return_labels=True,
|
| 475 |
+
return_tensors="pt",
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
if modal == 'text':
|
| 479 |
+
unit_size = self.vlprocessor.image_processor.patch_size**2 * 3
|
| 480 |
+
data_dict['pixel_values'] = torch.zeros(self.vlprocessor.image_merge_size**2, unit_size)
|
| 481 |
+
data_dict['grid_sizes'] = torch.as_tensor([[1, self.vlprocessor.image_merge_size, self.vlprocessor.image_merge_size]])
|
| 482 |
+
data_dict['merge_sizes'] = torch.as_tensor([self.vlprocessor.image_merge_size])
|
| 483 |
+
elif modal == 'image' or modal == 'video':
|
| 484 |
+
assert len(data_dict['pixel_values']) > 0 and len(data_dict['grid_sizes']) > 0, f"Invalid image data: {data_dict['pixel_values']}, {data_dict['grid_sizes']}"
|
| 485 |
+
|
| 486 |
+
data_dict['modals'] = [modal] if isinstance(modal, str) else modal
|
| 487 |
+
data_dict['masks'] = masks
|
| 488 |
+
data_dict['mask_ids'] = mask_ids
|
| 489 |
+
data_dict['idx'] = i
|
| 490 |
+
data_dict['sam_images'] = sam_images
|
| 491 |
+
data_dict['sam_size'] = sam_size
|
| 492 |
+
data_dict['image2maskids'] = image2maskids
|
| 493 |
+
|
| 494 |
+
except Exception as e:
|
| 495 |
+
traceback.print_exc()
|
| 496 |
+
backup_idx = random.randint(0, len(self.list_data_dict) - 1)
|
| 497 |
+
print(f"Encounted error when process {i}-th example: {data_dict}, use {backup_idx}-th example instead!!!")
|
| 498 |
+
return self.__getitem__(backup_idx)
|
| 499 |
+
|
| 500 |
+
return data_dict
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
@dataclass
|
| 504 |
+
class DataCollatorForSupervisedDataset(object):
|
| 505 |
+
"""Collate examples for supervised fine-tuning."""
|
| 506 |
+
|
| 507 |
+
vlprocessor: transformers.ProcessorMixin
|
| 508 |
+
|
| 509 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
| 510 |
+
input_ids, labels = tuple([instance[key] for instance in instances]
|
| 511 |
+
for key in ("input_ids", "labels"))
|
| 512 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
| 513 |
+
input_ids,
|
| 514 |
+
batch_first=True,
|
| 515 |
+
padding_value=self.vlprocessor.tokenizer.pad_token_id)
|
| 516 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels,
|
| 517 |
+
batch_first=True,
|
| 518 |
+
padding_value=IGNORE_INDEX)
|
| 519 |
+
input_ids = input_ids[:, :self.vlprocessor.tokenizer.model_max_length]
|
| 520 |
+
labels = labels[:, :self.vlprocessor.tokenizer.model_max_length]
|
| 521 |
+
attention_mask = input_ids.ne(self.vlprocessor.tokenizer.pad_token_id)
|
| 522 |
+
position_ids = attention_mask.cumsum(dim=-1) - 1
|
| 523 |
+
|
| 524 |
+
batch = dict(
|
| 525 |
+
input_ids=input_ids,
|
| 526 |
+
labels=labels,
|
| 527 |
+
attention_mask=input_ids.ne(self.vlprocessor.tokenizer.pad_token_id),
|
| 528 |
+
position_ids=position_ids
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
# work for 'images' argument in `prepare_inputs_labels_for_multimodal`
|
| 532 |
+
batch["pixel_values"] = torch.cat([x["pixel_values"] for x in instances])
|
| 533 |
+
batch["grid_sizes"] = torch.cat([x["grid_sizes"] for x in instances])
|
| 534 |
+
batch["merge_sizes"] = torch.cat([x["merge_sizes"] for x in instances])
|
| 535 |
+
batch["modals"] = sum([x["modals"] for x in instances], [])
|
| 536 |
+
|
| 537 |
+
batch['mask_ids'] = []
|
| 538 |
+
mask_idx_start = 0
|
| 539 |
+
for instance in instances:
|
| 540 |
+
if len(instance['mask_ids'])>0:
|
| 541 |
+
batch['mask_ids'].extend([idx+mask_idx_start for idx in instance['mask_ids']])
|
| 542 |
+
# print(int(instance['grid_sizes'][0][0]))
|
| 543 |
+
|
| 544 |
+
mask_idx_start += int(instance['grid_sizes'][0][0])
|
| 545 |
+
batch["masks"] = [x["masks"] for x in instances]
|
| 546 |
+
batch["sam_images"] = [x["sam_images"] for x in instances]
|
| 547 |
+
batch["sam_size"] = [x["sam_size"] for x in instances]
|
| 548 |
+
batch["image2maskids"] = [x["image2maskids"] for x in instances]
|
| 549 |
+
batch["idxes"] = [x["idx"] for x in instances]
|
| 550 |
+
return batch
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def make_supervised_data_module(vlprocessor, data_args) -> Dict:
|
| 554 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
| 555 |
+
train_dataset = LazySupervisedDataset(
|
| 556 |
+
vlprocessor=vlprocessor,
|
| 557 |
+
# data_folder=data_args.data_folder,
|
| 558 |
+
data_path=data_args.data_path,
|
| 559 |
+
data_args=data_args
|
| 560 |
+
)
|
| 561 |
+
data_collator = DataCollatorForSupervisedDataset(vlprocessor=vlprocessor)
|
| 562 |
+
return dict(train_dataset=train_dataset,
|
| 563 |
+
eval_dataset=None,
|
| 564 |
+
data_collator=data_collator)
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def train(attn_implementation=None):
|
| 568 |
+
global local_rank
|
| 569 |
+
set_seed(42)
|
| 570 |
+
|
| 571 |
+
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
| 572 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 573 |
+
|
| 574 |
+
local_rank = training_args.local_rank
|
| 575 |
+
|
| 576 |
+
if local_rank == 0:
|
| 577 |
+
print('------model args------')
|
| 578 |
+
print(model_args)
|
| 579 |
+
print('------data args------')
|
| 580 |
+
print(data_args)
|
| 581 |
+
print('------training args------')
|
| 582 |
+
print(training_args)
|
| 583 |
+
|
| 584 |
+
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
|
| 585 |
+
|
| 586 |
+
bnb_model_from_pretrained_args = {}
|
| 587 |
+
if training_args.bits in [4, 8]:
|
| 588 |
+
from transformers import BitsAndBytesConfig
|
| 589 |
+
bnb_model_from_pretrained_args.update(dict(
|
| 590 |
+
# device_map={"": training_args.device},
|
| 591 |
+
# BUG: High version transformers report error:
|
| 592 |
+
# ValueError: You can't pass `load_in_4bit`or `load_in_8bit` as a kwarg when passing `quantization_config` argument at the same time
|
| 593 |
+
# load_in_4bit=training_args.bits == 4,
|
| 594 |
+
# load_in_8bit=training_args.bits == 8,
|
| 595 |
+
quantization_config=BitsAndBytesConfig(
|
| 596 |
+
load_in_4bit=training_args.bits == 4,
|
| 597 |
+
load_in_8bit=training_args.bits == 8,
|
| 598 |
+
llm_int8_skip_modules=["mm_projector"],
|
| 599 |
+
llm_int8_threshold=6.0,
|
| 600 |
+
llm_int8_has_fp16_weight=False,
|
| 601 |
+
bnb_4bit_compute_dtype=compute_dtype,
|
| 602 |
+
bnb_4bit_use_double_quant=training_args.double_quant,
|
| 603 |
+
bnb_4bit_quant_type=training_args.quant_type, # {'fp4', 'nf4'}
|
| 604 |
+
bnb_4bit_quant_storage=compute_dtype,
|
| 605 |
+
)
|
| 606 |
+
))
|
| 607 |
+
|
| 608 |
+
config = RynnecQwen2Config.from_pretrained(model_args.model_path)
|
| 609 |
+
|
| 610 |
+
config._attn_implementation = attn_implementation
|
| 611 |
+
# NOTE: active spatial_merge_size arguments
|
| 612 |
+
config.spatial_merge_size = model_args.spatial_merge_size
|
| 613 |
+
config.mm_max_length = model_args.mm_max_length
|
| 614 |
+
config.use_token_compression = model_args.use_token_compression
|
| 615 |
+
config.loss_reduction_scope = training_args.loss_reduction_scope
|
| 616 |
+
config.mask_decoder_model = model_args.mask_decoder_model
|
| 617 |
+
config.training = model_args.training
|
| 618 |
+
config.has_mask = model_args.has_mask
|
| 619 |
+
|
| 620 |
+
if model_args.vision_encoder is not None:
|
| 621 |
+
model = RynnecQwen2ForCausalLM.from_pretrained(
|
| 622 |
+
model_args.model_path,
|
| 623 |
+
config=config,
|
| 624 |
+
torch_dtype=compute_dtype,
|
| 625 |
+
do_sample=True,
|
| 626 |
+
**bnb_model_from_pretrained_args
|
| 627 |
+
)
|
| 628 |
+
if 'mixtral' in model_args.model_type:
|
| 629 |
+
import deepspeed
|
| 630 |
+
deepspeed.utils.set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|
| 631 |
+
else:
|
| 632 |
+
model = transformers.LlamaForCausalLM.from_pretrained(
|
| 633 |
+
model_args.model_path,
|
| 634 |
+
config=config,
|
| 635 |
+
torch_dtype=compute_dtype,
|
| 636 |
+
do_sample=True,
|
| 637 |
+
**bnb_model_from_pretrained_args
|
| 638 |
+
)
|
| 639 |
+
model.config.use_cache = False
|
| 640 |
+
if model_args.freeze_backbone:
|
| 641 |
+
model.model.requires_grad_(False)
|
| 642 |
+
|
| 643 |
+
if training_args.bits in [4, 8]:
|
| 644 |
+
from peft import prepare_model_for_kbit_training
|
| 645 |
+
model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
|
| 646 |
+
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
|
| 647 |
+
|
| 648 |
+
if training_args.gradient_checkpointing:
|
| 649 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 650 |
+
model.enable_input_require_grads()
|
| 651 |
+
else:
|
| 652 |
+
def make_inputs_require_grad(module, input, output):
|
| 653 |
+
output.requires_grad_(True)
|
| 654 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 655 |
+
|
| 656 |
+
if training_args.lora_enable:
|
| 657 |
+
from peft import LoraConfig, get_peft_model
|
| 658 |
+
lora_config = LoraConfig(
|
| 659 |
+
r=training_args.lora_r,
|
| 660 |
+
lora_alpha=training_args.lora_alpha,
|
| 661 |
+
target_modules=find_all_linear_names(model),
|
| 662 |
+
lora_dropout=training_args.lora_dropout,
|
| 663 |
+
bias=training_args.lora_bias,
|
| 664 |
+
task_type="CAUSAL_LM",
|
| 665 |
+
)
|
| 666 |
+
if training_args.bits == 16:
|
| 667 |
+
if training_args.bf16:
|
| 668 |
+
model.to(torch.bfloat16)
|
| 669 |
+
if training_args.fp16:
|
| 670 |
+
model.to(torch.float16)
|
| 671 |
+
rank0_print("Adding LoRA adapters...")
|
| 672 |
+
model = get_peft_model(model, lora_config)
|
| 673 |
+
|
| 674 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 675 |
+
model_args.model_path,
|
| 676 |
+
model_max_length=training_args.model_max_length,
|
| 677 |
+
padding_side="right",
|
| 678 |
+
use_fast=True,
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
if tokenizer.pad_token is None:
|
| 682 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 683 |
+
|
| 684 |
+
if model_args.vision_encoder is not None:
|
| 685 |
+
# initialize vision encoder + multi-modal projector
|
| 686 |
+
model.get_model().initialize_vision_modules(model_args=model_args, fsdp=training_args.fsdp)
|
| 687 |
+
|
| 688 |
+
if model_args.load_sam2_weight is True:
|
| 689 |
+
model.get_model().build_mask_decoder(model.get_model().config)
|
| 690 |
+
model.load_sam2_weights(model_args.mask_decoder_model)
|
| 691 |
+
|
| 692 |
+
vision_encoder = model.get_vision_encoder()
|
| 693 |
+
vision_encoder.to(dtype=compute_dtype, device=training_args.device)
|
| 694 |
+
|
| 695 |
+
vision_encoder.image_processor.max_tokens = model_args.mm_max_length
|
| 696 |
+
mm_projector = model.get_mm_projector()
|
| 697 |
+
mm_projector.to(dtype=compute_dtype if training_args.bf16 else torch.float16, device=training_args.device)
|
| 698 |
+
|
| 699 |
+
data_args.is_multimodal = True
|
| 700 |
+
|
| 701 |
+
model.config.tokenizer_padding_side = tokenizer.padding_side
|
| 702 |
+
model.config.tokenizer_model_max_length = tokenizer.model_max_length
|
| 703 |
+
|
| 704 |
+
if training_args.bits in [4, 8]:
|
| 705 |
+
model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
|
| 706 |
+
|
| 707 |
+
# decoupled learning rate
|
| 708 |
+
model.config.llm_lr = training_args.llm_lr
|
| 709 |
+
model.config.vision_encoder_lr = training_args.vision_encoder_lr
|
| 710 |
+
model.config.mm_projector_lr = training_args.mm_projector_lr
|
| 711 |
+
model.config.region_encoder_lr = training_args.region_encoder_lr
|
| 712 |
+
model.config.sam_decoder_lr = training_args.sam_decoder_lr
|
| 713 |
+
model.config.sam_encoder_lr = training_args.sam_encoder_lr
|
| 714 |
+
model.config.dice_loss_weight = 0.5
|
| 715 |
+
model.config.bce_loss_weight = 2.0
|
| 716 |
+
|
| 717 |
+
if model.config.llm_lr is None:
|
| 718 |
+
for p in model.get_model().parameters():
|
| 719 |
+
p.requires_grad = False
|
| 720 |
+
for p in model.get_model().vision_encoder.parameters():
|
| 721 |
+
p.requires_grad = True
|
| 722 |
+
for p in model.get_model().mm_projector.parameters():
|
| 723 |
+
p.requires_grad = True
|
| 724 |
+
for p in model.get_model().region_encoder.parameters():
|
| 725 |
+
p.requires_grad = True
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
if model.config.vision_encoder_lr is None:
|
| 729 |
+
for p in model.get_model().vision_encoder.parameters():
|
| 730 |
+
p.requires_grad = False
|
| 731 |
+
|
| 732 |
+
if model.config.mm_projector_lr is None:
|
| 733 |
+
for p in model.get_model().mm_projector.parameters():
|
| 734 |
+
p.requires_grad = False
|
| 735 |
+
|
| 736 |
+
if model.config.region_encoder_lr is None:
|
| 737 |
+
for p in model.get_model().region_encoder.parameters():
|
| 738 |
+
p.requires_grad = False
|
| 739 |
+
|
| 740 |
+
if model.config.sam_decoder_lr is None:
|
| 741 |
+
for p in model.grounding_encoder.sam2_model.sam_mask_decoder.parameters():
|
| 742 |
+
p.requires_grad = False
|
| 743 |
+
else:
|
| 744 |
+
for p in model.grounding_encoder.sam2_model.sam_mask_decoder.parameters():
|
| 745 |
+
p.requires_grad = True
|
| 746 |
+
|
| 747 |
+
if model.config.sam_encoder_lr is None:
|
| 748 |
+
for p in model.grounding_encoder.sam2_model.image_encoder.parameters():
|
| 749 |
+
p.requires_grad = False
|
| 750 |
+
|
| 751 |
+
if training_args.lora_enable:
|
| 752 |
+
for n, p in model.named_parameters():
|
| 753 |
+
if any(
|
| 754 |
+
[
|
| 755 |
+
x in n
|
| 756 |
+
for x in ["lm_head", "embed_tokens", "text_hidden_fcs"]
|
| 757 |
+
]
|
| 758 |
+
):
|
| 759 |
+
# print(n)
|
| 760 |
+
p.requires_grad = True
|
| 761 |
+
|
| 762 |
+
model.config.max_frames = getattr(data_args, 'max_frames', NUM_FRAMES)
|
| 763 |
+
model.config.image_aspect_ratio = data_args.image_aspect_ratio if 'qwen2vl' not in model_args.vision_encoder else 'qwen2vl'
|
| 764 |
+
|
| 765 |
+
# NOTE: complement data_args via model hyperparameters
|
| 766 |
+
# 1. acquire image size
|
| 767 |
+
model.config.image_size = data_args.image_size = vision_encoder.image_size
|
| 768 |
+
# 2. calculate the number of tokens in the image
|
| 769 |
+
model.config.image_token_length = data_args.image_token_length = mm_projector.cal_proj_size(vision_encoder.num_patches_per_side)
|
| 770 |
+
# 3. check if alignment
|
| 771 |
+
model.config.is_alignment = training_args.is_alignment = data_args.is_alignment = (
|
| 772 |
+
model.config.mm_projector_lr is not None and
|
| 773 |
+
model.config.llm_lr is None and
|
| 774 |
+
model.config.vision_encoder_lr is None
|
| 775 |
+
)
|
| 776 |
+
# 4. set spatial merge size as default
|
| 777 |
+
model.config.spatial_merge_size = data_args.spatial_merge_size = model_args.spatial_merge_size
|
| 778 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_TOKEN, STREAM_START_TOKEN, STREAM_END_TOKEN], special_tokens=True)
|
| 779 |
+
tokenizer.add_tokens([REGION_TOKEN], special_tokens=True)
|
| 780 |
+
num_new_tokens = tokenizer.add_tokens([SEG_TOKEN], special_tokens=True)
|
| 781 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 782 |
+
|
| 783 |
+
model.config.image_token_index = tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_TOKEN)
|
| 784 |
+
model.config.region_token_index = tokenizer.convert_tokens_to_ids(REGION_TOKEN)
|
| 785 |
+
model.config.seg_token_index = tokenizer.convert_tokens_to_ids(SEG_TOKEN)
|
| 786 |
+
|
| 787 |
+
vlprocessor = Videollama3Qwen2Processor(vision_encoder.image_processor, tokenizer)
|
| 788 |
+
|
| 789 |
+
if training_args.bits in [4, 8]:
|
| 790 |
+
from peft.tuners.lora import LoraLayer
|
| 791 |
+
for name, module in model.named_modules():
|
| 792 |
+
if isinstance(module, LoraLayer):
|
| 793 |
+
if training_args.bf16:
|
| 794 |
+
module = module.to(torch.bfloat16)
|
| 795 |
+
if 'norm' in name:
|
| 796 |
+
module = module.to(torch.float32)
|
| 797 |
+
if 'lm_head' in name or 'embed_tokens' in name:
|
| 798 |
+
if hasattr(module, 'weight'):
|
| 799 |
+
if training_args.bf16 and module.weight.dtype == torch.float32:
|
| 800 |
+
module = module.to(torch.bfloat16)
|
| 801 |
+
|
| 802 |
+
if local_rank == 0:
|
| 803 |
+
print("Current model:", model)
|
| 804 |
+
print("Model config:", model.config)
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
data_module = make_supervised_data_module(vlprocessor=vlprocessor, data_args=data_args)
|
| 808 |
+
|
| 809 |
+
# select a Trainer
|
| 810 |
+
trainer = RynnECTrainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
|
| 811 |
+
|
| 812 |
+
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
|
| 813 |
+
trainer.train(resume_from_checkpoint=True)
|
| 814 |
+
else:
|
| 815 |
+
trainer.train()
|
| 816 |
+
trainer.save_state()
|
| 817 |
+
|
| 818 |
+
model.config.use_cache = True
|
| 819 |
+
|
| 820 |
+
if training_args.lora_enable:
|
| 821 |
+
state_dict = get_peft_state_maybe_zero_3(model.named_parameters(), training_args.lora_bias)
|
| 822 |
+
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(model.named_parameters())
|
| 823 |
+
if training_args.local_rank == 0 or training_args.local_rank == -1:
|
| 824 |
+
model.config.save_pretrained(training_args.output_dir)
|
| 825 |
+
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
|
| 826 |
+
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
|
| 827 |
+
else:
|
| 828 |
+
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
if __name__ == "__main__":
|
| 832 |
+
train(attn_implementation="flash_attention_2")
|
RynnEC/third_parts/sam2/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from hydra import initialize_config_module
|
| 8 |
+
|
| 9 |
+
initialize_config_module("third_parts.sam2.sam2_configs", version_base="1.2")
|
RynnEC/third_parts/sam2/automatic_mask_generator.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
|
| 8 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
| 13 |
+
|
| 14 |
+
from third_parts.sam2.modeling.sam2_base import SAM2Base
|
| 15 |
+
from third_parts.sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 16 |
+
from third_parts.sam2.utils.amg import (
|
| 17 |
+
area_from_rle,
|
| 18 |
+
batch_iterator,
|
| 19 |
+
batched_mask_to_box,
|
| 20 |
+
box_xyxy_to_xywh,
|
| 21 |
+
build_all_layer_point_grids,
|
| 22 |
+
calculate_stability_score,
|
| 23 |
+
coco_encode_rle,
|
| 24 |
+
generate_crop_boxes,
|
| 25 |
+
is_box_near_crop_edge,
|
| 26 |
+
mask_to_rle_pytorch,
|
| 27 |
+
MaskData,
|
| 28 |
+
remove_small_regions,
|
| 29 |
+
rle_to_mask,
|
| 30 |
+
uncrop_boxes_xyxy,
|
| 31 |
+
uncrop_masks,
|
| 32 |
+
uncrop_points,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SAM2AutomaticMaskGenerator:
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
model: SAM2Base,
|
| 40 |
+
points_per_side: Optional[int] = 32,
|
| 41 |
+
points_per_batch: int = 64,
|
| 42 |
+
pred_iou_thresh: float = 0.8,
|
| 43 |
+
stability_score_thresh: float = 0.95,
|
| 44 |
+
stability_score_offset: float = 1.0,
|
| 45 |
+
mask_threshold: float = 0.0,
|
| 46 |
+
box_nms_thresh: float = 0.7,
|
| 47 |
+
crop_n_layers: int = 0,
|
| 48 |
+
crop_nms_thresh: float = 0.7,
|
| 49 |
+
crop_overlap_ratio: float = 512 / 1500,
|
| 50 |
+
crop_n_points_downscale_factor: int = 1,
|
| 51 |
+
point_grids: Optional[List[np.ndarray]] = None,
|
| 52 |
+
min_mask_region_area: int = 0,
|
| 53 |
+
output_mode: str = "binary_mask",
|
| 54 |
+
use_m2m: bool = False,
|
| 55 |
+
multimask_output: bool = True,
|
| 56 |
+
) -> None:
|
| 57 |
+
"""
|
| 58 |
+
Using a SAM 2 model, generates masks for the entire image.
|
| 59 |
+
Generates a grid of point prompts over the image, then filters
|
| 60 |
+
low quality and duplicate masks. The default settings are chosen
|
| 61 |
+
for SAM 2 with a HieraL backbone.
|
| 62 |
+
|
| 63 |
+
Arguments:
|
| 64 |
+
model (Sam): The SAM 2 model to use for mask prediction.
|
| 65 |
+
points_per_side (int or None): The number of points to be sampled
|
| 66 |
+
along one side of the image. The total number of points is
|
| 67 |
+
points_per_side**2. If None, 'point_grids' must provide explicit
|
| 68 |
+
point sampling.
|
| 69 |
+
points_per_batch (int): Sets the number of points run simultaneously
|
| 70 |
+
by the model. Higher numbers may be faster but use more GPU memory.
|
| 71 |
+
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
| 72 |
+
model's predicted mask quality.
|
| 73 |
+
stability_score_thresh (float): A filtering threshold in [0,1], using
|
| 74 |
+
the stability of the mask under changes to the cutoff used to binarize
|
| 75 |
+
the model's mask predictions.
|
| 76 |
+
stability_score_offset (float): The amount to shift the cutoff when
|
| 77 |
+
calculated the stability score.
|
| 78 |
+
mask_threshold (float): Threshold for binarizing the mask logits
|
| 79 |
+
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
| 80 |
+
suppression to filter duplicate masks.
|
| 81 |
+
crop_n_layers (int): If >0, mask prediction will be run again on
|
| 82 |
+
crops of the image. Sets the number of layers to run, where each
|
| 83 |
+
layer has 2**i_layer number of image crops.
|
| 84 |
+
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
| 85 |
+
suppression to filter duplicate masks between different crops.
|
| 86 |
+
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
| 87 |
+
In the first crop layer, crops will overlap by this fraction of
|
| 88 |
+
the image length. Later layers with more crops scale down this overlap.
|
| 89 |
+
crop_n_points_downscale_factor (int): The number of points-per-side
|
| 90 |
+
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
| 91 |
+
point_grids (list(np.ndarray) or None): A list over explicit grids
|
| 92 |
+
of points used for sampling, normalized to [0,1]. The nth grid in the
|
| 93 |
+
list is used in the nth crop layer. Exclusive with points_per_side.
|
| 94 |
+
min_mask_region_area (int): If >0, postprocessing will be applied
|
| 95 |
+
to remove disconnected regions and holes in masks with area smaller
|
| 96 |
+
than min_mask_region_area. Requires opencv.
|
| 97 |
+
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
| 98 |
+
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
| 99 |
+
For large resolutions, 'binary_mask' may consume large amounts of
|
| 100 |
+
memory.
|
| 101 |
+
use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
|
| 102 |
+
multimask_output (bool): Whether to output multimask at each point of the grid.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
assert (points_per_side is None) != (
|
| 106 |
+
point_grids is None
|
| 107 |
+
), "Exactly one of points_per_side or point_grid must be provided."
|
| 108 |
+
if points_per_side is not None:
|
| 109 |
+
self.point_grids = build_all_layer_point_grids(
|
| 110 |
+
points_per_side,
|
| 111 |
+
crop_n_layers,
|
| 112 |
+
crop_n_points_downscale_factor,
|
| 113 |
+
)
|
| 114 |
+
elif point_grids is not None:
|
| 115 |
+
self.point_grids = point_grids
|
| 116 |
+
else:
|
| 117 |
+
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
| 118 |
+
|
| 119 |
+
assert output_mode in [
|
| 120 |
+
"binary_mask",
|
| 121 |
+
"uncompressed_rle",
|
| 122 |
+
"coco_rle",
|
| 123 |
+
], f"Unknown output_mode {output_mode}."
|
| 124 |
+
if output_mode == "coco_rle":
|
| 125 |
+
try:
|
| 126 |
+
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
| 127 |
+
except ImportError as e:
|
| 128 |
+
print("Please install pycocotools")
|
| 129 |
+
raise e
|
| 130 |
+
|
| 131 |
+
self.predictor = SAM2ImagePredictor(
|
| 132 |
+
model,
|
| 133 |
+
max_hole_area=min_mask_region_area,
|
| 134 |
+
max_sprinkle_area=min_mask_region_area,
|
| 135 |
+
)
|
| 136 |
+
self.points_per_batch = points_per_batch
|
| 137 |
+
self.pred_iou_thresh = pred_iou_thresh
|
| 138 |
+
self.stability_score_thresh = stability_score_thresh
|
| 139 |
+
self.stability_score_offset = stability_score_offset
|
| 140 |
+
self.mask_threshold = mask_threshold
|
| 141 |
+
self.box_nms_thresh = box_nms_thresh
|
| 142 |
+
self.crop_n_layers = crop_n_layers
|
| 143 |
+
self.crop_nms_thresh = crop_nms_thresh
|
| 144 |
+
self.crop_overlap_ratio = crop_overlap_ratio
|
| 145 |
+
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
| 146 |
+
self.min_mask_region_area = min_mask_region_area
|
| 147 |
+
self.output_mode = output_mode
|
| 148 |
+
self.use_m2m = use_m2m
|
| 149 |
+
self.multimask_output = multimask_output
|
| 150 |
+
|
| 151 |
+
@torch.no_grad()
|
| 152 |
+
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
| 153 |
+
"""
|
| 154 |
+
Generates masks for the given image.
|
| 155 |
+
|
| 156 |
+
Arguments:
|
| 157 |
+
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
list(dict(str, any)): A list over records for masks. Each record is
|
| 161 |
+
a dict containing the following keys:
|
| 162 |
+
segmentation (dict(str, any) or np.ndarray): The mask. If
|
| 163 |
+
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
| 164 |
+
is a dictionary containing the RLE.
|
| 165 |
+
bbox (list(float)): The box around the mask, in XYWH format.
|
| 166 |
+
area (int): The area in pixels of the mask.
|
| 167 |
+
predicted_iou (float): The model's own prediction of the mask's
|
| 168 |
+
quality. This is filtered by the pred_iou_thresh parameter.
|
| 169 |
+
point_coords (list(list(float))): The point coordinates input
|
| 170 |
+
to the model to generate this mask.
|
| 171 |
+
stability_score (float): A measure of the mask's quality. This
|
| 172 |
+
is filtered on using the stability_score_thresh parameter.
|
| 173 |
+
crop_box (list(float)): The crop of the image used to generate
|
| 174 |
+
the mask, given in XYWH format.
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
# Generate masks
|
| 178 |
+
mask_data = self._generate_masks(image)
|
| 179 |
+
|
| 180 |
+
# Encode masks
|
| 181 |
+
if self.output_mode == "coco_rle":
|
| 182 |
+
mask_data["segmentations"] = [
|
| 183 |
+
coco_encode_rle(rle) for rle in mask_data["rles"]
|
| 184 |
+
]
|
| 185 |
+
elif self.output_mode == "binary_mask":
|
| 186 |
+
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
| 187 |
+
else:
|
| 188 |
+
mask_data["segmentations"] = mask_data["rles"]
|
| 189 |
+
|
| 190 |
+
# Write mask records
|
| 191 |
+
curr_anns = []
|
| 192 |
+
for idx in range(len(mask_data["segmentations"])):
|
| 193 |
+
ann = {
|
| 194 |
+
"segmentation": mask_data["segmentations"][idx],
|
| 195 |
+
"area": area_from_rle(mask_data["rles"][idx]),
|
| 196 |
+
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
| 197 |
+
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
| 198 |
+
"point_coords": [mask_data["points"][idx].tolist()],
|
| 199 |
+
"stability_score": mask_data["stability_score"][idx].item(),
|
| 200 |
+
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
| 201 |
+
}
|
| 202 |
+
curr_anns.append(ann)
|
| 203 |
+
|
| 204 |
+
return curr_anns
|
| 205 |
+
|
| 206 |
+
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
| 207 |
+
orig_size = image.shape[:2]
|
| 208 |
+
crop_boxes, layer_idxs = generate_crop_boxes(
|
| 209 |
+
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# Iterate over image crops
|
| 213 |
+
data = MaskData()
|
| 214 |
+
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
| 215 |
+
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
| 216 |
+
data.cat(crop_data)
|
| 217 |
+
|
| 218 |
+
# Remove duplicate masks between crops
|
| 219 |
+
if len(crop_boxes) > 1:
|
| 220 |
+
# Prefer masks from smaller crops
|
| 221 |
+
scores = 1 / box_area(data["crop_boxes"])
|
| 222 |
+
scores = scores.to(data["boxes"].device)
|
| 223 |
+
keep_by_nms = batched_nms(
|
| 224 |
+
data["boxes"].float(),
|
| 225 |
+
scores,
|
| 226 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
| 227 |
+
iou_threshold=self.crop_nms_thresh,
|
| 228 |
+
)
|
| 229 |
+
data.filter(keep_by_nms)
|
| 230 |
+
data.to_numpy()
|
| 231 |
+
return data
|
| 232 |
+
|
| 233 |
+
def _process_crop(
|
| 234 |
+
self,
|
| 235 |
+
image: np.ndarray,
|
| 236 |
+
crop_box: List[int],
|
| 237 |
+
crop_layer_idx: int,
|
| 238 |
+
orig_size: Tuple[int, ...],
|
| 239 |
+
) -> MaskData:
|
| 240 |
+
# Crop the image and calculate embeddings
|
| 241 |
+
x0, y0, x1, y1 = crop_box
|
| 242 |
+
cropped_im = image[y0:y1, x0:x1, :]
|
| 243 |
+
cropped_im_size = cropped_im.shape[:2]
|
| 244 |
+
self.predictor.set_image(cropped_im)
|
| 245 |
+
|
| 246 |
+
# Get points for this crop
|
| 247 |
+
points_scale = np.array(cropped_im_size)[None, ::-1]
|
| 248 |
+
points_for_image = self.point_grids[crop_layer_idx] * points_scale
|
| 249 |
+
|
| 250 |
+
# Generate masks for this crop in batches
|
| 251 |
+
data = MaskData()
|
| 252 |
+
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
| 253 |
+
batch_data = self._process_batch(
|
| 254 |
+
points, cropped_im_size, crop_box, orig_size, normalize=True
|
| 255 |
+
)
|
| 256 |
+
data.cat(batch_data)
|
| 257 |
+
del batch_data
|
| 258 |
+
self.predictor.reset_predictor()
|
| 259 |
+
|
| 260 |
+
# Remove duplicates within this crop.
|
| 261 |
+
keep_by_nms = batched_nms(
|
| 262 |
+
data["boxes"].float(),
|
| 263 |
+
data["iou_preds"],
|
| 264 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
| 265 |
+
iou_threshold=self.box_nms_thresh,
|
| 266 |
+
)
|
| 267 |
+
data.filter(keep_by_nms)
|
| 268 |
+
|
| 269 |
+
# Return to the original image frame
|
| 270 |
+
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
| 271 |
+
data["points"] = uncrop_points(data["points"], crop_box)
|
| 272 |
+
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
| 273 |
+
|
| 274 |
+
return data
|
| 275 |
+
|
| 276 |
+
def _process_batch(
|
| 277 |
+
self,
|
| 278 |
+
points: np.ndarray,
|
| 279 |
+
im_size: Tuple[int, ...],
|
| 280 |
+
crop_box: List[int],
|
| 281 |
+
orig_size: Tuple[int, ...],
|
| 282 |
+
normalize=False,
|
| 283 |
+
) -> MaskData:
|
| 284 |
+
orig_h, orig_w = orig_size
|
| 285 |
+
|
| 286 |
+
# Run model on this batch
|
| 287 |
+
points = torch.as_tensor(points, device=self.predictor.device)
|
| 288 |
+
in_points = self.predictor._transforms.transform_coords(
|
| 289 |
+
points, normalize=normalize, orig_hw=im_size
|
| 290 |
+
)
|
| 291 |
+
in_labels = torch.ones(
|
| 292 |
+
in_points.shape[0], dtype=torch.int, device=in_points.device
|
| 293 |
+
)
|
| 294 |
+
masks, iou_preds, low_res_masks = self.predictor._predict(
|
| 295 |
+
in_points[:, None, :],
|
| 296 |
+
in_labels[:, None],
|
| 297 |
+
multimask_output=self.multimask_output,
|
| 298 |
+
return_logits=True,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Serialize predictions and store in MaskData
|
| 302 |
+
data = MaskData(
|
| 303 |
+
masks=masks.flatten(0, 1),
|
| 304 |
+
iou_preds=iou_preds.flatten(0, 1),
|
| 305 |
+
points=points.repeat_interleave(masks.shape[1], dim=0),
|
| 306 |
+
low_res_masks=low_res_masks.flatten(0, 1),
|
| 307 |
+
)
|
| 308 |
+
del masks
|
| 309 |
+
|
| 310 |
+
if not self.use_m2m:
|
| 311 |
+
# Filter by predicted IoU
|
| 312 |
+
if self.pred_iou_thresh > 0.0:
|
| 313 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
| 314 |
+
data.filter(keep_mask)
|
| 315 |
+
|
| 316 |
+
# Calculate and filter by stability score
|
| 317 |
+
data["stability_score"] = calculate_stability_score(
|
| 318 |
+
data["masks"], self.mask_threshold, self.stability_score_offset
|
| 319 |
+
)
|
| 320 |
+
if self.stability_score_thresh > 0.0:
|
| 321 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
| 322 |
+
data.filter(keep_mask)
|
| 323 |
+
else:
|
| 324 |
+
# One step refinement using previous mask predictions
|
| 325 |
+
in_points = self.predictor._transforms.transform_coords(
|
| 326 |
+
data["points"], normalize=normalize, orig_hw=im_size
|
| 327 |
+
)
|
| 328 |
+
labels = torch.ones(
|
| 329 |
+
in_points.shape[0], dtype=torch.int, device=in_points.device
|
| 330 |
+
)
|
| 331 |
+
masks, ious = self.refine_with_m2m(
|
| 332 |
+
in_points, labels, data["low_res_masks"], self.points_per_batch
|
| 333 |
+
)
|
| 334 |
+
data["masks"] = masks.squeeze(1)
|
| 335 |
+
data["iou_preds"] = ious.squeeze(1)
|
| 336 |
+
|
| 337 |
+
if self.pred_iou_thresh > 0.0:
|
| 338 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
| 339 |
+
data.filter(keep_mask)
|
| 340 |
+
|
| 341 |
+
data["stability_score"] = calculate_stability_score(
|
| 342 |
+
data["masks"], self.mask_threshold, self.stability_score_offset
|
| 343 |
+
)
|
| 344 |
+
if self.stability_score_thresh > 0.0:
|
| 345 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
| 346 |
+
data.filter(keep_mask)
|
| 347 |
+
|
| 348 |
+
# Threshold masks and calculate boxes
|
| 349 |
+
data["masks"] = data["masks"] > self.mask_threshold
|
| 350 |
+
data["boxes"] = batched_mask_to_box(data["masks"])
|
| 351 |
+
|
| 352 |
+
# Filter boxes that touch crop boundaries
|
| 353 |
+
keep_mask = ~is_box_near_crop_edge(
|
| 354 |
+
data["boxes"], crop_box, [0, 0, orig_w, orig_h]
|
| 355 |
+
)
|
| 356 |
+
if not torch.all(keep_mask):
|
| 357 |
+
data.filter(keep_mask)
|
| 358 |
+
|
| 359 |
+
# Compress to RLE
|
| 360 |
+
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
| 361 |
+
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
| 362 |
+
del data["masks"]
|
| 363 |
+
|
| 364 |
+
return data
|
| 365 |
+
|
| 366 |
+
@staticmethod
|
| 367 |
+
def postprocess_small_regions(
|
| 368 |
+
mask_data: MaskData, min_area: int, nms_thresh: float
|
| 369 |
+
) -> MaskData:
|
| 370 |
+
"""
|
| 371 |
+
Removes small disconnected regions and holes in masks, then reruns
|
| 372 |
+
box NMS to remove any new duplicates.
|
| 373 |
+
|
| 374 |
+
Edits mask_data in place.
|
| 375 |
+
|
| 376 |
+
Requires open-cv as a dependency.
|
| 377 |
+
"""
|
| 378 |
+
if len(mask_data["rles"]) == 0:
|
| 379 |
+
return mask_data
|
| 380 |
+
|
| 381 |
+
# Filter small disconnected regions and holes
|
| 382 |
+
new_masks = []
|
| 383 |
+
scores = []
|
| 384 |
+
for rle in mask_data["rles"]:
|
| 385 |
+
mask = rle_to_mask(rle)
|
| 386 |
+
|
| 387 |
+
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
| 388 |
+
unchanged = not changed
|
| 389 |
+
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
| 390 |
+
unchanged = unchanged and not changed
|
| 391 |
+
|
| 392 |
+
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
| 393 |
+
# Give score=0 to changed masks and score=1 to unchanged masks
|
| 394 |
+
# so NMS will prefer ones that didn't need postprocessing
|
| 395 |
+
scores.append(float(unchanged))
|
| 396 |
+
|
| 397 |
+
# Recalculate boxes and remove any new duplicates
|
| 398 |
+
masks = torch.cat(new_masks, dim=0)
|
| 399 |
+
boxes = batched_mask_to_box(masks)
|
| 400 |
+
keep_by_nms = batched_nms(
|
| 401 |
+
boxes.float(),
|
| 402 |
+
torch.as_tensor(scores),
|
| 403 |
+
torch.zeros_like(boxes[:, 0]), # categories
|
| 404 |
+
iou_threshold=nms_thresh,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
# Only recalculate RLEs for masks that have changed
|
| 408 |
+
for i_mask in keep_by_nms:
|
| 409 |
+
if scores[i_mask] == 0.0:
|
| 410 |
+
mask_torch = masks[i_mask].unsqueeze(0)
|
| 411 |
+
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
| 412 |
+
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
| 413 |
+
mask_data.filter(keep_by_nms)
|
| 414 |
+
|
| 415 |
+
return mask_data
|
| 416 |
+
|
| 417 |
+
def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch):
|
| 418 |
+
new_masks = []
|
| 419 |
+
new_iou_preds = []
|
| 420 |
+
|
| 421 |
+
for cur_points, cur_point_labels, low_res_mask in batch_iterator(
|
| 422 |
+
points_per_batch, points, point_labels, low_res_masks
|
| 423 |
+
):
|
| 424 |
+
best_masks, best_iou_preds, _ = self.predictor._predict(
|
| 425 |
+
cur_points[:, None, :],
|
| 426 |
+
cur_point_labels[:, None],
|
| 427 |
+
mask_input=low_res_mask[:, None, :],
|
| 428 |
+
multimask_output=False,
|
| 429 |
+
return_logits=True,
|
| 430 |
+
)
|
| 431 |
+
new_masks.append(best_masks)
|
| 432 |
+
new_iou_preds.append(best_iou_preds)
|
| 433 |
+
masks = torch.cat(new_masks, dim=0)
|
| 434 |
+
return masks, torch.cat(new_iou_preds, dim=0)
|
RynnEC/third_parts/sam2/build_sam.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from hydra import compose
|
| 11 |
+
from hydra.utils import instantiate
|
| 12 |
+
from omegaconf import OmegaConf
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def build_sam2(
|
| 16 |
+
config_file,
|
| 17 |
+
ckpt_path=None,
|
| 18 |
+
device="cuda",
|
| 19 |
+
mode="eval",
|
| 20 |
+
hydra_overrides_extra=[],
|
| 21 |
+
apply_postprocessing=True,
|
| 22 |
+
):
|
| 23 |
+
|
| 24 |
+
if apply_postprocessing:
|
| 25 |
+
hydra_overrides_extra = hydra_overrides_extra.copy()
|
| 26 |
+
hydra_overrides_extra += [
|
| 27 |
+
# dynamically fall back to multi-mask if the single mask is not stable
|
| 28 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
|
| 29 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
|
| 30 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
|
| 31 |
+
]
|
| 32 |
+
# Read config and init model
|
| 33 |
+
cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
|
| 34 |
+
OmegaConf.resolve(cfg)
|
| 35 |
+
model = instantiate(cfg.model, _recursive_=True)
|
| 36 |
+
_load_checkpoint(model, ckpt_path)
|
| 37 |
+
model = model.to(device)
|
| 38 |
+
if mode == "eval":
|
| 39 |
+
model.eval()
|
| 40 |
+
return model
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def build_sam2_video_predictor(
|
| 44 |
+
config_file,
|
| 45 |
+
ckpt_path=None,
|
| 46 |
+
device="cuda",
|
| 47 |
+
mode="eval",
|
| 48 |
+
hydra_overrides_extra=[],
|
| 49 |
+
apply_postprocessing=True,
|
| 50 |
+
):
|
| 51 |
+
hydra_overrides = [
|
| 52 |
+
"++model._target_=third_parts.sam2.sam2_video_predictor.SAM2VideoPredictor",
|
| 53 |
+
]
|
| 54 |
+
if apply_postprocessing:
|
| 55 |
+
hydra_overrides_extra = hydra_overrides_extra.copy()
|
| 56 |
+
hydra_overrides_extra += [
|
| 57 |
+
# dynamically fall back to multi-mask if the single mask is not stable
|
| 58 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
|
| 59 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
|
| 60 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
|
| 61 |
+
# the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
|
| 62 |
+
"++model.binarize_mask_from_pts_for_mem_enc=true",
|
| 63 |
+
# fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
|
| 64 |
+
"++model.fill_hole_area=8",
|
| 65 |
+
]
|
| 66 |
+
hydra_overrides.extend(hydra_overrides_extra)
|
| 67 |
+
|
| 68 |
+
# Read config and init model
|
| 69 |
+
cfg = compose(config_name=config_file, overrides=hydra_overrides)
|
| 70 |
+
OmegaConf.resolve(cfg)
|
| 71 |
+
model = instantiate(cfg.model, _recursive_=True)
|
| 72 |
+
_load_checkpoint(model, ckpt_path)
|
| 73 |
+
model = model.to(device)
|
| 74 |
+
if mode == "eval":
|
| 75 |
+
model.eval()
|
| 76 |
+
return model
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _load_checkpoint(model, ckpt_path):
|
| 80 |
+
if ckpt_path is not None:
|
| 81 |
+
sd = torch.load(ckpt_path, map_location="cpu")["model"]
|
| 82 |
+
missing_keys, unexpected_keys = model.load_state_dict(sd)
|
| 83 |
+
if missing_keys:
|
| 84 |
+
logging.error(missing_keys)
|
| 85 |
+
raise RuntimeError()
|
| 86 |
+
if unexpected_keys:
|
| 87 |
+
logging.error(unexpected_keys)
|
| 88 |
+
raise RuntimeError()
|
| 89 |
+
logging.info("Loaded checkpoint sucessfully")
|
RynnEC/third_parts/sam2/csrc/connected_components.cu
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
// adapted from https://github.com/zsef123/Connected_components_PyTorch
|
| 8 |
+
// with license found in the LICENSE_cctorch file in the root directory.
|
| 9 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 10 |
+
#include <cuda.h>
|
| 11 |
+
#include <cuda_runtime.h>
|
| 12 |
+
#include <torch/extension.h>
|
| 13 |
+
#include <torch/script.h>
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
// 2d
|
| 17 |
+
#define BLOCK_ROWS 16
|
| 18 |
+
#define BLOCK_COLS 16
|
| 19 |
+
|
| 20 |
+
namespace cc2d {
|
| 21 |
+
|
| 22 |
+
template <typename T>
|
| 23 |
+
__device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) {
|
| 24 |
+
return (bitmap >> pos) & 1;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
__device__ int32_t find(const int32_t* s_buf, int32_t n) {
|
| 28 |
+
while (s_buf[n] != n)
|
| 29 |
+
n = s_buf[n];
|
| 30 |
+
return n;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
__device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) {
|
| 34 |
+
const int32_t id = n;
|
| 35 |
+
while (s_buf[n] != n) {
|
| 36 |
+
n = s_buf[n];
|
| 37 |
+
s_buf[id] = n;
|
| 38 |
+
}
|
| 39 |
+
return n;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
__device__ void union_(int32_t* s_buf, int32_t a, int32_t b) {
|
| 43 |
+
bool done;
|
| 44 |
+
do {
|
| 45 |
+
a = find(s_buf, a);
|
| 46 |
+
b = find(s_buf, b);
|
| 47 |
+
|
| 48 |
+
if (a < b) {
|
| 49 |
+
int32_t old = atomicMin(s_buf + b, a);
|
| 50 |
+
done = (old == b);
|
| 51 |
+
b = old;
|
| 52 |
+
} else if (b < a) {
|
| 53 |
+
int32_t old = atomicMin(s_buf + a, b);
|
| 54 |
+
done = (old == a);
|
| 55 |
+
a = old;
|
| 56 |
+
} else
|
| 57 |
+
done = true;
|
| 58 |
+
|
| 59 |
+
} while (!done);
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
__global__ void
|
| 63 |
+
init_labeling(int32_t* label, const uint32_t W, const uint32_t H) {
|
| 64 |
+
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
|
| 65 |
+
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
|
| 66 |
+
const uint32_t idx = row * W + col;
|
| 67 |
+
|
| 68 |
+
if (row < H && col < W)
|
| 69 |
+
label[idx] = idx;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
__global__ void
|
| 73 |
+
merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) {
|
| 74 |
+
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
|
| 75 |
+
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
|
| 76 |
+
const uint32_t idx = row * W + col;
|
| 77 |
+
|
| 78 |
+
if (row >= H || col >= W)
|
| 79 |
+
return;
|
| 80 |
+
|
| 81 |
+
uint32_t P = 0;
|
| 82 |
+
|
| 83 |
+
if (img[idx])
|
| 84 |
+
P |= 0x777;
|
| 85 |
+
if (row + 1 < H && img[idx + W])
|
| 86 |
+
P |= 0x777 << 4;
|
| 87 |
+
if (col + 1 < W && img[idx + 1])
|
| 88 |
+
P |= 0x777 << 1;
|
| 89 |
+
|
| 90 |
+
if (col == 0)
|
| 91 |
+
P &= 0xEEEE;
|
| 92 |
+
if (col + 1 >= W)
|
| 93 |
+
P &= 0x3333;
|
| 94 |
+
else if (col + 2 >= W)
|
| 95 |
+
P &= 0x7777;
|
| 96 |
+
|
| 97 |
+
if (row == 0)
|
| 98 |
+
P &= 0xFFF0;
|
| 99 |
+
if (row + 1 >= H)
|
| 100 |
+
P &= 0xFF;
|
| 101 |
+
|
| 102 |
+
if (P > 0) {
|
| 103 |
+
// If need check about top-left pixel(if flag the first bit) and hit the
|
| 104 |
+
// top-left pixel
|
| 105 |
+
if (hasBit(P, 0) && img[idx - W - 1]) {
|
| 106 |
+
union_(label, idx, idx - 2 * W - 2); // top left block
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1]))
|
| 110 |
+
union_(label, idx, idx - 2 * W); // top bottom block
|
| 111 |
+
|
| 112 |
+
if (hasBit(P, 3) && img[idx + 2 - W])
|
| 113 |
+
union_(label, idx, idx - 2 * W + 2); // top right block
|
| 114 |
+
|
| 115 |
+
if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1]))
|
| 116 |
+
union_(label, idx, idx - 2); // just left block
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
__global__ void compression(int32_t* label, const int32_t W, const int32_t H) {
|
| 121 |
+
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
|
| 122 |
+
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
|
| 123 |
+
const uint32_t idx = row * W + col;
|
| 124 |
+
|
| 125 |
+
if (row < H && col < W)
|
| 126 |
+
find_n_compress(label, idx);
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
__global__ void final_labeling(
|
| 130 |
+
const uint8_t* img,
|
| 131 |
+
int32_t* label,
|
| 132 |
+
const int32_t W,
|
| 133 |
+
const int32_t H) {
|
| 134 |
+
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
|
| 135 |
+
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
|
| 136 |
+
const uint32_t idx = row * W + col;
|
| 137 |
+
|
| 138 |
+
if (row >= H || col >= W)
|
| 139 |
+
return;
|
| 140 |
+
|
| 141 |
+
int32_t y = label[idx] + 1;
|
| 142 |
+
|
| 143 |
+
if (img[idx])
|
| 144 |
+
label[idx] = y;
|
| 145 |
+
else
|
| 146 |
+
label[idx] = 0;
|
| 147 |
+
|
| 148 |
+
if (col + 1 < W) {
|
| 149 |
+
if (img[idx + 1])
|
| 150 |
+
label[idx + 1] = y;
|
| 151 |
+
else
|
| 152 |
+
label[idx + 1] = 0;
|
| 153 |
+
|
| 154 |
+
if (row + 1 < H) {
|
| 155 |
+
if (img[idx + W + 1])
|
| 156 |
+
label[idx + W + 1] = y;
|
| 157 |
+
else
|
| 158 |
+
label[idx + W + 1] = 0;
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
if (row + 1 < H) {
|
| 163 |
+
if (img[idx + W])
|
| 164 |
+
label[idx + W] = y;
|
| 165 |
+
else
|
| 166 |
+
label[idx + W] = 0;
|
| 167 |
+
}
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
__global__ void init_counting(
|
| 171 |
+
const int32_t* label,
|
| 172 |
+
int32_t* count_init,
|
| 173 |
+
const int32_t W,
|
| 174 |
+
const int32_t H) {
|
| 175 |
+
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
|
| 176 |
+
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
|
| 177 |
+
const uint32_t idx = row * W + col;
|
| 178 |
+
|
| 179 |
+
if (row >= H || col >= W)
|
| 180 |
+
return;
|
| 181 |
+
|
| 182 |
+
int32_t y = label[idx];
|
| 183 |
+
if (y > 0) {
|
| 184 |
+
int32_t count_idx = y - 1;
|
| 185 |
+
atomicAdd(count_init + count_idx, 1);
|
| 186 |
+
}
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
__global__ void final_counting(
|
| 190 |
+
const int32_t* label,
|
| 191 |
+
const int32_t* count_init,
|
| 192 |
+
int32_t* count_final,
|
| 193 |
+
const int32_t W,
|
| 194 |
+
const int32_t H) {
|
| 195 |
+
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
|
| 196 |
+
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
|
| 197 |
+
const uint32_t idx = row * W + col;
|
| 198 |
+
|
| 199 |
+
if (row >= H || col >= W)
|
| 200 |
+
return;
|
| 201 |
+
|
| 202 |
+
int32_t y = label[idx];
|
| 203 |
+
if (y > 0) {
|
| 204 |
+
int32_t count_idx = y - 1;
|
| 205 |
+
count_final[idx] = count_init[count_idx];
|
| 206 |
+
} else {
|
| 207 |
+
count_final[idx] = 0;
|
| 208 |
+
}
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
} // namespace cc2d
|
| 212 |
+
|
| 213 |
+
std::vector<torch::Tensor> get_connected_componnets(
|
| 214 |
+
const torch::Tensor& inputs) {
|
| 215 |
+
AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor");
|
| 216 |
+
AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape");
|
| 217 |
+
AT_ASSERTM(
|
| 218 |
+
inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type");
|
| 219 |
+
|
| 220 |
+
const uint32_t N = inputs.size(0);
|
| 221 |
+
const uint32_t C = inputs.size(1);
|
| 222 |
+
const uint32_t H = inputs.size(2);
|
| 223 |
+
const uint32_t W = inputs.size(3);
|
| 224 |
+
|
| 225 |
+
AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape");
|
| 226 |
+
AT_ASSERTM((H % 2) == 0, "height must be an even number");
|
| 227 |
+
AT_ASSERTM((W % 2) == 0, "width must be an even number");
|
| 228 |
+
|
| 229 |
+
// label must be uint32_t
|
| 230 |
+
auto label_options =
|
| 231 |
+
torch::TensorOptions().dtype(torch::kInt32).device(inputs.device());
|
| 232 |
+
torch::Tensor labels = torch::zeros({N, C, H, W}, label_options);
|
| 233 |
+
torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options);
|
| 234 |
+
torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options);
|
| 235 |
+
|
| 236 |
+
dim3 grid = dim3(
|
| 237 |
+
((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS,
|
| 238 |
+
((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS);
|
| 239 |
+
dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS);
|
| 240 |
+
dim3 grid_count =
|
| 241 |
+
dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS);
|
| 242 |
+
dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS);
|
| 243 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 244 |
+
|
| 245 |
+
for (int n = 0; n < N; n++) {
|
| 246 |
+
uint32_t offset = n * H * W;
|
| 247 |
+
|
| 248 |
+
cc2d::init_labeling<<<grid, block, 0, stream>>>(
|
| 249 |
+
labels.data_ptr<int32_t>() + offset, W, H);
|
| 250 |
+
cc2d::merge<<<grid, block, 0, stream>>>(
|
| 251 |
+
inputs.data_ptr<uint8_t>() + offset,
|
| 252 |
+
labels.data_ptr<int32_t>() + offset,
|
| 253 |
+
W,
|
| 254 |
+
H);
|
| 255 |
+
cc2d::compression<<<grid, block, 0, stream>>>(
|
| 256 |
+
labels.data_ptr<int32_t>() + offset, W, H);
|
| 257 |
+
cc2d::final_labeling<<<grid, block, 0, stream>>>(
|
| 258 |
+
inputs.data_ptr<uint8_t>() + offset,
|
| 259 |
+
labels.data_ptr<int32_t>() + offset,
|
| 260 |
+
W,
|
| 261 |
+
H);
|
| 262 |
+
|
| 263 |
+
// get the counting of each pixel
|
| 264 |
+
cc2d::init_counting<<<grid_count, block_count, 0, stream>>>(
|
| 265 |
+
labels.data_ptr<int32_t>() + offset,
|
| 266 |
+
counts_init.data_ptr<int32_t>() + offset,
|
| 267 |
+
W,
|
| 268 |
+
H);
|
| 269 |
+
cc2d::final_counting<<<grid_count, block_count, 0, stream>>>(
|
| 270 |
+
labels.data_ptr<int32_t>() + offset,
|
| 271 |
+
counts_init.data_ptr<int32_t>() + offset,
|
| 272 |
+
counts_final.data_ptr<int32_t>() + offset,
|
| 273 |
+
W,
|
| 274 |
+
H);
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
// returned values are [labels, counts]
|
| 278 |
+
std::vector<torch::Tensor> outputs;
|
| 279 |
+
outputs.push_back(labels);
|
| 280 |
+
outputs.push_back(counts_final);
|
| 281 |
+
return outputs;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 285 |
+
m.def(
|
| 286 |
+
"get_connected_componnets",
|
| 287 |
+
&get_connected_componnets,
|
| 288 |
+
"get_connected_componnets");
|
| 289 |
+
}
|
RynnEC/third_parts/sam2/modeling/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
RynnEC/third_parts/sam2/modeling/backbones/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
RynnEC/third_parts/sam2/modeling/backbones/hieradet.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from functools import partial
|
| 8 |
+
from typing import List, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
from third_parts.sam2.modeling.backbones.utils import (
|
| 15 |
+
PatchEmbed,
|
| 16 |
+
window_partition,
|
| 17 |
+
window_unpartition,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
from third_parts.sam2.modeling.sam2_utils import DropPath, MLP
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
|
| 24 |
+
if pool is None:
|
| 25 |
+
return x
|
| 26 |
+
# (B, H, W, C) -> (B, C, H, W)
|
| 27 |
+
x = x.permute(0, 3, 1, 2)
|
| 28 |
+
x = pool(x)
|
| 29 |
+
# (B, C, H', W') -> (B, H', W', C)
|
| 30 |
+
x = x.permute(0, 2, 3, 1)
|
| 31 |
+
if norm:
|
| 32 |
+
x = norm(x)
|
| 33 |
+
|
| 34 |
+
return x
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class MultiScaleAttention(nn.Module):
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
dim: int,
|
| 41 |
+
dim_out: int,
|
| 42 |
+
num_heads: int,
|
| 43 |
+
q_pool: nn.Module = None,
|
| 44 |
+
):
|
| 45 |
+
super().__init__()
|
| 46 |
+
|
| 47 |
+
self.dim = dim
|
| 48 |
+
self.dim_out = dim_out
|
| 49 |
+
|
| 50 |
+
self.num_heads = num_heads
|
| 51 |
+
head_dim = dim_out // num_heads
|
| 52 |
+
self.scale = head_dim**-0.5
|
| 53 |
+
|
| 54 |
+
self.q_pool = q_pool
|
| 55 |
+
self.qkv = nn.Linear(dim, dim_out * 3)
|
| 56 |
+
self.proj = nn.Linear(dim_out, dim_out)
|
| 57 |
+
|
| 58 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 59 |
+
B, H, W, _ = x.shape
|
| 60 |
+
# qkv with shape (B, H * W, 3, nHead, C)
|
| 61 |
+
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
|
| 62 |
+
# q, k, v with shape (B, H * W, nheads, C)
|
| 63 |
+
q, k, v = torch.unbind(qkv, 2)
|
| 64 |
+
|
| 65 |
+
# Q pooling (for downsample at stage changes)
|
| 66 |
+
if self.q_pool:
|
| 67 |
+
q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
|
| 68 |
+
H, W = q.shape[1:3] # downsampled shape
|
| 69 |
+
q = q.reshape(B, H * W, self.num_heads, -1)
|
| 70 |
+
|
| 71 |
+
# Torch's SDPA expects [B, nheads, H*W, C] so we transpose
|
| 72 |
+
x = F.scaled_dot_product_attention(
|
| 73 |
+
q.transpose(1, 2),
|
| 74 |
+
k.transpose(1, 2),
|
| 75 |
+
v.transpose(1, 2),
|
| 76 |
+
)
|
| 77 |
+
# Transpose back
|
| 78 |
+
x = x.transpose(1, 2)
|
| 79 |
+
x = x.reshape(B, H, W, -1)
|
| 80 |
+
|
| 81 |
+
x = self.proj(x)
|
| 82 |
+
|
| 83 |
+
return x
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class MultiScaleBlock(nn.Module):
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
dim: int,
|
| 90 |
+
dim_out: int,
|
| 91 |
+
num_heads: int,
|
| 92 |
+
mlp_ratio: float = 4.0,
|
| 93 |
+
drop_path: float = 0.0,
|
| 94 |
+
norm_layer: Union[nn.Module, str] = "LayerNorm",
|
| 95 |
+
q_stride: Tuple[int, int] = None,
|
| 96 |
+
act_layer: nn.Module = nn.GELU,
|
| 97 |
+
window_size: int = 0,
|
| 98 |
+
):
|
| 99 |
+
super().__init__()
|
| 100 |
+
|
| 101 |
+
if isinstance(norm_layer, str):
|
| 102 |
+
norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
|
| 103 |
+
|
| 104 |
+
self.dim = dim
|
| 105 |
+
self.dim_out = dim_out
|
| 106 |
+
self.norm1 = norm_layer(dim)
|
| 107 |
+
|
| 108 |
+
self.window_size = window_size
|
| 109 |
+
|
| 110 |
+
self.pool, self.q_stride = None, q_stride
|
| 111 |
+
if self.q_stride:
|
| 112 |
+
self.pool = nn.MaxPool2d(
|
| 113 |
+
kernel_size=q_stride, stride=q_stride, ceil_mode=False
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
self.attn = MultiScaleAttention(
|
| 117 |
+
dim,
|
| 118 |
+
dim_out,
|
| 119 |
+
num_heads=num_heads,
|
| 120 |
+
q_pool=self.pool,
|
| 121 |
+
)
|
| 122 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 123 |
+
|
| 124 |
+
self.norm2 = norm_layer(dim_out)
|
| 125 |
+
self.mlp = MLP(
|
| 126 |
+
dim_out,
|
| 127 |
+
int(dim_out * mlp_ratio),
|
| 128 |
+
dim_out,
|
| 129 |
+
num_layers=2,
|
| 130 |
+
activation=act_layer,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
if dim != dim_out:
|
| 134 |
+
self.proj = nn.Linear(dim, dim_out)
|
| 135 |
+
|
| 136 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 137 |
+
shortcut = x # B, H, W, C
|
| 138 |
+
x = self.norm1(x)
|
| 139 |
+
|
| 140 |
+
# Skip connection
|
| 141 |
+
if self.dim != self.dim_out:
|
| 142 |
+
shortcut = do_pool(self.proj(x), self.pool)
|
| 143 |
+
|
| 144 |
+
# Window partition
|
| 145 |
+
window_size = self.window_size
|
| 146 |
+
if window_size > 0:
|
| 147 |
+
H, W = x.shape[1], x.shape[2]
|
| 148 |
+
x, pad_hw = window_partition(x, window_size)
|
| 149 |
+
|
| 150 |
+
# Window Attention + Q Pooling (if stage change)
|
| 151 |
+
x = self.attn(x)
|
| 152 |
+
if self.q_stride:
|
| 153 |
+
# Shapes have changed due to Q pooling
|
| 154 |
+
window_size = self.window_size // self.q_stride[0]
|
| 155 |
+
H, W = shortcut.shape[1:3]
|
| 156 |
+
|
| 157 |
+
pad_h = (window_size - H % window_size) % window_size
|
| 158 |
+
pad_w = (window_size - W % window_size) % window_size
|
| 159 |
+
pad_hw = (H + pad_h, W + pad_w)
|
| 160 |
+
|
| 161 |
+
# Reverse window partition
|
| 162 |
+
if self.window_size > 0:
|
| 163 |
+
x = window_unpartition(x, window_size, pad_hw, (H, W))
|
| 164 |
+
|
| 165 |
+
x = shortcut + self.drop_path(x)
|
| 166 |
+
# MLP
|
| 167 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 168 |
+
return x
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class Hiera(nn.Module):
|
| 172 |
+
"""
|
| 173 |
+
Reference: https://arxiv.org/abs/2306.00989
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
def __init__(
|
| 177 |
+
self,
|
| 178 |
+
embed_dim: int = 96, # initial embed dim
|
| 179 |
+
num_heads: int = 1, # initial number of heads
|
| 180 |
+
drop_path_rate: float = 0.0, # stochastic depth
|
| 181 |
+
q_pool: int = 3, # number of q_pool stages
|
| 182 |
+
q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
|
| 183 |
+
stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
|
| 184 |
+
dim_mul: float = 2.0, # dim_mul factor at stage shift
|
| 185 |
+
head_mul: float = 2.0, # head_mul factor at stage shift
|
| 186 |
+
window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
|
| 187 |
+
# window size per stage, when not using global att.
|
| 188 |
+
window_spec: Tuple[int, ...] = (
|
| 189 |
+
8,
|
| 190 |
+
4,
|
| 191 |
+
14,
|
| 192 |
+
7,
|
| 193 |
+
),
|
| 194 |
+
# global attn in these blocks
|
| 195 |
+
global_att_blocks: Tuple[int, ...] = (
|
| 196 |
+
12,
|
| 197 |
+
16,
|
| 198 |
+
20,
|
| 199 |
+
),
|
| 200 |
+
return_interm_layers=True, # return feats from every stage
|
| 201 |
+
):
|
| 202 |
+
super().__init__()
|
| 203 |
+
|
| 204 |
+
assert len(stages) == len(window_spec)
|
| 205 |
+
self.window_spec = window_spec
|
| 206 |
+
|
| 207 |
+
depth = sum(stages)
|
| 208 |
+
self.q_stride = q_stride
|
| 209 |
+
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
|
| 210 |
+
assert 0 <= q_pool <= len(self.stage_ends[:-1])
|
| 211 |
+
self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
|
| 212 |
+
self.return_interm_layers = return_interm_layers
|
| 213 |
+
|
| 214 |
+
self.patch_embed = PatchEmbed(
|
| 215 |
+
embed_dim=embed_dim,
|
| 216 |
+
)
|
| 217 |
+
# Which blocks have global att?
|
| 218 |
+
self.global_att_blocks = global_att_blocks
|
| 219 |
+
|
| 220 |
+
# Windowed positional embedding (https://arxiv.org/abs/2311.05613)
|
| 221 |
+
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
|
| 222 |
+
self.pos_embed = nn.Parameter(
|
| 223 |
+
torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
|
| 224 |
+
)
|
| 225 |
+
self.pos_embed_window = nn.Parameter(
|
| 226 |
+
torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
dpr = [
|
| 230 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
| 231 |
+
] # stochastic depth decay rule
|
| 232 |
+
|
| 233 |
+
cur_stage = 1
|
| 234 |
+
self.blocks = nn.ModuleList()
|
| 235 |
+
|
| 236 |
+
for i in range(depth):
|
| 237 |
+
dim_out = embed_dim
|
| 238 |
+
# lags by a block, so first block of
|
| 239 |
+
# next stage uses an initial window size
|
| 240 |
+
# of previous stage and final window size of current stage
|
| 241 |
+
window_size = self.window_spec[cur_stage - 1]
|
| 242 |
+
|
| 243 |
+
if self.global_att_blocks is not None:
|
| 244 |
+
window_size = 0 if i in self.global_att_blocks else window_size
|
| 245 |
+
|
| 246 |
+
if i - 1 in self.stage_ends:
|
| 247 |
+
dim_out = int(embed_dim * dim_mul)
|
| 248 |
+
num_heads = int(num_heads * head_mul)
|
| 249 |
+
cur_stage += 1
|
| 250 |
+
|
| 251 |
+
block = MultiScaleBlock(
|
| 252 |
+
dim=embed_dim,
|
| 253 |
+
dim_out=dim_out,
|
| 254 |
+
num_heads=num_heads,
|
| 255 |
+
drop_path=dpr[i],
|
| 256 |
+
q_stride=self.q_stride if i in self.q_pool_blocks else None,
|
| 257 |
+
window_size=window_size,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
embed_dim = dim_out
|
| 261 |
+
self.blocks.append(block)
|
| 262 |
+
|
| 263 |
+
self.channel_list = (
|
| 264 |
+
[self.blocks[i].dim_out for i in self.stage_ends[::-1]]
|
| 265 |
+
if return_interm_layers
|
| 266 |
+
else [self.blocks[-1].dim_out]
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
|
| 270 |
+
h, w = hw
|
| 271 |
+
window_embed = self.pos_embed_window
|
| 272 |
+
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
|
| 273 |
+
pos_embed = pos_embed + window_embed.tile(
|
| 274 |
+
[x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
|
| 275 |
+
)
|
| 276 |
+
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
| 277 |
+
return pos_embed
|
| 278 |
+
|
| 279 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 280 |
+
x = self.patch_embed(x)
|
| 281 |
+
# x: (B, H, W, C)
|
| 282 |
+
|
| 283 |
+
# Add pos embed
|
| 284 |
+
x = x + self._get_pos_embed(x.shape[1:3])
|
| 285 |
+
|
| 286 |
+
outputs = []
|
| 287 |
+
for i, blk in enumerate(self.blocks):
|
| 288 |
+
x = blk(x)
|
| 289 |
+
if (i == self.stage_ends[-1]) or (
|
| 290 |
+
i in self.stage_ends and self.return_interm_layers
|
| 291 |
+
):
|
| 292 |
+
feats = x.permute(0, 3, 1, 2)
|
| 293 |
+
outputs.append(feats)
|
| 294 |
+
|
| 295 |
+
return outputs
|
RynnEC/third_parts/sam2/modeling/backbones/image_encoder.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ImageEncoder(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
trunk: nn.Module,
|
| 18 |
+
neck: nn.Module,
|
| 19 |
+
scalp: int = 0,
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.trunk = trunk
|
| 23 |
+
self.neck = neck
|
| 24 |
+
self.scalp = scalp
|
| 25 |
+
assert (
|
| 26 |
+
self.trunk.channel_list == self.neck.backbone_channel_list
|
| 27 |
+
), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
|
| 28 |
+
|
| 29 |
+
def forward(self, sample: torch.Tensor):
|
| 30 |
+
# Forward through backbone
|
| 31 |
+
features, pos = self.neck(self.trunk(sample))
|
| 32 |
+
if self.scalp > 0:
|
| 33 |
+
# Discard the lowest resolution features
|
| 34 |
+
features, pos = features[: -self.scalp], pos[: -self.scalp]
|
| 35 |
+
|
| 36 |
+
src = features[-1]
|
| 37 |
+
output = {
|
| 38 |
+
"vision_features": src,
|
| 39 |
+
"vision_pos_enc": pos,
|
| 40 |
+
"backbone_fpn": features,
|
| 41 |
+
}
|
| 42 |
+
return output
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class FpnNeck(nn.Module):
|
| 46 |
+
"""
|
| 47 |
+
A modified variant of Feature Pyramid Network (FPN) neck
|
| 48 |
+
(we remove output conv and also do bicubic interpolation similar to ViT
|
| 49 |
+
pos embed interpolation)
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
position_encoding: nn.Module,
|
| 55 |
+
d_model: int,
|
| 56 |
+
backbone_channel_list: List[int],
|
| 57 |
+
kernel_size: int = 1,
|
| 58 |
+
stride: int = 1,
|
| 59 |
+
padding: int = 0,
|
| 60 |
+
fpn_interp_model: str = "bilinear",
|
| 61 |
+
fuse_type: str = "sum",
|
| 62 |
+
fpn_top_down_levels: Optional[List[int]] = None,
|
| 63 |
+
):
|
| 64 |
+
"""Initialize the neck
|
| 65 |
+
:param trunk: the backbone
|
| 66 |
+
:param position_encoding: the positional encoding to use
|
| 67 |
+
:param d_model: the dimension of the model
|
| 68 |
+
:param neck_norm: the normalization to use
|
| 69 |
+
"""
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.position_encoding = position_encoding
|
| 72 |
+
self.convs = nn.ModuleList()
|
| 73 |
+
self.backbone_channel_list = backbone_channel_list
|
| 74 |
+
for dim in backbone_channel_list:
|
| 75 |
+
current = nn.Sequential()
|
| 76 |
+
current.add_module(
|
| 77 |
+
"conv",
|
| 78 |
+
nn.Conv2d(
|
| 79 |
+
in_channels=dim,
|
| 80 |
+
out_channels=d_model,
|
| 81 |
+
kernel_size=kernel_size,
|
| 82 |
+
stride=stride,
|
| 83 |
+
padding=padding,
|
| 84 |
+
),
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
self.convs.append(current)
|
| 88 |
+
self.fpn_interp_model = fpn_interp_model
|
| 89 |
+
assert fuse_type in ["sum", "avg"]
|
| 90 |
+
self.fuse_type = fuse_type
|
| 91 |
+
|
| 92 |
+
# levels to have top-down features in its outputs
|
| 93 |
+
# e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
|
| 94 |
+
# have top-down propagation, while outputs of level 0 and level 1 have only
|
| 95 |
+
# lateral features from the same backbone level.
|
| 96 |
+
if fpn_top_down_levels is None:
|
| 97 |
+
# default is to have top-down features on all levels
|
| 98 |
+
fpn_top_down_levels = range(len(self.convs))
|
| 99 |
+
self.fpn_top_down_levels = list(fpn_top_down_levels)
|
| 100 |
+
|
| 101 |
+
def forward(self, xs: List[torch.Tensor]):
|
| 102 |
+
|
| 103 |
+
out = [None] * len(self.convs)
|
| 104 |
+
pos = [None] * len(self.convs)
|
| 105 |
+
assert len(xs) == len(self.convs)
|
| 106 |
+
# fpn forward pass
|
| 107 |
+
# see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
|
| 108 |
+
prev_features = None
|
| 109 |
+
# forward in top-down order (from low to high resolution)
|
| 110 |
+
n = len(self.convs) - 1
|
| 111 |
+
for i in range(n, -1, -1):
|
| 112 |
+
x = xs[i]
|
| 113 |
+
lateral_features = self.convs[n - i](x)
|
| 114 |
+
if i in self.fpn_top_down_levels and prev_features is not None:
|
| 115 |
+
top_down_features = F.interpolate(
|
| 116 |
+
prev_features.to(dtype=torch.float32),
|
| 117 |
+
scale_factor=2.0,
|
| 118 |
+
mode=self.fpn_interp_model,
|
| 119 |
+
align_corners=(
|
| 120 |
+
None if self.fpn_interp_model == "nearest" else False
|
| 121 |
+
),
|
| 122 |
+
antialias=False,
|
| 123 |
+
)
|
| 124 |
+
prev_features = lateral_features + top_down_features
|
| 125 |
+
if self.fuse_type == "avg":
|
| 126 |
+
prev_features /= 2
|
| 127 |
+
else:
|
| 128 |
+
prev_features = lateral_features
|
| 129 |
+
x_out = prev_features
|
| 130 |
+
out[i] = x_out
|
| 131 |
+
pos[i] = self.position_encoding(x_out).to(x_out.dtype)
|
| 132 |
+
|
| 133 |
+
return out, pos
|
RynnEC/third_parts/sam2/modeling/backbones/utils.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Some utilities for backbones, in particular for windowing"""
|
| 8 |
+
|
| 9 |
+
from typing import Tuple
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def window_partition(x, window_size):
|
| 17 |
+
"""
|
| 18 |
+
Partition into non-overlapping windows with padding if needed.
|
| 19 |
+
Args:
|
| 20 |
+
x (tensor): input tokens with [B, H, W, C].
|
| 21 |
+
window_size (int): window size.
|
| 22 |
+
Returns:
|
| 23 |
+
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
| 24 |
+
(Hp, Wp): padded height and width before partition
|
| 25 |
+
"""
|
| 26 |
+
B, H, W, C = x.shape
|
| 27 |
+
|
| 28 |
+
pad_h = (window_size - H % window_size) % window_size
|
| 29 |
+
pad_w = (window_size - W % window_size) % window_size
|
| 30 |
+
if pad_h > 0 or pad_w > 0:
|
| 31 |
+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
| 32 |
+
Hp, Wp = H + pad_h, W + pad_w
|
| 33 |
+
|
| 34 |
+
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
| 35 |
+
windows = (
|
| 36 |
+
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
| 37 |
+
)
|
| 38 |
+
return windows, (Hp, Wp)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def window_unpartition(windows, window_size, pad_hw, hw):
|
| 42 |
+
"""
|
| 43 |
+
Window unpartition into original sequences and removing padding.
|
| 44 |
+
Args:
|
| 45 |
+
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
| 46 |
+
window_size (int): window size.
|
| 47 |
+
pad_hw (Tuple): padded height and width (Hp, Wp).
|
| 48 |
+
hw (Tuple): original height and width (H, W) before padding.
|
| 49 |
+
Returns:
|
| 50 |
+
x: unpartitioned sequences with [B, H, W, C].
|
| 51 |
+
"""
|
| 52 |
+
Hp, Wp = pad_hw
|
| 53 |
+
H, W = hw
|
| 54 |
+
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
| 55 |
+
x = windows.view(
|
| 56 |
+
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
|
| 57 |
+
)
|
| 58 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
| 59 |
+
|
| 60 |
+
if Hp > H or Wp > W:
|
| 61 |
+
x = x[:, :H, :W, :].contiguous()
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class PatchEmbed(nn.Module):
|
| 66 |
+
"""
|
| 67 |
+
Image to Patch Embedding.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
kernel_size: Tuple[int, ...] = (7, 7),
|
| 73 |
+
stride: Tuple[int, ...] = (4, 4),
|
| 74 |
+
padding: Tuple[int, ...] = (3, 3),
|
| 75 |
+
in_chans: int = 3,
|
| 76 |
+
embed_dim: int = 768,
|
| 77 |
+
):
|
| 78 |
+
"""
|
| 79 |
+
Args:
|
| 80 |
+
kernel_size (Tuple): kernel size of the projection layer.
|
| 81 |
+
stride (Tuple): stride of the projection layer.
|
| 82 |
+
padding (Tuple): padding size of the projection layer.
|
| 83 |
+
in_chans (int): Number of input image channels.
|
| 84 |
+
embed_dim (int): embed_dim (int): Patch embedding dimension.
|
| 85 |
+
"""
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.proj = nn.Conv2d(
|
| 88 |
+
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 92 |
+
x = self.proj(x)
|
| 93 |
+
# B C H W -> B H W C
|
| 94 |
+
x = x.permute(0, 2, 3, 1)
|
| 95 |
+
return x
|
RynnEC/third_parts/sam2/modeling/memory_attention.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn, Tensor
|
| 11 |
+
|
| 12 |
+
from third_parts.sam2.modeling.sam.transformer import RoPEAttention
|
| 13 |
+
|
| 14 |
+
from third_parts.sam2.modeling.sam2_utils import get_activation_fn, get_clones
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MemoryAttentionLayer(nn.Module):
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
activation: str,
|
| 22 |
+
cross_attention: nn.Module,
|
| 23 |
+
d_model: int,
|
| 24 |
+
dim_feedforward: int,
|
| 25 |
+
dropout: float,
|
| 26 |
+
pos_enc_at_attn: bool,
|
| 27 |
+
pos_enc_at_cross_attn_keys: bool,
|
| 28 |
+
pos_enc_at_cross_attn_queries: bool,
|
| 29 |
+
self_attention: nn.Module,
|
| 30 |
+
):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.d_model = d_model
|
| 33 |
+
self.dim_feedforward = dim_feedforward
|
| 34 |
+
self.dropout_value = dropout
|
| 35 |
+
self.self_attn = self_attention
|
| 36 |
+
self.cross_attn_image = cross_attention
|
| 37 |
+
|
| 38 |
+
# Implementation of Feedforward model
|
| 39 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 40 |
+
self.dropout = nn.Dropout(dropout)
|
| 41 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 42 |
+
|
| 43 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 44 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 45 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 46 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 47 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 48 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 49 |
+
|
| 50 |
+
self.activation_str = activation
|
| 51 |
+
self.activation = get_activation_fn(activation)
|
| 52 |
+
|
| 53 |
+
# Where to add pos enc
|
| 54 |
+
self.pos_enc_at_attn = pos_enc_at_attn
|
| 55 |
+
self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
|
| 56 |
+
self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
|
| 57 |
+
|
| 58 |
+
def _forward_sa(self, tgt, query_pos):
|
| 59 |
+
# Self-Attention
|
| 60 |
+
tgt2 = self.norm1(tgt)
|
| 61 |
+
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
|
| 62 |
+
tgt2 = self.self_attn(q, k, v=tgt2)
|
| 63 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 64 |
+
return tgt
|
| 65 |
+
|
| 66 |
+
def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
|
| 67 |
+
kwds = {}
|
| 68 |
+
if num_k_exclude_rope > 0:
|
| 69 |
+
assert isinstance(self.cross_attn_image, RoPEAttention)
|
| 70 |
+
kwds = {"num_k_exclude_rope": num_k_exclude_rope}
|
| 71 |
+
|
| 72 |
+
# Cross-Attention
|
| 73 |
+
tgt2 = self.norm2(tgt)
|
| 74 |
+
tgt2 = self.cross_attn_image(
|
| 75 |
+
q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
|
| 76 |
+
k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
|
| 77 |
+
v=memory,
|
| 78 |
+
**kwds,
|
| 79 |
+
)
|
| 80 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 81 |
+
return tgt
|
| 82 |
+
|
| 83 |
+
def forward(
|
| 84 |
+
self,
|
| 85 |
+
tgt,
|
| 86 |
+
memory,
|
| 87 |
+
pos: Optional[Tensor] = None,
|
| 88 |
+
query_pos: Optional[Tensor] = None,
|
| 89 |
+
num_k_exclude_rope: int = 0,
|
| 90 |
+
) -> torch.Tensor:
|
| 91 |
+
|
| 92 |
+
# Self-Attn, Cross-Attn
|
| 93 |
+
tgt = self._forward_sa(tgt, query_pos)
|
| 94 |
+
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
|
| 95 |
+
# MLP
|
| 96 |
+
tgt2 = self.norm3(tgt)
|
| 97 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
| 98 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 99 |
+
return tgt
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class MemoryAttention(nn.Module):
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
d_model: int,
|
| 106 |
+
pos_enc_at_input: bool,
|
| 107 |
+
layer: nn.Module,
|
| 108 |
+
num_layers: int,
|
| 109 |
+
batch_first: bool = True, # Do layers expect batch first input?
|
| 110 |
+
):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.d_model = d_model
|
| 113 |
+
self.layers = get_clones(layer, num_layers)
|
| 114 |
+
self.num_layers = num_layers
|
| 115 |
+
self.norm = nn.LayerNorm(d_model)
|
| 116 |
+
self.pos_enc_at_input = pos_enc_at_input
|
| 117 |
+
self.batch_first = batch_first
|
| 118 |
+
|
| 119 |
+
def forward(
|
| 120 |
+
self,
|
| 121 |
+
curr: torch.Tensor, # self-attention inputs
|
| 122 |
+
memory: torch.Tensor, # cross-attention inputs
|
| 123 |
+
curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
|
| 124 |
+
memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
|
| 125 |
+
num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
|
| 126 |
+
):
|
| 127 |
+
if isinstance(curr, list):
|
| 128 |
+
assert isinstance(curr_pos, list)
|
| 129 |
+
assert len(curr) == len(curr_pos) == 1
|
| 130 |
+
curr, curr_pos = (
|
| 131 |
+
curr[0],
|
| 132 |
+
curr_pos[0],
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
assert (
|
| 136 |
+
curr.shape[1] == memory.shape[1]
|
| 137 |
+
), "Batch size must be the same for curr and memory"
|
| 138 |
+
|
| 139 |
+
output = curr
|
| 140 |
+
if self.pos_enc_at_input and curr_pos is not None:
|
| 141 |
+
output = output + 0.1 * curr_pos
|
| 142 |
+
|
| 143 |
+
if self.batch_first:
|
| 144 |
+
# Convert to batch first
|
| 145 |
+
output = output.transpose(0, 1)
|
| 146 |
+
curr_pos = curr_pos.transpose(0, 1)
|
| 147 |
+
memory = memory.transpose(0, 1)
|
| 148 |
+
memory_pos = memory_pos.transpose(0, 1)
|
| 149 |
+
|
| 150 |
+
for layer in self.layers:
|
| 151 |
+
kwds = {}
|
| 152 |
+
if isinstance(layer.cross_attn_image, RoPEAttention):
|
| 153 |
+
kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
|
| 154 |
+
|
| 155 |
+
output = layer(
|
| 156 |
+
tgt=output,
|
| 157 |
+
memory=memory,
|
| 158 |
+
pos=memory_pos,
|
| 159 |
+
query_pos=curr_pos,
|
| 160 |
+
**kwds,
|
| 161 |
+
)
|
| 162 |
+
normed_output = self.norm(output)
|
| 163 |
+
|
| 164 |
+
if self.batch_first:
|
| 165 |
+
# Convert back to seq first
|
| 166 |
+
normed_output = normed_output.transpose(0, 1)
|
| 167 |
+
curr_pos = curr_pos.transpose(0, 1)
|
| 168 |
+
|
| 169 |
+
return normed_output
|
RynnEC/third_parts/sam2/modeling/memory_encoder.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
from third_parts.sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MaskDownSampler(nn.Module):
|
| 18 |
+
"""
|
| 19 |
+
Progressively downsample a mask by total_stride, each time by stride.
|
| 20 |
+
Note that LayerNorm is applied per *token*, like in ViT.
|
| 21 |
+
|
| 22 |
+
With each downsample (by a factor stride**2), channel capacity increases by the same factor.
|
| 23 |
+
In the end, we linearly project to embed_dim channels.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
embed_dim=256,
|
| 29 |
+
kernel_size=4,
|
| 30 |
+
stride=4,
|
| 31 |
+
padding=0,
|
| 32 |
+
total_stride=16,
|
| 33 |
+
activation=nn.GELU,
|
| 34 |
+
):
|
| 35 |
+
super().__init__()
|
| 36 |
+
num_layers = int(math.log2(total_stride) // math.log2(stride))
|
| 37 |
+
assert stride**num_layers == total_stride
|
| 38 |
+
self.encoder = nn.Sequential()
|
| 39 |
+
mask_in_chans, mask_out_chans = 1, 1
|
| 40 |
+
for _ in range(num_layers):
|
| 41 |
+
mask_out_chans = mask_in_chans * (stride**2)
|
| 42 |
+
self.encoder.append(
|
| 43 |
+
nn.Conv2d(
|
| 44 |
+
mask_in_chans,
|
| 45 |
+
mask_out_chans,
|
| 46 |
+
kernel_size=kernel_size,
|
| 47 |
+
stride=stride,
|
| 48 |
+
padding=padding,
|
| 49 |
+
)
|
| 50 |
+
)
|
| 51 |
+
self.encoder.append(LayerNorm2d(mask_out_chans))
|
| 52 |
+
self.encoder.append(activation())
|
| 53 |
+
mask_in_chans = mask_out_chans
|
| 54 |
+
|
| 55 |
+
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
return self.encoder(x)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
|
| 62 |
+
class CXBlock(nn.Module):
|
| 63 |
+
r"""ConvNeXt Block. There are two equivalent implementations:
|
| 64 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
| 65 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
| 66 |
+
We use (2) as we find it slightly faster in PyTorch
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
dim (int): Number of input channels.
|
| 70 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
| 71 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
dim,
|
| 77 |
+
kernel_size=7,
|
| 78 |
+
padding=3,
|
| 79 |
+
drop_path=0.0,
|
| 80 |
+
layer_scale_init_value=1e-6,
|
| 81 |
+
use_dwconv=True,
|
| 82 |
+
):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.dwconv = nn.Conv2d(
|
| 85 |
+
dim,
|
| 86 |
+
dim,
|
| 87 |
+
kernel_size=kernel_size,
|
| 88 |
+
padding=padding,
|
| 89 |
+
groups=dim if use_dwconv else 1,
|
| 90 |
+
) # depthwise conv
|
| 91 |
+
self.norm = LayerNorm2d(dim, eps=1e-6)
|
| 92 |
+
self.pwconv1 = nn.Linear(
|
| 93 |
+
dim, 4 * dim
|
| 94 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
| 95 |
+
self.act = nn.GELU()
|
| 96 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
| 97 |
+
self.gamma = (
|
| 98 |
+
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
| 99 |
+
if layer_scale_init_value > 0
|
| 100 |
+
else None
|
| 101 |
+
)
|
| 102 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 103 |
+
|
| 104 |
+
def forward(self, x):
|
| 105 |
+
input = x
|
| 106 |
+
x = self.dwconv(x)
|
| 107 |
+
x = self.norm(x)
|
| 108 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
| 109 |
+
x = self.pwconv1(x)
|
| 110 |
+
x = self.act(x)
|
| 111 |
+
x = self.pwconv2(x)
|
| 112 |
+
if self.gamma is not None:
|
| 113 |
+
x = self.gamma * x
|
| 114 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
| 115 |
+
|
| 116 |
+
x = input + self.drop_path(x)
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class Fuser(nn.Module):
|
| 121 |
+
def __init__(self, layer, num_layers, dim=None, input_projection=False):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.proj = nn.Identity()
|
| 124 |
+
self.layers = get_clones(layer, num_layers)
|
| 125 |
+
|
| 126 |
+
if input_projection:
|
| 127 |
+
assert dim is not None
|
| 128 |
+
self.proj = nn.Conv2d(dim, dim, kernel_size=1)
|
| 129 |
+
|
| 130 |
+
def forward(self, x):
|
| 131 |
+
# normally x: (N, C, H, W)
|
| 132 |
+
x = self.proj(x)
|
| 133 |
+
for layer in self.layers:
|
| 134 |
+
x = layer(x)
|
| 135 |
+
return x
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class MemoryEncoder(nn.Module):
|
| 139 |
+
def __init__(
|
| 140 |
+
self,
|
| 141 |
+
out_dim,
|
| 142 |
+
mask_downsampler,
|
| 143 |
+
fuser,
|
| 144 |
+
position_encoding,
|
| 145 |
+
in_dim=256, # in_dim of pix_feats
|
| 146 |
+
):
|
| 147 |
+
super().__init__()
|
| 148 |
+
|
| 149 |
+
self.mask_downsampler = mask_downsampler
|
| 150 |
+
|
| 151 |
+
self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
|
| 152 |
+
self.fuser = fuser
|
| 153 |
+
self.position_encoding = position_encoding
|
| 154 |
+
self.out_proj = nn.Identity()
|
| 155 |
+
if out_dim != in_dim:
|
| 156 |
+
self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
|
| 157 |
+
|
| 158 |
+
def forward(
|
| 159 |
+
self,
|
| 160 |
+
pix_feat: torch.Tensor,
|
| 161 |
+
masks: torch.Tensor,
|
| 162 |
+
skip_mask_sigmoid: bool = False,
|
| 163 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 164 |
+
## Process masks
|
| 165 |
+
# sigmoid, so that less domain shift from gt masks which are bool
|
| 166 |
+
if not skip_mask_sigmoid:
|
| 167 |
+
masks = F.sigmoid(masks)
|
| 168 |
+
masks = self.mask_downsampler(masks)
|
| 169 |
+
|
| 170 |
+
## Fuse pix_feats and downsampled masks
|
| 171 |
+
# in case the visual features are on CPU, cast them to CUDA
|
| 172 |
+
pix_feat = pix_feat.to(masks.device)
|
| 173 |
+
|
| 174 |
+
x = self.pix_feat_proj(pix_feat)
|
| 175 |
+
x = x + masks
|
| 176 |
+
x = self.fuser(x)
|
| 177 |
+
x = self.out_proj(x)
|
| 178 |
+
|
| 179 |
+
pos = self.position_encoding(x).to(x.dtype)
|
| 180 |
+
|
| 181 |
+
return {"vision_features": x, "vision_pos_enc": [pos]}
|
RynnEC/third_parts/sam2/modeling/position_encoding.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
from typing import Any, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch import nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class PositionEmbeddingSine(nn.Module):
|
| 17 |
+
"""
|
| 18 |
+
This is a more standard version of the position embedding, very similar to the one
|
| 19 |
+
used by the Attention is all you need paper, generalized to work on images.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
num_pos_feats,
|
| 25 |
+
temperature: int = 10000,
|
| 26 |
+
normalize: bool = True,
|
| 27 |
+
scale: Optional[float] = None,
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
assert num_pos_feats % 2 == 0, "Expecting even model width"
|
| 31 |
+
self.num_pos_feats = num_pos_feats // 2
|
| 32 |
+
self.temperature = temperature
|
| 33 |
+
self.normalize = normalize
|
| 34 |
+
if scale is not None and normalize is False:
|
| 35 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 36 |
+
if scale is None:
|
| 37 |
+
scale = 2 * math.pi
|
| 38 |
+
self.scale = scale
|
| 39 |
+
|
| 40 |
+
self.cache = {}
|
| 41 |
+
|
| 42 |
+
def _encode_xy(self, x, y):
|
| 43 |
+
# The positions are expected to be normalized
|
| 44 |
+
assert len(x) == len(y) and x.ndim == y.ndim == 1
|
| 45 |
+
x_embed = x * self.scale
|
| 46 |
+
y_embed = y * self.scale
|
| 47 |
+
|
| 48 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
| 49 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
| 50 |
+
|
| 51 |
+
pos_x = x_embed[:, None] / dim_t
|
| 52 |
+
pos_y = y_embed[:, None] / dim_t
|
| 53 |
+
pos_x = torch.stack(
|
| 54 |
+
(pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
|
| 55 |
+
).flatten(1)
|
| 56 |
+
pos_y = torch.stack(
|
| 57 |
+
(pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
|
| 58 |
+
).flatten(1)
|
| 59 |
+
return pos_x, pos_y
|
| 60 |
+
|
| 61 |
+
@torch.no_grad()
|
| 62 |
+
def encode_boxes(self, x, y, w, h):
|
| 63 |
+
pos_x, pos_y = self._encode_xy(x, y)
|
| 64 |
+
pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
|
| 65 |
+
return pos
|
| 66 |
+
|
| 67 |
+
encode = encode_boxes # Backwards compatibility
|
| 68 |
+
|
| 69 |
+
@torch.no_grad()
|
| 70 |
+
def encode_points(self, x, y, labels):
|
| 71 |
+
(bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
|
| 72 |
+
assert bx == by and nx == ny and bx == bl and nx == nl
|
| 73 |
+
pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
|
| 74 |
+
pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
|
| 75 |
+
pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
|
| 76 |
+
return pos
|
| 77 |
+
|
| 78 |
+
@torch.no_grad()
|
| 79 |
+
def forward(self, x: torch.Tensor):
|
| 80 |
+
cache_key = (x.shape[-2], x.shape[-1])
|
| 81 |
+
if cache_key in self.cache:
|
| 82 |
+
return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
|
| 83 |
+
y_embed = (
|
| 84 |
+
torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
|
| 85 |
+
.view(1, -1, 1)
|
| 86 |
+
.repeat(x.shape[0], 1, x.shape[-1])
|
| 87 |
+
)
|
| 88 |
+
x_embed = (
|
| 89 |
+
torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
|
| 90 |
+
.view(1, 1, -1)
|
| 91 |
+
.repeat(x.shape[0], x.shape[-2], 1)
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
if self.normalize:
|
| 95 |
+
eps = 1e-6
|
| 96 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
| 97 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
| 98 |
+
|
| 99 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
| 100 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
| 101 |
+
|
| 102 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
| 103 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
| 104 |
+
pos_x = torch.stack(
|
| 105 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
| 106 |
+
).flatten(3)
|
| 107 |
+
pos_y = torch.stack(
|
| 108 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
| 109 |
+
).flatten(3)
|
| 110 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 111 |
+
self.cache[cache_key] = pos[0]
|
| 112 |
+
return pos
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class PositionEmbeddingRandom(nn.Module):
|
| 116 |
+
"""
|
| 117 |
+
Positional encoding using random spatial frequencies.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
| 121 |
+
super().__init__()
|
| 122 |
+
if scale is None or scale <= 0.0:
|
| 123 |
+
scale = 1.0
|
| 124 |
+
self.register_buffer(
|
| 125 |
+
"positional_encoding_gaussian_matrix",
|
| 126 |
+
scale * torch.randn((2, num_pos_feats)),
|
| 127 |
+
)
|
| 128 |
+
self.first = True
|
| 129 |
+
|
| 130 |
+
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
| 131 |
+
"""Positionally encode points that are normalized to [0,1]."""
|
| 132 |
+
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
| 133 |
+
coords = 2 * coords - 1
|
| 134 |
+
coords = coords.to(self.positional_encoding_gaussian_matrix.dtype)
|
| 135 |
+
if self.first:
|
| 136 |
+
self.positional_encoding_gaussian_matrix = self.positional_encoding_gaussian_matrix.to(coords.device)
|
| 137 |
+
self.first = False
|
| 138 |
+
coords = coords @ self.positional_encoding_gaussian_matrix
|
| 139 |
+
coords = 2 * np.pi * coords
|
| 140 |
+
# outputs d_1 x ... x d_n x C shape
|
| 141 |
+
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
| 142 |
+
|
| 143 |
+
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
|
| 144 |
+
"""Generate positional encoding for a grid of the specified size."""
|
| 145 |
+
h, w = size
|
| 146 |
+
device: Any = self.positional_encoding_gaussian_matrix.device
|
| 147 |
+
grid = torch.ones((h, w), device=device, dtype=torch.float32)
|
| 148 |
+
y_embed = grid.cumsum(dim=0) - 0.5
|
| 149 |
+
x_embed = grid.cumsum(dim=1) - 0.5
|
| 150 |
+
y_embed = y_embed / h
|
| 151 |
+
x_embed = x_embed / w
|
| 152 |
+
|
| 153 |
+
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
|
| 154 |
+
return pe.permute(2, 0, 1) # C x H x W
|
| 155 |
+
|
| 156 |
+
def forward_with_coords(
|
| 157 |
+
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
|
| 158 |
+
) -> torch.Tensor:
|
| 159 |
+
"""Positionally encode points that are not normalized to [0,1]."""
|
| 160 |
+
coords = coords_input.clone()
|
| 161 |
+
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
| 162 |
+
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
|
| 163 |
+
return self._pe_encoding(coords.to(torch.float)) # B x N x C
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# Rotary Positional Encoding, adapted from:
|
| 167 |
+
# 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
|
| 168 |
+
# 2. https://github.com/naver-ai/rope-vit
|
| 169 |
+
# 3. https://github.com/lucidrains/rotary-embedding-torch
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def init_t_xy(end_x: int, end_y: int):
|
| 173 |
+
t = torch.arange(end_x * end_y, dtype=torch.float32)
|
| 174 |
+
t_x = (t % end_x).float()
|
| 175 |
+
t_y = torch.div(t, end_x, rounding_mode="floor").float()
|
| 176 |
+
return t_x, t_y
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
| 180 |
+
freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
| 181 |
+
freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
| 182 |
+
|
| 183 |
+
t_x, t_y = init_t_xy(end_x, end_y)
|
| 184 |
+
freqs_x = torch.outer(t_x, freqs_x)
|
| 185 |
+
freqs_y = torch.outer(t_y, freqs_y)
|
| 186 |
+
freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
|
| 187 |
+
freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
|
| 188 |
+
return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
| 192 |
+
ndim = x.ndim
|
| 193 |
+
assert 0 <= 1 < ndim
|
| 194 |
+
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
|
| 195 |
+
shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
|
| 196 |
+
return freqs_cis.view(*shape)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def apply_rotary_enc(
|
| 200 |
+
xq: torch.Tensor,
|
| 201 |
+
xk: torch.Tensor,
|
| 202 |
+
freqs_cis: torch.Tensor,
|
| 203 |
+
repeat_freqs_k: bool = False,
|
| 204 |
+
):
|
| 205 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
| 206 |
+
xk_ = (
|
| 207 |
+
torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
| 208 |
+
if xk.shape[-2] != 0
|
| 209 |
+
else None
|
| 210 |
+
)
|
| 211 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
| 212 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
| 213 |
+
if xk_ is None:
|
| 214 |
+
# no keys to rotate, due to dropout
|
| 215 |
+
return xq_out.type_as(xq).to(xq.device), xk
|
| 216 |
+
# repeat freqs along seq_len dim to match k seq_len
|
| 217 |
+
if repeat_freqs_k:
|
| 218 |
+
r = xk_.shape[-2] // xq_.shape[-2]
|
| 219 |
+
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
| 220 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
| 221 |
+
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
|
RynnEC/third_parts/sam2/modeling/sam/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
RynnEC/third_parts/sam2/modeling/sam/mask_decoder.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import List, Optional, Tuple, Type
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
from third_parts.sam2.modeling.sam2_utils import LayerNorm2d, MLP
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class MaskDecoder(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
*,
|
| 19 |
+
transformer_dim: int,
|
| 20 |
+
transformer: nn.Module,
|
| 21 |
+
num_multimask_outputs: int = 3,
|
| 22 |
+
activation: Type[nn.Module] = nn.GELU,
|
| 23 |
+
iou_head_depth: int = 3,
|
| 24 |
+
iou_head_hidden_dim: int = 256,
|
| 25 |
+
use_high_res_features: bool = False,
|
| 26 |
+
iou_prediction_use_sigmoid=False,
|
| 27 |
+
dynamic_multimask_via_stability=False,
|
| 28 |
+
dynamic_multimask_stability_delta=0.05,
|
| 29 |
+
dynamic_multimask_stability_thresh=0.98,
|
| 30 |
+
pred_obj_scores: bool = False,
|
| 31 |
+
pred_obj_scores_mlp: bool = False,
|
| 32 |
+
use_multimask_token_for_obj_ptr: bool = False,
|
| 33 |
+
) -> None:
|
| 34 |
+
"""
|
| 35 |
+
Predicts masks given an image and prompt embeddings, using a
|
| 36 |
+
transformer architecture.
|
| 37 |
+
|
| 38 |
+
Arguments:
|
| 39 |
+
transformer_dim (int): the channel dimension of the transformer
|
| 40 |
+
transformer (nn.Module): the transformer used to predict masks
|
| 41 |
+
num_multimask_outputs (int): the number of masks to predict
|
| 42 |
+
when disambiguating masks
|
| 43 |
+
activation (nn.Module): the type of activation to use when
|
| 44 |
+
upscaling masks
|
| 45 |
+
iou_head_depth (int): the depth of the MLP used to predict
|
| 46 |
+
mask quality
|
| 47 |
+
iou_head_hidden_dim (int): the hidden dimension of the MLP
|
| 48 |
+
used to predict mask quality
|
| 49 |
+
"""
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.transformer_dim = transformer_dim
|
| 52 |
+
self.transformer = transformer
|
| 53 |
+
|
| 54 |
+
self.num_multimask_outputs = num_multimask_outputs
|
| 55 |
+
|
| 56 |
+
self.iou_token = nn.Embedding(1, transformer_dim)
|
| 57 |
+
self.num_mask_tokens = num_multimask_outputs + 1
|
| 58 |
+
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
| 59 |
+
|
| 60 |
+
self.pred_obj_scores = pred_obj_scores
|
| 61 |
+
if self.pred_obj_scores:
|
| 62 |
+
self.obj_score_token = nn.Embedding(1, transformer_dim)
|
| 63 |
+
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
|
| 64 |
+
|
| 65 |
+
self.output_upscaling = nn.Sequential(
|
| 66 |
+
nn.ConvTranspose2d(
|
| 67 |
+
transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
|
| 68 |
+
),
|
| 69 |
+
LayerNorm2d(transformer_dim // 4),
|
| 70 |
+
activation(),
|
| 71 |
+
nn.ConvTranspose2d(
|
| 72 |
+
transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
|
| 73 |
+
),
|
| 74 |
+
activation(),
|
| 75 |
+
)
|
| 76 |
+
self.use_high_res_features = use_high_res_features
|
| 77 |
+
if use_high_res_features:
|
| 78 |
+
self.conv_s0 = nn.Conv2d(
|
| 79 |
+
transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
|
| 80 |
+
)
|
| 81 |
+
self.conv_s1 = nn.Conv2d(
|
| 82 |
+
transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
self.output_hypernetworks_mlps = nn.ModuleList(
|
| 86 |
+
[
|
| 87 |
+
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
|
| 88 |
+
for i in range(self.num_mask_tokens)
|
| 89 |
+
]
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
self.iou_prediction_head = MLP(
|
| 93 |
+
transformer_dim,
|
| 94 |
+
iou_head_hidden_dim,
|
| 95 |
+
self.num_mask_tokens,
|
| 96 |
+
iou_head_depth,
|
| 97 |
+
sigmoid_output=iou_prediction_use_sigmoid,
|
| 98 |
+
)
|
| 99 |
+
if self.pred_obj_scores:
|
| 100 |
+
self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
|
| 101 |
+
if pred_obj_scores_mlp:
|
| 102 |
+
self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
|
| 103 |
+
|
| 104 |
+
# When outputting a single mask, optionally we can dynamically fall back to the best
|
| 105 |
+
# multimask output token if the single mask output token gives low stability scores.
|
| 106 |
+
self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
|
| 107 |
+
self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
|
| 108 |
+
self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
|
| 109 |
+
|
| 110 |
+
def forward(
|
| 111 |
+
self,
|
| 112 |
+
image_embeddings: torch.Tensor,
|
| 113 |
+
image_pe: torch.Tensor,
|
| 114 |
+
sparse_prompt_embeddings: torch.Tensor,
|
| 115 |
+
dense_prompt_embeddings: torch.Tensor,
|
| 116 |
+
multimask_output: bool,
|
| 117 |
+
repeat_image: bool,
|
| 118 |
+
high_res_features: Optional[List[torch.Tensor]] = None,
|
| 119 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 120 |
+
"""
|
| 121 |
+
Predict masks given image and prompt embeddings.
|
| 122 |
+
|
| 123 |
+
Arguments:
|
| 124 |
+
image_embeddings (torch.Tensor): the embeddings from the image encoder
|
| 125 |
+
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
| 126 |
+
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
| 127 |
+
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
|
| 128 |
+
multimask_output (bool): Whether to return multiple masks or a single
|
| 129 |
+
mask.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
torch.Tensor: batched predicted masks
|
| 133 |
+
torch.Tensor: batched predictions of mask quality
|
| 134 |
+
torch.Tensor: batched SAM token for mask output
|
| 135 |
+
"""
|
| 136 |
+
masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
|
| 137 |
+
image_embeddings=image_embeddings,
|
| 138 |
+
image_pe=image_pe,
|
| 139 |
+
sparse_prompt_embeddings=sparse_prompt_embeddings,
|
| 140 |
+
dense_prompt_embeddings=dense_prompt_embeddings,
|
| 141 |
+
repeat_image=repeat_image,
|
| 142 |
+
high_res_features=high_res_features,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Select the correct mask or masks for output
|
| 146 |
+
if multimask_output:
|
| 147 |
+
masks = masks[:, 1:, :, :]
|
| 148 |
+
iou_pred = iou_pred[:, 1:]
|
| 149 |
+
elif self.dynamic_multimask_via_stability and not self.training:
|
| 150 |
+
masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
|
| 151 |
+
else:
|
| 152 |
+
masks = masks[:, 0:1, :, :]
|
| 153 |
+
iou_pred = iou_pred[:, 0:1]
|
| 154 |
+
|
| 155 |
+
if multimask_output and self.use_multimask_token_for_obj_ptr:
|
| 156 |
+
sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
|
| 157 |
+
else:
|
| 158 |
+
# Take the mask output token. Here we *always* use the token for single mask output.
|
| 159 |
+
# At test time, even if we track after 1-click (and using multimask_output=True),
|
| 160 |
+
# we still take the single mask token here. The rationale is that we always track
|
| 161 |
+
# after multiple clicks during training, so the past tokens seen during training
|
| 162 |
+
# are always the single mask token (and we'll let it be the object-memory token).
|
| 163 |
+
sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
|
| 164 |
+
|
| 165 |
+
# Prepare output
|
| 166 |
+
return masks, iou_pred, sam_tokens_out, object_score_logits
|
| 167 |
+
|
| 168 |
+
def predict_masks(
|
| 169 |
+
self,
|
| 170 |
+
image_embeddings: torch.Tensor,
|
| 171 |
+
image_pe: torch.Tensor,
|
| 172 |
+
sparse_prompt_embeddings: torch.Tensor,
|
| 173 |
+
dense_prompt_embeddings: torch.Tensor,
|
| 174 |
+
repeat_image: bool,
|
| 175 |
+
high_res_features: Optional[List[torch.Tensor]] = None,
|
| 176 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 177 |
+
"""Predicts masks. See 'forward' for more details."""
|
| 178 |
+
# Concatenate output tokens
|
| 179 |
+
s = 0
|
| 180 |
+
if self.pred_obj_scores:
|
| 181 |
+
output_tokens = torch.cat(
|
| 182 |
+
[
|
| 183 |
+
self.obj_score_token.weight,
|
| 184 |
+
self.iou_token.weight,
|
| 185 |
+
self.mask_tokens.weight,
|
| 186 |
+
],
|
| 187 |
+
dim=0,
|
| 188 |
+
)
|
| 189 |
+
s = 1
|
| 190 |
+
else:
|
| 191 |
+
output_tokens = torch.cat(
|
| 192 |
+
[self.iou_token.weight, self.mask_tokens.weight], dim=0
|
| 193 |
+
)
|
| 194 |
+
output_tokens = output_tokens.unsqueeze(0).expand(
|
| 195 |
+
sparse_prompt_embeddings.size(0), -1, -1
|
| 196 |
+
)
|
| 197 |
+
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
| 198 |
+
|
| 199 |
+
# Expand per-image data in batch direction to be per-mask
|
| 200 |
+
if repeat_image:
|
| 201 |
+
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
|
| 202 |
+
else:
|
| 203 |
+
assert image_embeddings.shape[0] == tokens.shape[0]
|
| 204 |
+
src = image_embeddings
|
| 205 |
+
src = src + dense_prompt_embeddings
|
| 206 |
+
assert (
|
| 207 |
+
image_pe.size(0) == 1
|
| 208 |
+
), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
|
| 209 |
+
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
| 210 |
+
b, c, h, w = src.shape
|
| 211 |
+
|
| 212 |
+
# Run the transformer
|
| 213 |
+
# print('src: ', src.dtype, 'pos_src:', pos_src.dtype, 'tokens:', tokens.dtype)
|
| 214 |
+
_dtype = pos_src.dtype
|
| 215 |
+
src = src.to(_dtype)
|
| 216 |
+
tokens = tokens.to(_dtype)
|
| 217 |
+
hs, src = self.transformer(src, pos_src, tokens)
|
| 218 |
+
iou_token_out = hs[:, s, :]
|
| 219 |
+
mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
|
| 220 |
+
|
| 221 |
+
# Upscale mask embeddings and predict masks using the mask tokens
|
| 222 |
+
src = src.transpose(1, 2).view(b, c, h, w)
|
| 223 |
+
if not self.use_high_res_features:
|
| 224 |
+
upscaled_embedding = self.output_upscaling(src)
|
| 225 |
+
else:
|
| 226 |
+
dc1, ln1, act1, dc2, act2 = self.output_upscaling
|
| 227 |
+
feat_s0, feat_s1 = high_res_features
|
| 228 |
+
upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
|
| 229 |
+
upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
|
| 230 |
+
|
| 231 |
+
hyper_in_list: List[torch.Tensor] = []
|
| 232 |
+
for i in range(self.num_mask_tokens):
|
| 233 |
+
hyper_in_list.append(
|
| 234 |
+
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
|
| 235 |
+
)
|
| 236 |
+
hyper_in = torch.stack(hyper_in_list, dim=1)
|
| 237 |
+
b, c, h, w = upscaled_embedding.shape
|
| 238 |
+
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
| 239 |
+
|
| 240 |
+
# Generate mask quality predictions
|
| 241 |
+
iou_pred = self.iou_prediction_head(iou_token_out)
|
| 242 |
+
if self.pred_obj_scores:
|
| 243 |
+
assert s == 1
|
| 244 |
+
object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
|
| 245 |
+
else:
|
| 246 |
+
# Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
|
| 247 |
+
object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
|
| 248 |
+
|
| 249 |
+
return masks, iou_pred, mask_tokens_out, object_score_logits
|
| 250 |
+
|
| 251 |
+
def _get_stability_scores(self, mask_logits):
|
| 252 |
+
"""
|
| 253 |
+
Compute stability scores of the mask logits based on the IoU between upper and
|
| 254 |
+
lower thresholds, similar to https://github.com/fairinternal/onevision/pull/568.
|
| 255 |
+
"""
|
| 256 |
+
mask_logits = mask_logits.flatten(-2)
|
| 257 |
+
stability_delta = self.dynamic_multimask_stability_delta
|
| 258 |
+
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
|
| 259 |
+
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
|
| 260 |
+
stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
|
| 261 |
+
return stability_scores
|
| 262 |
+
|
| 263 |
+
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
|
| 264 |
+
"""
|
| 265 |
+
When outputting a single mask, if the stability score from the current single-mask
|
| 266 |
+
output (based on output token 0) falls below a threshold, we instead select from
|
| 267 |
+
multi-mask outputs (based on output token 1~3) the mask with the highest predicted
|
| 268 |
+
IoU score. This is intended to ensure a valid mask for both clicking and tracking.
|
| 269 |
+
"""
|
| 270 |
+
# The best mask from multimask output tokens (1~3)
|
| 271 |
+
multimask_logits = all_mask_logits[:, 1:, :, :]
|
| 272 |
+
multimask_iou_scores = all_iou_scores[:, 1:]
|
| 273 |
+
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
|
| 274 |
+
batch_inds = torch.arange(
|
| 275 |
+
multimask_iou_scores.size(0), device=all_iou_scores.device
|
| 276 |
+
)
|
| 277 |
+
best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
|
| 278 |
+
best_multimask_logits = best_multimask_logits.unsqueeze(1)
|
| 279 |
+
best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
|
| 280 |
+
best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
|
| 281 |
+
|
| 282 |
+
# The mask from singlemask output token 0 and its stability score
|
| 283 |
+
singlemask_logits = all_mask_logits[:, 0:1, :, :]
|
| 284 |
+
singlemask_iou_scores = all_iou_scores[:, 0:1]
|
| 285 |
+
stability_scores = self._get_stability_scores(singlemask_logits)
|
| 286 |
+
is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
|
| 287 |
+
|
| 288 |
+
# Dynamically fall back to best multimask output upon low stability scores.
|
| 289 |
+
mask_logits_out = torch.where(
|
| 290 |
+
is_stable[..., None, None].expand_as(singlemask_logits),
|
| 291 |
+
singlemask_logits,
|
| 292 |
+
best_multimask_logits,
|
| 293 |
+
)
|
| 294 |
+
iou_scores_out = torch.where(
|
| 295 |
+
is_stable.expand_as(singlemask_iou_scores),
|
| 296 |
+
singlemask_iou_scores,
|
| 297 |
+
best_multimask_iou_scores,
|
| 298 |
+
)
|
| 299 |
+
return mask_logits_out, iou_scores_out
|
RynnEC/third_parts/sam2/modeling/sam/prompt_encoder.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Optional, Tuple, Type
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
from third_parts.sam2.modeling.position_encoding import PositionEmbeddingRandom
|
| 13 |
+
|
| 14 |
+
from third_parts.sam2.modeling.sam2_utils import LayerNorm2d
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class PromptEncoder(nn.Module):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
embed_dim: int,
|
| 21 |
+
image_embedding_size: Tuple[int, int],
|
| 22 |
+
input_image_size: Tuple[int, int],
|
| 23 |
+
mask_in_chans: int,
|
| 24 |
+
activation: Type[nn.Module] = nn.GELU,
|
| 25 |
+
) -> None:
|
| 26 |
+
"""
|
| 27 |
+
Encodes prompts for input to SAM's mask decoder.
|
| 28 |
+
|
| 29 |
+
Arguments:
|
| 30 |
+
embed_dim (int): The prompts' embedding dimension
|
| 31 |
+
image_embedding_size (tuple(int, int)): The spatial size of the
|
| 32 |
+
image embedding, as (H, W).
|
| 33 |
+
input_image_size (int): The padded size of the image as input
|
| 34 |
+
to the image encoder, as (H, W).
|
| 35 |
+
mask_in_chans (int): The number of hidden channels used for
|
| 36 |
+
encoding input masks.
|
| 37 |
+
activation (nn.Module): The activation to use when encoding
|
| 38 |
+
input masks.
|
| 39 |
+
"""
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.embed_dim = embed_dim
|
| 42 |
+
self.input_image_size = input_image_size
|
| 43 |
+
self.image_embedding_size = image_embedding_size
|
| 44 |
+
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
| 45 |
+
|
| 46 |
+
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
|
| 47 |
+
point_embeddings = [
|
| 48 |
+
nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
|
| 49 |
+
]
|
| 50 |
+
self.point_embeddings = nn.ModuleList(point_embeddings)
|
| 51 |
+
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
| 52 |
+
|
| 53 |
+
self.mask_input_size = (
|
| 54 |
+
4 * image_embedding_size[0],
|
| 55 |
+
4 * image_embedding_size[1],
|
| 56 |
+
)
|
| 57 |
+
self.mask_downscaling = nn.Sequential(
|
| 58 |
+
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
|
| 59 |
+
LayerNorm2d(mask_in_chans // 4),
|
| 60 |
+
activation(),
|
| 61 |
+
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
|
| 62 |
+
LayerNorm2d(mask_in_chans),
|
| 63 |
+
activation(),
|
| 64 |
+
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
|
| 65 |
+
)
|
| 66 |
+
self.no_mask_embed = nn.Embedding(1, embed_dim)
|
| 67 |
+
|
| 68 |
+
def get_dense_pe(self) -> torch.Tensor:
|
| 69 |
+
"""
|
| 70 |
+
Returns the positional encoding used to encode point prompts,
|
| 71 |
+
applied to a dense set of points the shape of the image encoding.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
torch.Tensor: Positional encoding with shape
|
| 75 |
+
1x(embed_dim)x(embedding_h)x(embedding_w)
|
| 76 |
+
"""
|
| 77 |
+
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
| 78 |
+
|
| 79 |
+
def _embed_points(
|
| 80 |
+
self,
|
| 81 |
+
points: torch.Tensor,
|
| 82 |
+
labels: torch.Tensor,
|
| 83 |
+
pad: bool,
|
| 84 |
+
) -> torch.Tensor:
|
| 85 |
+
"""Embeds point prompts."""
|
| 86 |
+
points = points + 0.5 # Shift to center of pixel
|
| 87 |
+
if pad:
|
| 88 |
+
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
|
| 89 |
+
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
|
| 90 |
+
points = torch.cat([points, padding_point], dim=1)
|
| 91 |
+
labels = torch.cat([labels, padding_label], dim=1)
|
| 92 |
+
point_embedding = self.pe_layer.forward_with_coords(
|
| 93 |
+
points, self.input_image_size
|
| 94 |
+
)
|
| 95 |
+
point_embedding[labels == -1] = 0.0
|
| 96 |
+
point_embedding[labels == -1] += self.not_a_point_embed.weight
|
| 97 |
+
point_embedding[labels == 0] += self.point_embeddings[0].weight
|
| 98 |
+
point_embedding[labels == 1] += self.point_embeddings[1].weight
|
| 99 |
+
point_embedding[labels == 2] += self.point_embeddings[2].weight
|
| 100 |
+
point_embedding[labels == 3] += self.point_embeddings[3].weight
|
| 101 |
+
return point_embedding
|
| 102 |
+
|
| 103 |
+
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
| 104 |
+
"""Embeds box prompts."""
|
| 105 |
+
boxes = boxes + 0.5 # Shift to center of pixel
|
| 106 |
+
coords = boxes.reshape(-1, 2, 2)
|
| 107 |
+
corner_embedding = self.pe_layer.forward_with_coords(
|
| 108 |
+
coords, self.input_image_size
|
| 109 |
+
)
|
| 110 |
+
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
|
| 111 |
+
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
|
| 112 |
+
return corner_embedding
|
| 113 |
+
|
| 114 |
+
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
|
| 115 |
+
"""Embeds mask inputs."""
|
| 116 |
+
mask_embedding = self.mask_downscaling(masks)
|
| 117 |
+
return mask_embedding
|
| 118 |
+
|
| 119 |
+
def _get_batch_size(
|
| 120 |
+
self,
|
| 121 |
+
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
| 122 |
+
boxes: Optional[torch.Tensor],
|
| 123 |
+
masks: Optional[torch.Tensor],
|
| 124 |
+
) -> int:
|
| 125 |
+
"""
|
| 126 |
+
Gets the batch size of the output given the batch size of the input prompts.
|
| 127 |
+
"""
|
| 128 |
+
if points is not None:
|
| 129 |
+
return points[0].shape[0]
|
| 130 |
+
elif boxes is not None:
|
| 131 |
+
return boxes.shape[0]
|
| 132 |
+
elif masks is not None:
|
| 133 |
+
return masks.shape[0]
|
| 134 |
+
else:
|
| 135 |
+
return 1
|
| 136 |
+
|
| 137 |
+
def _get_device(self) -> torch.device:
|
| 138 |
+
return self.point_embeddings[0].weight.device
|
| 139 |
+
|
| 140 |
+
def forward(
|
| 141 |
+
self,
|
| 142 |
+
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
| 143 |
+
boxes: Optional[torch.Tensor],
|
| 144 |
+
masks: Optional[torch.Tensor],
|
| 145 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 146 |
+
"""
|
| 147 |
+
Embeds different types of prompts, returning both sparse and dense
|
| 148 |
+
embeddings.
|
| 149 |
+
|
| 150 |
+
Arguments:
|
| 151 |
+
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
|
| 152 |
+
and labels to embed.
|
| 153 |
+
boxes (torch.Tensor or none): boxes to embed
|
| 154 |
+
masks (torch.Tensor or none): masks to embed
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
torch.Tensor: sparse embeddings for the points and boxes, with shape
|
| 158 |
+
BxNx(embed_dim), where N is determined by the number of input points
|
| 159 |
+
and boxes.
|
| 160 |
+
torch.Tensor: dense embeddings for the masks, in the shape
|
| 161 |
+
Bx(embed_dim)x(embed_H)x(embed_W)
|
| 162 |
+
"""
|
| 163 |
+
bs = self._get_batch_size(points, boxes, masks)
|
| 164 |
+
sparse_embeddings = torch.empty(
|
| 165 |
+
(bs, 0, self.embed_dim), device=self._get_device()
|
| 166 |
+
)
|
| 167 |
+
if points is not None:
|
| 168 |
+
coords, labels = points
|
| 169 |
+
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
|
| 170 |
+
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
|
| 171 |
+
if boxes is not None:
|
| 172 |
+
box_embeddings = self._embed_boxes(boxes)
|
| 173 |
+
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
|
| 174 |
+
|
| 175 |
+
if masks is not None:
|
| 176 |
+
dense_embeddings = self._embed_masks(masks)
|
| 177 |
+
else:
|
| 178 |
+
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
|
| 179 |
+
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
return sparse_embeddings, dense_embeddings
|
RynnEC/third_parts/sam2/modeling/sam/transformer.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import warnings
|
| 9 |
+
from functools import partial
|
| 10 |
+
from typing import Tuple, Type
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from torch import nn, Tensor
|
| 15 |
+
|
| 16 |
+
from third_parts.sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
|
| 17 |
+
|
| 18 |
+
from third_parts.sam2.modeling.sam2_utils import MLP
|
| 19 |
+
from third_parts.sam2.utils.misc import get_sdpa_settings
|
| 20 |
+
|
| 21 |
+
warnings.simplefilter(action="ignore", category=FutureWarning)
|
| 22 |
+
# OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
|
| 23 |
+
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = True, True, True
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TwoWayTransformer(nn.Module):
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
depth: int,
|
| 30 |
+
embedding_dim: int,
|
| 31 |
+
num_heads: int,
|
| 32 |
+
mlp_dim: int,
|
| 33 |
+
activation: Type[nn.Module] = nn.ReLU,
|
| 34 |
+
attention_downsample_rate: int = 2,
|
| 35 |
+
) -> None:
|
| 36 |
+
"""
|
| 37 |
+
A transformer decoder that attends to an input image using
|
| 38 |
+
queries whose positional embedding is supplied.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
depth (int): number of layers in the transformer
|
| 42 |
+
embedding_dim (int): the channel dimension for the input embeddings
|
| 43 |
+
num_heads (int): the number of heads for multihead attention. Must
|
| 44 |
+
divide embedding_dim
|
| 45 |
+
mlp_dim (int): the channel dimension internal to the MLP block
|
| 46 |
+
activation (nn.Module): the activation to use in the MLP block
|
| 47 |
+
"""
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.depth = depth
|
| 50 |
+
self.embedding_dim = embedding_dim
|
| 51 |
+
self.num_heads = num_heads
|
| 52 |
+
self.mlp_dim = mlp_dim
|
| 53 |
+
self.layers = nn.ModuleList()
|
| 54 |
+
|
| 55 |
+
for i in range(depth):
|
| 56 |
+
self.layers.append(
|
| 57 |
+
TwoWayAttentionBlock(
|
| 58 |
+
embedding_dim=embedding_dim,
|
| 59 |
+
num_heads=num_heads,
|
| 60 |
+
mlp_dim=mlp_dim,
|
| 61 |
+
activation=activation,
|
| 62 |
+
attention_downsample_rate=attention_downsample_rate,
|
| 63 |
+
skip_first_layer_pe=(i == 0),
|
| 64 |
+
)
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
self.final_attn_token_to_image = Attention(
|
| 68 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
| 69 |
+
)
|
| 70 |
+
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
| 71 |
+
|
| 72 |
+
def forward(
|
| 73 |
+
self,
|
| 74 |
+
image_embedding: Tensor,
|
| 75 |
+
image_pe: Tensor,
|
| 76 |
+
point_embedding: Tensor,
|
| 77 |
+
) -> Tuple[Tensor, Tensor]:
|
| 78 |
+
"""
|
| 79 |
+
Args:
|
| 80 |
+
image_embedding (torch.Tensor): image to attend to. Should be shape
|
| 81 |
+
B x embedding_dim x h x w for any h and w.
|
| 82 |
+
image_pe (torch.Tensor): the positional encoding to add to the image. Must
|
| 83 |
+
have the same shape as image_embedding.
|
| 84 |
+
point_embedding (torch.Tensor): the embedding to add to the query points.
|
| 85 |
+
Must have shape B x N_points x embedding_dim for any N_points.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
torch.Tensor: the processed point_embedding
|
| 89 |
+
torch.Tensor: the processed image_embedding
|
| 90 |
+
"""
|
| 91 |
+
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
| 92 |
+
bs, c, h, w = image_embedding.shape
|
| 93 |
+
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
| 94 |
+
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
| 95 |
+
|
| 96 |
+
# Prepare queries
|
| 97 |
+
queries = point_embedding
|
| 98 |
+
keys = image_embedding
|
| 99 |
+
|
| 100 |
+
# Apply transformer blocks and final layernorm
|
| 101 |
+
for layer in self.layers:
|
| 102 |
+
queries, keys = layer(
|
| 103 |
+
queries=queries,
|
| 104 |
+
keys=keys,
|
| 105 |
+
query_pe=point_embedding,
|
| 106 |
+
key_pe=image_pe,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# Apply the final attention layer from the points to the image
|
| 110 |
+
q = queries + point_embedding
|
| 111 |
+
k = keys + image_pe
|
| 112 |
+
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
|
| 113 |
+
queries = queries + attn_out
|
| 114 |
+
queries = self.norm_final_attn(queries)
|
| 115 |
+
|
| 116 |
+
return queries, keys
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class TwoWayAttentionBlock(nn.Module):
|
| 120 |
+
def __init__(
|
| 121 |
+
self,
|
| 122 |
+
embedding_dim: int,
|
| 123 |
+
num_heads: int,
|
| 124 |
+
mlp_dim: int = 2048,
|
| 125 |
+
activation: Type[nn.Module] = nn.ReLU,
|
| 126 |
+
attention_downsample_rate: int = 2,
|
| 127 |
+
skip_first_layer_pe: bool = False,
|
| 128 |
+
) -> None:
|
| 129 |
+
"""
|
| 130 |
+
A transformer block with four layers: (1) self-attention of sparse
|
| 131 |
+
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
|
| 132 |
+
block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
| 133 |
+
inputs.
|
| 134 |
+
|
| 135 |
+
Arguments:
|
| 136 |
+
embedding_dim (int): the channel dimension of the embeddings
|
| 137 |
+
num_heads (int): the number of heads in the attention layers
|
| 138 |
+
mlp_dim (int): the hidden dimension of the mlp block
|
| 139 |
+
activation (nn.Module): the activation of the mlp block
|
| 140 |
+
skip_first_layer_pe (bool): skip the PE on the first layer
|
| 141 |
+
"""
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.self_attn = Attention(embedding_dim, num_heads)
|
| 144 |
+
self.norm1 = nn.LayerNorm(embedding_dim)
|
| 145 |
+
|
| 146 |
+
self.cross_attn_token_to_image = Attention(
|
| 147 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
| 148 |
+
)
|
| 149 |
+
self.norm2 = nn.LayerNorm(embedding_dim)
|
| 150 |
+
|
| 151 |
+
self.mlp = MLP(
|
| 152 |
+
embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
|
| 153 |
+
)
|
| 154 |
+
self.norm3 = nn.LayerNorm(embedding_dim)
|
| 155 |
+
|
| 156 |
+
self.norm4 = nn.LayerNorm(embedding_dim)
|
| 157 |
+
self.cross_attn_image_to_token = Attention(
|
| 158 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
self.skip_first_layer_pe = skip_first_layer_pe
|
| 162 |
+
|
| 163 |
+
def forward(
|
| 164 |
+
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
|
| 165 |
+
) -> Tuple[Tensor, Tensor]:
|
| 166 |
+
# Self attention block
|
| 167 |
+
if self.skip_first_layer_pe:
|
| 168 |
+
queries = self.self_attn(q=queries, k=queries, v=queries)
|
| 169 |
+
else:
|
| 170 |
+
q = queries + query_pe
|
| 171 |
+
attn_out = self.self_attn(q=q, k=q, v=queries)
|
| 172 |
+
queries = queries + attn_out
|
| 173 |
+
queries = self.norm1(queries)
|
| 174 |
+
|
| 175 |
+
# Cross attention block, tokens attending to image embedding
|
| 176 |
+
q = queries + query_pe
|
| 177 |
+
k = keys + key_pe
|
| 178 |
+
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
| 179 |
+
queries = queries + attn_out
|
| 180 |
+
queries = self.norm2(queries)
|
| 181 |
+
|
| 182 |
+
# MLP block
|
| 183 |
+
mlp_out = self.mlp(queries)
|
| 184 |
+
queries = queries + mlp_out
|
| 185 |
+
queries = self.norm3(queries)
|
| 186 |
+
|
| 187 |
+
# Cross attention block, image embedding attending to tokens
|
| 188 |
+
q = queries + query_pe
|
| 189 |
+
k = keys + key_pe
|
| 190 |
+
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
| 191 |
+
keys = keys + attn_out
|
| 192 |
+
keys = self.norm4(keys)
|
| 193 |
+
|
| 194 |
+
return queries, keys
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class Attention(nn.Module):
|
| 198 |
+
"""
|
| 199 |
+
An attention layer that allows for downscaling the size of the embedding
|
| 200 |
+
after projection to queries, keys, and values.
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
def __init__(
|
| 204 |
+
self,
|
| 205 |
+
embedding_dim: int,
|
| 206 |
+
num_heads: int,
|
| 207 |
+
downsample_rate: int = 1,
|
| 208 |
+
dropout: float = 0.0,
|
| 209 |
+
kv_in_dim: int = None,
|
| 210 |
+
) -> None:
|
| 211 |
+
super().__init__()
|
| 212 |
+
self.embedding_dim = embedding_dim
|
| 213 |
+
self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
|
| 214 |
+
self.internal_dim = embedding_dim // downsample_rate
|
| 215 |
+
self.num_heads = num_heads
|
| 216 |
+
assert (
|
| 217 |
+
self.internal_dim % num_heads == 0
|
| 218 |
+
), "num_heads must divide embedding_dim."
|
| 219 |
+
|
| 220 |
+
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
| 221 |
+
self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
| 222 |
+
self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
| 223 |
+
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
| 224 |
+
|
| 225 |
+
self.dropout_p = dropout
|
| 226 |
+
|
| 227 |
+
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
| 228 |
+
b, n, c = x.shape
|
| 229 |
+
x = x.reshape(b, n, num_heads, c // num_heads)
|
| 230 |
+
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
| 231 |
+
|
| 232 |
+
def _recombine_heads(self, x: Tensor) -> Tensor:
|
| 233 |
+
b, n_heads, n_tokens, c_per_head = x.shape
|
| 234 |
+
x = x.transpose(1, 2)
|
| 235 |
+
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
| 236 |
+
|
| 237 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
| 238 |
+
# Input projections
|
| 239 |
+
q = self.q_proj(q)
|
| 240 |
+
k = self.k_proj(k)
|
| 241 |
+
v = self.v_proj(v)
|
| 242 |
+
|
| 243 |
+
# Separate into heads
|
| 244 |
+
q = self._separate_heads(q, self.num_heads)
|
| 245 |
+
k = self._separate_heads(k, self.num_heads)
|
| 246 |
+
v = self._separate_heads(v, self.num_heads)
|
| 247 |
+
|
| 248 |
+
dropout_p = self.dropout_p if self.training else 0.0
|
| 249 |
+
# Attention
|
| 250 |
+
with torch.backends.cuda.sdp_kernel(
|
| 251 |
+
enable_flash=USE_FLASH_ATTN,
|
| 252 |
+
# if Flash attention kernel is off, then math kernel needs to be enabled
|
| 253 |
+
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
| 254 |
+
enable_mem_efficient=OLD_GPU,
|
| 255 |
+
):
|
| 256 |
+
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
| 257 |
+
|
| 258 |
+
out = self._recombine_heads(out)
|
| 259 |
+
out = self.out_proj(out)
|
| 260 |
+
|
| 261 |
+
return out
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class RoPEAttention(Attention):
|
| 265 |
+
"""Attention with rotary position encoding."""
|
| 266 |
+
|
| 267 |
+
def __init__(
|
| 268 |
+
self,
|
| 269 |
+
*args,
|
| 270 |
+
rope_theta=10000.0,
|
| 271 |
+
# whether to repeat q rope to match k length
|
| 272 |
+
# this is needed for cross-attention to memories
|
| 273 |
+
rope_k_repeat=False,
|
| 274 |
+
feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
|
| 275 |
+
**kwargs,
|
| 276 |
+
):
|
| 277 |
+
super().__init__(*args, **kwargs)
|
| 278 |
+
|
| 279 |
+
self.compute_cis = partial(
|
| 280 |
+
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
|
| 281 |
+
)
|
| 282 |
+
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
|
| 283 |
+
self.freqs_cis = freqs_cis
|
| 284 |
+
self.rope_k_repeat = rope_k_repeat
|
| 285 |
+
|
| 286 |
+
def forward(
|
| 287 |
+
self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
|
| 288 |
+
) -> Tensor:
|
| 289 |
+
# Input projections
|
| 290 |
+
q = self.q_proj(q)
|
| 291 |
+
k = self.k_proj(k)
|
| 292 |
+
v = self.v_proj(v)
|
| 293 |
+
|
| 294 |
+
# Separate into heads
|
| 295 |
+
q = self._separate_heads(q, self.num_heads)
|
| 296 |
+
k = self._separate_heads(k, self.num_heads)
|
| 297 |
+
v = self._separate_heads(v, self.num_heads)
|
| 298 |
+
|
| 299 |
+
# Apply rotary position encoding
|
| 300 |
+
w = h = math.sqrt(q.shape[-2])
|
| 301 |
+
self.freqs_cis = self.freqs_cis.to(q.device)
|
| 302 |
+
if self.freqs_cis.shape[0] != q.shape[-2]:
|
| 303 |
+
self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
|
| 304 |
+
if q.shape[-2] != k.shape[-2]:
|
| 305 |
+
assert self.rope_k_repeat
|
| 306 |
+
|
| 307 |
+
num_k_rope = k.size(-2) - num_k_exclude_rope
|
| 308 |
+
q, k[:, :, :num_k_rope] = apply_rotary_enc(
|
| 309 |
+
q,
|
| 310 |
+
k[:, :, :num_k_rope],
|
| 311 |
+
freqs_cis=self.freqs_cis,
|
| 312 |
+
repeat_freqs_k=self.rope_k_repeat,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
dropout_p = self.dropout_p if self.training else 0.0
|
| 316 |
+
# Attention
|
| 317 |
+
with torch.backends.cuda.sdp_kernel(
|
| 318 |
+
enable_flash=USE_FLASH_ATTN,
|
| 319 |
+
# if Flash attention kernel is off, then math kernel needs to be enabled
|
| 320 |
+
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
| 321 |
+
enable_mem_efficient=OLD_GPU,
|
| 322 |
+
):
|
| 323 |
+
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
| 324 |
+
|
| 325 |
+
out = self._recombine_heads(out)
|
| 326 |
+
out = self.out_proj(out)
|
| 327 |
+
|
| 328 |
+
return out
|
RynnEC/third_parts/sam2/modeling/sam2_base.py
ADDED
|
@@ -0,0 +1,830 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from torch.nn.init import trunc_normal_
|
| 12 |
+
|
| 13 |
+
from third_parts.sam2.modeling.sam.mask_decoder import MaskDecoder
|
| 14 |
+
from third_parts.sam2.modeling.sam.prompt_encoder import PromptEncoder
|
| 15 |
+
from third_parts.sam2.modeling.sam.transformer import TwoWayTransformer
|
| 16 |
+
from third_parts.sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames
|
| 17 |
+
|
| 18 |
+
# a large negative value as a placeholder score for missing objects
|
| 19 |
+
NO_OBJ_SCORE = -1024.0
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class SAM2Base(torch.nn.Module):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
image_encoder,
|
| 26 |
+
memory_attention,
|
| 27 |
+
memory_encoder,
|
| 28 |
+
num_maskmem=7, # default 1 input frame + 6 previous frames
|
| 29 |
+
image_size=512,
|
| 30 |
+
backbone_stride=16, # stride of the image backbone output
|
| 31 |
+
sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob
|
| 32 |
+
sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob
|
| 33 |
+
# During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks
|
| 34 |
+
binarize_mask_from_pts_for_mem_enc=False,
|
| 35 |
+
use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder
|
| 36 |
+
# The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit,
|
| 37 |
+
# we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model
|
| 38 |
+
# a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM.
|
| 39 |
+
max_cond_frames_in_attn=-1,
|
| 40 |
+
# on the first frame, whether to directly add the no-memory embedding to the image feature
|
| 41 |
+
# (instead of using the transformer encoder)
|
| 42 |
+
directly_add_no_mem_embed=False,
|
| 43 |
+
# whether to use high-resolution feature maps in the SAM mask decoder
|
| 44 |
+
use_high_res_features_in_sam=False,
|
| 45 |
+
# whether to output multiple (3) masks for the first click on initial conditioning frames
|
| 46 |
+
multimask_output_in_sam=False,
|
| 47 |
+
# the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`;
|
| 48 |
+
# default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points)
|
| 49 |
+
multimask_min_pt_num=1,
|
| 50 |
+
multimask_max_pt_num=1,
|
| 51 |
+
# whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`)
|
| 52 |
+
multimask_output_for_tracking=False,
|
| 53 |
+
# Whether to use multimask tokens for obj ptr; Only relevant when both
|
| 54 |
+
# use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True
|
| 55 |
+
use_multimask_token_for_obj_ptr: bool = False,
|
| 56 |
+
# whether to use sigmoid to restrict ious prediction to [0-1]
|
| 57 |
+
iou_prediction_use_sigmoid=False,
|
| 58 |
+
# The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5).
|
| 59 |
+
# For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of
|
| 60 |
+
# (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame.
|
| 61 |
+
memory_temporal_stride_for_eval=1,
|
| 62 |
+
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
|
| 63 |
+
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
|
| 64 |
+
add_all_frames_to_correct_as_cond=False,
|
| 65 |
+
# whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks)
|
| 66 |
+
non_overlap_masks_for_mem_enc=False,
|
| 67 |
+
# whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 68 |
+
use_obj_ptrs_in_encoder=False,
|
| 69 |
+
# the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`)
|
| 70 |
+
max_obj_ptrs_in_encoder=16,
|
| 71 |
+
# whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`)
|
| 72 |
+
add_tpos_enc_to_obj_ptrs=True,
|
| 73 |
+
# whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference
|
| 74 |
+
# with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
|
| 75 |
+
proj_tpos_enc_in_obj_ptrs=False,
|
| 76 |
+
# whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation
|
| 77 |
+
# (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking)
|
| 78 |
+
only_obj_ptrs_in_the_past_for_eval=False,
|
| 79 |
+
# Whether to predict if there is an object in the frame
|
| 80 |
+
pred_obj_scores: bool = False,
|
| 81 |
+
# Whether to use an MLP to predict object scores
|
| 82 |
+
pred_obj_scores_mlp: bool = False,
|
| 83 |
+
# Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True;
|
| 84 |
+
# Whether to have a fixed no obj pointer when there is no object present
|
| 85 |
+
# or to use it as an additive embedding with obj_ptr produced by decoder
|
| 86 |
+
fixed_no_obj_ptr: bool = False,
|
| 87 |
+
# Soft no object, i.e. mix in no_obj_ptr softly,
|
| 88 |
+
# hope to make recovery easier if there is a mistake and mitigate accumulation of errors
|
| 89 |
+
soft_no_obj_ptr: bool = False,
|
| 90 |
+
use_mlp_for_obj_ptr_proj: bool = False,
|
| 91 |
+
# extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
|
| 92 |
+
sam_mask_decoder_extra_args=None,
|
| 93 |
+
compile_image_encoder: bool = False,
|
| 94 |
+
):
|
| 95 |
+
super().__init__()
|
| 96 |
+
|
| 97 |
+
# Part 1: the image backbone
|
| 98 |
+
self.image_encoder = image_encoder
|
| 99 |
+
# Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
|
| 100 |
+
self.use_high_res_features_in_sam = use_high_res_features_in_sam
|
| 101 |
+
self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
|
| 102 |
+
self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
|
| 103 |
+
self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
|
| 104 |
+
if use_obj_ptrs_in_encoder:
|
| 105 |
+
# A conv layer to downsample the mask prompt to stride 4 (the same stride as
|
| 106 |
+
# low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
|
| 107 |
+
# so that it can be fed into the SAM mask decoder to generate a pointer.
|
| 108 |
+
self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
|
| 109 |
+
self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
|
| 110 |
+
if proj_tpos_enc_in_obj_ptrs:
|
| 111 |
+
assert add_tpos_enc_to_obj_ptrs # these options need to be used together
|
| 112 |
+
self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
|
| 113 |
+
self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
|
| 114 |
+
|
| 115 |
+
# Part 2: memory attention to condition current frame's visual features
|
| 116 |
+
# with memories (and obj ptrs) from past frames
|
| 117 |
+
self.memory_attention = memory_attention
|
| 118 |
+
self.hidden_dim = memory_attention.d_model
|
| 119 |
+
|
| 120 |
+
# Part 3: memory encoder for the previous frame's outputs
|
| 121 |
+
self.memory_encoder = memory_encoder
|
| 122 |
+
self.mem_dim = self.hidden_dim
|
| 123 |
+
if hasattr(self.memory_encoder, "out_proj") and hasattr(
|
| 124 |
+
self.memory_encoder.out_proj, "weight"
|
| 125 |
+
):
|
| 126 |
+
# if there is compression of memories along channel dim
|
| 127 |
+
self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
|
| 128 |
+
self.num_maskmem = num_maskmem # Number of memories accessible
|
| 129 |
+
# Temporal encoding of the memories
|
| 130 |
+
self.maskmem_tpos_enc = torch.nn.Parameter(
|
| 131 |
+
torch.zeros(num_maskmem, 1, 1, self.mem_dim)
|
| 132 |
+
)
|
| 133 |
+
trunc_normal_(self.maskmem_tpos_enc, std=0.02)
|
| 134 |
+
# a single token to indicate no memory embedding from previous frames
|
| 135 |
+
self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
|
| 136 |
+
self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
|
| 137 |
+
trunc_normal_(self.no_mem_embed, std=0.02)
|
| 138 |
+
trunc_normal_(self.no_mem_pos_enc, std=0.02)
|
| 139 |
+
self.directly_add_no_mem_embed = directly_add_no_mem_embed
|
| 140 |
+
# Apply sigmoid to the output raw mask logits (to turn them from
|
| 141 |
+
# range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
|
| 142 |
+
self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
|
| 143 |
+
self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
|
| 144 |
+
self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
|
| 145 |
+
self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
|
| 146 |
+
self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
|
| 147 |
+
# On frames with mask input, whether to directly output the input mask without
|
| 148 |
+
# using a SAM prompt encoder + mask decoder
|
| 149 |
+
self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
|
| 150 |
+
self.multimask_output_in_sam = multimask_output_in_sam
|
| 151 |
+
self.multimask_min_pt_num = multimask_min_pt_num
|
| 152 |
+
self.multimask_max_pt_num = multimask_max_pt_num
|
| 153 |
+
self.multimask_output_for_tracking = multimask_output_for_tracking
|
| 154 |
+
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
|
| 155 |
+
self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
|
| 156 |
+
|
| 157 |
+
# Part 4: SAM-style prompt encoder (for both mask and point inputs)
|
| 158 |
+
# and SAM-style mask decoder for the final mask output
|
| 159 |
+
self.image_size = image_size
|
| 160 |
+
self.backbone_stride = backbone_stride
|
| 161 |
+
self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
|
| 162 |
+
self.pred_obj_scores = pred_obj_scores
|
| 163 |
+
self.pred_obj_scores_mlp = pred_obj_scores_mlp
|
| 164 |
+
self.fixed_no_obj_ptr = fixed_no_obj_ptr
|
| 165 |
+
self.soft_no_obj_ptr = soft_no_obj_ptr
|
| 166 |
+
if self.fixed_no_obj_ptr:
|
| 167 |
+
assert self.pred_obj_scores
|
| 168 |
+
assert self.use_obj_ptrs_in_encoder
|
| 169 |
+
if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
|
| 170 |
+
self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
|
| 171 |
+
trunc_normal_(self.no_obj_ptr, std=0.02)
|
| 172 |
+
self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
|
| 173 |
+
|
| 174 |
+
self._build_sam_heads()
|
| 175 |
+
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
|
| 176 |
+
self.max_cond_frames_in_attn = max_cond_frames_in_attn
|
| 177 |
+
|
| 178 |
+
# Model compilation
|
| 179 |
+
if compile_image_encoder:
|
| 180 |
+
# Compile the forward function (not the full module) to allow loading checkpoints.
|
| 181 |
+
print(
|
| 182 |
+
"Image encoder compilation is enabled. First forward pass will be slow."
|
| 183 |
+
)
|
| 184 |
+
self.image_encoder.forward = torch.compile(
|
| 185 |
+
self.image_encoder.forward,
|
| 186 |
+
mode="max-autotune",
|
| 187 |
+
fullgraph=True,
|
| 188 |
+
dynamic=False,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
@property
|
| 192 |
+
def device(self):
|
| 193 |
+
return next(self.parameters()).device
|
| 194 |
+
|
| 195 |
+
def forward(self, *args, **kwargs):
|
| 196 |
+
raise NotImplementedError(
|
| 197 |
+
"Please use the corresponding methods in SAM2VideoPredictor for inference."
|
| 198 |
+
"See notebooks/video_predictor_example.ipynb for an example."
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
def _build_sam_heads(self):
|
| 202 |
+
"""Build SAM-style prompt encoder and mask decoder."""
|
| 203 |
+
self.sam_prompt_embed_dim = self.hidden_dim
|
| 204 |
+
self.sam_image_embedding_size = self.image_size // self.backbone_stride
|
| 205 |
+
|
| 206 |
+
# build PromptEncoder and MaskDecoder from SAM
|
| 207 |
+
# (their hyperparameters like `mask_in_chans=16` are from SAM code)
|
| 208 |
+
self.sam_prompt_encoder = PromptEncoder(
|
| 209 |
+
embed_dim=self.sam_prompt_embed_dim,
|
| 210 |
+
image_embedding_size=(
|
| 211 |
+
self.sam_image_embedding_size,
|
| 212 |
+
self.sam_image_embedding_size,
|
| 213 |
+
),
|
| 214 |
+
input_image_size=(self.image_size, self.image_size),
|
| 215 |
+
mask_in_chans=16,
|
| 216 |
+
)
|
| 217 |
+
self.sam_mask_decoder = MaskDecoder(
|
| 218 |
+
num_multimask_outputs=3,
|
| 219 |
+
transformer=TwoWayTransformer(
|
| 220 |
+
depth=2,
|
| 221 |
+
embedding_dim=self.sam_prompt_embed_dim,
|
| 222 |
+
mlp_dim=2048,
|
| 223 |
+
num_heads=8,
|
| 224 |
+
),
|
| 225 |
+
transformer_dim=self.sam_prompt_embed_dim,
|
| 226 |
+
iou_head_depth=3,
|
| 227 |
+
iou_head_hidden_dim=256,
|
| 228 |
+
use_high_res_features=self.use_high_res_features_in_sam,
|
| 229 |
+
iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
|
| 230 |
+
pred_obj_scores=self.pred_obj_scores,
|
| 231 |
+
pred_obj_scores_mlp=self.pred_obj_scores_mlp,
|
| 232 |
+
use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
|
| 233 |
+
**(self.sam_mask_decoder_extra_args or {}),
|
| 234 |
+
)
|
| 235 |
+
if self.use_obj_ptrs_in_encoder:
|
| 236 |
+
# a linear projection on SAM output tokens to turn them into object pointers
|
| 237 |
+
self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
|
| 238 |
+
if self.use_mlp_for_obj_ptr_proj:
|
| 239 |
+
self.obj_ptr_proj = MLP(
|
| 240 |
+
self.hidden_dim, self.hidden_dim, self.hidden_dim, 3
|
| 241 |
+
)
|
| 242 |
+
else:
|
| 243 |
+
self.obj_ptr_proj = torch.nn.Identity()
|
| 244 |
+
if self.proj_tpos_enc_in_obj_ptrs:
|
| 245 |
+
# a linear projection on temporal positional encoding in object pointers to
|
| 246 |
+
# avoid potential interference with spatial positional encoding
|
| 247 |
+
self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
|
| 248 |
+
else:
|
| 249 |
+
self.obj_ptr_tpos_proj = torch.nn.Identity()
|
| 250 |
+
|
| 251 |
+
def _forward_sam_heads(
|
| 252 |
+
self,
|
| 253 |
+
backbone_features,
|
| 254 |
+
point_inputs=None,
|
| 255 |
+
mask_inputs=None,
|
| 256 |
+
high_res_features=None,
|
| 257 |
+
multimask_output=False,
|
| 258 |
+
):
|
| 259 |
+
"""
|
| 260 |
+
Forward SAM prompt encoders and mask heads.
|
| 261 |
+
|
| 262 |
+
Inputs:
|
| 263 |
+
- backbone_features: image features of [B, C, H, W] shape
|
| 264 |
+
- point_inputs: a dictionary with "point_coords" and "point_labels", where
|
| 265 |
+
1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the
|
| 266 |
+
absolute pixel-unit coordinate in (x, y) format of the P input points
|
| 267 |
+
2) "point_labels" has shape [B, P] and int32 dtype, where 1 means
|
| 268 |
+
positive clicks, 0 means negative clicks, and -1 means padding
|
| 269 |
+
- mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the
|
| 270 |
+
same spatial size as the image.
|
| 271 |
+
- high_res_features: either 1) None or 2) or a list of length 2 containing
|
| 272 |
+
two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively,
|
| 273 |
+
which will be used as high-resolution feature maps for SAM decoder.
|
| 274 |
+
- multimask_output: if it's True, we output 3 candidate masks and their 3
|
| 275 |
+
corresponding IoU estimates, and if it's False, we output only 1 mask and
|
| 276 |
+
its corresponding IoU estimate.
|
| 277 |
+
|
| 278 |
+
Outputs:
|
| 279 |
+
- low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if
|
| 280 |
+
`multimask_output=True` and M = 1 if `multimask_output=False`), the SAM
|
| 281 |
+
output mask logits (before sigmoid) for the low-resolution masks, with 4x
|
| 282 |
+
the resolution (1/4 stride) of the input backbone_features.
|
| 283 |
+
- high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3
|
| 284 |
+
if `multimask_output=True` and M = 1 if `multimask_output=False`),
|
| 285 |
+
upsampled from the low-resolution masks, with shape size as the image
|
| 286 |
+
(stride is 1 pixel).
|
| 287 |
+
- ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1
|
| 288 |
+
if `multimask_output=False`), the estimated IoU of each output mask.
|
| 289 |
+
- low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`.
|
| 290 |
+
If `multimask_output=True`, it's the mask with the highest IoU estimate.
|
| 291 |
+
If `multimask_output=False`, it's the same as `low_res_multimasks`.
|
| 292 |
+
- high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`.
|
| 293 |
+
If `multimask_output=True`, it's the mask with the highest IoU estimate.
|
| 294 |
+
If `multimask_output=False`, it's the same as `high_res_multimasks`.
|
| 295 |
+
- obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted
|
| 296 |
+
based on the output token from the SAM mask decoder.
|
| 297 |
+
"""
|
| 298 |
+
B = backbone_features.size(0)
|
| 299 |
+
device = backbone_features.device
|
| 300 |
+
assert backbone_features.size(1) == self.sam_prompt_embed_dim
|
| 301 |
+
assert backbone_features.size(2) == self.sam_image_embedding_size
|
| 302 |
+
assert backbone_features.size(3) == self.sam_image_embedding_size
|
| 303 |
+
|
| 304 |
+
# a) Handle point prompts
|
| 305 |
+
if point_inputs is not None:
|
| 306 |
+
sam_point_coords = point_inputs["point_coords"]
|
| 307 |
+
sam_point_labels = point_inputs["point_labels"]
|
| 308 |
+
assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
|
| 309 |
+
else:
|
| 310 |
+
# If no points are provide, pad with an empty point (with label -1)
|
| 311 |
+
sam_point_coords = torch.zeros(B, 1, 2, device=device)
|
| 312 |
+
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
|
| 313 |
+
|
| 314 |
+
# b) Handle mask prompts
|
| 315 |
+
if mask_inputs is not None:
|
| 316 |
+
# If mask_inputs is provided, downsize it into low-res mask input if needed
|
| 317 |
+
# and feed it as a dense mask prompt into the SAM mask encoder
|
| 318 |
+
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
|
| 319 |
+
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
|
| 320 |
+
sam_mask_prompt = F.interpolate(
|
| 321 |
+
mask_inputs.float(),
|
| 322 |
+
size=self.sam_prompt_encoder.mask_input_size,
|
| 323 |
+
align_corners=False,
|
| 324 |
+
mode="bilinear",
|
| 325 |
+
antialias=True, # use antialias for downsampling
|
| 326 |
+
)
|
| 327 |
+
else:
|
| 328 |
+
sam_mask_prompt = mask_inputs
|
| 329 |
+
else:
|
| 330 |
+
# Otherwise, simply feed None (and SAM's prompt encoder will add
|
| 331 |
+
# a learned `no_mask_embed` to indicate no mask input in this case).
|
| 332 |
+
sam_mask_prompt = None
|
| 333 |
+
|
| 334 |
+
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
|
| 335 |
+
points=(sam_point_coords, sam_point_labels),
|
| 336 |
+
boxes=None,
|
| 337 |
+
masks=sam_mask_prompt,
|
| 338 |
+
)
|
| 339 |
+
(
|
| 340 |
+
low_res_multimasks,
|
| 341 |
+
ious,
|
| 342 |
+
sam_output_tokens,
|
| 343 |
+
object_score_logits,
|
| 344 |
+
) = self.sam_mask_decoder(
|
| 345 |
+
image_embeddings=backbone_features,
|
| 346 |
+
image_pe=self.sam_prompt_encoder.get_dense_pe(),
|
| 347 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
| 348 |
+
dense_prompt_embeddings=dense_embeddings,
|
| 349 |
+
multimask_output=multimask_output,
|
| 350 |
+
repeat_image=False, # the image is already batched
|
| 351 |
+
high_res_features=high_res_features,
|
| 352 |
+
)
|
| 353 |
+
if self.pred_obj_scores:
|
| 354 |
+
is_obj_appearing = object_score_logits > 0
|
| 355 |
+
|
| 356 |
+
# Mask used for spatial memories is always a *hard* choice between obj and no obj,
|
| 357 |
+
# consistent with the actual mask prediction
|
| 358 |
+
low_res_multimasks = torch.where(
|
| 359 |
+
is_obj_appearing[:, None, None],
|
| 360 |
+
low_res_multimasks,
|
| 361 |
+
NO_OBJ_SCORE,
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
# convert masks from possibly bfloat16 (or float16) to float32
|
| 365 |
+
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
|
| 366 |
+
_dtype = low_res_multimasks.dtype
|
| 367 |
+
# low_res_multimasks = low_res_multimasks.float()
|
| 368 |
+
high_res_multimasks = F.interpolate(
|
| 369 |
+
low_res_multimasks.float(),
|
| 370 |
+
size=(self.image_size, self.image_size),
|
| 371 |
+
mode="bilinear",
|
| 372 |
+
align_corners=False,
|
| 373 |
+
).to(_dtype)
|
| 374 |
+
|
| 375 |
+
sam_output_token = sam_output_tokens[:, 0]
|
| 376 |
+
if multimask_output:
|
| 377 |
+
# take the best mask prediction (with the highest IoU estimation)
|
| 378 |
+
best_iou_inds = torch.argmax(ious, dim=-1)
|
| 379 |
+
batch_inds = torch.arange(B, device=device)
|
| 380 |
+
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
| 381 |
+
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
| 382 |
+
if sam_output_tokens.size(1) > 1:
|
| 383 |
+
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
|
| 384 |
+
else:
|
| 385 |
+
low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
|
| 386 |
+
|
| 387 |
+
# Extract object pointer from the SAM output token (with occlusion handling)
|
| 388 |
+
obj_ptr = self.obj_ptr_proj(sam_output_token)
|
| 389 |
+
if self.pred_obj_scores:
|
| 390 |
+
# Allow *soft* no obj ptr, unlike for masks
|
| 391 |
+
if self.soft_no_obj_ptr:
|
| 392 |
+
# Only hard possible with gt
|
| 393 |
+
assert not self.teacher_force_obj_scores_for_mem
|
| 394 |
+
lambda_is_obj_appearing = object_score_logits.sigmoid()
|
| 395 |
+
else:
|
| 396 |
+
lambda_is_obj_appearing = is_obj_appearing.float()
|
| 397 |
+
|
| 398 |
+
if self.fixed_no_obj_ptr:
|
| 399 |
+
obj_ptr = lambda_is_obj_appearing * obj_ptr
|
| 400 |
+
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
|
| 401 |
+
|
| 402 |
+
return (
|
| 403 |
+
low_res_multimasks,
|
| 404 |
+
high_res_multimasks,
|
| 405 |
+
ious,
|
| 406 |
+
low_res_masks,
|
| 407 |
+
high_res_masks,
|
| 408 |
+
obj_ptr,
|
| 409 |
+
object_score_logits,
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
|
| 413 |
+
"""
|
| 414 |
+
Directly turn binary `mask_inputs` into a output mask logits without using SAM.
|
| 415 |
+
(same input and output shapes as in _forward_sam_heads above).
|
| 416 |
+
"""
|
| 417 |
+
# Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
|
| 418 |
+
out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
|
| 419 |
+
mask_inputs_float = mask_inputs.float()
|
| 420 |
+
high_res_masks = mask_inputs_float * out_scale + out_bias
|
| 421 |
+
low_res_masks = F.interpolate(
|
| 422 |
+
high_res_masks,
|
| 423 |
+
size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
|
| 424 |
+
align_corners=False,
|
| 425 |
+
mode="bilinear",
|
| 426 |
+
antialias=True, # use antialias for downsampling
|
| 427 |
+
)
|
| 428 |
+
# a dummy IoU prediction of all 1's under mask input
|
| 429 |
+
ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
|
| 430 |
+
if not self.use_obj_ptrs_in_encoder:
|
| 431 |
+
# all zeros as a dummy object pointer (of shape [B, C])
|
| 432 |
+
obj_ptr = torch.zeros(
|
| 433 |
+
mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device
|
| 434 |
+
)
|
| 435 |
+
else:
|
| 436 |
+
# produce an object pointer using the SAM decoder from the mask input
|
| 437 |
+
_, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
|
| 438 |
+
backbone_features=backbone_features,
|
| 439 |
+
mask_inputs=self.mask_downsample(mask_inputs_float),
|
| 440 |
+
high_res_features=high_res_features,
|
| 441 |
+
)
|
| 442 |
+
# In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
|
| 443 |
+
# Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
|
| 444 |
+
# on the object_scores from the SAM decoder.
|
| 445 |
+
is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
|
| 446 |
+
is_obj_appearing = is_obj_appearing[..., None]
|
| 447 |
+
lambda_is_obj_appearing = is_obj_appearing.float()
|
| 448 |
+
object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
|
| 449 |
+
if self.pred_obj_scores:
|
| 450 |
+
if self.fixed_no_obj_ptr:
|
| 451 |
+
obj_ptr = lambda_is_obj_appearing * obj_ptr
|
| 452 |
+
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
|
| 453 |
+
|
| 454 |
+
return (
|
| 455 |
+
low_res_masks,
|
| 456 |
+
high_res_masks,
|
| 457 |
+
ious,
|
| 458 |
+
low_res_masks,
|
| 459 |
+
high_res_masks,
|
| 460 |
+
obj_ptr,
|
| 461 |
+
object_score_logits,
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
def forward_image(self, img_batch: torch.Tensor):
|
| 465 |
+
"""Get the image feature on the input batch."""
|
| 466 |
+
backbone_out = self.image_encoder(img_batch)
|
| 467 |
+
if self.use_high_res_features_in_sam:
|
| 468 |
+
# precompute projected level 0 and level 1 features in SAM decoder
|
| 469 |
+
# to avoid running it again on every SAM click
|
| 470 |
+
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
|
| 471 |
+
backbone_out["backbone_fpn"][0]
|
| 472 |
+
)
|
| 473 |
+
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
|
| 474 |
+
backbone_out["backbone_fpn"][1]
|
| 475 |
+
)
|
| 476 |
+
return backbone_out
|
| 477 |
+
|
| 478 |
+
def _prepare_backbone_features(self, backbone_out):
|
| 479 |
+
"""Prepare and flatten visual features."""
|
| 480 |
+
backbone_out = backbone_out.copy()
|
| 481 |
+
assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
|
| 482 |
+
assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
|
| 483 |
+
|
| 484 |
+
feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
|
| 485 |
+
vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
|
| 486 |
+
|
| 487 |
+
feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
|
| 488 |
+
# flatten NxCxHxW to HWxNxC
|
| 489 |
+
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
|
| 490 |
+
vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
|
| 491 |
+
|
| 492 |
+
return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
|
| 493 |
+
|
| 494 |
+
def _prepare_memory_conditioned_features(
|
| 495 |
+
self,
|
| 496 |
+
frame_idx,
|
| 497 |
+
is_init_cond_frame,
|
| 498 |
+
current_vision_feats,
|
| 499 |
+
current_vision_pos_embeds,
|
| 500 |
+
feat_sizes,
|
| 501 |
+
output_dict,
|
| 502 |
+
num_frames,
|
| 503 |
+
track_in_reverse=False, # tracking in reverse time order (for demo usage)
|
| 504 |
+
):
|
| 505 |
+
"""Fuse the current frame's visual feature map with previous memory."""
|
| 506 |
+
B = current_vision_feats[-1].size(1) # batch size on this frame
|
| 507 |
+
C = self.hidden_dim
|
| 508 |
+
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
|
| 509 |
+
device = current_vision_feats[-1].device
|
| 510 |
+
# The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
|
| 511 |
+
# In this case, we skip the fusion with any memory.
|
| 512 |
+
if self.num_maskmem == 0: # Disable memory and skip fusion
|
| 513 |
+
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
|
| 514 |
+
return pix_feat
|
| 515 |
+
|
| 516 |
+
num_obj_ptr_tokens = 0
|
| 517 |
+
# Step 1: condition the visual features of the current frame on previous memories
|
| 518 |
+
if not is_init_cond_frame:
|
| 519 |
+
# Retrieve the memories encoded with the maskmem backbone
|
| 520 |
+
to_cat_memory, to_cat_memory_pos_embed = [], []
|
| 521 |
+
# Add conditioning frames's output first (all cond frames have t_pos=0 for
|
| 522 |
+
# when getting temporal positional embedding below)
|
| 523 |
+
assert len(output_dict["cond_frame_outputs"]) > 0
|
| 524 |
+
# Select a maximum number of temporally closest cond frames for cross attention
|
| 525 |
+
cond_outputs = output_dict["cond_frame_outputs"]
|
| 526 |
+
selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
|
| 527 |
+
frame_idx, cond_outputs, self.max_cond_frames_in_attn
|
| 528 |
+
)
|
| 529 |
+
t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
|
| 530 |
+
# Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
|
| 531 |
+
# the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
|
| 532 |
+
# We also allow taking the memory frame non-consecutively (with r>1), in which case
|
| 533 |
+
# we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.
|
| 534 |
+
r = self.memory_temporal_stride_for_eval
|
| 535 |
+
for t_pos in range(1, self.num_maskmem):
|
| 536 |
+
t_rel = self.num_maskmem - t_pos # how many frames before current frame
|
| 537 |
+
if t_rel == 1:
|
| 538 |
+
# for t_rel == 1, we take the last frame (regardless of r)
|
| 539 |
+
if not track_in_reverse:
|
| 540 |
+
# the frame immediately before this frame (i.e. frame_idx - 1)
|
| 541 |
+
prev_frame_idx = frame_idx - t_rel
|
| 542 |
+
else:
|
| 543 |
+
# the frame immediately after this frame (i.e. frame_idx + 1)
|
| 544 |
+
prev_frame_idx = frame_idx + t_rel
|
| 545 |
+
else:
|
| 546 |
+
# for t_rel >= 2, we take the memory frame from every r-th frames
|
| 547 |
+
if not track_in_reverse:
|
| 548 |
+
# first find the nearest frame among every r-th frames before this frame
|
| 549 |
+
# for r=1, this would be (frame_idx - 2)
|
| 550 |
+
prev_frame_idx = ((frame_idx - 2) // r) * r
|
| 551 |
+
# then seek further among every r-th frames
|
| 552 |
+
prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
|
| 553 |
+
else:
|
| 554 |
+
# first find the nearest frame among every r-th frames after this frame
|
| 555 |
+
# for r=1, this would be (frame_idx + 2)
|
| 556 |
+
prev_frame_idx = -(-(frame_idx + 2) // r) * r
|
| 557 |
+
# then seek further among every r-th frames
|
| 558 |
+
prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
|
| 559 |
+
out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
|
| 560 |
+
if out is None:
|
| 561 |
+
# If an unselected conditioning frame is among the last (self.num_maskmem - 1)
|
| 562 |
+
# frames, we still attend to it as if it's a non-conditioning frame.
|
| 563 |
+
out = unselected_cond_outputs.get(prev_frame_idx, None)
|
| 564 |
+
t_pos_and_prevs.append((t_pos, out))
|
| 565 |
+
|
| 566 |
+
for t_pos, prev in t_pos_and_prevs:
|
| 567 |
+
if prev is None:
|
| 568 |
+
continue # skip padding frames
|
| 569 |
+
# "maskmem_features" might have been offloaded to CPU in demo use cases,
|
| 570 |
+
# so we load it back to GPU (it's a no-op if it's already on GPU).
|
| 571 |
+
feats = prev["maskmem_features"].cuda(non_blocking=True)
|
| 572 |
+
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
|
| 573 |
+
# Spatial positional encoding (it might have been offloaded to CPU in eval)
|
| 574 |
+
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
|
| 575 |
+
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
|
| 576 |
+
# Temporal positional encoding
|
| 577 |
+
maskmem_enc = (
|
| 578 |
+
maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
|
| 579 |
+
)
|
| 580 |
+
to_cat_memory_pos_embed.append(maskmem_enc)
|
| 581 |
+
|
| 582 |
+
# Construct the list of past object pointers
|
| 583 |
+
if self.use_obj_ptrs_in_encoder:
|
| 584 |
+
max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
|
| 585 |
+
# First add those object pointers from selected conditioning frames
|
| 586 |
+
# (optionally, only include object pointers in the past during evaluation)
|
| 587 |
+
if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
|
| 588 |
+
ptr_cond_outputs = {
|
| 589 |
+
t: out
|
| 590 |
+
for t, out in selected_cond_outputs.items()
|
| 591 |
+
if (t >= frame_idx if track_in_reverse else t <= frame_idx)
|
| 592 |
+
}
|
| 593 |
+
else:
|
| 594 |
+
ptr_cond_outputs = selected_cond_outputs
|
| 595 |
+
pos_and_ptrs = [
|
| 596 |
+
# Temporal pos encoding contains how far away each pointer is from current frame
|
| 597 |
+
(abs(frame_idx - t), out["obj_ptr"])
|
| 598 |
+
for t, out in ptr_cond_outputs.items()
|
| 599 |
+
]
|
| 600 |
+
# Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
|
| 601 |
+
for t_diff in range(1, max_obj_ptrs_in_encoder):
|
| 602 |
+
t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
|
| 603 |
+
if t < 0 or (num_frames is not None and t >= num_frames):
|
| 604 |
+
break
|
| 605 |
+
out = output_dict["non_cond_frame_outputs"].get(
|
| 606 |
+
t, unselected_cond_outputs.get(t, None)
|
| 607 |
+
)
|
| 608 |
+
if out is not None:
|
| 609 |
+
pos_and_ptrs.append((t_diff, out["obj_ptr"]))
|
| 610 |
+
# If we have at least one object pointer, add them to the across attention
|
| 611 |
+
if len(pos_and_ptrs) > 0:
|
| 612 |
+
pos_list, ptrs_list = zip(*pos_and_ptrs)
|
| 613 |
+
# stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
|
| 614 |
+
obj_ptrs = torch.stack(ptrs_list, dim=0)
|
| 615 |
+
# a temporal positional embedding based on how far each object pointer is from
|
| 616 |
+
# the current frame (sine embedding normalized by the max pointer num).
|
| 617 |
+
if self.add_tpos_enc_to_obj_ptrs:
|
| 618 |
+
t_diff_max = max_obj_ptrs_in_encoder - 1
|
| 619 |
+
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
|
| 620 |
+
obj_pos = torch.tensor(pos_list, device=device)
|
| 621 |
+
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
|
| 622 |
+
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
|
| 623 |
+
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
|
| 624 |
+
else:
|
| 625 |
+
obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
|
| 626 |
+
if self.mem_dim < C:
|
| 627 |
+
# split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
|
| 628 |
+
obj_ptrs = obj_ptrs.reshape(
|
| 629 |
+
-1, B, C // self.mem_dim, self.mem_dim
|
| 630 |
+
)
|
| 631 |
+
obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
|
| 632 |
+
obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
|
| 633 |
+
to_cat_memory.append(obj_ptrs)
|
| 634 |
+
to_cat_memory_pos_embed.append(obj_pos)
|
| 635 |
+
num_obj_ptr_tokens = obj_ptrs.shape[0]
|
| 636 |
+
else:
|
| 637 |
+
num_obj_ptr_tokens = 0
|
| 638 |
+
else:
|
| 639 |
+
# for initial conditioning frames, encode them without using any previous memory
|
| 640 |
+
if self.directly_add_no_mem_embed:
|
| 641 |
+
# directly add no-mem embedding (instead of using the transformer encoder)
|
| 642 |
+
pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
|
| 643 |
+
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
|
| 644 |
+
return pix_feat_with_mem
|
| 645 |
+
|
| 646 |
+
# Use a dummy token on the first frame (to avoid emtpy memory input to tranformer encoder)
|
| 647 |
+
to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
|
| 648 |
+
to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
|
| 649 |
+
|
| 650 |
+
# Step 2: Concatenate the memories and forward through the transformer encoder
|
| 651 |
+
memory = torch.cat(to_cat_memory, dim=0)
|
| 652 |
+
memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
|
| 653 |
+
|
| 654 |
+
pix_feat_with_mem = self.memory_attention(
|
| 655 |
+
curr=current_vision_feats,
|
| 656 |
+
curr_pos=current_vision_pos_embeds,
|
| 657 |
+
memory=memory,
|
| 658 |
+
memory_pos=memory_pos_embed,
|
| 659 |
+
num_obj_ptr_tokens=num_obj_ptr_tokens,
|
| 660 |
+
)
|
| 661 |
+
# reshape the output (HW)BC => BCHW
|
| 662 |
+
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
|
| 663 |
+
return pix_feat_with_mem
|
| 664 |
+
|
| 665 |
+
def _encode_new_memory(
|
| 666 |
+
self,
|
| 667 |
+
current_vision_feats,
|
| 668 |
+
feat_sizes,
|
| 669 |
+
pred_masks_high_res,
|
| 670 |
+
is_mask_from_pts,
|
| 671 |
+
):
|
| 672 |
+
"""Encode the current image and its prediction into a memory feature."""
|
| 673 |
+
B = current_vision_feats[-1].size(1) # batch size on this frame
|
| 674 |
+
C = self.hidden_dim
|
| 675 |
+
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
|
| 676 |
+
# top-level feature, (HW)BC => BCHW
|
| 677 |
+
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
|
| 678 |
+
if self.non_overlap_masks_for_mem_enc and not self.training:
|
| 679 |
+
# optionally, apply non-overlapping constraints to the masks (it's applied
|
| 680 |
+
# in the batch dimension and should only be used during eval, where all
|
| 681 |
+
# the objects come from the same video under batch size 1).
|
| 682 |
+
pred_masks_high_res = self._apply_non_overlapping_constraints(
|
| 683 |
+
pred_masks_high_res
|
| 684 |
+
)
|
| 685 |
+
# scale the raw mask logits with a temperature before applying sigmoid
|
| 686 |
+
binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
|
| 687 |
+
if binarize and not self.training:
|
| 688 |
+
mask_for_mem = (pred_masks_high_res > 0).float()
|
| 689 |
+
else:
|
| 690 |
+
# apply sigmoid on the raw mask logits to turn them into range (0, 1)
|
| 691 |
+
mask_for_mem = torch.sigmoid(pred_masks_high_res)
|
| 692 |
+
# apply scale and bias terms to the sigmoid probabilities
|
| 693 |
+
if self.sigmoid_scale_for_mem_enc != 1.0:
|
| 694 |
+
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
|
| 695 |
+
if self.sigmoid_bias_for_mem_enc != 0.0:
|
| 696 |
+
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
|
| 697 |
+
maskmem_out = self.memory_encoder(
|
| 698 |
+
pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
|
| 699 |
+
)
|
| 700 |
+
maskmem_features = maskmem_out["vision_features"]
|
| 701 |
+
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
|
| 702 |
+
|
| 703 |
+
return maskmem_features, maskmem_pos_enc
|
| 704 |
+
|
| 705 |
+
def track_step(
|
| 706 |
+
self,
|
| 707 |
+
frame_idx,
|
| 708 |
+
is_init_cond_frame,
|
| 709 |
+
current_vision_feats,
|
| 710 |
+
current_vision_pos_embeds,
|
| 711 |
+
feat_sizes,
|
| 712 |
+
point_inputs,
|
| 713 |
+
mask_inputs,
|
| 714 |
+
output_dict,
|
| 715 |
+
num_frames,
|
| 716 |
+
track_in_reverse=False, # tracking in reverse time order (for demo usage)
|
| 717 |
+
# Whether to run the memory encoder on the predicted masks. Sometimes we might want
|
| 718 |
+
# to skip the memory encoder with `run_mem_encoder=False`. For example,
|
| 719 |
+
# in demo we might call `track_step` multiple times for each user click,
|
| 720 |
+
# and only encode the memory when the user finalizes their clicks. And in ablation
|
| 721 |
+
# settings like SAM training on static images, we don't need the memory encoder.
|
| 722 |
+
run_mem_encoder=True,
|
| 723 |
+
# The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
|
| 724 |
+
prev_sam_mask_logits=None,
|
| 725 |
+
):
|
| 726 |
+
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
|
| 727 |
+
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
|
| 728 |
+
if len(current_vision_feats) > 1:
|
| 729 |
+
high_res_features = [
|
| 730 |
+
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
|
| 731 |
+
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
|
| 732 |
+
]
|
| 733 |
+
else:
|
| 734 |
+
high_res_features = None
|
| 735 |
+
if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
|
| 736 |
+
# When use_mask_input_as_output_without_sam=True, we directly output the mask input
|
| 737 |
+
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
|
| 738 |
+
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
|
| 739 |
+
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
|
| 740 |
+
sam_outputs = self._use_mask_as_output(
|
| 741 |
+
pix_feat, high_res_features, mask_inputs
|
| 742 |
+
)
|
| 743 |
+
else:
|
| 744 |
+
# fused the visual feature with previous memory features in the memory bank
|
| 745 |
+
pix_feat_with_mem = self._prepare_memory_conditioned_features(
|
| 746 |
+
frame_idx=frame_idx,
|
| 747 |
+
is_init_cond_frame=is_init_cond_frame,
|
| 748 |
+
current_vision_feats=current_vision_feats[-1:],
|
| 749 |
+
current_vision_pos_embeds=current_vision_pos_embeds[-1:],
|
| 750 |
+
feat_sizes=feat_sizes[-1:],
|
| 751 |
+
output_dict=output_dict,
|
| 752 |
+
num_frames=num_frames,
|
| 753 |
+
track_in_reverse=track_in_reverse,
|
| 754 |
+
)
|
| 755 |
+
# apply SAM-style segmentation head
|
| 756 |
+
# here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
|
| 757 |
+
# e.g. in demo where such logits come from earlier interaction instead of correction sampling
|
| 758 |
+
# (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
|
| 759 |
+
if prev_sam_mask_logits is not None:
|
| 760 |
+
assert point_inputs is not None and mask_inputs is None
|
| 761 |
+
mask_inputs = prev_sam_mask_logits
|
| 762 |
+
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
|
| 763 |
+
sam_outputs = self._forward_sam_heads(
|
| 764 |
+
backbone_features=pix_feat_with_mem,
|
| 765 |
+
point_inputs=point_inputs,
|
| 766 |
+
mask_inputs=mask_inputs,
|
| 767 |
+
high_res_features=high_res_features,
|
| 768 |
+
multimask_output=multimask_output,
|
| 769 |
+
)
|
| 770 |
+
(
|
| 771 |
+
_,
|
| 772 |
+
_,
|
| 773 |
+
_,
|
| 774 |
+
low_res_masks,
|
| 775 |
+
high_res_masks,
|
| 776 |
+
obj_ptr,
|
| 777 |
+
_,
|
| 778 |
+
) = sam_outputs
|
| 779 |
+
|
| 780 |
+
current_out["pred_masks"] = low_res_masks
|
| 781 |
+
current_out["pred_masks_high_res"] = high_res_masks
|
| 782 |
+
current_out["obj_ptr"] = obj_ptr
|
| 783 |
+
|
| 784 |
+
# Finally run the memory encoder on the predicted mask to encode
|
| 785 |
+
# it into a new memory feature (that can be used in future frames)
|
| 786 |
+
if run_mem_encoder and self.num_maskmem > 0:
|
| 787 |
+
high_res_masks_for_mem_enc = high_res_masks
|
| 788 |
+
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
|
| 789 |
+
current_vision_feats=current_vision_feats,
|
| 790 |
+
feat_sizes=feat_sizes,
|
| 791 |
+
pred_masks_high_res=high_res_masks_for_mem_enc,
|
| 792 |
+
is_mask_from_pts=(point_inputs is not None),
|
| 793 |
+
)
|
| 794 |
+
current_out["maskmem_features"] = maskmem_features
|
| 795 |
+
current_out["maskmem_pos_enc"] = maskmem_pos_enc
|
| 796 |
+
else:
|
| 797 |
+
current_out["maskmem_features"] = None
|
| 798 |
+
current_out["maskmem_pos_enc"] = None
|
| 799 |
+
|
| 800 |
+
return current_out
|
| 801 |
+
|
| 802 |
+
def _use_multimask(self, is_init_cond_frame, point_inputs):
|
| 803 |
+
"""Whether to use multimask output in the SAM head."""
|
| 804 |
+
num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
|
| 805 |
+
multimask_output = (
|
| 806 |
+
self.multimask_output_in_sam
|
| 807 |
+
and (is_init_cond_frame or self.multimask_output_for_tracking)
|
| 808 |
+
and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
|
| 809 |
+
)
|
| 810 |
+
return multimask_output
|
| 811 |
+
|
| 812 |
+
def _apply_non_overlapping_constraints(self, pred_masks):
|
| 813 |
+
"""
|
| 814 |
+
Apply non-overlapping constraints to the object scores in pred_masks. Here we
|
| 815 |
+
keep only the highest scoring object at each spatial location in pred_masks.
|
| 816 |
+
"""
|
| 817 |
+
batch_size = pred_masks.size(0)
|
| 818 |
+
if batch_size == 1:
|
| 819 |
+
return pred_masks
|
| 820 |
+
|
| 821 |
+
device = pred_masks.device
|
| 822 |
+
# "max_obj_inds": object index of the object with the highest score at each location
|
| 823 |
+
max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
|
| 824 |
+
# "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
|
| 825 |
+
batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
|
| 826 |
+
keep = max_obj_inds == batch_obj_inds
|
| 827 |
+
# suppress overlapping regions' scores below -10.0 so that the foreground regions
|
| 828 |
+
# don't overlap (here sigmoid(-10.0)=4.5398e-05)
|
| 829 |
+
pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
|
| 830 |
+
return pred_masks
|
RynnEC/third_parts/sam2/modeling/sam2_utils.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import copy
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
|
| 16 |
+
"""
|
| 17 |
+
Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
|
| 18 |
+
that are temporally closest to the current frame at `frame_idx`. Here, we take
|
| 19 |
+
- a) the closest conditioning frame before `frame_idx` (if any);
|
| 20 |
+
- b) the closest conditioning frame after `frame_idx` (if any);
|
| 21 |
+
- c) any other temporally closest conditioning frames until reaching a total
|
| 22 |
+
of `max_cond_frame_num` conditioning frames.
|
| 23 |
+
|
| 24 |
+
Outputs:
|
| 25 |
+
- selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
|
| 26 |
+
- unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
|
| 27 |
+
"""
|
| 28 |
+
if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
|
| 29 |
+
selected_outputs = cond_frame_outputs
|
| 30 |
+
unselected_outputs = {}
|
| 31 |
+
else:
|
| 32 |
+
assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
|
| 33 |
+
selected_outputs = {}
|
| 34 |
+
|
| 35 |
+
# the closest conditioning frame before `frame_idx` (if any)
|
| 36 |
+
idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
|
| 37 |
+
if idx_before is not None:
|
| 38 |
+
selected_outputs[idx_before] = cond_frame_outputs[idx_before]
|
| 39 |
+
|
| 40 |
+
# the closest conditioning frame after `frame_idx` (if any)
|
| 41 |
+
idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
|
| 42 |
+
if idx_after is not None:
|
| 43 |
+
selected_outputs[idx_after] = cond_frame_outputs[idx_after]
|
| 44 |
+
|
| 45 |
+
# add other temporally closest conditioning frames until reaching a total
|
| 46 |
+
# of `max_cond_frame_num` conditioning frames.
|
| 47 |
+
num_remain = max_cond_frame_num - len(selected_outputs)
|
| 48 |
+
inds_remain = sorted(
|
| 49 |
+
(t for t in cond_frame_outputs if t not in selected_outputs),
|
| 50 |
+
key=lambda x: abs(x - frame_idx),
|
| 51 |
+
)[:num_remain]
|
| 52 |
+
selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
|
| 53 |
+
unselected_outputs = {
|
| 54 |
+
t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
return selected_outputs, unselected_outputs
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def get_1d_sine_pe(pos_inds, dim, temperature=10000):
|
| 61 |
+
"""
|
| 62 |
+
Get 1D sine positional embedding as in the original Transformer paper.
|
| 63 |
+
"""
|
| 64 |
+
pe_dim = dim // 2
|
| 65 |
+
dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
|
| 66 |
+
dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
|
| 67 |
+
|
| 68 |
+
pos_embed = pos_inds.unsqueeze(-1) / dim_t
|
| 69 |
+
pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
|
| 70 |
+
return pos_embed
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def get_activation_fn(activation):
|
| 74 |
+
"""Return an activation function given a string"""
|
| 75 |
+
if activation == "relu":
|
| 76 |
+
return F.relu
|
| 77 |
+
if activation == "gelu":
|
| 78 |
+
return F.gelu
|
| 79 |
+
if activation == "glu":
|
| 80 |
+
return F.glu
|
| 81 |
+
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_clones(module, N):
|
| 85 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class DropPath(nn.Module):
|
| 89 |
+
# adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
|
| 90 |
+
def __init__(self, drop_prob=0.0, scale_by_keep=True):
|
| 91 |
+
super(DropPath, self).__init__()
|
| 92 |
+
self.drop_prob = drop_prob
|
| 93 |
+
self.scale_by_keep = scale_by_keep
|
| 94 |
+
|
| 95 |
+
def forward(self, x):
|
| 96 |
+
if self.drop_prob == 0.0 or not self.training:
|
| 97 |
+
return x
|
| 98 |
+
keep_prob = 1 - self.drop_prob
|
| 99 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
| 100 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 101 |
+
if keep_prob > 0.0 and self.scale_by_keep:
|
| 102 |
+
random_tensor.div_(keep_prob)
|
| 103 |
+
return x * random_tensor
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# Lightly adapted from
|
| 107 |
+
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
|
| 108 |
+
class MLP(nn.Module):
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
input_dim: int,
|
| 112 |
+
hidden_dim: int,
|
| 113 |
+
output_dim: int,
|
| 114 |
+
num_layers: int,
|
| 115 |
+
activation: nn.Module = nn.ReLU,
|
| 116 |
+
sigmoid_output: bool = False,
|
| 117 |
+
) -> None:
|
| 118 |
+
super().__init__()
|
| 119 |
+
self.num_layers = num_layers
|
| 120 |
+
h = [hidden_dim] * (num_layers - 1)
|
| 121 |
+
self.layers = nn.ModuleList(
|
| 122 |
+
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
| 123 |
+
)
|
| 124 |
+
self.sigmoid_output = sigmoid_output
|
| 125 |
+
self.act = activation()
|
| 126 |
+
|
| 127 |
+
def forward(self, x):
|
| 128 |
+
for i, layer in enumerate(self.layers):
|
| 129 |
+
x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
|
| 130 |
+
if self.sigmoid_output:
|
| 131 |
+
x = F.sigmoid(x)
|
| 132 |
+
return x
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
|
| 136 |
+
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
|
| 137 |
+
class LayerNorm2d(nn.Module):
|
| 138 |
+
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
| 139 |
+
super().__init__()
|
| 140 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
| 141 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
| 142 |
+
self.eps = eps
|
| 143 |
+
|
| 144 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 145 |
+
u = x.mean(1, keepdim=True)
|
| 146 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 147 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 148 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 149 |
+
return x
|
RynnEC/third_parts/sam2/sam2_configs/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
RynnEC/third_parts/sam2/sam2_configs/sam2_hiera_b+.yaml
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: third_parts.sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: third_parts.sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: third_parts.sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 112
|
| 12 |
+
num_heads: 2
|
| 13 |
+
neck:
|
| 14 |
+
_target_: third_parts.sam2.modeling.backbones.image_encoder.FpnNeck
|
| 15 |
+
position_encoding:
|
| 16 |
+
_target_: third_parts.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 17 |
+
num_pos_feats: 256
|
| 18 |
+
normalize: true
|
| 19 |
+
scale: null
|
| 20 |
+
temperature: 10000
|
| 21 |
+
d_model: 256
|
| 22 |
+
backbone_channel_list: [896, 448, 224, 112]
|
| 23 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 24 |
+
fpn_interp_model: nearest
|
| 25 |
+
|
| 26 |
+
memory_attention:
|
| 27 |
+
_target_: third_parts.sam2.modeling.memory_attention.MemoryAttention
|
| 28 |
+
d_model: 256
|
| 29 |
+
pos_enc_at_input: true
|
| 30 |
+
layer:
|
| 31 |
+
_target_: third_parts.sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 32 |
+
activation: relu
|
| 33 |
+
dim_feedforward: 2048
|
| 34 |
+
dropout: 0.1
|
| 35 |
+
pos_enc_at_attn: false
|
| 36 |
+
self_attention:
|
| 37 |
+
_target_: third_parts.sam2.modeling.sam.transformer.RoPEAttention
|
| 38 |
+
rope_theta: 10000.0
|
| 39 |
+
feat_sizes: [32, 32]
|
| 40 |
+
embedding_dim: 256
|
| 41 |
+
num_heads: 1
|
| 42 |
+
downsample_rate: 1
|
| 43 |
+
dropout: 0.1
|
| 44 |
+
d_model: 256
|
| 45 |
+
pos_enc_at_cross_attn_keys: true
|
| 46 |
+
pos_enc_at_cross_attn_queries: false
|
| 47 |
+
cross_attention:
|
| 48 |
+
_target_: third_parts.sam2.modeling.sam.transformer.RoPEAttention
|
| 49 |
+
rope_theta: 10000.0
|
| 50 |
+
feat_sizes: [32, 32]
|
| 51 |
+
rope_k_repeat: True
|
| 52 |
+
embedding_dim: 256
|
| 53 |
+
num_heads: 1
|
| 54 |
+
downsample_rate: 1
|
| 55 |
+
dropout: 0.1
|
| 56 |
+
kv_in_dim: 64
|
| 57 |
+
num_layers: 4
|
| 58 |
+
|
| 59 |
+
memory_encoder:
|
| 60 |
+
_target_: third_parts.sam2.modeling.memory_encoder.MemoryEncoder
|
| 61 |
+
out_dim: 64
|
| 62 |
+
position_encoding:
|
| 63 |
+
_target_: third_parts.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 64 |
+
num_pos_feats: 64
|
| 65 |
+
normalize: true
|
| 66 |
+
scale: null
|
| 67 |
+
temperature: 10000
|
| 68 |
+
mask_downsampler:
|
| 69 |
+
_target_: third_parts.sam2.modeling.memory_encoder.MaskDownSampler
|
| 70 |
+
kernel_size: 3
|
| 71 |
+
stride: 2
|
| 72 |
+
padding: 1
|
| 73 |
+
fuser:
|
| 74 |
+
_target_: third_parts.sam2.modeling.memory_encoder.Fuser
|
| 75 |
+
layer:
|
| 76 |
+
_target_: third_parts.sam2.modeling.memory_encoder.CXBlock
|
| 77 |
+
dim: 256
|
| 78 |
+
kernel_size: 7
|
| 79 |
+
padding: 3
|
| 80 |
+
layer_scale_init_value: 1e-6
|
| 81 |
+
use_dwconv: True # depth-wise convs
|
| 82 |
+
num_layers: 2
|
| 83 |
+
|
| 84 |
+
num_maskmem: 7
|
| 85 |
+
image_size: 1024
|
| 86 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 87 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 88 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 89 |
+
use_mask_input_as_output_without_sam: true
|
| 90 |
+
# Memory
|
| 91 |
+
directly_add_no_mem_embed: true
|
| 92 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 93 |
+
use_high_res_features_in_sam: true
|
| 94 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 95 |
+
multimask_output_in_sam: true
|
| 96 |
+
# SAM heads
|
| 97 |
+
iou_prediction_use_sigmoid: True
|
| 98 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 99 |
+
use_obj_ptrs_in_encoder: true
|
| 100 |
+
add_tpos_enc_to_obj_ptrs: false
|
| 101 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 102 |
+
# object occlusion prediction
|
| 103 |
+
pred_obj_scores: true
|
| 104 |
+
pred_obj_scores_mlp: true
|
| 105 |
+
fixed_no_obj_ptr: true
|
| 106 |
+
# multimask tracking settings
|
| 107 |
+
multimask_output_for_tracking: true
|
| 108 |
+
use_multimask_token_for_obj_ptr: true
|
| 109 |
+
multimask_min_pt_num: 0
|
| 110 |
+
multimask_max_pt_num: 1
|
| 111 |
+
use_mlp_for_obj_ptr_proj: true
|
| 112 |
+
# Compilation flag
|
| 113 |
+
compile_image_encoder: False
|
RynnEC/third_parts/sam2/sam2_configs/sam2_hiera_l.yaml
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: third_parts.sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: third_parts.sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: third_parts.sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 144
|
| 12 |
+
num_heads: 2
|
| 13 |
+
stages: [2, 6, 36, 4]
|
| 14 |
+
global_att_blocks: [23, 33, 43]
|
| 15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
+
window_spec: [8, 4, 16, 8]
|
| 17 |
+
neck:
|
| 18 |
+
_target_: third_parts.sam2.modeling.backbones.image_encoder.FpnNeck
|
| 19 |
+
position_encoding:
|
| 20 |
+
_target_: third_parts.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 21 |
+
num_pos_feats: 256
|
| 22 |
+
normalize: true
|
| 23 |
+
scale: null
|
| 24 |
+
temperature: 10000
|
| 25 |
+
d_model: 256
|
| 26 |
+
backbone_channel_list: [1152, 576, 288, 144]
|
| 27 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 28 |
+
fpn_interp_model: nearest
|
| 29 |
+
|
| 30 |
+
memory_attention:
|
| 31 |
+
_target_: third_parts.sam2.modeling.memory_attention.MemoryAttention
|
| 32 |
+
d_model: 256
|
| 33 |
+
pos_enc_at_input: true
|
| 34 |
+
layer:
|
| 35 |
+
_target_: third_parts.sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 36 |
+
activation: relu
|
| 37 |
+
dim_feedforward: 2048
|
| 38 |
+
dropout: 0.1
|
| 39 |
+
pos_enc_at_attn: false
|
| 40 |
+
self_attention:
|
| 41 |
+
_target_: third_parts.sam2.modeling.sam.transformer.RoPEAttention
|
| 42 |
+
rope_theta: 10000.0
|
| 43 |
+
feat_sizes: [32, 32]
|
| 44 |
+
embedding_dim: 256
|
| 45 |
+
num_heads: 1
|
| 46 |
+
downsample_rate: 1
|
| 47 |
+
dropout: 0.1
|
| 48 |
+
d_model: 256
|
| 49 |
+
pos_enc_at_cross_attn_keys: true
|
| 50 |
+
pos_enc_at_cross_attn_queries: false
|
| 51 |
+
cross_attention:
|
| 52 |
+
_target_: third_parts.sam2.modeling.sam.transformer.RoPEAttention
|
| 53 |
+
rope_theta: 10000.0
|
| 54 |
+
feat_sizes: [32, 32]
|
| 55 |
+
rope_k_repeat: True
|
| 56 |
+
embedding_dim: 256
|
| 57 |
+
num_heads: 1
|
| 58 |
+
downsample_rate: 1
|
| 59 |
+
dropout: 0.1
|
| 60 |
+
kv_in_dim: 64
|
| 61 |
+
num_layers: 4
|
| 62 |
+
|
| 63 |
+
memory_encoder:
|
| 64 |
+
_target_: third_parts.sam2.modeling.memory_encoder.MemoryEncoder
|
| 65 |
+
out_dim: 64
|
| 66 |
+
position_encoding:
|
| 67 |
+
_target_: third_parts.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 68 |
+
num_pos_feats: 64
|
| 69 |
+
normalize: true
|
| 70 |
+
scale: null
|
| 71 |
+
temperature: 10000
|
| 72 |
+
mask_downsampler:
|
| 73 |
+
_target_: third_parts.sam2.modeling.memory_encoder.MaskDownSampler
|
| 74 |
+
kernel_size: 3
|
| 75 |
+
stride: 2
|
| 76 |
+
padding: 1
|
| 77 |
+
fuser:
|
| 78 |
+
_target_: third_parts.sam2.modeling.memory_encoder.Fuser
|
| 79 |
+
layer:
|
| 80 |
+
_target_: third_parts.sam2.modeling.memory_encoder.CXBlock
|
| 81 |
+
dim: 256
|
| 82 |
+
kernel_size: 7
|
| 83 |
+
padding: 3
|
| 84 |
+
layer_scale_init_value: 1e-6
|
| 85 |
+
use_dwconv: True # depth-wise convs
|
| 86 |
+
num_layers: 2
|
| 87 |
+
|
| 88 |
+
num_maskmem: 7
|
| 89 |
+
image_size: 1024
|
| 90 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 93 |
+
use_mask_input_as_output_without_sam: true
|
| 94 |
+
# Memory
|
| 95 |
+
directly_add_no_mem_embed: true
|
| 96 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 97 |
+
use_high_res_features_in_sam: true
|
| 98 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 99 |
+
multimask_output_in_sam: true
|
| 100 |
+
# SAM heads
|
| 101 |
+
iou_prediction_use_sigmoid: True
|
| 102 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 103 |
+
use_obj_ptrs_in_encoder: true
|
| 104 |
+
add_tpos_enc_to_obj_ptrs: false
|
| 105 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 106 |
+
# object occlusion prediction
|
| 107 |
+
pred_obj_scores: true
|
| 108 |
+
pred_obj_scores_mlp: true
|
| 109 |
+
fixed_no_obj_ptr: true
|
| 110 |
+
# multimask tracking settings
|
| 111 |
+
multimask_output_for_tracking: true
|
| 112 |
+
use_multimask_token_for_obj_ptr: true
|
| 113 |
+
multimask_min_pt_num: 0
|
| 114 |
+
multimask_max_pt_num: 1
|
| 115 |
+
use_mlp_for_obj_ptr_proj: true
|
| 116 |
+
# Compilation flag
|
| 117 |
+
compile_image_encoder: False
|
RynnEC/third_parts/sam2/sam2_configs/sam2_hiera_s.yaml
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: third_parts.sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: third_parts.sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: third_parts.sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 96
|
| 12 |
+
num_heads: 1
|
| 13 |
+
stages: [1, 2, 11, 2]
|
| 14 |
+
global_att_blocks: [7, 10, 13]
|
| 15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
+
neck:
|
| 17 |
+
_target_: third_parts.sam2.modeling.backbones.image_encoder.FpnNeck
|
| 18 |
+
position_encoding:
|
| 19 |
+
_target_: third_parts.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 20 |
+
num_pos_feats: 256
|
| 21 |
+
normalize: true
|
| 22 |
+
scale: null
|
| 23 |
+
temperature: 10000
|
| 24 |
+
d_model: 256
|
| 25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
| 26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 27 |
+
fpn_interp_model: nearest
|
| 28 |
+
|
| 29 |
+
memory_attention:
|
| 30 |
+
_target_: third_parts.sam2.modeling.memory_attention.MemoryAttention
|
| 31 |
+
d_model: 256
|
| 32 |
+
pos_enc_at_input: true
|
| 33 |
+
layer:
|
| 34 |
+
_target_: third_parts.sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 35 |
+
activation: relu
|
| 36 |
+
dim_feedforward: 2048
|
| 37 |
+
dropout: 0.1
|
| 38 |
+
pos_enc_at_attn: false
|
| 39 |
+
self_attention:
|
| 40 |
+
_target_: third_parts.sam2.modeling.sam.transformer.RoPEAttention
|
| 41 |
+
rope_theta: 10000.0
|
| 42 |
+
feat_sizes: [32, 32]
|
| 43 |
+
embedding_dim: 256
|
| 44 |
+
num_heads: 1
|
| 45 |
+
downsample_rate: 1
|
| 46 |
+
dropout: 0.1
|
| 47 |
+
d_model: 256
|
| 48 |
+
pos_enc_at_cross_attn_keys: true
|
| 49 |
+
pos_enc_at_cross_attn_queries: false
|
| 50 |
+
cross_attention:
|
| 51 |
+
_target_: third_parts.sam2.modeling.sam.transformer.RoPEAttention
|
| 52 |
+
rope_theta: 10000.0
|
| 53 |
+
feat_sizes: [32, 32]
|
| 54 |
+
rope_k_repeat: True
|
| 55 |
+
embedding_dim: 256
|
| 56 |
+
num_heads: 1
|
| 57 |
+
downsample_rate: 1
|
| 58 |
+
dropout: 0.1
|
| 59 |
+
kv_in_dim: 64
|
| 60 |
+
num_layers: 4
|
| 61 |
+
|
| 62 |
+
memory_encoder:
|
| 63 |
+
_target_: third_parts.sam2.modeling.memory_encoder.MemoryEncoder
|
| 64 |
+
out_dim: 64
|
| 65 |
+
position_encoding:
|
| 66 |
+
_target_: third_parts.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 67 |
+
num_pos_feats: 64
|
| 68 |
+
normalize: true
|
| 69 |
+
scale: null
|
| 70 |
+
temperature: 10000
|
| 71 |
+
mask_downsampler:
|
| 72 |
+
_target_: third_parts.sam2.modeling.memory_encoder.MaskDownSampler
|
| 73 |
+
kernel_size: 3
|
| 74 |
+
stride: 2
|
| 75 |
+
padding: 1
|
| 76 |
+
fuser:
|
| 77 |
+
_target_: third_parts.sam2.modeling.memory_encoder.Fuser
|
| 78 |
+
layer:
|
| 79 |
+
_target_: third_parts.sam2.modeling.memory_encoder.CXBlock
|
| 80 |
+
dim: 256
|
| 81 |
+
kernel_size: 7
|
| 82 |
+
padding: 3
|
| 83 |
+
layer_scale_init_value: 1e-6
|
| 84 |
+
use_dwconv: True # depth-wise convs
|
| 85 |
+
num_layers: 2
|
| 86 |
+
|
| 87 |
+
num_maskmem: 7
|
| 88 |
+
image_size: 1024
|
| 89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 90 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 91 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 92 |
+
use_mask_input_as_output_without_sam: true
|
| 93 |
+
# Memory
|
| 94 |
+
directly_add_no_mem_embed: true
|
| 95 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 96 |
+
use_high_res_features_in_sam: true
|
| 97 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 98 |
+
multimask_output_in_sam: true
|
| 99 |
+
# SAM heads
|
| 100 |
+
iou_prediction_use_sigmoid: True
|
| 101 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 102 |
+
use_obj_ptrs_in_encoder: true
|
| 103 |
+
add_tpos_enc_to_obj_ptrs: false
|
| 104 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 105 |
+
# object occlusion prediction
|
| 106 |
+
pred_obj_scores: true
|
| 107 |
+
pred_obj_scores_mlp: true
|
| 108 |
+
fixed_no_obj_ptr: true
|
| 109 |
+
# multimask tracking settings
|
| 110 |
+
multimask_output_for_tracking: true
|
| 111 |
+
use_multimask_token_for_obj_ptr: true
|
| 112 |
+
multimask_min_pt_num: 0
|
| 113 |
+
multimask_max_pt_num: 1
|
| 114 |
+
use_mlp_for_obj_ptr_proj: true
|
| 115 |
+
# Compilation flag
|
| 116 |
+
compile_image_encoder: False
|
RynnEC/third_parts/sam2/sam2_configs/sam2_hiera_t.yaml
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: third_parts.sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: third_parts.sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: third_parts.sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 96
|
| 12 |
+
num_heads: 1
|
| 13 |
+
stages: [1, 2, 7, 2]
|
| 14 |
+
global_att_blocks: [5, 7, 9]
|
| 15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
+
neck:
|
| 17 |
+
_target_: third_parts.sam2.modeling.backbones.image_encoder.FpnNeck
|
| 18 |
+
position_encoding:
|
| 19 |
+
_target_: third_parts.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 20 |
+
num_pos_feats: 256
|
| 21 |
+
normalize: true
|
| 22 |
+
scale: null
|
| 23 |
+
temperature: 10000
|
| 24 |
+
d_model: 256
|
| 25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
| 26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 27 |
+
fpn_interp_model: nearest
|
| 28 |
+
|
| 29 |
+
memory_attention:
|
| 30 |
+
_target_: third_parts.sam2.modeling.memory_attention.MemoryAttention
|
| 31 |
+
d_model: 256
|
| 32 |
+
pos_enc_at_input: true
|
| 33 |
+
layer:
|
| 34 |
+
_target_: third_parts.sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 35 |
+
activation: relu
|
| 36 |
+
dim_feedforward: 2048
|
| 37 |
+
dropout: 0.1
|
| 38 |
+
pos_enc_at_attn: false
|
| 39 |
+
self_attention:
|
| 40 |
+
_target_: third_parts.sam2.modeling.sam.transformer.RoPEAttention
|
| 41 |
+
rope_theta: 10000.0
|
| 42 |
+
feat_sizes: [32, 32]
|
| 43 |
+
embedding_dim: 256
|
| 44 |
+
num_heads: 1
|
| 45 |
+
downsample_rate: 1
|
| 46 |
+
dropout: 0.1
|
| 47 |
+
d_model: 256
|
| 48 |
+
pos_enc_at_cross_attn_keys: true
|
| 49 |
+
pos_enc_at_cross_attn_queries: false
|
| 50 |
+
cross_attention:
|
| 51 |
+
_target_: third_parts.sam2.modeling.sam.transformer.RoPEAttention
|
| 52 |
+
rope_theta: 10000.0
|
| 53 |
+
feat_sizes: [32, 32]
|
| 54 |
+
rope_k_repeat: True
|
| 55 |
+
embedding_dim: 256
|
| 56 |
+
num_heads: 1
|
| 57 |
+
downsample_rate: 1
|
| 58 |
+
dropout: 0.1
|
| 59 |
+
kv_in_dim: 64
|
| 60 |
+
num_layers: 4
|
| 61 |
+
|
| 62 |
+
memory_encoder:
|
| 63 |
+
_target_: third_parts.sam2.modeling.memory_encoder.MemoryEncoder
|
| 64 |
+
out_dim: 64
|
| 65 |
+
position_encoding:
|
| 66 |
+
_target_: third_parts.sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 67 |
+
num_pos_feats: 64
|
| 68 |
+
normalize: true
|
| 69 |
+
scale: null
|
| 70 |
+
temperature: 10000
|
| 71 |
+
mask_downsampler:
|
| 72 |
+
_target_: third_parts.sam2.modeling.memory_encoder.MaskDownSampler
|
| 73 |
+
kernel_size: 3
|
| 74 |
+
stride: 2
|
| 75 |
+
padding: 1
|
| 76 |
+
fuser:
|
| 77 |
+
_target_: third_parts.sam2.modeling.memory_encoder.Fuser
|
| 78 |
+
layer:
|
| 79 |
+
_target_: third_parts.sam2.modeling.memory_encoder.CXBlock
|
| 80 |
+
dim: 256
|
| 81 |
+
kernel_size: 7
|
| 82 |
+
padding: 3
|
| 83 |
+
layer_scale_init_value: 1e-6
|
| 84 |
+
use_dwconv: True # depth-wise convs
|
| 85 |
+
num_layers: 2
|
| 86 |
+
|
| 87 |
+
num_maskmem: 7
|
| 88 |
+
image_size: 1024
|
| 89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 90 |
+
# SAM decoder
|
| 91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 93 |
+
use_mask_input_as_output_without_sam: true
|
| 94 |
+
# Memory
|
| 95 |
+
directly_add_no_mem_embed: true
|
| 96 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 97 |
+
use_high_res_features_in_sam: true
|
| 98 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 99 |
+
multimask_output_in_sam: true
|
| 100 |
+
# SAM heads
|
| 101 |
+
iou_prediction_use_sigmoid: True
|
| 102 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 103 |
+
use_obj_ptrs_in_encoder: true
|
| 104 |
+
add_tpos_enc_to_obj_ptrs: false
|
| 105 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 106 |
+
# object occlusion prediction
|
| 107 |
+
pred_obj_scores: true
|
| 108 |
+
pred_obj_scores_mlp: true
|
| 109 |
+
fixed_no_obj_ptr: true
|
| 110 |
+
# multimask tracking settings
|
| 111 |
+
multimask_output_for_tracking: true
|
| 112 |
+
use_multimask_token_for_obj_ptr: true
|
| 113 |
+
multimask_min_pt_num: 0
|
| 114 |
+
multimask_max_pt_num: 1
|
| 115 |
+
use_mlp_for_obj_ptr_proj: true
|
| 116 |
+
# Compilation flag
|
| 117 |
+
# HieraT does not currently support compilation, should always be set to False
|
| 118 |
+
compile_image_encoder: False
|
RynnEC/third_parts/sam2/sam2_image_predictor.py
ADDED
|
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
from typing import List, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
from PIL.Image import Image
|
| 14 |
+
|
| 15 |
+
from third_parts.sam2.modeling.sam2_base import SAM2Base
|
| 16 |
+
|
| 17 |
+
from third_parts.sam2.utils.transforms import SAM2Transforms
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SAM2ImagePredictor:
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
sam_model: SAM2Base,
|
| 24 |
+
mask_threshold=0.0,
|
| 25 |
+
max_hole_area=0.0,
|
| 26 |
+
max_sprinkle_area=0.0,
|
| 27 |
+
) -> None:
|
| 28 |
+
"""
|
| 29 |
+
Uses SAM-2 to calculate the image embedding for an image, and then
|
| 30 |
+
allow repeated, efficient mask prediction given prompts.
|
| 31 |
+
|
| 32 |
+
Arguments:
|
| 33 |
+
sam_model (Sam-2): The model to use for mask prediction.
|
| 34 |
+
mask_threshold (float): The threshold to use when converting mask logits
|
| 35 |
+
to binary masks. Masks are thresholded at 0 by default.
|
| 36 |
+
fill_hole_area (int): If fill_hole_area > 0, we fill small holes in up to
|
| 37 |
+
the maximum area of fill_hole_area in low_res_masks.
|
| 38 |
+
"""
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.model = sam_model
|
| 41 |
+
self._transforms = SAM2Transforms(
|
| 42 |
+
resolution=self.model.image_size,
|
| 43 |
+
mask_threshold=mask_threshold,
|
| 44 |
+
max_hole_area=max_hole_area,
|
| 45 |
+
max_sprinkle_area=max_sprinkle_area,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Predictor state
|
| 49 |
+
self._is_image_set = False
|
| 50 |
+
self._features = None
|
| 51 |
+
self._orig_hw = None
|
| 52 |
+
# Whether the predictor is set for single image or a batch of images
|
| 53 |
+
self._is_batch = False
|
| 54 |
+
|
| 55 |
+
# Predictor config
|
| 56 |
+
self.mask_threshold = mask_threshold
|
| 57 |
+
|
| 58 |
+
# Spatial dim for backbone feature maps
|
| 59 |
+
self._bb_feat_sizes = [
|
| 60 |
+
(256, 256),
|
| 61 |
+
(128, 128),
|
| 62 |
+
(64, 64),
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
@torch.no_grad()
|
| 66 |
+
def set_image(
|
| 67 |
+
self,
|
| 68 |
+
image: Union[np.ndarray, Image],
|
| 69 |
+
) -> None:
|
| 70 |
+
"""
|
| 71 |
+
Calculates the image embeddings for the provided image, allowing
|
| 72 |
+
masks to be predicted with the 'predict' method.
|
| 73 |
+
|
| 74 |
+
Arguments:
|
| 75 |
+
image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image
|
| 76 |
+
with pixel values in [0, 255].
|
| 77 |
+
image_format (str): The color format of the image, in ['RGB', 'BGR'].
|
| 78 |
+
"""
|
| 79 |
+
self.reset_predictor()
|
| 80 |
+
# Transform the image to the form expected by the model
|
| 81 |
+
if isinstance(image, np.ndarray):
|
| 82 |
+
logging.info("For numpy array image, we assume (HxWxC) format")
|
| 83 |
+
self._orig_hw = [image.shape[:2]]
|
| 84 |
+
elif isinstance(image, Image):
|
| 85 |
+
w, h = image.size
|
| 86 |
+
self._orig_hw = [(h, w)]
|
| 87 |
+
else:
|
| 88 |
+
raise NotImplementedError("Image format not supported")
|
| 89 |
+
|
| 90 |
+
input_image = self._transforms(image)
|
| 91 |
+
input_image = input_image[None, ...].to(self.device)
|
| 92 |
+
|
| 93 |
+
assert (
|
| 94 |
+
len(input_image.shape) == 4 and input_image.shape[1] == 3
|
| 95 |
+
), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
|
| 96 |
+
logging.info("Computing image embeddings for the provided image...")
|
| 97 |
+
backbone_out = self.model.forward_image(input_image)
|
| 98 |
+
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
|
| 99 |
+
# Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
|
| 100 |
+
if self.model.directly_add_no_mem_embed:
|
| 101 |
+
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
|
| 102 |
+
|
| 103 |
+
feats = [
|
| 104 |
+
feat.permute(1, 2, 0).view(1, -1, *feat_size)
|
| 105 |
+
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
|
| 106 |
+
][::-1]
|
| 107 |
+
self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
|
| 108 |
+
self._is_image_set = True
|
| 109 |
+
logging.info("Image embeddings computed.")
|
| 110 |
+
|
| 111 |
+
@torch.no_grad()
|
| 112 |
+
def set_image_batch(
|
| 113 |
+
self,
|
| 114 |
+
image_list: List[Union[np.ndarray]],
|
| 115 |
+
) -> None:
|
| 116 |
+
"""
|
| 117 |
+
Calculates the image embeddings for the provided image batch, allowing
|
| 118 |
+
masks to be predicted with the 'predict_batch' method.
|
| 119 |
+
|
| 120 |
+
Arguments:
|
| 121 |
+
image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray
|
| 122 |
+
with pixel values in [0, 255].
|
| 123 |
+
"""
|
| 124 |
+
self.reset_predictor()
|
| 125 |
+
assert isinstance(image_list, list)
|
| 126 |
+
self._orig_hw = []
|
| 127 |
+
for image in image_list:
|
| 128 |
+
assert isinstance(
|
| 129 |
+
image, np.ndarray
|
| 130 |
+
), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
|
| 131 |
+
self._orig_hw.append(image.shape[:2])
|
| 132 |
+
# Transform the image to the form expected by the model
|
| 133 |
+
img_batch = self._transforms.forward_batch(image_list)
|
| 134 |
+
img_batch = img_batch.to(self.device)
|
| 135 |
+
batch_size = img_batch.shape[0]
|
| 136 |
+
assert (
|
| 137 |
+
len(img_batch.shape) == 4 and img_batch.shape[1] == 3
|
| 138 |
+
), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
|
| 139 |
+
logging.info("Computing image embeddings for the provided images...")
|
| 140 |
+
backbone_out = self.model.forward_image(img_batch)
|
| 141 |
+
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
|
| 142 |
+
# Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
|
| 143 |
+
if self.model.directly_add_no_mem_embed:
|
| 144 |
+
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
|
| 145 |
+
|
| 146 |
+
feats = [
|
| 147 |
+
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
|
| 148 |
+
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
|
| 149 |
+
][::-1]
|
| 150 |
+
self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
|
| 151 |
+
self._is_image_set = True
|
| 152 |
+
self._is_batch = True
|
| 153 |
+
logging.info("Image embeddings computed.")
|
| 154 |
+
|
| 155 |
+
def predict_batch(
|
| 156 |
+
self,
|
| 157 |
+
point_coords_batch: List[np.ndarray] = None,
|
| 158 |
+
point_labels_batch: List[np.ndarray] = None,
|
| 159 |
+
box_batch: List[np.ndarray] = None,
|
| 160 |
+
mask_input_batch: List[np.ndarray] = None,
|
| 161 |
+
multimask_output: bool = True,
|
| 162 |
+
return_logits: bool = False,
|
| 163 |
+
normalize_coords=True,
|
| 164 |
+
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
|
| 165 |
+
"""This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
|
| 166 |
+
It returns a tupele of lists of masks, ious, and low_res_masks_logits.
|
| 167 |
+
"""
|
| 168 |
+
assert self._is_batch, "This function should only be used when in batched mode"
|
| 169 |
+
if not self._is_image_set:
|
| 170 |
+
raise RuntimeError(
|
| 171 |
+
"An image must be set with .set_image_batch(...) before mask prediction."
|
| 172 |
+
)
|
| 173 |
+
num_images = len(self._features["image_embed"])
|
| 174 |
+
all_masks = []
|
| 175 |
+
all_ious = []
|
| 176 |
+
all_low_res_masks = []
|
| 177 |
+
for img_idx in range(num_images):
|
| 178 |
+
# Transform input prompts
|
| 179 |
+
point_coords = (
|
| 180 |
+
point_coords_batch[img_idx] if point_coords_batch is not None else None
|
| 181 |
+
)
|
| 182 |
+
point_labels = (
|
| 183 |
+
point_labels_batch[img_idx] if point_labels_batch is not None else None
|
| 184 |
+
)
|
| 185 |
+
box = box_batch[img_idx] if box_batch is not None else None
|
| 186 |
+
mask_input = (
|
| 187 |
+
mask_input_batch[img_idx] if mask_input_batch is not None else None
|
| 188 |
+
)
|
| 189 |
+
mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
|
| 190 |
+
point_coords,
|
| 191 |
+
point_labels,
|
| 192 |
+
box,
|
| 193 |
+
mask_input,
|
| 194 |
+
normalize_coords,
|
| 195 |
+
img_idx=img_idx,
|
| 196 |
+
)
|
| 197 |
+
masks, iou_predictions, low_res_masks = self._predict(
|
| 198 |
+
unnorm_coords,
|
| 199 |
+
labels,
|
| 200 |
+
unnorm_box,
|
| 201 |
+
mask_input,
|
| 202 |
+
multimask_output,
|
| 203 |
+
return_logits=return_logits,
|
| 204 |
+
img_idx=img_idx,
|
| 205 |
+
)
|
| 206 |
+
masks_np = masks.squeeze(0).float().detach().cpu().numpy()
|
| 207 |
+
iou_predictions_np = (
|
| 208 |
+
iou_predictions.squeeze(0).float().detach().cpu().numpy()
|
| 209 |
+
)
|
| 210 |
+
low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
|
| 211 |
+
all_masks.append(masks_np)
|
| 212 |
+
all_ious.append(iou_predictions_np)
|
| 213 |
+
all_low_res_masks.append(low_res_masks_np)
|
| 214 |
+
|
| 215 |
+
return all_masks, all_ious, all_low_res_masks
|
| 216 |
+
|
| 217 |
+
def predict(
|
| 218 |
+
self,
|
| 219 |
+
point_coords: Optional[np.ndarray] = None,
|
| 220 |
+
point_labels: Optional[np.ndarray] = None,
|
| 221 |
+
box: Optional[np.ndarray] = None,
|
| 222 |
+
mask_input: Optional[np.ndarray] = None,
|
| 223 |
+
multimask_output: bool = True,
|
| 224 |
+
return_logits: bool = False,
|
| 225 |
+
normalize_coords=True,
|
| 226 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 227 |
+
"""
|
| 228 |
+
Predict masks for the given input prompts, using the currently set image.
|
| 229 |
+
|
| 230 |
+
Arguments:
|
| 231 |
+
point_coords (np.ndarray or None): A Nx2 array of point prompts to the
|
| 232 |
+
model. Each point is in (X,Y) in pixels.
|
| 233 |
+
point_labels (np.ndarray or None): A length N array of labels for the
|
| 234 |
+
point prompts. 1 indicates a foreground point and 0 indicates a
|
| 235 |
+
background point.
|
| 236 |
+
box (np.ndarray or None): A length 4 array given a box prompt to the
|
| 237 |
+
model, in XYXY format.
|
| 238 |
+
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
| 239 |
+
coming from a previous prediction iteration. Has form 1xHxW, where
|
| 240 |
+
for SAM, H=W=256.
|
| 241 |
+
multimask_output (bool): If true, the model will return three masks.
|
| 242 |
+
For ambiguous input prompts (such as a single click), this will often
|
| 243 |
+
produce better masks than a single prediction. If only a single
|
| 244 |
+
mask is needed, the model's predicted quality score can be used
|
| 245 |
+
to select the best mask. For non-ambiguous prompts, such as multiple
|
| 246 |
+
input prompts, multimask_output=False can give better results.
|
| 247 |
+
return_logits (bool): If true, returns un-thresholded masks logits
|
| 248 |
+
instead of a binary mask.
|
| 249 |
+
normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions.
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
(np.ndarray): The output masks in CxHxW format, where C is the
|
| 253 |
+
number of masks, and (H, W) is the original image size.
|
| 254 |
+
(np.ndarray): An array of length C containing the model's
|
| 255 |
+
predictions for the quality of each mask.
|
| 256 |
+
(np.ndarray): An array of shape CxHxW, where C is the number
|
| 257 |
+
of masks and H=W=256. These low resolution logits can be passed to
|
| 258 |
+
a subsequent iteration as mask input.
|
| 259 |
+
"""
|
| 260 |
+
if not self._is_image_set:
|
| 261 |
+
raise RuntimeError(
|
| 262 |
+
"An image must be set with .set_image(...) before mask prediction."
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# Transform input prompts
|
| 266 |
+
|
| 267 |
+
mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
|
| 268 |
+
point_coords, point_labels, box, mask_input, normalize_coords
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
masks, iou_predictions, low_res_masks = self._predict(
|
| 272 |
+
unnorm_coords,
|
| 273 |
+
labels,
|
| 274 |
+
unnorm_box,
|
| 275 |
+
mask_input,
|
| 276 |
+
multimask_output,
|
| 277 |
+
return_logits=return_logits,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
masks_np = masks.squeeze(0).float().detach().cpu().numpy()
|
| 281 |
+
iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy()
|
| 282 |
+
low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
|
| 283 |
+
return masks_np, iou_predictions_np, low_res_masks_np
|
| 284 |
+
|
| 285 |
+
def _prep_prompts(
|
| 286 |
+
self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1
|
| 287 |
+
):
|
| 288 |
+
|
| 289 |
+
unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
|
| 290 |
+
if point_coords is not None:
|
| 291 |
+
assert (
|
| 292 |
+
point_labels is not None
|
| 293 |
+
), "point_labels must be supplied if point_coords is supplied."
|
| 294 |
+
point_coords = torch.as_tensor(
|
| 295 |
+
point_coords, dtype=torch.float, device=self.device
|
| 296 |
+
)
|
| 297 |
+
unnorm_coords = self._transforms.transform_coords(
|
| 298 |
+
point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
|
| 299 |
+
)
|
| 300 |
+
labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
|
| 301 |
+
if len(unnorm_coords.shape) == 2:
|
| 302 |
+
unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
|
| 303 |
+
if box is not None:
|
| 304 |
+
box = torch.as_tensor(box, dtype=torch.float, device=self.device)
|
| 305 |
+
unnorm_box = self._transforms.transform_boxes(
|
| 306 |
+
box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
|
| 307 |
+
) # Bx2x2
|
| 308 |
+
if mask_logits is not None:
|
| 309 |
+
mask_input = torch.as_tensor(
|
| 310 |
+
mask_logits, dtype=torch.float, device=self.device
|
| 311 |
+
)
|
| 312 |
+
if len(mask_input.shape) == 3:
|
| 313 |
+
mask_input = mask_input[None, :, :, :]
|
| 314 |
+
return mask_input, unnorm_coords, labels, unnorm_box
|
| 315 |
+
|
| 316 |
+
@torch.no_grad()
|
| 317 |
+
def _predict(
|
| 318 |
+
self,
|
| 319 |
+
point_coords: Optional[torch.Tensor],
|
| 320 |
+
point_labels: Optional[torch.Tensor],
|
| 321 |
+
boxes: Optional[torch.Tensor] = None,
|
| 322 |
+
mask_input: Optional[torch.Tensor] = None,
|
| 323 |
+
multimask_output: bool = True,
|
| 324 |
+
return_logits: bool = False,
|
| 325 |
+
img_idx: int = -1,
|
| 326 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 327 |
+
"""
|
| 328 |
+
Predict masks for the given input prompts, using the currently set image.
|
| 329 |
+
Input prompts are batched torch tensors and are expected to already be
|
| 330 |
+
transformed to the input frame using SAM2Transforms.
|
| 331 |
+
|
| 332 |
+
Arguments:
|
| 333 |
+
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
|
| 334 |
+
model. Each point is in (X,Y) in pixels.
|
| 335 |
+
point_labels (torch.Tensor or None): A BxN array of labels for the
|
| 336 |
+
point prompts. 1 indicates a foreground point and 0 indicates a
|
| 337 |
+
background point.
|
| 338 |
+
boxes (np.ndarray or None): A Bx4 array given a box prompt to the
|
| 339 |
+
model, in XYXY format.
|
| 340 |
+
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
| 341 |
+
coming from a previous prediction iteration. Has form Bx1xHxW, where
|
| 342 |
+
for SAM, H=W=256. Masks returned by a previous iteration of the
|
| 343 |
+
predict method do not need further transformation.
|
| 344 |
+
multimask_output (bool): If true, the model will return three masks.
|
| 345 |
+
For ambiguous input prompts (such as a single click), this will often
|
| 346 |
+
produce better masks than a single prediction. If only a single
|
| 347 |
+
mask is needed, the model's predicted quality score can be used
|
| 348 |
+
to select the best mask. For non-ambiguous prompts, such as multiple
|
| 349 |
+
input prompts, multimask_output=False can give better results.
|
| 350 |
+
return_logits (bool): If true, returns un-thresholded masks logits
|
| 351 |
+
instead of a binary mask.
|
| 352 |
+
|
| 353 |
+
Returns:
|
| 354 |
+
(torch.Tensor): The output masks in BxCxHxW format, where C is the
|
| 355 |
+
number of masks, and (H, W) is the original image size.
|
| 356 |
+
(torch.Tensor): An array of shape BxC containing the model's
|
| 357 |
+
predictions for the quality of each mask.
|
| 358 |
+
(torch.Tensor): An array of shape BxCxHxW, where C is the number
|
| 359 |
+
of masks and H=W=256. These low res logits can be passed to
|
| 360 |
+
a subsequent iteration as mask input.
|
| 361 |
+
"""
|
| 362 |
+
if not self._is_image_set:
|
| 363 |
+
raise RuntimeError(
|
| 364 |
+
"An image must be set with .set_image(...) before mask prediction."
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
if point_coords is not None:
|
| 368 |
+
concat_points = (point_coords, point_labels)
|
| 369 |
+
else:
|
| 370 |
+
concat_points = None
|
| 371 |
+
|
| 372 |
+
# Embed prompts
|
| 373 |
+
if boxes is not None:
|
| 374 |
+
box_coords = boxes.reshape(-1, 2, 2)
|
| 375 |
+
box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
|
| 376 |
+
box_labels = box_labels.repeat(boxes.size(0), 1)
|
| 377 |
+
# we merge "boxes" and "points" into a single "concat_points" input (where
|
| 378 |
+
# boxes are added at the beginning) to sam_prompt_encoder
|
| 379 |
+
if concat_points is not None:
|
| 380 |
+
concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
|
| 381 |
+
concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
|
| 382 |
+
concat_points = (concat_coords, concat_labels)
|
| 383 |
+
else:
|
| 384 |
+
concat_points = (box_coords, box_labels)
|
| 385 |
+
|
| 386 |
+
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
|
| 387 |
+
points=concat_points,
|
| 388 |
+
boxes=None,
|
| 389 |
+
masks=mask_input,
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
# Predict masks
|
| 393 |
+
batched_mode = (
|
| 394 |
+
concat_points is not None and concat_points[0].shape[0] > 1
|
| 395 |
+
) # multi object prediction
|
| 396 |
+
high_res_features = [
|
| 397 |
+
feat_level[img_idx].unsqueeze(0)
|
| 398 |
+
for feat_level in self._features["high_res_feats"]
|
| 399 |
+
]
|
| 400 |
+
low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
|
| 401 |
+
image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0),
|
| 402 |
+
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
|
| 403 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
| 404 |
+
dense_prompt_embeddings=dense_embeddings,
|
| 405 |
+
multimask_output=multimask_output,
|
| 406 |
+
repeat_image=batched_mode,
|
| 407 |
+
high_res_features=high_res_features,
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
# Upscale the masks to the original image resolution
|
| 411 |
+
masks = self._transforms.postprocess_masks(
|
| 412 |
+
low_res_masks, self._orig_hw[img_idx]
|
| 413 |
+
)
|
| 414 |
+
low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
|
| 415 |
+
if not return_logits:
|
| 416 |
+
masks = masks > self.mask_threshold
|
| 417 |
+
|
| 418 |
+
return masks, iou_predictions, low_res_masks
|
| 419 |
+
|
| 420 |
+
def get_image_embedding(self) -> torch.Tensor:
|
| 421 |
+
"""
|
| 422 |
+
Returns the image embeddings for the currently set image, with
|
| 423 |
+
shape 1xCxHxW, where C is the embedding dimension and (H,W) are
|
| 424 |
+
the embedding spatial dimension of SAM (typically C=256, H=W=64).
|
| 425 |
+
"""
|
| 426 |
+
if not self._is_image_set:
|
| 427 |
+
raise RuntimeError(
|
| 428 |
+
"An image must be set with .set_image(...) to generate an embedding."
|
| 429 |
+
)
|
| 430 |
+
assert (
|
| 431 |
+
self._features is not None
|
| 432 |
+
), "Features must exist if an image has been set."
|
| 433 |
+
return self._features["image_embed"]
|
| 434 |
+
|
| 435 |
+
@property
|
| 436 |
+
def device(self) -> torch.device:
|
| 437 |
+
return self.model.device
|
| 438 |
+
|
| 439 |
+
def reset_predictor(self) -> None:
|
| 440 |
+
"""
|
| 441 |
+
Resets the image embeddings and other state variables.
|
| 442 |
+
"""
|
| 443 |
+
self._is_image_set = False
|
| 444 |
+
self._features = None
|
| 445 |
+
self._orig_hw = None
|
| 446 |
+
self._is_batch = False
|