Spaces:
Runtime error
Runtime error
| import torch | |
| import torchvision | |
| from torchvision import transforms | |
| from torch import nn | |
| from facenet_pytorch import InceptionResnetV1 | |
| def create_vggface2_model(num_classes:int=2, | |
| seed:int=42): | |
| """Creates an InceptionResnetV1 - Vggface2 model and transforms. | |
| Args: | |
| num_classes (int, optional): number of classes in the classifier head. | |
| Defaults to 2. | |
| seed (int, optional): random seed value. Defaults to 42. | |
| Returns: | |
| model (torch.nn.Module): vggface2 feature extractor model. | |
| transforms (torchvision.transforms): vggface2 image transforms. | |
| """ | |
| # load the saved model | |
| model_pred = InceptionResnetV1(pretrained='vggface2' , classify = True , num_classes = 2) | |
| layer_list = list(model_pred.children())[-5:] # all final layers | |
| model_pred = nn.Sequential(*list(model_pred.children())[:-5]) | |
| class Flatten(nn.Module): | |
| def __init__(self): | |
| super(Flatten, self).__init__() | |
| def forward(self, x): | |
| x = x.view(x.size(0), -1) | |
| return x | |
| for param in model_pred.parameters(): | |
| param.requires_grad = False | |
| # Recreate the classifier layer and seed it to the target device | |
| model_pred.classifier = torch.nn.Sequential( | |
| torch.nn.AdaptiveAvgPool2d(output_size=1), | |
| torch.nn.Dropout(p=0.6, inplace=False), | |
| Flatten(), | |
| torch.nn.Linear(in_features=1792, | |
| out_features=512, | |
| bias=False), | |
| torch.nn.BatchNorm1d(512, | |
| eps=0.001, | |
| momentum=0.1, | |
| affine=True, | |
| track_running_stats=True), | |
| torch.nn.Linear(in_features=512, | |
| out_features=2, # same number of output units as our number of classes | |
| bias=True)) | |
| # Write transform for image | |
| data_transform = transforms.Compose([ | |
| # Resize the images to 64x64 --> RECOMENDATION FROM TRAINING FROM FACENET --> 160x160 | |
| transforms.Resize(size=(160, 160)), | |
| # Flip the images randomly on the horizontal | |
| transforms.RandomHorizontalFlip(p=0.5), # p = probability of flip, 0.5 = 50% chance | |
| # Turn the image into a torch.Tensor | |
| transforms.ToTensor() # this also converts all pixel values from 0 to 255 to be between 0.0 and 1.0 | |
| ]) | |
| return model_pred, data_transform | |