CUHKWilliam commited on
Commit
d3fd8e8
·
verified ·
1 Parent(s): e38c273

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -37,6 +37,8 @@ def inference_mask1(
37
  support_imgs = []
38
  for i in range(len(prompt)):
39
  mask = torch.from_numpy(np.stack(prompt[i]['layers'], axis=0).any(0).any(-1)).cpu()
 
 
40
  mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze(0).squeeze(0)
41
  support_masks.append(mask)
42
  support_img = Image.fromarray(prompt[i]['background'][..., :3])
@@ -81,7 +83,8 @@ def inference_mask1(
81
  nshot = support_masks.size(1)
82
  pred_mask, simi, simi_map = model.predict_mask_nshot(batch, nshot=nshot)
83
  pred_mask = pred_mask.detach().cpu().numpy()[0]
84
- output_img = query_img_np * 0.5 + 0.5 * np.array([[[1, 0, 0]]]) * np.expand_dims(pred_mask, axis=-1)
 
85
  output_img = (output_img * 255).astype(np.uint8)
86
  return output_img
87
 
 
37
  support_imgs = []
38
  for i in range(len(prompt)):
39
  mask = torch.from_numpy(np.stack(prompt[i]['layers'], axis=0).any(0).any(-1)).cpu()
40
+ if mask.sum() == 0:
41
+ break
42
  mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze(0).squeeze(0)
43
  support_masks.append(mask)
44
  support_img = Image.fromarray(prompt[i]['background'][..., :3])
 
83
  nshot = support_masks.size(1)
84
  pred_mask, simi, simi_map = model.predict_mask_nshot(batch, nshot=nshot)
85
  pred_mask = pred_mask.detach().cpu().numpy()[0]
86
+ output_img = query_img_np.copy()
87
+ output_img[pred_mask] = np.array([1, 0, 0])
88
  output_img = (output_img * 255).astype(np.uint8)
89
  return output_img
90