liveness_detection / model.py
nraffa's picture
fix identation
3b661d3
raw
history blame
2.48 kB
import torch
import torchvision
from torchvision import transforms
from torch import nn
from facenet_pytorch import InceptionResnetV1
def create_vggface2_model(num_classes:int=2,
seed:int=42):
"""Creates an InceptionResnetV1 - Vggface2 model and transforms.
Args:
num_classes (int, optional): number of classes in the classifier head.
Defaults to 2.
seed (int, optional): random seed value. Defaults to 42.
Returns:
model (torch.nn.Module): vggface2 feature extractor model.
transforms (torchvision.transforms): vggface2 image transforms.
"""
# load the saved model
model_pred = InceptionResnetV1(pretrained='vggface2' , classify = True , num_classes = 2)
layer_list = list(model_pred.children())[-5:] # all final layers
model_pred = nn.Sequential(*list(model_pred.children())[:-5])
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
x = x.view(x.size(0), -1)
return x
for param in model_pred.parameters():
param.requires_grad = False
# Recreate the classifier layer and seed it to the target device
model_pred.classifier = torch.nn.Sequential(
torch.nn.AdaptiveAvgPool2d(output_size=1),
torch.nn.Dropout(p=0.6, inplace=False),
Flatten(),
torch.nn.Linear(in_features=1792,
out_features=512,
bias=False),
torch.nn.BatchNorm1d(512,
eps=0.001,
momentum=0.1,
affine=True,
track_running_stats=True),
torch.nn.Linear(in_features=512,
out_features=2, # same number of output units as our number of classes
bias=True))
# Write transform for image
data_transform = transforms.Compose([
# Resize the images to 64x64 --> RECOMENDATION FROM TRAINING FROM FACENET --> 160x160
transforms.Resize(size=(160, 160)),
# Flip the images randomly on the horizontal
transforms.RandomHorizontalFlip(p=0.5), # p = probability of flip, 0.5 = 50% chance
# Turn the image into a torch.Tensor
transforms.ToTensor() # this also converts all pixel values from 0 to 255 to be between 0.0 and 1.0
])
return model_pred, data_transform