Multiple-shots / app.py
CUHKWilliam's picture
Update app.py
922346a verified
raw
history blame
1.83 kB
import torch.nn as nn
import torch
from model.DCAMA import DCAMA
from common.logger import Logger, AverageMeter
from common.vis import Visualizer
from common.evaluation import Evaluator
from common.config import parse_opts
from common import utils
from data.dataset import FSSDataset
import cv2
import numpy as np
import os
import gradio as gr
import time
def inference_mask1(
query_img,
*prompt,
):
support_masks = []
support_imgs = []
for i in range(len(prompt)):
mask = np.stack(prompt[i]['layers'], axis=0).any(0).any(-1)
support_masks.append(mask)
support_imgs.append(prompt[i]['background'])
model = DCAMA('resnet50', 'resnet50_a1h-35c100f8.pth', True)
model.eval()
model.cpu()
state_dict = torch.load('model_45.pt')
if 'state_dict' in state_dict.keys():
state_dict = state_dict['state_dict']
state_dict2 = {}
for k, v in state_dict.items():
if 'scorer' not in k:
state_dict2[k] = v
state_dict = state_dict2
for k1, k2 in zip(list(state_dict.keys()), params.keys()):
state_dict[k2] = state_dict.pop(k1)
try:
model.load_state_dict(state_dict, strict=True)
except:
for k in params.keys():
if k not in state_dict.keys():
state_dict[k] = params[k]
model.load_state_dict(state_dict, strict=True)
import ipdb;ipdb.set_trace()
pred_mask, simi, simi_map = model.module.predict_mask_nshot(batch, nshot=nshot)
inputs = [gr.Image(label='query')]
for i in range(5):
inputs.append(gr.ImageMask(label='prompt{}'.format(i)))
demo_mask = gr.Interface(fn=inference_mask1,
inputs=inputs,
outputs=[gr.Image(label="output")],
)
demo = gr.TabbedInterface([demo_mask], ['demo'])
demo.launch()