Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -37,6 +37,8 @@ def inference_mask1(
|
|
| 37 |
support_imgs = []
|
| 38 |
for i in range(len(prompt)):
|
| 39 |
mask = torch.from_numpy(np.stack(prompt[i]['layers'], axis=0).any(0).any(-1)).cpu()
|
|
|
|
|
|
|
| 40 |
mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze(0).squeeze(0)
|
| 41 |
support_masks.append(mask)
|
| 42 |
support_img = Image.fromarray(prompt[i]['background'][..., :3])
|
|
@@ -81,7 +83,8 @@ def inference_mask1(
|
|
| 81 |
nshot = support_masks.size(1)
|
| 82 |
pred_mask, simi, simi_map = model.predict_mask_nshot(batch, nshot=nshot)
|
| 83 |
pred_mask = pred_mask.detach().cpu().numpy()[0]
|
| 84 |
-
output_img = query_img_np
|
|
|
|
| 85 |
output_img = (output_img * 255).astype(np.uint8)
|
| 86 |
return output_img
|
| 87 |
|
|
|
|
| 37 |
support_imgs = []
|
| 38 |
for i in range(len(prompt)):
|
| 39 |
mask = torch.from_numpy(np.stack(prompt[i]['layers'], axis=0).any(0).any(-1)).cpu()
|
| 40 |
+
if mask.sum() == 0:
|
| 41 |
+
break
|
| 42 |
mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze(0).squeeze(0)
|
| 43 |
support_masks.append(mask)
|
| 44 |
support_img = Image.fromarray(prompt[i]['background'][..., :3])
|
|
|
|
| 83 |
nshot = support_masks.size(1)
|
| 84 |
pred_mask, simi, simi_map = model.predict_mask_nshot(batch, nshot=nshot)
|
| 85 |
pred_mask = pred_mask.detach().cpu().numpy()[0]
|
| 86 |
+
output_img = query_img_np.copy()
|
| 87 |
+
output_img[pred_mask] = np.array([1, 0, 0])
|
| 88 |
output_img = (output_img * 255).astype(np.uint8)
|
| 89 |
return output_img
|
| 90 |
|