File size: 1,827 Bytes
1301060
 
 
 
 
 
 
 
 
 
 
 
 
 
b065c48
e044413
a26cb5e
 
1301060
 
 
 
 
 
 
 
 
 
 
 
922346a
1301060
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd7bf2f
922346a
1301060
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch.nn as nn
import torch

from model.DCAMA import DCAMA
from common.logger import Logger, AverageMeter
from common.vis import Visualizer
from common.evaluation import Evaluator
from common.config import parse_opts
from common import utils
from data.dataset import FSSDataset
import cv2
import numpy as np
import os

import gradio as gr
import time


def inference_mask1(
            query_img,
            *prompt,
              ):
    support_masks = []
    support_imgs = []
    for i in range(len(prompt)):
        mask = np.stack(prompt[i]['layers'], axis=0).any(0).any(-1)
        support_masks.append(mask)
        support_imgs.append(prompt[i]['background'])
    model = DCAMA('resnet50', 'resnet50_a1h-35c100f8.pth', True)
    model.eval()
    model.cpu()
    state_dict = torch.load('model_45.pt')
    if 'state_dict' in state_dict.keys():
        state_dict = state_dict['state_dict']
    state_dict2 = {}
    for k, v in state_dict.items():
        if 'scorer' not in k:
            state_dict2[k] = v
    state_dict = state_dict2

    for k1, k2 in zip(list(state_dict.keys()), params.keys()):
        state_dict[k2] = state_dict.pop(k1)

    try:
        model.load_state_dict(state_dict, strict=True)
    except:
        for k in params.keys():
            if k not in state_dict.keys():
                state_dict[k] = params[k]
        model.load_state_dict(state_dict, strict=True)
    import ipdb;ipdb.set_trace()
    pred_mask, simi, simi_map = model.module.predict_mask_nshot(batch, nshot=nshot)


inputs = [gr.Image(label='query')]
for i in range(5):
    inputs.append(gr.ImageMask(label='prompt{}'.format(i)))
demo_mask = gr.Interface(fn=inference_mask1,
                   inputs=inputs,
    outputs=[gr.Image(label="output")],
)

demo = gr.TabbedInterface([demo_mask], ['demo'])
demo.launch()