Multiple-shots / app.py
CUHKWilliam's picture
Update app.py
da57deb verified
import torch.nn as nn
import torch
from model.DCAMA import DCAMA
from common import utils
import cv2
import numpy as np
import os
import gradio as gr
import time
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]
img_size = 384
transformation = transforms.Compose([transforms.Resize(size=(img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(img_mean, img_std)])
def inference_mask1(
query_img,
*prompt,
):
query_img = Image.fromarray(query_img)
org_qry_imsize = query_img.size
query_img_np = np.asarray(query_img)
query_img = transformation(query_img)
shape = query_img_np.shape
support_masks = []
support_imgs = []
for i in range(len(prompt)):
mask = torch.from_numpy(np.stack(prompt[i]['layers'], axis=0).any(0).any(-1)).cpu()
mask[mask > 0] = 1
if mask.sum() == 0:
break
mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze(0).squeeze(0)
support_masks.append(mask)
support_img = Image.fromarray(prompt[i]['background'][..., :3])
support_img = transformation(support_img)
support_imgs.append(support_img)
model = DCAMA('resnet50', 'resnet50_a1h-35c100f8.pth', True)
model.eval()
model.cpu()
params = model.state_dict()
state_dict = torch.load('model_45.pt', map_location=torch.device('cpu'))
if 'state_dict' in state_dict.keys():
state_dict = state_dict['state_dict']
state_dict2 = {}
for k, v in state_dict.items():
if 'scorer' not in k:
state_dict2[k] = v
state_dict = state_dict2
for k1, k2 in zip(list(state_dict.keys()), params.keys()):
state_dict[k2] = state_dict.pop(k1)
try:
model.load_state_dict(state_dict, strict=True)
except:
for k in params.keys():
if k not in state_dict.keys():
state_dict[k] = params[k]
model.load_state_dict(state_dict, strict=True)
query_img = query_img.unsqueeze(0)
support_img = torch.stack(support_imgs, dim=0).unsqueeze(0)
support_masks = torch.stack(support_masks, dim=0).unsqueeze(0)
print("query_img:", query_img.size())
print("support_img:", support_img.size())
print("support_masks:", support_masks.size())
batch = {
"support_masks": support_masks,
"support_imgs": support_img,
"query_img": query_img,
"org_query_imsize": [torch.tensor([org_qry_imsize[0]]), torch.tensor([org_qry_imsize[1]])],
}
nshot = support_masks.size(1)
pred_mask, simi, simi_map = model.predict_mask_nshot(batch, nshot=nshot)
pred_mask = pred_mask.detach().cpu().numpy()[0]
output_img = query_img_np.copy()
output_img[pred_mask > 0] = np.array([255, 0, 0])
output_img = (output_img).astype(np.uint8)
return output_img
inputs = [gr.Image(label='query')]
for i in range(10):
inputs.append(gr.ImageMask(label='support {}'.format(i)))
demo_mask = gr.Interface(fn=inference_mask1,
inputs=inputs,
outputs=[gr.Image(label="output")],
)
demo = gr.TabbedInterface([demo_mask], ['demo'])
demo.launch()