CUHKWilliam commited on
Commit
7ec5dd7
·
verified ·
1 Parent(s): 9332003

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -2
app.py CHANGED
@@ -13,7 +13,6 @@ from torchvision import transforms
13
  from PIL import Image
14
  import torch.nn.functional as F
15
  import ipdb
16
- import ipdb
17
 
18
  img_mean = [0.485, 0.456, 0.406]
19
  img_std = [0.229, 0.224, 0.225]
@@ -86,7 +85,7 @@ def inference_mask1(
86
  pred_mask, simi, simi_map = model.predict_mask_nshot(batch, nshot=nshot)
87
  pred_mask = pred_mask.detach().cpu().numpy()[0]
88
  output_img = query_img_np.copy()
89
- output_img[pred_mask] = np.array([255, 0, 0])
90
  output_img = (output_img).astype(np.uint8)
91
  return output_img
92
 
 
13
  from PIL import Image
14
  import torch.nn.functional as F
15
  import ipdb
 
16
 
17
  img_mean = [0.485, 0.456, 0.406]
18
  img_std = [0.229, 0.224, 0.225]
 
85
  pred_mask, simi, simi_map = model.predict_mask_nshot(batch, nshot=nshot)
86
  pred_mask = pred_mask.detach().cpu().numpy()[0]
87
  output_img = query_img_np.copy()
88
+ output_img[pred_mask > 0] = np.array([255, 0, 0])
89
  output_img = (output_img).astype(np.uint8)
90
  return output_img
91