| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from utils import NestedTensor, nested_tensor_from_tensor_list | |
| from backbone import build_backbone | |
| from transformer import build_transformer | |
| class Caption(nn.Module): | |
| def __init__(self, backbone, transformer, hidden_dim, vocab_size): | |
| super().__init__() | |
| self.backbone = backbone | |
| self.input_proj = nn.Conv2d( | |
| backbone.num_channels, hidden_dim, kernel_size=1) | |
| self.transformer = transformer | |
| self.mlp = MLP(hidden_dim, 512, vocab_size, 3) | |
| def forward(self, samples, target, target_mask): | |
| if not isinstance(samples, NestedTensor): | |
| samples = nested_tensor_from_tensor_list(samples) | |
| features, pos = self.backbone(samples) | |
| src, mask = features[-1].decompose() | |
| assert mask is not None | |
| hs = self.transformer(self.input_proj(src), mask, | |
| pos[-1], target, target_mask) | |
| out = self.mlp(hs.permute(1, 0, 2)) | |
| return out | |
| class MLP(nn.Module): | |
| """ Very simple multi-layer perceptron (also called FFN)""" | |
| def __init__(self, input_dim, hidden_dim, output_dim, num_layers): | |
| super().__init__() | |
| self.num_layers = num_layers | |
| h = [hidden_dim] * (num_layers - 1) | |
| self.layers = nn.ModuleList(nn.Linear(n, k) | |
| for n, k in zip([input_dim] + h, h + [output_dim])) | |
| def forward(self, x): | |
| for i, layer in enumerate(self.layers): | |
| x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) | |
| return x | |
| def build_model(config): | |
| backbone = build_backbone(config) | |
| transformer = build_transformer(config) | |
| model = Caption(backbone, transformer, config.hidden_dim, config.vocab_size) | |
| criterion = torch.nn.CrossEntropyLoss() | |
| return model, criterion |