faisalishfaq2005 commited on
Commit
c0ddffe
·
1 Parent(s): b8dafec

updated model.py

Browse files
Files changed (1) hide show
  1. model.py +17 -15
model.py CHANGED
@@ -1,7 +1,6 @@
1
  import torch
2
- import torchvision
3
  import torch.nn as nn
4
- import torch.optim as optim
5
  import math
6
 
7
  class ImprovedEfficientBackbone(nn.Module):
@@ -11,8 +10,7 @@ class ImprovedEfficientBackbone(nn.Module):
11
  self.features = self.efficientnet.features
12
 
13
  def forward(self, x):
14
- return self.features(x)
15
-
16
 
17
  class ImprovedPatchEmbedding(nn.Module):
18
  def __init__(self, in_channels=1280, embed_dim=384):
@@ -25,7 +23,7 @@ class ImprovedPatchEmbedding(nn.Module):
25
  Output: [B, 49, 384]
26
  """
27
  B, C, H, W = x.shape
28
- x = x.flatten(2).transpose(1, 2)
29
  x = self.proj(x)
30
  return x
31
 
@@ -41,23 +39,24 @@ class ImprovedViTBlock(nn.Module):
41
  nn.GELU(),
42
  nn.Linear(embed_dim * mlp_ratio, embed_dim)
43
  )
44
- self.dropout = nn.Dropout(0.1)
45
 
46
  def forward(self, x):
47
  x = x + self.dropout(self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0])
48
  x = x + self.dropout(self.mlp(self.norm2(x)))
49
  return x
50
 
51
-
52
-
53
  class ImprovedEfficientViT(nn.Module):
54
- def __init__(self, embed_dim=384, depth=8, num_heads=4):
55
  super().__init__()
56
  self.backbone = ImprovedEfficientBackbone()
57
  self.patch_embed = ImprovedPatchEmbedding(embed_dim=embed_dim)
58
 
59
  self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
60
- self.register_buffer("pos_embed", self._get_sinusoidal_encoding(50, embed_dim)) # Use sin/cos
 
 
 
61
 
62
  self.blocks = nn.ModuleList([ImprovedViTBlock(embed_dim, num_heads) for _ in range(depth)])
63
 
@@ -65,6 +64,7 @@ class ImprovedEfficientViT(nn.Module):
65
  nn.LayerNorm(embed_dim),
66
  nn.Linear(embed_dim, 128),
67
  nn.GELU(),
 
68
  nn.Linear(128, 1)
69
  )
70
 
@@ -82,16 +82,18 @@ class ImprovedEfficientViT(nn.Module):
82
  return pe.unsqueeze(0)
83
 
84
  def forward(self, x):
85
- features = self.backbone(x)
86
- tokens = self.patch_embed(features)
 
87
 
88
  B = tokens.shape[0]
89
- cls_tokens = self.cls_token.expand(B, -1, -1)
90
  x = torch.cat((cls_tokens, tokens), dim=1)
91
  x = x + self.pos_embed[:, :x.size(1), :]
 
92
 
93
  for block in self.blocks:
94
  x = block(x)
95
 
96
- cls_final = x[:, 0]
97
- return self.head(cls_final)
 
1
  import torch
 
2
  import torch.nn as nn
3
+ import torchvision
4
  import math
5
 
6
  class ImprovedEfficientBackbone(nn.Module):
 
10
  self.features = self.efficientnet.features
11
 
12
  def forward(self, x):
13
+ return self.features(x)
 
14
 
15
  class ImprovedPatchEmbedding(nn.Module):
16
  def __init__(self, in_channels=1280, embed_dim=384):
 
23
  Output: [B, 49, 384]
24
  """
25
  B, C, H, W = x.shape
26
+ x = x.flatten(2).transpose(1, 2)
27
  x = self.proj(x)
28
  return x
29
 
 
39
  nn.GELU(),
40
  nn.Linear(embed_dim * mlp_ratio, embed_dim)
41
  )
42
+ self.dropout = nn.Dropout(0.2)
43
 
44
  def forward(self, x):
45
  x = x + self.dropout(self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0])
46
  x = x + self.dropout(self.mlp(self.norm2(x)))
47
  return x
48
 
 
 
49
  class ImprovedEfficientViT(nn.Module):
50
+ def __init__(self, embed_dim=384, depth=6, num_heads=4):
51
  super().__init__()
52
  self.backbone = ImprovedEfficientBackbone()
53
  self.patch_embed = ImprovedPatchEmbedding(embed_dim=embed_dim)
54
 
55
  self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
56
+ self.register_buffer("pos_embed", self._get_sinusoidal_encoding(50, embed_dim))
57
+
58
+ self.patch_dropout = nn.Dropout(0.2)
59
+ self.pos_dropout = nn.Dropout(0.2)
60
 
61
  self.blocks = nn.ModuleList([ImprovedViTBlock(embed_dim, num_heads) for _ in range(depth)])
62
 
 
64
  nn.LayerNorm(embed_dim),
65
  nn.Linear(embed_dim, 128),
66
  nn.GELU(),
67
+ nn.Dropout(0.3),
68
  nn.Linear(128, 1)
69
  )
70
 
 
82
  return pe.unsqueeze(0)
83
 
84
  def forward(self, x):
85
+ features = self.backbone(x)
86
+ tokens = self.patch_embed(features)
87
+ tokens = self.patch_dropout(tokens)
88
 
89
  B = tokens.shape[0]
90
+ cls_tokens = self.cls_token.expand(B, -1, -1)
91
  x = torch.cat((cls_tokens, tokens), dim=1)
92
  x = x + self.pos_embed[:, :x.size(1), :]
93
+ x = self.pos_dropout(x)
94
 
95
  for block in self.blocks:
96
  x = block(x)
97
 
98
+ cls_final = x[:, 0]
99
+ return self.head(cls_final)