File size: 4,686 Bytes
1457867 cb2a584 1457867 cb2a584 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformers import (
AutoTokenizer,
AutoModel,
PreTrainedModel,
PretrainedConfig
)
from torchvision.models import resnet50
from typing import Optional, Dict, Any
class JapaneseCLIPConfig(PretrainedConfig):
"""Japanese CLIP モデル設定クラス"""
model_type = "japanese-clip"
def __init__(
self,
text_model_name="cl-tohoku/bert-base-japanese-v3",
image_embed_dim=512,
text_embed_dim=512,
temperature=0.07,
**kwargs
):
super().__init__(**kwargs)
self.text_model_name = text_model_name
self.image_embed_dim = image_embed_dim
self.text_embed_dim = text_embed_dim
self.temperature = temperature
class JapaneseCLIPModel(PreTrainedModel):
"""Hugging Face互換のJapaneseCLIPモデル"""
config_class = JapaneseCLIPConfig
def __init__(self, config):
super().__init__(config)
# torchvisionのインポートを内部で行う
try:
from torchvision.models import resnet50
except ImportError:
raise ImportError("torchvision is required for this model. Install it with: pip install torchvision")
# 画像エンコーダ(ResNet50ベース)
self.image_encoder = resnet50(pretrained=True)
self.image_encoder.fc = nn.Linear(
self.image_encoder.fc.in_features,
config.image_embed_dim
)
# テキストエンコーダ(日本語BERT)
self.text_encoder = AutoModel.from_pretrained(config.text_model_name)
# プロジェクション層
self.text_projection = nn.Linear(
self.text_encoder.config.hidden_size,
config.text_embed_dim
)
self.image_projection = nn.Linear(
config.image_embed_dim,
config.text_embed_dim
)
# 正規化層
self.image_norm = nn.LayerNorm(config.text_embed_dim)
self.text_norm = nn.LayerNorm(config.text_embed_dim)
# 温度パラメータ
self.temperature = nn.Parameter(
torch.ones([]) * np.log(1 / config.temperature)
)
def encode_image(self, pixel_values):
"""画像をエンコード"""
image_features = self.image_encoder(pixel_values)
image_features = self.image_projection(image_features)
image_features = self.image_norm(image_features)
return F.normalize(image_features, dim=-1)
def encode_text(self, input_ids, attention_mask):
"""テキストをエンコード"""
text_outputs = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask
)
text_features = text_outputs.last_hidden_state[:, 0, :]
text_features = self.text_projection(text_features)
text_features = self.text_norm(text_features)
return F.normalize(text_features, dim=-1)
def get_image_features(self, pixel_values):
"""画像特徴量を取得"""
return self.encode_image(pixel_values)
def get_text_features(self, input_ids, attention_mask):
"""テキスト特徴量を取得"""
return self.encode_text(input_ids, attention_mask)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs
) -> Dict[str, torch.Tensor]:
"""順伝播"""
outputs = {}
if pixel_values is not None:
outputs['image_features'] = self.encode_image(pixel_values)
if input_ids is not None and attention_mask is not None:
outputs['text_features'] = self.encode_text(input_ids, attention_mask)
if 'image_features' in outputs and 'text_features' in outputs:
# 類似度計算
similarity = torch.matmul(
outputs['image_features'],
outputs['text_features'].T
)
temperature = self.temperature.exp()
outputs['logits_per_image'] = similarity * temperature
outputs['logits_per_text'] = outputs['logits_per_image'].T
outputs['temperature'] = temperature
return outputs
# AutoModelにカスタムモデルを登録
from transformers import AutoConfig, AutoModel
AutoConfig.register("japanese-clip", JapaneseCLIPConfig)
AutoModel.register(JapaneseCLIPConfig, JapaneseCLIPModel)
|