Spaces:
Runtime error
Runtime error
File size: 2,482 Bytes
2d57b43 15865f6 2d57b43 65d5ae1 3b661d3 65d5ae1 3b661d3 65d5ae1 2d57b43 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import torch
import torchvision
from torchvision import transforms
from torch import nn
from facenet_pytorch import InceptionResnetV1
def create_vggface2_model(num_classes:int=2,
seed:int=42):
"""Creates an InceptionResnetV1 - Vggface2 model and transforms.
Args:
num_classes (int, optional): number of classes in the classifier head.
Defaults to 2.
seed (int, optional): random seed value. Defaults to 42.
Returns:
model (torch.nn.Module): vggface2 feature extractor model.
transforms (torchvision.transforms): vggface2 image transforms.
"""
# load the saved model
model_pred = InceptionResnetV1(pretrained='vggface2' , classify = True , num_classes = 2)
layer_list = list(model_pred.children())[-5:] # all final layers
model_pred = nn.Sequential(*list(model_pred.children())[:-5])
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
x = x.view(x.size(0), -1)
return x
for param in model_pred.parameters():
param.requires_grad = False
# Recreate the classifier layer and seed it to the target device
model_pred.classifier = torch.nn.Sequential(
torch.nn.AdaptiveAvgPool2d(output_size=1),
torch.nn.Dropout(p=0.6, inplace=False),
Flatten(),
torch.nn.Linear(in_features=1792,
out_features=512,
bias=False),
torch.nn.BatchNorm1d(512,
eps=0.001,
momentum=0.1,
affine=True,
track_running_stats=True),
torch.nn.Linear(in_features=512,
out_features=2, # same number of output units as our number of classes
bias=True))
# Write transform for image
data_transform = transforms.Compose([
# Resize the images to 64x64 --> RECOMENDATION FROM TRAINING FROM FACENET --> 160x160
transforms.Resize(size=(160, 160)),
# Flip the images randomly on the horizontal
transforms.RandomHorizontalFlip(p=0.5), # p = probability of flip, 0.5 = 50% chance
# Turn the image into a torch.Tensor
transforms.ToTensor() # this also converts all pixel values from 0 to 255 to be between 0.0 and 1.0
])
return model_pred, data_transform
|