|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torchvision.models.video as models
|
|
|
|
|
|
class TimeSformerBlock(nn.Module):
|
|
|
def __init__(self, dim, num_heads, num_frames):
|
|
|
super().__init__()
|
|
|
self.norm1 = nn.LayerNorm(dim)
|
|
|
self.attn_time = nn.MultiheadAttention(dim, num_heads, batch_first=True)
|
|
|
self.norm2 = nn.LayerNorm(dim)
|
|
|
self.attn_space = nn.MultiheadAttention(dim, num_heads, batch_first=True)
|
|
|
self.norm3 = nn.LayerNorm(dim)
|
|
|
self.mlp = nn.Sequential(
|
|
|
nn.Linear(dim, dim * 4),
|
|
|
nn.GELU(),
|
|
|
nn.Linear(dim * 4, dim)
|
|
|
)
|
|
|
self.num_frames = num_frames
|
|
|
|
|
|
def forward(self, x):
|
|
|
B, TP, D = x.shape
|
|
|
T = self.num_frames
|
|
|
P = TP // T
|
|
|
|
|
|
|
|
|
xt = x.view(B, T, P, D).permute(0, 2, 1, 3).reshape(B * P, T, D)
|
|
|
xt_res = xt
|
|
|
xt = self.norm1(xt)
|
|
|
xt, _ = self.attn_time(xt, xt, xt)
|
|
|
xt = xt + xt_res
|
|
|
x = xt.view(B, P, T, D).permute(0, 2, 1, 3).reshape(B, TP, D)
|
|
|
|
|
|
|
|
|
xs = x.view(B, T, P, D).reshape(B * T, P, D)
|
|
|
xs_res = xs
|
|
|
xs = self.norm2(xs)
|
|
|
xs, _ = self.attn_space(xs, xs, xs)
|
|
|
xs = xs + xs_res
|
|
|
x = xs.view(B, T, P, D).reshape(B, TP, D)
|
|
|
|
|
|
x = x + self.mlp(self.norm3(x))
|
|
|
return x
|
|
|
|
|
|
class FeatureFusionNetwork(nn.Module):
|
|
|
def __init__(self):
|
|
|
super(FeatureFusionNetwork, self).__init__()
|
|
|
|
|
|
|
|
|
self.cnn = models.r3d_18(weights=None)
|
|
|
self.cnn.fc = nn.Identity()
|
|
|
|
|
|
|
|
|
self.patch_size = 16
|
|
|
self.embed_dim = 256
|
|
|
self.img_size = 112
|
|
|
self.num_patches = (self.img_size // self.patch_size) ** 2
|
|
|
self.num_frames = 16
|
|
|
|
|
|
self.patch_embed = nn.Conv2d(3, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size)
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_frames * self.num_patches + 1, self.embed_dim))
|
|
|
|
|
|
self.transformer_layer = TimeSformerBlock(self.embed_dim, num_heads=4, num_frames=self.num_frames)
|
|
|
|
|
|
self.fusion_fc = nn.Sequential(
|
|
|
nn.Linear(512 + self.embed_dim, 256),
|
|
|
nn.ReLU(),
|
|
|
nn.Dropout(0.5),
|
|
|
nn.Linear(256, 2)
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
cnn_feat = self.cnn(x)
|
|
|
|
|
|
|
|
|
b, c, t, h, w = x.shape
|
|
|
x_uv = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
|
|
patches = self.patch_embed(x_uv).flatten(2).transpose(1, 2)
|
|
|
patches = patches.reshape(b, t * self.num_patches, self.embed_dim)
|
|
|
|
|
|
cls_tokens = self.cls_token.expand(b, -1, -1)
|
|
|
x_trans = torch.cat((cls_tokens, patches), dim=1)
|
|
|
x_trans = x_trans + self.pos_embed[:, :x_trans.size(1), :]
|
|
|
|
|
|
patch_tokens = x_trans[:, 1:, :]
|
|
|
out_patches = self.transformer_layer(patch_tokens)
|
|
|
trans_feat = out_patches.mean(dim=1)
|
|
|
|
|
|
combined = torch.cat((cnn_feat, trans_feat), dim=1)
|
|
|
out = self.fusion_fc(combined)
|
|
|
return out
|
|
|
|