import torch.nn as nn import torch from model.DCAMA import DCAMA from common import utils import cv2 import numpy as np import os import gradio as gr import time from torchvision import transforms from PIL import Image import torch.nn.functional as F img_mean = [0.485, 0.456, 0.406] img_std = [0.229, 0.224, 0.225] img_size = 384 transformation = transforms.Compose([transforms.Resize(size=(img_size, img_size)), transforms.ToTensor(), transforms.Normalize(img_mean, img_std)]) def inference_mask1( query_img, *prompt, ): query_img = Image.fromarray(query_img) org_qry_imsize = query_img.size query_img_np = np.asarray(query_img) query_img = transformation(query_img) shape = query_img_np.shape support_masks = [] support_imgs = [] for i in range(len(prompt)): mask = torch.from_numpy(np.stack(prompt[i]['layers'], axis=0).any(0).any(-1)).cpu() mask[mask > 0] = 1 if mask.sum() == 0: break mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze(0).squeeze(0) support_masks.append(mask) support_img = Image.fromarray(prompt[i]['background'][..., :3]) support_img = transformation(support_img) support_imgs.append(support_img) model = DCAMA('resnet50', 'resnet50_a1h-35c100f8.pth', True) model.eval() model.cpu() params = model.state_dict() state_dict = torch.load('model_45.pt', map_location=torch.device('cpu')) if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] state_dict2 = {} for k, v in state_dict.items(): if 'scorer' not in k: state_dict2[k] = v state_dict = state_dict2 for k1, k2 in zip(list(state_dict.keys()), params.keys()): state_dict[k2] = state_dict.pop(k1) try: model.load_state_dict(state_dict, strict=True) except: for k in params.keys(): if k not in state_dict.keys(): state_dict[k] = params[k] model.load_state_dict(state_dict, strict=True) query_img = query_img.unsqueeze(0) support_img = torch.stack(support_imgs, dim=0).unsqueeze(0) support_masks = torch.stack(support_masks, dim=0).unsqueeze(0) print("query_img:", query_img.size()) print("support_img:", support_img.size()) print("support_masks:", support_masks.size()) batch = { "support_masks": support_masks, "support_imgs": support_img, "query_img": query_img, "org_query_imsize": [torch.tensor([org_qry_imsize[0]]), torch.tensor([org_qry_imsize[1]])], } nshot = support_masks.size(1) pred_mask, simi, simi_map = model.predict_mask_nshot(batch, nshot=nshot) pred_mask = pred_mask.detach().cpu().numpy()[0] output_img = query_img_np.copy() output_img[pred_mask > 0] = np.array([255, 0, 0]) output_img = (output_img).astype(np.uint8) return output_img inputs = [gr.Image(label='query')] for i in range(10): inputs.append(gr.ImageMask(label='support {}'.format(i))) demo_mask = gr.Interface(fn=inference_mask1, inputs=inputs, outputs=[gr.Image(label="output")], ) demo = gr.TabbedInterface([demo_mask], ['demo']) demo.launch()