nraffa commited on
Commit
65d5ae1
·
1 Parent(s): 15865f6

flatten class added

Browse files
Files changed (1) hide show
  1. model.py +8 -0
model.py CHANGED
@@ -23,6 +23,14 @@ def create_vggface2_model(num_classes:int=2,
23
  layer_list = list(model_pred.children())[-5:] # all final layers
24
  model_pred = nn.Sequential(*list(model_pred.children())[:-5])
25
 
 
 
 
 
 
 
 
 
26
  for param in model_pred.parameters():
27
  param.requires_grad = False
28
 
 
23
  layer_list = list(model_pred.children())[-5:] # all final layers
24
  model_pred = nn.Sequential(*list(model_pred.children())[:-5])
25
 
26
+ class Flatten(nn.Module):
27
+ def __init__(self):
28
+ super(Flatten, self).__init__()
29
+
30
+ def forward(self, x):
31
+ x = x.view(x.size(0), -1)
32
+ return x
33
+
34
  for param in model_pred.parameters():
35
  param.requires_grad = False
36