CUHKWilliam commited on
Commit
bdf0ffa
·
verified ·
1 Parent(s): 79d8b10

Update model/DCAMA.py

Browse files
Files changed (1) hide show
  1. model/DCAMA.py +3 -3
model/DCAMA.py CHANGED
@@ -146,7 +146,7 @@ class DCAMA(nn.Module):
146
  # n_support_feats = [n_support_feats[arg] for arg in args]
147
  n_simis = n_simi[args].max()
148
  else:
149
- n_simis = torch.tensor(0.).float().cuda()
150
  return logit_mask, n_simis
151
 
152
  def extract_feats(self, img):
@@ -453,11 +453,11 @@ class DCAMA(nn.Module):
453
  n_support_feats4 = []
454
  for idx in selected:
455
  n_support_feats4.append(n_support_feats[idx])
456
- support_masks = support_masks[:, torch.tensor(selected).long().cuda(), :, :]
457
  n_support_feats = n_support_feats4
458
  simi_map = None
459
  else:
460
- n_simis = torch.tensor(0.).float().cuda()
461
  simi_map = None
462
 
463
  logit_mask = self.model(query_feats, n_support_feats, support_masks.clone(), nshot)
 
146
  # n_support_feats = [n_support_feats[arg] for arg in args]
147
  n_simis = n_simi[args].max()
148
  else:
149
+ n_simis = torch.tensor(0.).float()
150
  return logit_mask, n_simis
151
 
152
  def extract_feats(self, img):
 
453
  n_support_feats4 = []
454
  for idx in selected:
455
  n_support_feats4.append(n_support_feats[idx])
456
+ support_masks = support_masks[:, torch.tensor(selected).long(), :, :]
457
  n_support_feats = n_support_feats4
458
  simi_map = None
459
  else:
460
+ n_simis = torch.tensor(0.).float()
461
  simi_map = None
462
 
463
  logit_mask = self.model(query_feats, n_support_feats, support_masks.clone(), nshot)