CUHKWilliam commited on
Commit
1301060
·
1 Parent(s): c70812a
Files changed (1) hide show
  1. app.py +59 -37
app.py CHANGED
@@ -1,44 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import time
3
 
4
 
5
- def sleep(im):
6
- time.sleep(5)
7
- return [im["background"], im["layers"][0], im["layers"][1], im["composite"]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
9
 
10
- support_im_masks = [None for _ in range(1000)]
11
- def predict(im):
12
- for i in range(len(support_im_masks)):
13
- if support_im_masks[i] is None:
14
- break
15
- support_im_mask = support_im_masks[i]
16
  import ipdb;ipdb.set_trace()
17
- pass
18
-
19
- with gr.Blocks() as demo:
20
- b = gr.Button("Add Textbox")
21
- b2 = gr.Button("Generate Masks")
22
- b2.click(predict)
23
-
24
- num = gr.State(0)
25
-
26
- b.click(lambda x:x+1, num, num)
27
-
28
- with gr.Row():
29
- query_im = gr.Image(label='query image')
30
-
31
- @gr.render(inputs=num)
32
- def show_support_imgs(n):
33
- with gr.Column():
34
- for i in range(n):
35
- support_im = gr.ImageEditor(
36
- label="support image {}".format(i),
37
- type="numpy",
38
- crop_size="1:1",
39
- )
40
- support_im_masks[i] = support_im
41
-
42
-
43
- if __name__ == "__main__":
44
- demo.launch(debug=True)
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ from model.DCAMA import DCAMA
5
+ from common.logger import Logger, AverageMeter
6
+ from common.vis import Visualizer
7
+ from common.evaluation import Evaluator
8
+ from common.config import parse_opts
9
+ from common import utils
10
+ from data.dataset import FSSDataset
11
+ import cv2
12
+ import numpy as np
13
+ import os
14
+
15
  import gradio as gr
16
  import time
17
 
18
 
19
+ def inference_mask1(
20
+ query_img,
21
+ *prompt,
22
+ ):
23
+ support_masks = []
24
+ support_imgs = []
25
+ for i in range(len(prompt)):
26
+ mask = np.stack(prompt[i]['layers'], axis=0).any(0).any(-1)
27
+ support_masks.append(mask)
28
+ support_imgs.append(prompt[i]['background'])
29
+ model = DCAMA('resnet50', 'resnet50_a1h-35c100f8.pth', True)
30
+ model.eval()
31
+ model.cuda()
32
+ state_dict = torch.load('model_45.pt')
33
+ if 'state_dict' in state_dict.keys():
34
+ state_dict = state_dict['state_dict']
35
+ state_dict2 = {}
36
+ for k, v in state_dict.items():
37
+ if 'scorer' not in k:
38
+ state_dict2[k] = v
39
+ state_dict = state_dict2
40
+
41
+ for k1, k2 in zip(list(state_dict.keys()), params.keys()):
42
+ state_dict[k2] = state_dict.pop(k1)
43
+
44
+ try:
45
+ model.load_state_dict(state_dict, strict=True)
46
+ except:
47
+ for k in params.keys():
48
+ if k not in state_dict.keys():
49
+ state_dict[k] = params[k]
50
+ model.load_state_dict(state_dict, strict=True)
51
 
52
+ pred_mask, simi, simi_map = model.module.predict_mask_nshot(batch, nshot=nshot)
53
 
 
 
 
 
 
 
54
  import ipdb;ipdb.set_trace()
55
+
56
+
57
+ inputs = [gr.Image(label='query')]
58
+ for i in range(5):
59
+ inputs.append(gr.ImageMask(label='prompt{}'.format(i)))
60
+ demo_mask = gr.Interface(fn=inference_mask1,
61
+ inputs=inputs,
62
+ outputs=[gr.Image(label="output")],
63
+ )
64
+
65
+ demo = gr.TabbedInterface([demo_mask], ['demo'])
66
+ demo.launch()