Spaces:
Sleeping
Sleeping
File size: 3,336 Bytes
1301060 b065c48 e044413 4031034 8fb9092 e40c581 a26cb5e 4031034 a26cb5e 1301060 753655d 4faccae 4031034 8abd383 4031034 1301060 4031034 e803103 8fe3c03 d3fd8e8 4031034 1301060 9dcd552 4031034 1301060 922346a d3733c8 6bd32e7 1301060 d3733c8 ba5b4fb 4031034 f79acf9 4031034 6e64d71 4031034 79d8b10 a6ef8ef ae314b5 d3fd8e8 7ec5dd7 5a05dd2 e38c273 4031034 1301060 d02a299 56cf879 1301060 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import torch.nn as nn
import torch
from model.DCAMA import DCAMA
from common import utils
import cv2
import numpy as np
import os
import gradio as gr
import time
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]
img_size = 384
transformation = transforms.Compose([transforms.Resize(size=(img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(img_mean, img_std)])
def inference_mask1(
query_img,
*prompt,
):
query_img = Image.fromarray(query_img)
org_qry_imsize = query_img.size
query_img_np = np.asarray(query_img)
query_img = transformation(query_img)
shape = query_img_np.shape
support_masks = []
support_imgs = []
for i in range(len(prompt)):
mask = torch.from_numpy(np.stack(prompt[i]['layers'], axis=0).any(0).any(-1)).cpu()
mask[mask > 0] = 1
if mask.sum() == 0:
break
mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze(0).squeeze(0)
support_masks.append(mask)
support_img = Image.fromarray(prompt[i]['background'][..., :3])
support_img = transformation(support_img)
support_imgs.append(support_img)
model = DCAMA('resnet50', 'resnet50_a1h-35c100f8.pth', True)
model.eval()
model.cpu()
params = model.state_dict()
state_dict = torch.load('model_45.pt', map_location=torch.device('cpu'))
if 'state_dict' in state_dict.keys():
state_dict = state_dict['state_dict']
state_dict2 = {}
for k, v in state_dict.items():
if 'scorer' not in k:
state_dict2[k] = v
state_dict = state_dict2
for k1, k2 in zip(list(state_dict.keys()), params.keys()):
state_dict[k2] = state_dict.pop(k1)
try:
model.load_state_dict(state_dict, strict=True)
except:
for k in params.keys():
if k not in state_dict.keys():
state_dict[k] = params[k]
model.load_state_dict(state_dict, strict=True)
query_img = query_img.unsqueeze(0)
support_img = torch.stack(support_imgs, dim=0).unsqueeze(0)
support_masks = torch.stack(support_masks, dim=0).unsqueeze(0)
print("query_img:", query_img.size())
print("support_img:", support_img.size())
print("support_masks:", support_masks.size())
batch = {
"support_masks": support_masks,
"support_imgs": support_img,
"query_img": query_img,
"org_query_imsize": [torch.tensor([org_qry_imsize[0]]), torch.tensor([org_qry_imsize[1]])],
}
nshot = support_masks.size(1)
pred_mask, simi, simi_map = model.predict_mask_nshot(batch, nshot=nshot)
pred_mask = pred_mask.detach().cpu().numpy()[0]
output_img = query_img_np.copy()
output_img[pred_mask > 0] = np.array([255, 0, 0])
output_img = (output_img).astype(np.uint8)
return output_img
inputs = [gr.Image(label='query')]
for i in range(10):
inputs.append(gr.ImageMask(label='support {}'.format(i)))
demo_mask = gr.Interface(fn=inference_mask1,
inputs=inputs,
outputs=[gr.Image(label="output")],
)
demo = gr.TabbedInterface([demo_mask], ['demo'])
demo.launch()
|