Multiple-shots / app.py
CUHKWilliam's picture
Update app.py
9daa1fa verified
raw
history blame
3.35 kB
import torch.nn as nn
import torch
from model.DCAMA import DCAMA
from common import utils
import cv2
import numpy as np
import os
import gradio as gr
import time
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
import ipdb
img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]
img_size = 384
transformation = transforms.Compose([transforms.Resize(size=(img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(img_mean, img_std)])
def inference_mask1(
query_img,
*prompt,
):
query_img = Image.fromarray(query_img)
org_qry_imsize = query_img.size
query_img_np = np.asarray(query_img)
query_img = transformation(query_img)
shape = query_img_np.shape
support_masks = []
support_imgs = []
for i in range(len(prompt)):
mask = torch.from_numpy(np.stack(prompt[i]['layers'], axis=0).any(0).any(-1)).cpu()
mask[mask > 0] = 1
if mask.sum() == 0:
break
mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze(0).squeeze(0)
support_masks.append(mask)
support_img = Image.fromarray(prompt[i]['background'][..., :3])
support_img = transformation(support_img)
support_imgs.append(support_img)
model = DCAMA('resnet50', 'resnet50_a1h-35c100f8.pth', True)
model.eval()
model.cpu()
params = model.state_dict()
state_dict = torch.load('model_45.pt', map_location=torch.device('cpu'))
if 'state_dict' in state_dict.keys():
state_dict = state_dict['state_dict']
state_dict2 = {}
for k, v in state_dict.items():
if 'scorer' not in k:
state_dict2[k] = v
state_dict = state_dict2
for k1, k2 in zip(list(state_dict.keys()), params.keys()):
state_dict[k2] = state_dict.pop(k1)
try:
model.load_state_dict(state_dict, strict=True)
except:
for k in params.keys():
if k not in state_dict.keys():
state_dict[k] = params[k]
model.load_state_dict(state_dict, strict=True)
query_img = query_img.unsqueeze(0)
support_img = torch.stack(support_imgs, dim=0).unsqueeze(0)
support_masks = torch.stack(support_masks, dim=0).unsqueeze(0)
print("query_img:", query_img.size())
print("support_img:", support_img.size())
print("support_masks:", support_masks.size())
batch = {
"support_masks": support_masks,
"support_imgs": support_img,
"query_img": query_img,
"org_query_imsize": [torch.tensor([org_qry_imsize[0]]), torch.tensor([org_qry_imsize[1]])],
}
nshot = support_masks.size(1)
pred_mask, simi, simi_map = model.predict_mask_nshot(batch, nshot=nshot)
pred_mask = pred_mask.detach().cpu().numpy()[0]
output_img = query_img_np.copy()
output_img[pred_mask > 0] = np.array([255, 0, 0])
output_img = (output_img).astype(np.uint8)
return output_img
inputs = [gr.Image(label='query')]
for i in range(10):
inputs.append(gr.ImageMask(label='support {}'.format(i)))
demo_mask = gr.Interface(fn=inference_mask1,
inputs=inputs,
outputs=[gr.Image(label="output")],
)
demo = gr.TabbedInterface([demo_mask], ['demo'])
demo.launch()