CUHKWilliam commited on
Commit
8fe3c03
·
verified ·
1 Parent(s): e803103

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -38,6 +38,8 @@ def inference_mask1(
38
  for i in range(len(prompt)):
39
  mask = torch.from_numpy(np.stack(prompt[i]['layers'], axis=0).any(0).any(-1)).cpu()
40
  mask[mask > 0] = 1
 
 
41
  if mask.sum() == 0:
42
  break
43
  mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze(0).squeeze(0)
@@ -45,7 +47,6 @@ def inference_mask1(
45
  support_img = Image.fromarray(prompt[i]['background'][..., :3])
46
  support_img = transformation(support_img)
47
  support_imgs.append(support_img)
48
- return (support_imgs[0].detach().cpu().numpy() * 255).astype(np.uint8)
49
  model = DCAMA('resnet50', 'resnet50_a1h-35c100f8.pth', True)
50
  model.eval()
51
  model.cpu()
 
38
  for i in range(len(prompt)):
39
  mask = torch.from_numpy(np.stack(prompt[i]['layers'], axis=0).any(0).any(-1)).cpu()
40
  mask[mask > 0] = 1
41
+ return (mask.detach().cpu().numpy() * 255).astype(np.uint8)
42
+
43
  if mask.sum() == 0:
44
  break
45
  mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze(0).squeeze(0)
 
47
  support_img = Image.fromarray(prompt[i]['background'][..., :3])
48
  support_img = transformation(support_img)
49
  support_imgs.append(support_img)
 
50
  model = DCAMA('resnet50', 'resnet50_a1h-35c100f8.pth', True)
51
  model.eval()
52
  model.cpu()