File size: 3,336 Bytes
1301060
 
 
 
 
 
 
 
 
b065c48
e044413
4031034
8fb9092
e40c581
a26cb5e
4031034
 
 
 
 
 
 
a26cb5e
1301060
 
 
 
753655d
4faccae
4031034
8abd383
4031034
 
1301060
 
 
4031034
e803103
8fe3c03
d3fd8e8
 
4031034
1301060
9dcd552
4031034
 
1301060
 
922346a
d3733c8
6bd32e7
1301060
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3733c8
ba5b4fb
 
4031034
f79acf9
 
 
4031034
 
 
 
6e64d71
4031034
79d8b10
a6ef8ef
ae314b5
d3fd8e8
7ec5dd7
5a05dd2
e38c273
4031034
1301060
 
d02a299
56cf879
1301060
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch.nn as nn
import torch

from model.DCAMA import DCAMA
from common import utils
import cv2
import numpy as np
import os

import gradio as gr
import time
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F

img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]

img_size = 384
transformation = transforms.Compose([transforms.Resize(size=(img_size, img_size)),
                transforms.ToTensor(),
                transforms.Normalize(img_mean, img_std)])

def inference_mask1(
            query_img,
            *prompt,
              ):
    query_img = Image.fromarray(query_img)
    org_qry_imsize = query_img.size
    query_img_np = np.asarray(query_img)

    query_img = transformation(query_img)
    shape = query_img_np.shape
    support_masks = []
    support_imgs = []
    for i in range(len(prompt)):
        mask = torch.from_numpy(np.stack(prompt[i]['layers'], axis=0).any(0).any(-1)).cpu()
        mask[mask > 0] = 1

        if mask.sum() == 0:
            break
        mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(), query_img.size()[-2:], mode='nearest').squeeze(0).squeeze(0)
        support_masks.append(mask)
        support_img = Image.fromarray(prompt[i]['background'][..., :3])
        support_img = transformation(support_img)
        support_imgs.append(support_img)
    model = DCAMA('resnet50', 'resnet50_a1h-35c100f8.pth', True)
    model.eval()
    model.cpu()
    params = model.state_dict()
    state_dict = torch.load('model_45.pt', map_location=torch.device('cpu'))
    if 'state_dict' in state_dict.keys():
        state_dict = state_dict['state_dict']
    state_dict2 = {}
    for k, v in state_dict.items():
        if 'scorer' not in k:
            state_dict2[k] = v
    state_dict = state_dict2

    for k1, k2 in zip(list(state_dict.keys()), params.keys()):
        state_dict[k2] = state_dict.pop(k1)

    try:
        model.load_state_dict(state_dict, strict=True)
    except:
        for k in params.keys():
            if k not in state_dict.keys():
                state_dict[k] = params[k]
        model.load_state_dict(state_dict, strict=True)
    
    query_img = query_img.unsqueeze(0)
    support_img = torch.stack(support_imgs, dim=0).unsqueeze(0)
    support_masks = torch.stack(support_masks, dim=0).unsqueeze(0)
    print("query_img:", query_img.size())
    print("support_img:", support_img.size())
    print("support_masks:", support_masks.size())
    batch = {
        "support_masks": support_masks,
        "support_imgs": support_img,
        "query_img": query_img,
        "org_query_imsize": [torch.tensor([org_qry_imsize[0]]), torch.tensor([org_qry_imsize[1]])],
    }
    nshot = support_masks.size(1)
    pred_mask, simi, simi_map = model.predict_mask_nshot(batch, nshot=nshot)
    pred_mask = pred_mask.detach().cpu().numpy()[0]
    output_img = query_img_np.copy()
    output_img[pred_mask > 0] = np.array([255, 0, 0])
    output_img = (output_img).astype(np.uint8)
    return output_img
    

inputs = [gr.Image(label='query')]
for i in range(10):
    inputs.append(gr.ImageMask(label='support {}'.format(i)))
demo_mask = gr.Interface(fn=inference_mask1,
                   inputs=inputs,
    outputs=[gr.Image(label="output")],
)

demo = gr.TabbedInterface([demo_mask], ['demo'])
demo.launch()