import torch.nn as nn import torch from model.DCAMA import DCAMA from common.logger import Logger, AverageMeter from common.vis import Visualizer from common.evaluation import Evaluator from common.config import parse_opts from common import utils from data.dataset import FSSDataset import cv2 import numpy as np import os import gradio as gr import time def inference_mask1( query_img, *prompt, ): support_masks = [] support_imgs = [] for i in range(len(prompt)): mask = np.stack(prompt[i]['layers'], axis=0).any(0).any(-1) support_masks.append(mask) support_imgs.append(prompt[i]['background']) model = DCAMA('resnet50', 'resnet50_a1h-35c100f8.pth', True) model.eval() model.cpu() state_dict = torch.load('model_45.pt') if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] state_dict2 = {} for k, v in state_dict.items(): if 'scorer' not in k: state_dict2[k] = v state_dict = state_dict2 for k1, k2 in zip(list(state_dict.keys()), params.keys()): state_dict[k2] = state_dict.pop(k1) try: model.load_state_dict(state_dict, strict=True) except: for k in params.keys(): if k not in state_dict.keys(): state_dict[k] = params[k] model.load_state_dict(state_dict, strict=True) import ipdb;ipdb.set_trace() pred_mask, simi, simi_map = model.module.predict_mask_nshot(batch, nshot=nshot) inputs = [gr.Image(label='query')] for i in range(5): inputs.append(gr.ImageMask(label='prompt{}'.format(i))) demo_mask = gr.Interface(fn=inference_mask1, inputs=inputs, outputs=[gr.Image(label="output")], ) demo = gr.TabbedInterface([demo_mask], ['demo']) demo.launch()