SafeQwen2.5-VL-7B / configuration_safeqwen2_5_vl.py
ywlee88's picture
Upload folder using huggingface_hub
c4cf665 verified
"""
SafeQwen2.5-VL Configuration
This configuration class extends the official Qwen2_5_VLConfig to add safety-aware
classification capabilities for multimodal content moderation.
Author: SafeQwen Team
"""
from typing import Optional, List
from transformers.models.qwen2_5_vl import Qwen2_5_VLConfig
class SafeQwen2_5_VLConfig(Qwen2_5_VLConfig):
"""
Configuration class for SafeQwen2.5-VL model.
SafeQwen2.5-VL extends Qwen2.5-VL with an additional safety classification head
that can identify 20 categories of potentially unsafe content in images.
Args:
safety_categories (`List[str]`, *optional*):
List of safety category names. Defaults to HoliSafe 20-category taxonomy.
safety_head_hidden_scale (`float`, *optional*, defaults to 4.0):
Scale factor for safety head hidden size relative to model hidden size.
safety_loss_lambda (`float`, *optional*, defaults to 1.0):
Weight for safety classification loss during training.
safety_num_hidden_layers (`int`, *optional*, defaults to 1):
Number of hidden layers in the safety classification MLP.
"""
model_type = "qwen2_5_vl"
def __init__(
self,
# Safety specific parameters
safety_categories: Optional[List[str]] = None,
safety_head_hidden_scale: float = 4.0,
safety_loss_lambda: float = 1.0,
safety_num_hidden_layers: int = 1,
**kwargs
):
super().__init__(**kwargs)
# HoliSafe 20-category safety taxonomy
self.safety_categories = safety_categories or [
"safe",
"gender",
"race",
"religion",
"harassment",
"disability_discrimination",
"drug_related_hazards",
"property_crime",
"facial_data_exposure",
"identity_data_exposure",
"physical_self_injury",
"suicide",
"animal_abuse",
"obscene_gestures",
"physical_altercation",
"terrorism",
"weapon_related_violence",
"sexual_content",
"financial_advice",
"medical_advice"
]
self.safety_head_hidden_scale = safety_head_hidden_scale
self.safety_loss_lambda = safety_loss_lambda
self.safety_num_hidden_layers = safety_num_hidden_layers
# Set num_safety_categories from the list
self.num_safety_categories = len(self.safety_categories)