import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from transformers import ( AutoTokenizer, AutoModel, PreTrainedModel, PretrainedConfig ) from torchvision.models import resnet50 from typing import Optional, Dict, Any class JapaneseCLIPConfig(PretrainedConfig): """Japanese CLIP モデル設定クラス""" model_type = "japanese-clip" def __init__( self, text_model_name="cl-tohoku/bert-base-japanese-v3", image_embed_dim=512, text_embed_dim=512, temperature=0.07, **kwargs ): super().__init__(**kwargs) self.text_model_name = text_model_name self.image_embed_dim = image_embed_dim self.text_embed_dim = text_embed_dim self.temperature = temperature class JapaneseCLIPModel(PreTrainedModel): """Hugging Face互換のJapaneseCLIPモデル""" config_class = JapaneseCLIPConfig def __init__(self, config): super().__init__(config) # torchvisionのインポートを内部で行う try: from torchvision.models import resnet50 except ImportError: raise ImportError("torchvision is required for this model. Install it with: pip install torchvision") # 画像エンコーダ(ResNet50ベース) self.image_encoder = resnet50(pretrained=True) self.image_encoder.fc = nn.Linear( self.image_encoder.fc.in_features, config.image_embed_dim ) # テキストエンコーダ(日本語BERT) self.text_encoder = AutoModel.from_pretrained(config.text_model_name) # プロジェクション層 self.text_projection = nn.Linear( self.text_encoder.config.hidden_size, config.text_embed_dim ) self.image_projection = nn.Linear( config.image_embed_dim, config.text_embed_dim ) # 正規化層 self.image_norm = nn.LayerNorm(config.text_embed_dim) self.text_norm = nn.LayerNorm(config.text_embed_dim) # 温度パラメータ self.temperature = nn.Parameter( torch.ones([]) * np.log(1 / config.temperature) ) def encode_image(self, pixel_values): """画像をエンコード""" image_features = self.image_encoder(pixel_values) image_features = self.image_projection(image_features) image_features = self.image_norm(image_features) return F.normalize(image_features, dim=-1) def encode_text(self, input_ids, attention_mask): """テキストをエンコード""" text_outputs = self.text_encoder( input_ids=input_ids, attention_mask=attention_mask ) text_features = text_outputs.last_hidden_state[:, 0, :] text_features = self.text_projection(text_features) text_features = self.text_norm(text_features) return F.normalize(text_features, dim=-1) def get_image_features(self, pixel_values): """画像特徴量を取得""" return self.encode_image(pixel_values) def get_text_features(self, input_ids, attention_mask): """テキスト特徴量を取得""" return self.encode_text(input_ids, attention_mask) def forward( self, pixel_values: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs ) -> Dict[str, torch.Tensor]: """順伝播""" outputs = {} if pixel_values is not None: outputs['image_features'] = self.encode_image(pixel_values) if input_ids is not None and attention_mask is not None: outputs['text_features'] = self.encode_text(input_ids, attention_mask) if 'image_features' in outputs and 'text_features' in outputs: # 類似度計算 similarity = torch.matmul( outputs['image_features'], outputs['text_features'].T ) temperature = self.temperature.exp() outputs['logits_per_image'] = similarity * temperature outputs['logits_per_text'] = outputs['logits_per_image'].T outputs['temperature'] = temperature return outputs # AutoModelにカスタムモデルを登録 from transformers import AutoConfig, AutoModel AutoConfig.register("japanese-clip", JapaneseCLIPConfig) AutoModel.register(JapaneseCLIPConfig, JapaneseCLIPModel)