illume_plus-qwen2_5-7b-hf / configuration_dualvitok.py
huangrh9's picture
Upload folder using huggingface_hub
16ac099 verified
""" DualViTok model configuration """
import os
from typing import List, Union
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from .configuration_movqgan import MoVQConfig
logger = logging.get_logger(__name__)
class SemanticEncoderConfig(PretrainedConfig):
model_type = "DualViTokSemanticEncoder"
def __init__(
self,
pretrained_semantic_encoder='Emova-ollm/qwen2vit600m',
z_channels=32,
num_blocks=4,
embed_dim=1280,
out_layer='linear',
target_mlp='norm',
**kwargs
):
super().__init__(**kwargs)
self.pretrained_semantic_encoder = pretrained_semantic_encoder
self.z_channels = z_channels
self.num_blocks = num_blocks
self.out_layer = out_layer
self.embed_dim = embed_dim
self.target_mlp = target_mlp
class SemanticDecoderConfig(PretrainedConfig):
model_type = "DualViTokSemanticDecoder"
def __init__(
self,
z_channels=32,
num_blocks=4,
embed_dim=1280,
out_layer='linear_norm',
out_channels=3584,
**kwargs
):
super().__init__(**kwargs)
self.z_channels = z_channels
self.num_blocks = num_blocks
self.embed_dim = embed_dim
self.out_layer = out_layer
self.out_channels = out_channels
class DualViTokConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`DualViTok`]. It is used to instantiate an video movq
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a configuration to the VQ model presented in paper.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
semantic_encoder (`dict`, *optional*):
Configuration dictionary for the semantic encoder. If `None`, defaults to `SemanticEncoderConfig()`.
The provided dictionary will be unpacked to initialize a `SemanticEncoderConfig` instance.
semantic_decoder (`dict`, *optional*):
Configuration dictionary for the semantic decoder. If `None`, defaults to `SemanticEncoderConfig()`.
The provided dictionary will be unpacked to initialize a `SemanticEncoderConfig` instance (note: uses `SemanticEncoderConfig` as per current implementation).
pixel_encoder (`dict`, *optional*):
Configuration dictionary for the pixel pathway's VQ-VAE model (e.g., `MoVQConfig`). If `None`, defaults to `MoVQConfig()`.
The provided dictionary will be unpacked to initialize a `MoVQConfig` instance, which defines both encoder and decoder for pixel-level features.
semantic_quantizer_type (`str`, *optional*, defaults to `'simvq'`):
Type of the quantizer for semantic tokens (e.g., `'simvq'`, `'ema_simvq'`).
pixel_quantizer_type (`str`, *optional*, defaults to `'simvq'`):
Type of the quantizer for pixel tokens (e.g., `'simvq'`, `'ema_simvq'`).
semantic_quantizer_codebook_size (`int`, *optional*, defaults to 32768):
Number of entries in the codebook for the semantic quantizer.
pixel_quantizer_codebook_size (`int`, *optional*, defaults to 98304):
Number of entries in the codebook for the pixel quantizer.
attn_implementation (`str`, *optional*, defaults to `'sdpa'`):
The attention implementation to use (e.g., `'sdpa'`, `'flash_attention_2'`, `'eager'`).
Can be `'sdpa'` (scaled dot product attention), `'flash_attention_2'` (if available and installed),
or `'eager'` (the default PyTorch attention implementation).
```python
>>> from transformers import DualViTok, DualViTokConfig
>>> # Initializing a video VQ model of configuration
>>> configuration = DualViTokConfig()
>>> # Initializing a model from the VQ model style configuration
>>> model = DualViTok(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "DualViTok"
def __init__(
self,
semantic_encoder=None,
semantic_decoder=None,
pixel_encoder=None,
pixel_decoder=None,
semantic_quantizer_type='simvq',
pixel_quantizer_type='simvq',
semantic_quantizer_codebook_size=32768,
pixel_quantizer_codebook_size=98304,
attn_implementation='sdpa',
**kwargs,
):
super().__init__(**kwargs)
if semantic_encoder is None:
self.semantic_encoder = SemanticEncoderConfig()
else:
self.semantic_encoder = SemanticEncoderConfig(**semantic_encoder)
if semantic_decoder is None:
self.semantic_decoder = SemanticDecoderConfig()
else:
self.semantic_decoder = SemanticDecoderConfig(**semantic_decoder)
self.semantic_quantizer_type = semantic_quantizer_type
self.pixel_quantizer_type = pixel_quantizer_type
self.semantic_quantizer_codebook_size = semantic_quantizer_codebook_size
self.pixel_quantizer_codebook_size = pixel_quantizer_codebook_size
if pixel_encoder is None:
self.pixel_encoder = MoVQConfig()
else:
self.pixel_encoder = MoVQConfig(**pixel_encoder)
self.pixel_decoder = self.pixel_encoder if pixel_decoder is None else MoVQConfig(**pixel_decoder)
self.attn_implementation = attn_implementation
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], attn_implementation='sdpa', **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, attn_implementation=attn_implementation, **kwargs)