Spaces:
Sleeping
Sleeping
Update model/DCAMA.py
Browse files- model/DCAMA.py +3 -3
model/DCAMA.py
CHANGED
|
@@ -146,7 +146,7 @@ class DCAMA(nn.Module):
|
|
| 146 |
# n_support_feats = [n_support_feats[arg] for arg in args]
|
| 147 |
n_simis = n_simi[args].max()
|
| 148 |
else:
|
| 149 |
-
n_simis = torch.tensor(0.).float()
|
| 150 |
return logit_mask, n_simis
|
| 151 |
|
| 152 |
def extract_feats(self, img):
|
|
@@ -453,11 +453,11 @@ class DCAMA(nn.Module):
|
|
| 453 |
n_support_feats4 = []
|
| 454 |
for idx in selected:
|
| 455 |
n_support_feats4.append(n_support_feats[idx])
|
| 456 |
-
support_masks = support_masks[:, torch.tensor(selected).long()
|
| 457 |
n_support_feats = n_support_feats4
|
| 458 |
simi_map = None
|
| 459 |
else:
|
| 460 |
-
n_simis = torch.tensor(0.).float()
|
| 461 |
simi_map = None
|
| 462 |
|
| 463 |
logit_mask = self.model(query_feats, n_support_feats, support_masks.clone(), nshot)
|
|
|
|
| 146 |
# n_support_feats = [n_support_feats[arg] for arg in args]
|
| 147 |
n_simis = n_simi[args].max()
|
| 148 |
else:
|
| 149 |
+
n_simis = torch.tensor(0.).float()
|
| 150 |
return logit_mask, n_simis
|
| 151 |
|
| 152 |
def extract_feats(self, img):
|
|
|
|
| 453 |
n_support_feats4 = []
|
| 454 |
for idx in selected:
|
| 455 |
n_support_feats4.append(n_support_feats[idx])
|
| 456 |
+
support_masks = support_masks[:, torch.tensor(selected).long(), :, :]
|
| 457 |
n_support_feats = n_support_feats4
|
| 458 |
simi_map = None
|
| 459 |
else:
|
| 460 |
+
n_simis = torch.tensor(0.).float()
|
| 461 |
simi_map = None
|
| 462 |
|
| 463 |
logit_mask = self.model(query_feats, n_support_feats, support_masks.clone(), nshot)
|