|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""DiffusionVL model implementation.""" |
|
|
|
|
|
import math |
|
|
from dataclasses import dataclass |
|
|
from typing import Callable, Dict, List, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from transformers import PreTrainedModel |
|
|
from transformers.activations import ACT2FN |
|
|
from transformers.cache_utils import Cache, DynamicCache |
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput |
|
|
from transformers.utils import logging |
|
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS |
|
|
from transformers.modeling_layers import GradientCheckpointingLayer |
|
|
|
|
|
from .configuration_diffusionvl_qwen2_5_vl import DiffusionVL_Qwen2_5_VL_Config, DiffusionVL_Qwen2_5_VL_VisionConfig |
|
|
|
|
|
IMAGE_TOKEN_INDEX = -200 |
|
|
|
|
|
def rotate_half(x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Rotates half the hidden dims of the input for rotary position embedding. |
|
|
|
|
|
Args: |
|
|
x: Input tensor of shape (..., head_dim). |
|
|
|
|
|
Returns: |
|
|
Rotated tensor of the same shape. |
|
|
""" |
|
|
x1 = x[..., : x.shape[-1] // 2] |
|
|
x2 = x[..., x.shape[-1] // 2 :] |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
|
|
|
def apply_rotary_pos_emb_vision( |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
cos: torch.Tensor, |
|
|
sin: torch.Tensor, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Apply rotary position embedding for vision encoder. |
|
|
|
|
|
Args: |
|
|
q: Query tensor. |
|
|
k: Key tensor. |
|
|
cos: Cosine part of rotary embedding. |
|
|
sin: Sine part of rotary embedding. |
|
|
|
|
|
Returns: |
|
|
Tuple of (rotated_q, rotated_k). |
|
|
""" |
|
|
orig_q_dtype = q.dtype |
|
|
orig_k_dtype = k.dtype |
|
|
q, k = q.float(), k.float() |
|
|
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() |
|
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
|
return q_embed.to(orig_q_dtype), k_embed.to(orig_k_dtype) |
|
|
|
|
|
|
|
|
def apply_multimodal_rotary_pos_emb( |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
cos: torch.Tensor, |
|
|
sin: torch.Tensor, |
|
|
mrope_section: List[int], |
|
|
unsqueeze_dim: int = 1, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Apply multimodal rotary position embedding (M-RoPE) for 3D position encoding. |
|
|
|
|
|
Args: |
|
|
q: Query tensor of shape (batch, heads, seq_len, head_dim). |
|
|
k: Key tensor of shape (batch, heads, seq_len, head_dim). |
|
|
cos: Cosine tensor of shape (3, batch, seq_len, head_dim). |
|
|
sin: Sine tensor of shape (3, batch, seq_len, head_dim). |
|
|
mrope_section: List of 3 ints defining section sizes [temporal, height, width]. |
|
|
For example, [16, 24, 24] for head_dim=128. |
|
|
unsqueeze_dim: Dimension to unsqueeze for broadcasting. |
|
|
|
|
|
Returns: |
|
|
Tuple of (rotated_q, rotated_k) with M-RoPE applied. |
|
|
""" |
|
|
|
|
|
|
|
|
mrope_section = mrope_section * 2 |
|
|
|
|
|
|
|
|
|
|
|
cos = torch.cat( |
|
|
[m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1 |
|
|
).unsqueeze(unsqueeze_dim) |
|
|
sin = torch.cat( |
|
|
[m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1 |
|
|
).unsqueeze(unsqueeze_dim) |
|
|
|
|
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
|
return q_embed, k_embed |
|
|
|
|
|
|
|
|
class DiffusionVL_Qwen2_5_VL_RMSNorm(nn.Module): |
|
|
def __init__(self, hidden_size, eps=1e-6): |
|
|
super().__init__() |
|
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
|
self.variance_epsilon = eps |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
input_dtype = hidden_states.dtype |
|
|
hidden_states = hidden_states.to(torch.float32) |
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True) |
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
|
|
return self.weight * hidden_states.to(input_dtype) |
|
|
|
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
|
""" |
|
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
|
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
|
|
""" |
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
|
if n_rep == 1: |
|
|
return hidden_states |
|
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
|
|
|
def eager_attention_forward( |
|
|
module: nn.Module, |
|
|
query: torch.Tensor, |
|
|
key: torch.Tensor, |
|
|
value: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor], |
|
|
scaling: float, |
|
|
dropout: float = 0.0, |
|
|
**kwargs, |
|
|
): |
|
|
"""Eager attention implementation.""" |
|
|
key_states = repeat_kv(key, module.num_key_value_groups) |
|
|
value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
|
|
|
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
|
|
if attention_mask is not None: |
|
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
|
|
attn_weights = attn_weights + causal_mask |
|
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
|
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) |
|
|
attn_output = torch.matmul(attn_weights, value_states) |
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
|
|
|
return attn_output, attn_weights |
|
|
|
|
|
|
|
|
class DiffusionVL_Qwen2_5_VL_VisionMLP(nn.Module): |
|
|
def __init__(self, config, bias: bool = False): |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
self.intermediate_size = config.intermediate_size |
|
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) |
|
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) |
|
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) |
|
|
self.act_fn = ACT2FN[config.hidden_act] |
|
|
|
|
|
def forward(self, hidden_state): |
|
|
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) |
|
|
|
|
|
|
|
|
class DiffusionVL_Qwen2_5_VL_VisionPatchEmbed(nn.Module): |
|
|
def __init__(self, patch_size=14, temporal_patch_size=2, in_channels=3, embed_dim=1152): |
|
|
super().__init__() |
|
|
self.patch_size = patch_size |
|
|
self.temporal_patch_size = temporal_patch_size |
|
|
self.in_channels = in_channels |
|
|
self.embed_dim = embed_dim |
|
|
kernel_size = [temporal_patch_size, patch_size, patch_size] |
|
|
self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
target_dtype = self.proj.weight.dtype |
|
|
hidden_states = hidden_states.view( |
|
|
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size |
|
|
) |
|
|
hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class DiffusionVL_Qwen2_5_VL_VisionRotaryEmbedding(nn.Module): |
|
|
inv_freq: torch.Tensor |
|
|
|
|
|
def __init__(self, dim: int, theta: float = 10000.0): |
|
|
super().__init__() |
|
|
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) |
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
|
|
|
def forward(self, seqlen: int) -> torch.Tensor: |
|
|
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) |
|
|
freqs = torch.outer(seq, self.inv_freq) |
|
|
return freqs |
|
|
|
|
|
|
|
|
class DiffusionVL_Qwen2_5_VL_VisionPatchMerger(nn.Module): |
|
|
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2): |
|
|
super().__init__() |
|
|
self.hidden_size = context_dim * (spatial_merge_size ** 2) |
|
|
self.ln_q = DiffusionVL_Qwen2_5_VL_RMSNorm(context_dim, eps=1e-6) |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(self.hidden_size, self.hidden_size), |
|
|
nn.GELU(), |
|
|
nn.Linear(self.hidden_size, dim), |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) |
|
|
return x |
|
|
|
|
|
|
|
|
class DiffusionVL_Qwen2_5_VL_VisionAttention(nn.Module): |
|
|
def __init__(self, config: DiffusionVL_Qwen2_5_VL_VisionConfig) -> None: |
|
|
super().__init__() |
|
|
self.dim = config.hidden_size |
|
|
self.num_heads = config.num_heads |
|
|
self.head_dim = self.dim // self.num_heads |
|
|
self.num_key_value_groups = 1 |
|
|
self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) |
|
|
self.proj = nn.Linear(self.dim, self.dim) |
|
|
self.scaling = self.head_dim**-0.5 |
|
|
self.config = config |
|
|
self.attention_dropout = 0.0 |
|
|
self.is_causal = False |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
cu_seqlens: torch.Tensor, |
|
|
rotary_pos_emb: Optional[torch.Tensor] = None, |
|
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
**kwargs, |
|
|
) -> torch.Tensor: |
|
|
seq_length = hidden_states.shape[0] |
|
|
query_states, key_states, value_states = ( |
|
|
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) |
|
|
) |
|
|
cos, sin = position_embeddings |
|
|
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) |
|
|
|
|
|
query_states = query_states.transpose(0, 1).unsqueeze(0) |
|
|
key_states = key_states.transpose(0, 1).unsqueeze(0) |
|
|
value_states = value_states.transpose(0, 1).unsqueeze(0) |
|
|
|
|
|
attention_interface: Callable = eager_attention_forward |
|
|
if getattr(self.config, "_attn_implementation", "eager") != "eager": |
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
|
|
|
if getattr(self.config, "_attn_implementation", "eager") == "flash_attention_2": |
|
|
|
|
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() |
|
|
attn_output, _ = attention_interface( |
|
|
self, |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attention_mask=None, |
|
|
scaling=self.scaling, |
|
|
dropout=0.0 if not self.training else self.attention_dropout, |
|
|
cu_seq_lens_q=cu_seqlens, |
|
|
cu_seq_lens_k=cu_seqlens, |
|
|
max_length_q=max_seqlen, |
|
|
max_length_k=max_seqlen, |
|
|
is_causal=False, |
|
|
**kwargs, |
|
|
) |
|
|
else: |
|
|
|
|
|
lengths = cu_seqlens[1:] - cu_seqlens[:-1] |
|
|
splits = [ |
|
|
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) |
|
|
] |
|
|
|
|
|
attn_outputs = [ |
|
|
attention_interface( |
|
|
self, |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
attention_mask=None, |
|
|
scaling=self.scaling, |
|
|
dropout=0.0 if not self.training else self.attention_dropout, |
|
|
is_causal=False, |
|
|
**kwargs, |
|
|
)[0] |
|
|
for q, k, v in zip(*splits) |
|
|
] |
|
|
attn_output = torch.cat(attn_outputs, dim=1) |
|
|
|
|
|
attn_output = attn_output.reshape(seq_length, -1).contiguous() |
|
|
attn_output = self.proj(attn_output) |
|
|
return attn_output |
|
|
|
|
|
|
|
|
class DiffusionVL_Qwen2_5_VL_VisionBlock(GradientCheckpointingLayer): |
|
|
def __init__(self, config, attn_implementation: str = "sdpa") -> None: |
|
|
super().__init__() |
|
|
self.norm1 = DiffusionVL_Qwen2_5_VL_RMSNorm(config.hidden_size, eps=1e-6) |
|
|
self.norm2 = DiffusionVL_Qwen2_5_VL_RMSNorm(config.hidden_size, eps=1e-6) |
|
|
self.attn = DiffusionVL_Qwen2_5_VL_VisionAttention(config=config) |
|
|
self.mlp = DiffusionVL_Qwen2_5_VL_VisionMLP(config, bias=True) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
cu_seqlens: torch.Tensor, |
|
|
rotary_pos_emb: Optional[torch.Tensor] = None, |
|
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
**kwargs, |
|
|
) -> torch.Tensor: |
|
|
hidden_states = hidden_states + self.attn( |
|
|
self.norm1(hidden_states), |
|
|
cu_seqlens=cu_seqlens, |
|
|
rotary_pos_emb=rotary_pos_emb, |
|
|
position_embeddings=position_embeddings, |
|
|
**kwargs, |
|
|
) |
|
|
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class DiffusionVL_Qwen2_5_VL_VisionPreTrainedModel(PreTrainedModel): |
|
|
config_class = DiffusionVL_Qwen2_5_VL_VisionConfig |
|
|
base_model_prefix = "model" |
|
|
supports_gradient_checkpointing = True |
|
|
_no_split_modules = ["DiffusionVL_Qwen2_5_VL_VisionBlock"] |
|
|
_supports_flash_attn_2 = True |
|
|
_supports_sdpa = True |
|
|
_supports_attention_backend = True |
|
|
|
|
|
|
|
|
class DiffusionVL_Qwen2_5_VL_VisionTransformer(DiffusionVL_Qwen2_5_VL_VisionPreTrainedModel): |
|
|
config_class = DiffusionVL_Qwen2_5_VL_VisionConfig |
|
|
_no_split_modules = ["DiffusionVL_Qwen2_5_VL_VisionBlock"] |
|
|
|
|
|
def __init__(self, config: DiffusionVL_Qwen2_5_VL_VisionConfig, *inputs, **kwargs) -> None: |
|
|
super().__init__(config, *inputs, **kwargs) |
|
|
self.spatial_merge_size = config.spatial_merge_size |
|
|
self.patch_size = config.patch_size |
|
|
self.fullatt_block_indexes = config.fullatt_block_indexes |
|
|
self.window_size = config.window_size |
|
|
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size |
|
|
|
|
|
self.patch_embed = DiffusionVL_Qwen2_5_VL_VisionPatchEmbed( |
|
|
patch_size=config.patch_size, |
|
|
temporal_patch_size=config.temporal_patch_size, |
|
|
in_channels=config.in_channels, |
|
|
embed_dim=config.hidden_size, |
|
|
) |
|
|
|
|
|
head_dim = config.hidden_size // config.num_heads |
|
|
self.rotary_pos_emb = DiffusionVL_Qwen2_5_VL_VisionRotaryEmbedding(head_dim // 2) |
|
|
|
|
|
self.blocks = nn.ModuleList([DiffusionVL_Qwen2_5_VL_VisionBlock(config) for _ in range(config.depth)]) |
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
pos_ids = [] |
|
|
for t, h, w in grid_thw: |
|
|
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) |
|
|
hpos_ids = hpos_ids.reshape( |
|
|
h // self.spatial_merge_size, |
|
|
self.spatial_merge_size, |
|
|
w // self.spatial_merge_size, |
|
|
self.spatial_merge_size, |
|
|
) |
|
|
hpos_ids = hpos_ids.permute(0, 2, 1, 3).flatten() |
|
|
|
|
|
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) |
|
|
wpos_ids = wpos_ids.reshape( |
|
|
h // self.spatial_merge_size, |
|
|
self.spatial_merge_size, |
|
|
w // self.spatial_merge_size, |
|
|
self.spatial_merge_size, |
|
|
) |
|
|
wpos_ids = wpos_ids.permute(0, 2, 1, 3).flatten() |
|
|
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) |
|
|
pos_ids = torch.cat(pos_ids, dim=0) |
|
|
max_grid_size = grid_thw[:, 1:].max() |
|
|
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) |
|
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) |
|
|
return rotary_pos_emb |
|
|
|
|
|
def get_window_index(self, grid_thw: torch.Tensor): |
|
|
|
|
|
window_index: list = [] |
|
|
cu_window_seqlens: list = [0] |
|
|
window_index_id = 0 |
|
|
vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size |
|
|
|
|
|
for grid_t, grid_h, grid_w in grid_thw: |
|
|
llm_grid_h = grid_h // self.spatial_merge_size |
|
|
llm_grid_w = grid_w // self.spatial_merge_size |
|
|
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) |
|
|
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size |
|
|
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size |
|
|
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size |
|
|
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size |
|
|
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) |
|
|
index_padded = index_padded.reshape( |
|
|
grid_t, |
|
|
num_windows_h, |
|
|
vit_merger_window_size, |
|
|
num_windows_w, |
|
|
vit_merger_window_size, |
|
|
) |
|
|
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( |
|
|
grid_t, |
|
|
num_windows_h * num_windows_w, |
|
|
vit_merger_window_size, |
|
|
vit_merger_window_size, |
|
|
) |
|
|
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) |
|
|
index_padded = index_padded.reshape(-1) |
|
|
index_new = index_padded[index_padded != -100] |
|
|
window_index.append(index_new + window_index_id) |
|
|
cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] |
|
|
cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) |
|
|
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() |
|
|
window_index = torch.cat(window_index, dim=0) |
|
|
return window_index, cu_window_seqlens |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs): |
|
|
|
|
|
hidden_states = self.patch_embed(hidden_states) |
|
|
rotary_pos_emb = self.rot_pos_emb(grid_thw) |
|
|
window_index, cu_window_seqlens = self.get_window_index(grid_thw) |
|
|
cu_window_seqlens = torch.tensor( |
|
|
cu_window_seqlens, |
|
|
device=hidden_states.device, |
|
|
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, |
|
|
) |
|
|
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) |
|
|
|
|
|
seq_len, _ = hidden_states.size() |
|
|
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) |
|
|
hidden_states = hidden_states[window_index, :, :] |
|
|
hidden_states = hidden_states.reshape(seq_len, -1) |
|
|
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) |
|
|
rotary_pos_emb = rotary_pos_emb[window_index, :, :] |
|
|
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) |
|
|
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) |
|
|
position_embeddings = (emb.cos(), emb.sin()) |
|
|
|
|
|
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( |
|
|
dim=0, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, |
|
|
) |
|
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) |
|
|
|
|
|
for layer_num, blk in enumerate(self.blocks): |
|
|
if layer_num in self.fullatt_block_indexes: |
|
|
cu_seqlens_now = cu_seqlens |
|
|
else: |
|
|
cu_seqlens_now = cu_window_seqlens |
|
|
|
|
|
hidden_states = blk( |
|
|
hidden_states, |
|
|
cu_seqlens=cu_seqlens_now, |
|
|
position_embeddings=position_embeddings, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
return hidden_states, window_index |
|
|
|
|
|
|
|
|
class DiffusionVL_Qwen2_5_VL_VisionTower(nn.Module): |
|
|
|
|
|
def __init__(self, config: DiffusionVL_Qwen2_5_VL_VisionConfig): |
|
|
super().__init__() |
|
|
self.vision_tower = DiffusionVL_Qwen2_5_VL_VisionTransformer(config) |
|
|
self.spatial_merge_size = config.spatial_merge_size |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor = None): |
|
|
"""Returns (hidden_states, window_index) tuple for MMProjector.""" |
|
|
return self.vision_tower(hidden_states, grid_thw) |
|
|
|
|
|
|
|
|
class DiffusionVL_Qwen2_5_VL_MMProjector(nn.Module): |
|
|
|
|
|
def __init__(self, config: DiffusionVL_Qwen2_5_VL_VisionConfig): |
|
|
super().__init__() |
|
|
self.merger = DiffusionVL_Qwen2_5_VL_VisionPatchMerger( |
|
|
dim=config.out_hidden_size, |
|
|
context_dim=config.hidden_size, |
|
|
spatial_merge_size=config.spatial_merge_size, |
|
|
) |
|
|
|
|
|
def forward(self, features_tuple): |
|
|
"""Forward pass with merger and window index reversal.""" |
|
|
if isinstance(features_tuple, tuple): |
|
|
hidden_states, window_index = features_tuple |
|
|
|
|
|
projected_features = self.merger(hidden_states) |
|
|
|
|
|
reverse_indices = torch.argsort(window_index) |
|
|
final_features = projected_features[reverse_indices, :] |
|
|
return final_features |
|
|
else: |
|
|
|
|
|
return self.merger(features_tuple) |
|
|
|
|
|
class DiffusionVL_Qwen2_5_VL_RotaryEmbedding(nn.Module): |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
dim = config.hidden_size // config.num_attention_heads |
|
|
inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) |
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
|
|
|
def forward(self, x, position_ids): |
|
|
""" |
|
|
Args: |
|
|
x: Input tensor for dtype reference |
|
|
position_ids: Position IDs with shape (3, batch_size, seq_length) for M-RoPE |
|
|
or (batch_size, seq_length) for standard RoPE (will be converted to 3D) |
|
|
|
|
|
Returns: |
|
|
cos, sin: Tensors of shape (3, batch, seq_len, head_dim) for M-RoPE |
|
|
""" |
|
|
|
|
|
if position_ids.ndim == 2: |
|
|
|
|
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) |
|
|
|
|
|
|
|
|
if position_ids.ndim == 3 and position_ids.shape[0] == 3: |
|
|
|
|
|
|
|
|
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand( |
|
|
3, position_ids.shape[1], -1, 1 |
|
|
) |
|
|
|
|
|
position_ids_expanded = position_ids[:, :, None, :].float() |
|
|
|
|
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
|
|
with torch.autocast(device_type=device_type, enabled=False): |
|
|
|
|
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) |
|
|
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
cos = emb.cos() |
|
|
sin = emb.sin() |
|
|
|
|
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
else: |
|
|
|
|
|
inv_freq_expanded = self.inv_freq[None, :, None].expand(position_ids.shape[0], -1, 1) |
|
|
position_ids_expanded = position_ids[:, None, :].float() |
|
|
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) |
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
cos = emb.cos() |
|
|
sin = emb.sin() |
|
|
return cos.to(x.dtype), sin.to(x.dtype) |
|
|
|
|
|
|
|
|
class DiffusionVL_Qwen2_5_VL_MLP(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
self.intermediate_size = config.intermediate_size |
|
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
|
|
self.act_fn = nn.SiLU() |
|
|
|
|
|
def forward(self, x): |
|
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
|
|
|
|
class DiffusionVL_Qwen2_5_VL_Attention(nn.Module): |
|
|
"""Non-causal attention for diffusion-based generation with KV-cache support.""" |
|
|
|
|
|
def __init__(self, config, layer_idx): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.layer_idx = layer_idx |
|
|
self.hidden_size = config.hidden_size |
|
|
self.num_heads = config.num_attention_heads |
|
|
self.head_dim = self.hidden_size // self.num_heads |
|
|
self.num_key_value_heads = config.num_key_value_heads |
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
|
|
self.scaling = self.head_dim ** -0.5 |
|
|
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) |
|
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) |
|
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) |
|
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) |
|
|
|
|
|
|
|
|
self.is_causal = False |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states, |
|
|
attention_mask=None, |
|
|
position_ids=None, |
|
|
past_key_values=None, |
|
|
output_attentions=False, |
|
|
use_cache=False, |
|
|
cache_position=None, |
|
|
position_embeddings=None, |
|
|
store_kv=False, |
|
|
**kwargs, |
|
|
): |
|
|
bsz, q_len, _ = hidden_states.size() |
|
|
|
|
|
query_states = self.q_proj(hidden_states) |
|
|
key_states = self.k_proj(hidden_states) |
|
|
value_states = self.v_proj(hidden_states) |
|
|
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
if position_embeddings is not None: |
|
|
cos, sin = position_embeddings |
|
|
query_states, key_states = apply_multimodal_rotary_pos_emb( |
|
|
query_states, key_states, cos, sin, |
|
|
self.config.rope_scaling.get("mrope_section", [16, 24, 24]) |
|
|
) |
|
|
|
|
|
|
|
|
if past_key_values is not None and use_cache: |
|
|
cache_kwargs = {"cache_position": cache_position} |
|
|
if store_kv: |
|
|
|
|
|
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
else: |
|
|
|
|
|
cached_key = past_key_values.key_cache[self.layer_idx] if self.layer_idx < len(past_key_values.key_cache) else None |
|
|
cached_value = past_key_values.value_cache[self.layer_idx] if self.layer_idx < len(past_key_values.value_cache) else None |
|
|
if cached_key is not None and cached_value is not None: |
|
|
key_states = torch.cat([cached_key, key_states], dim=2) |
|
|
value_states = torch.cat([cached_value, value_states], dim=2) |
|
|
|
|
|
|
|
|
key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1) |
|
|
value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1) |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
if isinstance(attention_mask, dict): |
|
|
|
|
|
attn_mask = attention_mask.get("full_attention", None) |
|
|
else: |
|
|
attn_mask = attention_mask |
|
|
else: |
|
|
attn_mask = None |
|
|
|
|
|
if attn_mask is not None: |
|
|
attn_output = F.scaled_dot_product_attention( |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attn_mask=attn_mask, |
|
|
dropout_p=0.0, |
|
|
is_causal=False, |
|
|
scale=self.scaling, |
|
|
) |
|
|
else: |
|
|
attn_output = F.scaled_dot_product_attention( |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
dropout_p=0.0, |
|
|
is_causal=False, |
|
|
scale=self.scaling, |
|
|
) |
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, -1) |
|
|
attn_output = self.o_proj(attn_output) |
|
|
|
|
|
return attn_output, None |
|
|
|
|
|
|
|
|
class DiffusionVL_Qwen2_5_VL_DecoderLayer(nn.Module): |
|
|
def __init__(self, config, layer_idx): |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
self.self_attn = DiffusionVL_Qwen2_5_VL_Attention(config, layer_idx) |
|
|
self.mlp = DiffusionVL_Qwen2_5_VL_MLP(config) |
|
|
self.input_layernorm = DiffusionVL_Qwen2_5_VL_RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.post_attention_layernorm = DiffusionVL_Qwen2_5_VL_RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states, |
|
|
attention_mask=None, |
|
|
position_ids=None, |
|
|
past_key_values=None, |
|
|
output_attentions=False, |
|
|
use_cache=False, |
|
|
cache_position=None, |
|
|
position_embeddings=None, |
|
|
store_kv=False, |
|
|
**kwargs, |
|
|
): |
|
|
residual = hidden_states |
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
hidden_states, attn_weights = self.self_attn( |
|
|
hidden_states=hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
output_attentions=output_attentions, |
|
|
use_cache=use_cache, |
|
|
cache_position=cache_position, |
|
|
position_embeddings=position_embeddings, |
|
|
store_kv=store_kv, |
|
|
**kwargs, |
|
|
) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
hidden_states = self.mlp(hidden_states) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
return hidden_states, attn_weights |
|
|
|
|
|
class DiffusionVL_Qwen2_5_VL_PreTrainedModel(PreTrainedModel): |
|
|
|
|
|
config_class = DiffusionVL_Qwen2_5_VL_Config |
|
|
base_model_prefix = "model" |
|
|
supports_gradient_checkpointing = True |
|
|
_no_split_modules = ["DiffusionVL_Qwen2_5_VL_DecoderLayer", "DiffusionVL_Qwen2_5_VL_VisionBlock"] |
|
|
|
|
|
def _init_weights(self, module: nn.Module) -> None: |
|
|
"""Initialize the weights.""" |
|
|
std = self.config.initializer_range |
|
|
if isinstance(module, nn.Linear): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
|
|
|
|
|
|
class DiffusionVL_Qwen2_5_VL_Model(DiffusionVL_Qwen2_5_VL_PreTrainedModel): |
|
|
|
|
|
def __init__(self, config: DiffusionVL_Qwen2_5_VL_Config): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.vision_tower = DiffusionVL_Qwen2_5_VL_VisionTower(config.vision_config) |
|
|
self.mm_projector = DiffusionVL_Qwen2_5_VL_MMProjector(config.vision_config) |
|
|
|
|
|
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
|
|
self.layers = nn.ModuleList([ |
|
|
DiffusionVL_Qwen2_5_VL_DecoderLayer(config, layer_idx) |
|
|
for layer_idx in range(config.num_hidden_layers) |
|
|
]) |
|
|
self.norm = DiffusionVL_Qwen2_5_VL_RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.rotary_emb = DiffusionVL_Qwen2_5_VL_RotaryEmbedding(config) |
|
|
|
|
|
|
|
|
self.bd3lm_block_size = config.bd3lm_block_size |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.embed_tokens |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.embed_tokens = value |
|
|
|
|
|
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): |
|
|
""" |
|
|
Encodes images into continuous embeddings through vision tower and mm_projector. |
|
|
|
|
|
Args: |
|
|
pixel_values: Image tensor |
|
|
image_grid_thw: Grid dimensions (temporal, height, width) for each image |
|
|
|
|
|
Returns: |
|
|
Image embeddings ready to be merged with text embeddings |
|
|
""" |
|
|
pixel_values = pixel_values.to(dtype=self.vision_tower.vision_tower.patch_embed.proj.weight.dtype) |
|
|
hidden_states = self.vision_tower(pixel_values, image_grid_thw) |
|
|
image_embeds = self.mm_projector(hidden_states) |
|
|
return image_embeds |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids=None, |
|
|
attention_mask=None, |
|
|
position_ids=None, |
|
|
past_key_values=None, |
|
|
inputs_embeds=None, |
|
|
use_cache=None, |
|
|
output_attentions=None, |
|
|
output_hidden_states=None, |
|
|
return_dict=None, |
|
|
cache_position=None, |
|
|
store_kv=False, |
|
|
pixel_values=None, |
|
|
image_grid_thw=None, |
|
|
**kwargs, |
|
|
): |
|
|
"""Forward pass with optional vision input processing.""" |
|
|
output_attentions = output_attentions or False |
|
|
output_hidden_states = output_hidden_states or False |
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
return_dict = return_dict if return_dict is not None else True |
|
|
|
|
|
IMAGE_TOKEN_INDEX = -200 |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
|
|
if pixel_values is not None and image_grid_thw is not None: |
|
|
|
|
|
image_features = self.get_image_features(pixel_values, image_grid_thw) |
|
|
|
|
|
|
|
|
spatial_merge_size = self.vision_tower.spatial_merge_size |
|
|
split_sizes = (image_grid_thw.prod(dim=1) // (spatial_merge_size ** 2)).tolist() |
|
|
image_features_list = list(torch.split(image_features, split_sizes)) |
|
|
|
|
|
|
|
|
batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] |
|
|
new_inputs_embeds_list = [] |
|
|
|
|
|
for batch_idx in range(batch_size): |
|
|
cur_input_ids = input_ids[batch_idx] if input_ids is not None else None |
|
|
cur_embeds = inputs_embeds[batch_idx] |
|
|
|
|
|
if cur_input_ids is None or (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: |
|
|
new_inputs_embeds_list.append(cur_embeds) |
|
|
continue |
|
|
|
|
|
|
|
|
image_positions = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() |
|
|
image_token_indices = [-1] + image_positions + [len(cur_input_ids)] |
|
|
|
|
|
|
|
|
cur_new_embeds = [] |
|
|
cur_image_idx = 0 |
|
|
|
|
|
for i in range(len(image_token_indices) - 1): |
|
|
start = image_token_indices[i] + 1 |
|
|
end = image_token_indices[i + 1] |
|
|
|
|
|
|
|
|
if start < end: |
|
|
cur_new_embeds.append(cur_embeds[start:end]) |
|
|
|
|
|
|
|
|
if i < len(image_positions) and cur_image_idx < len(image_features_list): |
|
|
cur_new_embeds.append(image_features_list[cur_image_idx].to(cur_embeds.dtype)) |
|
|
cur_image_idx += 1 |
|
|
|
|
|
if cur_new_embeds: |
|
|
new_inputs_embeds_list.append(torch.cat(cur_new_embeds, dim=0)) |
|
|
else: |
|
|
new_inputs_embeds_list.append(cur_embeds) |
|
|
|
|
|
|
|
|
max_len = max(x.shape[0] for x in new_inputs_embeds_list) |
|
|
hidden_size = new_inputs_embeds_list[0].shape[-1] |
|
|
inputs_embeds = torch.zeros( |
|
|
batch_size, max_len, hidden_size, |
|
|
dtype=new_inputs_embeds_list[0].dtype, |
|
|
device=new_inputs_embeds_list[0].device |
|
|
) |
|
|
for i, embed in enumerate(new_inputs_embeds_list): |
|
|
inputs_embeds[i, :embed.shape[0]] = embed |
|
|
|
|
|
batch_size, seq_length = inputs_embeds.shape[:2] |
|
|
|
|
|
if cache_position is None: |
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
|
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device) |
|
|
|
|
|
if position_ids is None: |
|
|
|
|
|
position_ids = cache_position.unsqueeze(0) |
|
|
|
|
|
|
|
|
position_embeddings = self.rotary_emb(inputs_embeds, position_ids) |
|
|
|
|
|
hidden_states = inputs_embeds |
|
|
all_hidden_states = () if output_hidden_states else None |
|
|
all_attentions = () if output_attentions else None |
|
|
|
|
|
for layer in self.layers: |
|
|
if output_hidden_states: |
|
|
all_hidden_states += (hidden_states,) |
|
|
|
|
|
hidden_states, attn_weights = layer( |
|
|
hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
output_attentions=output_attentions, |
|
|
use_cache=use_cache, |
|
|
cache_position=cache_position, |
|
|
position_embeddings=position_embeddings, |
|
|
store_kv=store_kv, |
|
|
) |
|
|
|
|
|
if output_attentions: |
|
|
all_attentions += (attn_weights,) |
|
|
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
|
all_hidden_states += (hidden_states,) |
|
|
|
|
|
return BaseModelOutputWithPast( |
|
|
last_hidden_state=hidden_states, |
|
|
past_key_values=past_key_values, |
|
|
hidden_states=all_hidden_states, |
|
|
attentions=all_attentions, |
|
|
) |
|
|
|
|
|
|
|
|
class DiffusionVL_Qwen2_5_VL_ForConditionalGeneration(DiffusionVL_Qwen2_5_VL_PreTrainedModel): |
|
|
r""" |
|
|
DiffusionVL Model with a language modeling head for diffusion-based generation. |
|
|
|
|
|
This model uses block diffusion instead of autoregressive |
|
|
generation. The `generate()` method implements the diffusion denoising process. |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
|
|
def __init__(self, config: DiffusionVL_Qwen2_5_VL_Config): |
|
|
super().__init__(config) |
|
|
self.model = DiffusionVL_Qwen2_5_VL_Model(config) |
|
|
self.vocab_size = config.vocab_size |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
|
|
|
self.mask_token_id = config.mask_token_id |
|
|
self.block_size = config.bd3lm_block_size |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_model(self): |
|
|
return self.model |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.model.embed_tokens |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.model.embed_tokens = value |
|
|
|
|
|
def tie_weights(self): |
|
|
"""Tie weights if config.tie_word_embeddings is True (3B model).""" |
|
|
if getattr(self.config, "tie_word_embeddings", False): |
|
|
|
|
|
super().tie_weights() |
|
|
|
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.lm_head |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
self.lm_head = new_embeddings |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids=None, |
|
|
attention_mask=None, |
|
|
position_ids=None, |
|
|
past_key_values=None, |
|
|
inputs_embeds=None, |
|
|
labels=None, |
|
|
use_cache=None, |
|
|
output_attentions=None, |
|
|
output_hidden_states=None, |
|
|
return_dict=None, |
|
|
pixel_values=None, |
|
|
image_grid_thw=None, |
|
|
**kwargs, |
|
|
): |
|
|
return_dict = return_dict if return_dict is not None else True |
|
|
|
|
|
|
|
|
if pixel_values is not None and inputs_embeds is None: |
|
|
|
|
|
vision_features = self.model.vision_tower(pixel_values, image_grid_thw) |
|
|
inputs_embeds = self._merge_vision_text(input_ids, vision_features) |
|
|
input_ids = None |
|
|
|
|
|
outputs = self.model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=True, |
|
|
) |
|
|
|
|
|
hidden_states = outputs.last_hidden_state |
|
|
logits = self.lm_head(hidden_states) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
loss = F.cross_entropy( |
|
|
shift_logits.view(-1, self.vocab_size), |
|
|
shift_labels.view(-1), |
|
|
ignore_index=-100, |
|
|
) |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
def _merge_vision_text(self, input_ids, vision_features): |
|
|
"""Merge vision features with text embeddings.""" |
|
|
text_embeds = self.model.embed_tokens(input_ids) |
|
|
|
|
|
return text_embeds |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
self, |
|
|
inputs: Optional[torch.Tensor] = None, |
|
|
images: Optional[torch.Tensor] = None, |
|
|
image_sizes: Optional[torch.Tensor] = None, |
|
|
image_grid_thws: Optional[torch.Tensor] = None, |
|
|
modalities: Optional[List] = None, |
|
|
gen_length: int = 256, |
|
|
steps: int = 8, |
|
|
temperature: float = 0.0, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
Diffusion-based generation using BD3LM algorithm. |
|
|
|
|
|
Follows the same logic as DiffusionVLQwenVLForCausalLM.generate(): |
|
|
1. If images provided, call prepare_inputs_labels_for_multimodal |
|
|
2. Otherwise, just embed the input tokens |
|
|
3. Call generate_with_bd3lm |
|
|
|
|
|
Args: |
|
|
inputs: Input token IDs (prompt) [batch_size, seq_len] |
|
|
images: Image tensor (pixel_values) for vision inputs |
|
|
image_sizes: Image sizes |
|
|
image_grid_thws: Grid dimensions for vision inputs (num_images, 3) |
|
|
modalities: List of modalities (e.g., ["image"]) |
|
|
gen_length: Number of tokens to generate |
|
|
steps: Number of diffusion steps per block |
|
|
temperature: Sampling temperature (0 for greedy) |
|
|
|
|
|
Returns: |
|
|
Generated token IDs |
|
|
""" |
|
|
if modalities is None: |
|
|
modalities = ["image"] |
|
|
|
|
|
if images is not None: |
|
|
inputs_embeds = self.prepare_inputs_labels_for_multimodal( |
|
|
input_ids=inputs, |
|
|
images=images, |
|
|
image_grid_thws=image_grid_thws, |
|
|
) |
|
|
else: |
|
|
inputs_embeds = self.get_input_embeddings()(inputs) |
|
|
|
|
|
|
|
|
return self.generate_with_bd3lm( |
|
|
inputs_embeds=inputs_embeds, |
|
|
gen_length=gen_length, |
|
|
steps=steps, |
|
|
temperature=temperature, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
def prepare_inputs_labels_for_multimodal( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
images: torch.Tensor, |
|
|
image_grid_thws: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Prepare inputs_embeds by merging text embeddings with image features. |
|
|
|
|
|
Uses LLaVA format: IMAGE_TOKEN_INDEX (-200) as placeholder. |
|
|
|
|
|
Args: |
|
|
input_ids: Input token IDs with IMAGE_TOKEN_INDEX (-200) as image placeholders |
|
|
images: Pixel values tensor |
|
|
image_grid_thws: Grid dimensions for each image |
|
|
|
|
|
Returns: |
|
|
inputs_embeds: Merged text + image embeddings |
|
|
""" |
|
|
IMAGE_TOKEN_INDEX = -200 |
|
|
|
|
|
device = input_ids.device |
|
|
batch_size = input_ids.shape[0] |
|
|
|
|
|
|
|
|
if image_grid_thws is not None: |
|
|
if not isinstance(image_grid_thws, torch.Tensor): |
|
|
image_grid_thw = torch.tensor(image_grid_thws, device=device) |
|
|
else: |
|
|
image_grid_thw = image_grid_thws.to(device) |
|
|
else: |
|
|
raise ValueError("image_grid_thws is required for vision processing") |
|
|
|
|
|
|
|
|
image_features = self.model.get_image_features(images, image_grid_thw) |
|
|
|
|
|
|
|
|
spatial_merge_size = self.model.vision_tower.spatial_merge_size |
|
|
split_sizes = (image_grid_thw.prod(dim=1) // (spatial_merge_size ** 2)).tolist() |
|
|
image_features_list = list(torch.split(image_features, split_sizes)) |
|
|
|
|
|
|
|
|
new_input_embeds_list = [] |
|
|
|
|
|
for batch_idx in range(batch_size): |
|
|
cur_input_ids = input_ids[batch_idx] |
|
|
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum().item() |
|
|
|
|
|
if num_images == 0: |
|
|
|
|
|
cur_input_embeds = self.get_input_embeddings()(cur_input_ids) |
|
|
new_input_embeds_list.append(cur_input_embeds) |
|
|
continue |
|
|
|
|
|
|
|
|
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [len(cur_input_ids)] |
|
|
|
|
|
cur_input_ids_noim = [] |
|
|
for idx in range(len(image_token_indices) - 1): |
|
|
start = image_token_indices[idx] + 1 |
|
|
end = image_token_indices[idx + 1] |
|
|
if start < end: |
|
|
cur_input_ids_noim.append(cur_input_ids[start:end]) |
|
|
|
|
|
if cur_input_ids_noim: |
|
|
cur_input_embeds_noim = self.get_input_embeddings()(torch.cat(cur_input_ids_noim)) |
|
|
split_sizes_text = [x.shape[0] for x in cur_input_ids_noim] |
|
|
cur_input_embeds_noim_split = list(torch.split(cur_input_embeds_noim, split_sizes_text)) |
|
|
else: |
|
|
cur_input_embeds_noim_split = [] |
|
|
|
|
|
cur_new_input_embeds = [] |
|
|
cur_image_idx = 0 |
|
|
|
|
|
for idx in range(num_images + 1): |
|
|
if idx < len(cur_input_embeds_noim_split): |
|
|
cur_new_input_embeds.append(cur_input_embeds_noim_split[idx]) |
|
|
if idx < num_images and cur_image_idx < len(image_features_list): |
|
|
cur_image_features = image_features_list[cur_image_idx] |
|
|
target_dtype = cur_input_embeds_noim_split[0].dtype if cur_input_embeds_noim_split else images.dtype |
|
|
cur_new_input_embeds.append(cur_image_features.to(target_dtype)) |
|
|
cur_image_idx += 1 |
|
|
|
|
|
if cur_new_input_embeds: |
|
|
|
|
|
target_device = cur_new_input_embeds[0].device |
|
|
cur_new_input_embeds = [t.to(target_device) for t in cur_new_input_embeds] |
|
|
cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) |
|
|
else: |
|
|
cur_new_input_embeds = self.get_input_embeddings()(cur_input_ids) |
|
|
|
|
|
new_input_embeds_list.append(cur_new_input_embeds) |
|
|
|
|
|
|
|
|
max_len = max(x.shape[0] for x in new_input_embeds_list) |
|
|
hidden_size = new_input_embeds_list[0].shape[-1] |
|
|
dtype = new_input_embeds_list[0].dtype |
|
|
|
|
|
inputs_embeds = torch.zeros(batch_size, max_len, hidden_size, dtype=dtype, device=device) |
|
|
for i, embed in enumerate(new_input_embeds_list): |
|
|
inputs_embeds[i, :embed.shape[0]] = embed.to(device) |
|
|
|
|
|
return inputs_embeds |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_with_bd3lm( |
|
|
self, |
|
|
inputs_embeds: torch.FloatTensor, |
|
|
gen_length: int = 256, |
|
|
steps: int = 8, |
|
|
temperature: float = 0.0, |
|
|
top_k: int = 0, |
|
|
top_p: float = 1.0, |
|
|
remasking_strategy: str = 'low_confidence_static', |
|
|
use_kv_cache: bool = True, |
|
|
confidence_threshold: float = 0.85, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
BD3LM generation algorithm with KV-cache support. |
|
|
|
|
|
Args: |
|
|
inputs_embeds: Input embeddings (prompt) |
|
|
gen_length: Number of tokens to generate |
|
|
steps: Number of diffusion steps per block |
|
|
temperature: Sampling temperature (0 for greedy) |
|
|
top_k: Top-k sampling parameter |
|
|
top_p: Top-p (nucleus) sampling parameter |
|
|
remasking_strategy: 'low_confidence_static', 'low_confidence_dynamic', or 'sequential' |
|
|
use_kv_cache: Whether to use KV cache (default True) |
|
|
confidence_threshold: Threshold for low_confidence_dynamic strategy |
|
|
|
|
|
Returns: |
|
|
Generated token IDs |
|
|
""" |
|
|
device = inputs_embeds.device |
|
|
batch_size = inputs_embeds.shape[0] |
|
|
prompt_len = inputs_embeds.shape[1] |
|
|
block_size = self.block_size |
|
|
mask_id = self.mask_token_id |
|
|
|
|
|
|
|
|
num_blocks = (prompt_len + gen_length + block_size - 1) // block_size |
|
|
total_length = num_blocks * block_size |
|
|
|
|
|
|
|
|
x_ids = torch.full((batch_size, total_length), mask_id, dtype=torch.long, device=device) |
|
|
|
|
|
embed_layer = self.get_input_embeddings() |
|
|
mask_embed = embed_layer(torch.tensor([mask_id], device=embed_layer.weight.device)) |
|
|
mask_embed = mask_embed.to(device) |
|
|
x_embeds = mask_embed.repeat(batch_size, total_length, 1) |
|
|
x_embeds[:, :prompt_len] = inputs_embeds.clone() |
|
|
|
|
|
|
|
|
prompt_logits = self.lm_head(inputs_embeds) |
|
|
prompt_ids = torch.argmax(prompt_logits, dim=-1) |
|
|
x_ids[:, :prompt_len] = prompt_ids |
|
|
|
|
|
|
|
|
block_mask = torch.tril(torch.ones(num_blocks, num_blocks, device=device)).to(inputs_embeds.dtype) |
|
|
block_diffusion_mask_bool = block_mask.repeat_interleave(block_size, dim=0) \ |
|
|
.repeat_interleave(block_size, dim=1).unsqueeze(0) |
|
|
block_diffusion_mask = block_diffusion_mask_bool.unsqueeze(1) |
|
|
block_diffusion_mask = torch.where(block_diffusion_mask == 0., torch.full_like(block_diffusion_mask, float('-inf')), 0.) |
|
|
|
|
|
position_ids = torch.arange(total_length, device=device).unsqueeze(0).expand(batch_size, -1) |
|
|
|
|
|
|
|
|
prefill_blocks = prompt_len // block_size |
|
|
prefill_length = prefill_blocks * block_size |
|
|
|
|
|
past_key_values = DynamicCache() if use_kv_cache else None |
|
|
if use_kv_cache and prefill_length > 0: |
|
|
prefill_embeds = x_embeds[:, :prefill_length] |
|
|
prefill_mask = block_diffusion_mask[:, :, :prefill_length, :prefill_length] |
|
|
prefill_pos_ids = position_ids[:, :prefill_length] |
|
|
|
|
|
|
|
|
model_mask = {"full_attention": prefill_mask, "sliding_attention": prefill_mask} |
|
|
|
|
|
prefill_outputs = self.model( |
|
|
inputs_embeds=prefill_embeds, |
|
|
attention_mask=model_mask, |
|
|
position_ids=prefill_pos_ids, |
|
|
past_key_values=past_key_values, |
|
|
use_cache=True, |
|
|
store_kv=True |
|
|
) |
|
|
prefill_logits = self.lm_head(prefill_outputs.last_hidden_state).float() |
|
|
self.last_prefill_logits = prefill_logits[:, -1:, :].clone() |
|
|
past_key_values = prefill_outputs.past_key_values |
|
|
|
|
|
|
|
|
num_transfer_tokens = self._get_num_transfer_tokens(block_size, steps) |
|
|
eos_token_id = kwargs.get('eos_token_id', 151645) |
|
|
|
|
|
|
|
|
for block_idx in range(prefill_blocks, num_blocks): |
|
|
block_start = block_idx * block_size |
|
|
block_end = block_start + block_size |
|
|
|
|
|
cur_block_embeds = x_embeds[:, block_start:block_end].clone() |
|
|
cur_block_ids = x_ids[:, block_start:block_end] |
|
|
|
|
|
cur_mask = block_diffusion_mask[:, :, block_start:block_end, :block_end] |
|
|
cur_pos_ids = position_ids[:, block_start:block_end] |
|
|
|
|
|
|
|
|
model_mask = {"full_attention": cur_mask, "sliding_attention": cur_mask} |
|
|
|
|
|
|
|
|
for step in range(steps + 1): |
|
|
|
|
|
is_mask = torch.all(torch.abs(cur_block_embeds - mask_embed.to(cur_block_embeds.device)) < 1e-5, dim=-1) |
|
|
if not is_mask.any(): |
|
|
|
|
|
if use_kv_cache: |
|
|
_ = self.model( |
|
|
inputs_embeds=cur_block_embeds, |
|
|
attention_mask=model_mask, |
|
|
position_ids=cur_pos_ids, |
|
|
past_key_values=past_key_values, |
|
|
use_cache=True, |
|
|
store_kv=True |
|
|
) |
|
|
break |
|
|
|
|
|
|
|
|
if use_kv_cache: |
|
|
outputs = self.model( |
|
|
inputs_embeds=cur_block_embeds, |
|
|
attention_mask=model_mask, |
|
|
position_ids=cur_pos_ids, |
|
|
past_key_values=past_key_values, |
|
|
use_cache=True, |
|
|
store_kv=False |
|
|
) |
|
|
logits = self.lm_head(outputs.last_hidden_state).float() |
|
|
else: |
|
|
|
|
|
context_embeds = x_embeds[:, :block_end].clone() |
|
|
context_embeds[:, block_start:block_end] = cur_block_embeds |
|
|
context_mask = block_diffusion_mask[:, :, :block_end, :block_end] |
|
|
context_pos_ids = position_ids[:, :block_end] |
|
|
context_model_mask = {"full_attention": context_mask, "sliding_attention": context_mask} |
|
|
|
|
|
outputs = self.model( |
|
|
inputs_embeds=context_embeds, |
|
|
attention_mask=context_model_mask, |
|
|
position_ids=context_pos_ids, |
|
|
past_key_values=None, |
|
|
use_cache=False, |
|
|
store_kv=False |
|
|
) |
|
|
logits = self.lm_head(outputs.last_hidden_state[:, block_start:block_end]).float() |
|
|
|
|
|
|
|
|
x0, x0_p = self._sample_tokens(logits, temperature, top_k, top_p) |
|
|
|
|
|
|
|
|
num_to_transfer = num_transfer_tokens[step].item() |
|
|
|
|
|
|
|
|
target_device = x0.device |
|
|
is_mask = is_mask.to(target_device) |
|
|
x0_p = x0_p.to(target_device) |
|
|
|
|
|
transfer_mask = torch.zeros_like(x0, dtype=torch.bool) |
|
|
|
|
|
if remasking_strategy == 'sequential': |
|
|
for j in range(batch_size): |
|
|
if is_mask[j].any(): |
|
|
mask_positions = is_mask[j].nonzero(as_tuple=True)[0] |
|
|
num_to_select = min(num_to_transfer, len(mask_positions)) |
|
|
selected_positions = mask_positions[:num_to_select] |
|
|
transfer_mask[j, selected_positions] = True |
|
|
|
|
|
elif remasking_strategy == 'low_confidence_static': |
|
|
confidence = torch.where(is_mask, x0_p, torch.tensor(-torch.inf, device=target_device)) |
|
|
for j in range(batch_size): |
|
|
num_masks = is_mask[j].sum().item() |
|
|
k = min(num_to_transfer, num_masks) |
|
|
if k > 0 and not torch.all(torch.isinf(confidence[j])): |
|
|
_, idx = torch.topk(confidence[j], k) |
|
|
transfer_mask[j, idx] = True |
|
|
|
|
|
elif remasking_strategy == 'low_confidence_dynamic': |
|
|
confidence = torch.where(is_mask, x0_p, torch.tensor(-torch.inf, device=target_device)) |
|
|
for j in range(batch_size): |
|
|
high_conf_mask = confidence[j] > confidence_threshold |
|
|
num_high_confidence = high_conf_mask.sum().item() |
|
|
if num_high_confidence >= num_to_transfer: |
|
|
transfer_mask[j] = high_conf_mask |
|
|
else: |
|
|
num_masks = is_mask[j].sum().item() |
|
|
k = min(num_to_transfer, num_masks) |
|
|
if k > 0: |
|
|
_, idx = torch.topk(confidence[j], k) |
|
|
transfer_mask[j, idx] = True |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown remasking strategy: {remasking_strategy}") |
|
|
|
|
|
|
|
|
cur_block_ids = cur_block_ids.to(x0.device) |
|
|
cur_block_ids = torch.where(transfer_mask, x0, cur_block_ids) |
|
|
|
|
|
embed_layer = self.get_input_embeddings() |
|
|
x0_embeds = embed_layer(x0.to(embed_layer.weight.device)) |
|
|
cur_block_embeds = cur_block_embeds.to(x0_embeds.device) |
|
|
cur_block_embeds = torch.where(transfer_mask.unsqueeze(-1).to(x0_embeds.device), x0_embeds, cur_block_embeds) |
|
|
|
|
|
|
|
|
x_embeds[:, block_start:block_end] = cur_block_embeds.to(x_embeds.device) |
|
|
x_ids[:, block_start:block_end] = cur_block_ids.to(x_ids.device) |
|
|
|
|
|
|
|
|
if block_end > prompt_len: |
|
|
gen_start_in_block = max(prompt_len, block_start) |
|
|
gen_ids_check = x_ids[:, gen_start_in_block:block_end] |
|
|
if eos_token_id in gen_ids_check: |
|
|
break |
|
|
|
|
|
|
|
|
return x_ids[:, prompt_len:prompt_len + gen_length] |
|
|
|
|
|
def _sample_tokens(self, logits, temperature=0.0, top_k=0, top_p=1.0): |
|
|
"""Sample tokens with temperature, top-k, and top-p.""" |
|
|
batch_size = logits.shape[0] |
|
|
seq_len = logits.shape[1] |
|
|
vocab_size = logits.shape[-1] |
|
|
|
|
|
logits_2d = logits.reshape(-1, vocab_size) |
|
|
|
|
|
if temperature == 0: |
|
|
|
|
|
tokens = torch.argmax(logits_2d, dim=-1, keepdim=True) |
|
|
probs = F.softmax(logits_2d, dim=-1) |
|
|
token_probs = torch.gather(probs, -1, tokens) |
|
|
else: |
|
|
|
|
|
logits_scaled = logits_2d / temperature |
|
|
|
|
|
|
|
|
if top_k > 0: |
|
|
values, _ = torch.topk(logits_scaled, top_k) |
|
|
min_values = values[:, -1:] |
|
|
logits_scaled = torch.where(logits_scaled < min_values, float('-inf'), logits_scaled) |
|
|
|
|
|
|
|
|
if top_p < 1.0: |
|
|
sorted_logits, sorted_indices = torch.sort(logits_scaled, descending=True) |
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
sorted_mask = cumulative_probs > top_p |
|
|
sorted_mask[:, 1:] = sorted_mask[:, :-1].clone() |
|
|
sorted_mask[:, 0] = False |
|
|
mask_indices = torch.scatter( |
|
|
torch.zeros_like(logits_scaled, dtype=torch.bool), |
|
|
-1, sorted_indices, sorted_mask |
|
|
) |
|
|
logits_scaled = logits_scaled.masked_fill(mask_indices, float('-inf')) |
|
|
|
|
|
probs = F.softmax(logits_scaled, dim=-1) |
|
|
tokens = torch.multinomial(probs, num_samples=1) |
|
|
token_probs = torch.gather(probs, -1, tokens) |
|
|
|
|
|
return tokens.view(batch_size, seq_len), token_probs.view(batch_size, seq_len) |
|
|
|
|
|
def _get_num_transfer_tokens(self, block_length, steps): |
|
|
"""Calculate how many tokens to unmask at each step.""" |
|
|
if steps == 0: |
|
|
return torch.zeros(1, dtype=torch.int64) |
|
|
base = block_length // steps |
|
|
remainder = block_length % steps |
|
|
num_transfer = torch.zeros(steps + 1, dtype=torch.int64) + base |
|
|
num_transfer[:remainder] += 1 |
|
|
return num_transfer |
|
|
|
|
|
from transformers import AutoConfig, AutoModelForCausalLM |
|
|
|
|
|
AutoConfig.register("diffusionvl_qwen2_5_vl", DiffusionVL_Qwen2_5_VL_Config) |
|
|
AutoModelForCausalLM.register(DiffusionVL_Qwen2_5_VL_Config, DiffusionVL_Qwen2_5_VL_ForConditionalGeneration) |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"DiffusionVL_Qwen2_5_VL_Config", |
|
|
"DiffusionVL_Qwen2_5_VL_VisionConfig", |
|
|
"DiffusionVL_Qwen2_5_VL_PreTrainedModel", |
|
|
"DiffusionVL_Qwen2_5_VL_Model", |
|
|
"DiffusionVL_Qwen2_5_VL_ForConditionalGeneration", |
|
|
] |
|
|
|