CUHKWilliam commited on
Commit
c70812a
·
1 Parent(s): 214c299
Files changed (41) hide show
  1. analyze.py +11 -0
  2. common/__pycache__/config.cpython-38.pyc +0 -0
  3. common/__pycache__/config.cpython-39.pyc +0 -0
  4. common/__pycache__/evaluation.cpython-38.pyc +0 -0
  5. common/__pycache__/evaluation.cpython-39.pyc +0 -0
  6. common/__pycache__/logger.cpython-38.pyc +0 -0
  7. common/__pycache__/logger.cpython-39.pyc +0 -0
  8. common/__pycache__/utils.cpython-38.pyc +0 -0
  9. common/__pycache__/utils.cpython-39.pyc +0 -0
  10. common/__pycache__/vis.cpython-38.pyc +0 -0
  11. common/__pycache__/vis.cpython-39.pyc +0 -0
  12. common/config.py +31 -0
  13. common/evaluation.py +39 -0
  14. common/logger.py +117 -0
  15. common/utils.py +32 -0
  16. common/vis.py +106 -0
  17. gpu_mem_track.py +113 -0
  18. importance_analysis.py +130 -0
  19. model/DCAMA.py +625 -0
  20. model/__pycache__/DCAMA.cpython-38.pyc +0 -0
  21. model/__pycache__/DCAMA.cpython-39.pyc +0 -0
  22. model/base/__pycache__/swin_transformer.cpython-38.pyc +0 -0
  23. model/base/__pycache__/swin_transformer.cpython-39.pyc +0 -0
  24. model/base/__pycache__/transformer.cpython-38.pyc +0 -0
  25. model/base/__pycache__/transformer.cpython-39.pyc +0 -0
  26. model/base/swin_transformer.py +605 -0
  27. model/base/transformer.py +99 -0
  28. modelsize_estimate.py +38 -0
  29. scripts/importance_analysis.sh +16 -0
  30. scripts/test.sh +15 -0
  31. scripts/train.sh +11 -0
  32. scripts/train_1gpu.sh +12 -0
  33. scripts/train_1gpu_retriver.sh +12 -0
  34. scripts/train_2gpu.sh +14 -0
  35. scripts/train_2gpu_retriever.sh +14 -0
  36. scripts/train_4gpu.sh +14 -0
  37. test.py +132 -0
  38. train.py +149 -0
  39. train_1gpu.py +170 -0
  40. train_1gpu_retriever.py +172 -0
  41. train_retriever.py +164 -0
analyze.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import cv2
4
+ import os
5
+
6
+ with open('debug/stats.txt', 'r') as f:
7
+ stats = f.readlines()
8
+ for stat in stats:
9
+ plt.scatter(float(stat.split(" ")[0]), float(stat.split(' ')[1]), alpha=0.1, s=10, c='red')
10
+ print(stat)
11
+ plt.savefig('stats.png')
common/__pycache__/config.cpython-38.pyc ADDED
Binary file (1.28 kB). View file
 
common/__pycache__/config.cpython-39.pyc ADDED
Binary file (1.28 kB). View file
 
common/__pycache__/evaluation.cpython-38.pyc ADDED
Binary file (1.42 kB). View file
 
common/__pycache__/evaluation.cpython-39.pyc ADDED
Binary file (1.39 kB). View file
 
common/__pycache__/logger.cpython-38.pyc ADDED
Binary file (4.34 kB). View file
 
common/__pycache__/logger.cpython-39.pyc ADDED
Binary file (4.32 kB). View file
 
common/__pycache__/utils.cpython-38.pyc ADDED
Binary file (1.12 kB). View file
 
common/__pycache__/utils.cpython-39.pyc ADDED
Binary file (1.1 kB). View file
 
common/__pycache__/vis.cpython-38.pyc ADDED
Binary file (4.72 kB). View file
 
common/__pycache__/vis.cpython-39.pyc ADDED
Binary file (4.68 kB). View file
 
common/config.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""config"""
2
+ import argparse
3
+
4
+ def parse_opts():
5
+ r"""arguments"""
6
+ parser = argparse.ArgumentParser(description='Dense Cross-Query-and-Support Attention Weighted Mask Aggregation for Few-Shot Segmentation')
7
+
8
+ # common
9
+ parser.add_argument('--datapath', type=str, default='./datasets')
10
+ parser.add_argument('--benchmark', type=str, default='pascal', choices=['pascal', 'coco', 'fss'])
11
+ parser.add_argument('--fold', type=int, default=0, choices=[0, 1, 2, 3])
12
+ parser.add_argument('--bsz', type=int, default=20)
13
+ parser.add_argument('--nworker', type=int, default=8)
14
+ parser.add_argument('--backbone', type=str, default='swin', choices=['resnet50', 'resnet101', 'swin'])
15
+ parser.add_argument('--feature_extractor_path', type=str, default='')
16
+ parser.add_argument('--logpath', type=str, default='./logs')
17
+
18
+ # for train
19
+ parser.add_argument('--lr', type=float, default=1e-3)
20
+ parser.add_argument('--nepoch', type=int, default=1000)
21
+ parser.add_argument('--local-rank', default=0, type=int, help='node rank for distributed training')
22
+
23
+ # for test
24
+ parser.add_argument('--load', type=str, default='')
25
+ parser.add_argument('--nshot', type=int, default=1)
26
+ parser.add_argument('--visualize', action='store_true')
27
+ parser.add_argument('--vispath', type=str, default='./vis')
28
+ parser.add_argument('--use_original_imgsize', action='store_true')
29
+
30
+ args = parser.parse_args()
31
+ return args
common/evaluation.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Evaluate mask prediction """
2
+ import torch
3
+
4
+
5
+ class Evaluator:
6
+ r""" Computes intersection and union between prediction and ground-truth """
7
+ @classmethod
8
+ def initialize(cls):
9
+ cls.ignore_index = 255
10
+
11
+ @classmethod
12
+ def classify_prediction(cls, pred_mask, batch):
13
+ gt_mask = batch.get('query_mask')
14
+
15
+ # Apply ignore_index in PASCAL-5i masks (following evaluation scheme in PFE-Net (TPAMI 2020))
16
+ query_ignore_idx = batch.get('query_ignore_idx')
17
+ if query_ignore_idx is not None:
18
+ assert torch.logical_and(query_ignore_idx, gt_mask).sum() == 0
19
+ query_ignore_idx *= cls.ignore_index
20
+ gt_mask = gt_mask + query_ignore_idx
21
+ pred_mask[gt_mask == cls.ignore_index] = cls.ignore_index
22
+
23
+ # compute intersection and union of each episode in a batch
24
+ area_inter, area_pred, area_gt = [], [], []
25
+ for _pred_mask, _gt_mask in zip(pred_mask, gt_mask):
26
+ _inter = _pred_mask[_pred_mask == _gt_mask]
27
+ if _inter.size(0) == 0: # as torch.histc returns error if it gets empty tensor (pytorch 1.5.1)
28
+ _area_inter = torch.tensor([0, 0], device=_pred_mask.device)
29
+ else:
30
+ _area_inter = torch.histc(_inter, bins=2, min=0, max=1)
31
+ area_inter.append(_area_inter)
32
+ area_pred.append(torch.histc(_pred_mask, bins=2, min=0, max=1))
33
+ area_gt.append(torch.histc(_gt_mask, bins=2, min=0, max=1))
34
+ area_inter = torch.stack(area_inter).t()
35
+ area_pred = torch.stack(area_pred).t()
36
+ area_gt = torch.stack(area_gt).t()
37
+ area_union = area_pred + area_gt - area_inter
38
+
39
+ return area_inter, area_union
common/logger.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Logging during training/testing """
2
+ import datetime
3
+ import logging
4
+ import os
5
+
6
+ from tensorboardX import SummaryWriter
7
+ import torch
8
+
9
+
10
+ class AverageMeter:
11
+ r""" Stores loss, evaluation results """
12
+ def __init__(self, dataset):
13
+ self.benchmark = dataset.benchmark
14
+ self.class_ids_interest = dataset.class_ids
15
+ self.class_ids_interest = torch.tensor(self.class_ids_interest).cuda()
16
+
17
+ if self.benchmark == 'pascal':
18
+ self.nclass = 20
19
+ elif self.benchmark == 'coco':
20
+ self.nclass = 80
21
+ elif self.benchmark == 'fss':
22
+ self.nclass = 1000
23
+
24
+ self.intersection_buf = torch.zeros([2, self.nclass]).float().cuda()
25
+ self.union_buf = torch.zeros([2, self.nclass]).float().cuda()
26
+ self.ones = torch.ones_like(self.union_buf)
27
+ self.loss_buf = []
28
+
29
+ def update(self, inter_b, union_b, class_id, loss):
30
+ self.intersection_buf.index_add_(1, class_id, inter_b.float())
31
+ self.union_buf.index_add_(1, class_id, union_b.float())
32
+ if loss is None:
33
+ loss = torch.tensor(0.0)
34
+ self.loss_buf.append(loss)
35
+
36
+ def compute_iou(self):
37
+ iou = self.intersection_buf.float() / \
38
+ torch.max(torch.stack([self.union_buf, self.ones]), dim=0)[0]
39
+ iou = iou.index_select(1, self.class_ids_interest)
40
+ miou = iou[1].mean() * 100
41
+
42
+ fb_iou = (self.intersection_buf.index_select(1, self.class_ids_interest).sum(dim=1) /
43
+ self.union_buf.index_select(1, self.class_ids_interest).sum(dim=1)).mean() * 100
44
+
45
+ return miou, fb_iou
46
+
47
+ def write_result(self, split, epoch):
48
+ iou, fb_iou = self.compute_iou()
49
+
50
+ loss_buf = torch.stack(self.loss_buf)
51
+ msg = '\n*** %s ' % split
52
+ msg += '[@Epoch %02d] ' % epoch
53
+ msg += 'Avg L: %6.5f ' % loss_buf.mean()
54
+ msg += 'mIoU: %5.2f ' % iou
55
+ msg += 'FB-IoU: %5.2f ' % fb_iou
56
+
57
+ msg += '***\n'
58
+ Logger.info(msg)
59
+
60
+ def write_process(self, batch_idx, datalen, epoch, write_batch_idx=20):
61
+ if batch_idx % write_batch_idx == 0:
62
+ msg = '[Epoch: %02d] ' % epoch if epoch != -1 else ''
63
+ msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen)
64
+ iou, fb_iou = self.compute_iou()
65
+ if epoch != -1:
66
+ loss_buf = torch.stack(self.loss_buf)
67
+ msg += 'L: %6.5f ' % loss_buf[-1]
68
+ msg += 'Avg L: %6.5f ' % loss_buf.mean()
69
+ msg += 'mIoU: %5.2f | ' % iou
70
+ msg += 'FB-IoU: %5.2f' % fb_iou
71
+ Logger.info(msg)
72
+
73
+
74
+ class Logger:
75
+ r""" Writes evaluation results of training/testing """
76
+ @classmethod
77
+ def initialize(cls, args, training):
78
+ logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S')
79
+ logpath = os.path.join(args.logpath, 'train/fold_' + str(args.fold) + logtime) if training \
80
+ else os.path.join(args.logpath, 'test/fold_' + args.load.split('/')[-2].split('.')[0] + logtime)
81
+ if logpath == '': logpath = logtime
82
+
83
+ cls.logpath = logpath
84
+ cls.benchmark = args.benchmark
85
+ if not os.path.exists(cls.logpath): os.makedirs(cls.logpath)
86
+
87
+ logging.basicConfig(filemode='w',
88
+ filename=os.path.join(cls.logpath, 'log.txt'),
89
+ level=logging.INFO,
90
+ format='%(message)s',
91
+ datefmt='%m-%d %H:%M:%S')
92
+
93
+ # Console log config
94
+ console = logging.StreamHandler()
95
+ console.setLevel(logging.INFO)
96
+ formatter = logging.Formatter('%(message)s')
97
+ console.setFormatter(formatter)
98
+ logging.getLogger('').addHandler(console)
99
+
100
+ # Tensorboard writer
101
+ cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs'))
102
+
103
+ # Log arguments
104
+ logging.info('\n:==================== Start =====================')
105
+ for arg_key in args.__dict__:
106
+ logging.info('| %20s: %-24s' % (arg_key, str(args.__dict__[arg_key])))
107
+ logging.info(':================================================\n')
108
+
109
+ @classmethod
110
+ def info(cls, msg):
111
+ r""" Writes log message to log.txt """
112
+ logging.info(msg)
113
+
114
+ @classmethod
115
+ def save_model_miou(cls, model, epoch, val_miou):
116
+ torch.save(model.state_dict(), os.path.join(cls.logpath, "model_{}.pt".format(epoch)))
117
+ cls.info('Model saved @%d w/ val. mIoU: %5.2f.\n' % (epoch, val_miou))
common/utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Helper functions """
2
+ import random
3
+
4
+ import torch
5
+ import numpy as np
6
+
7
+
8
+ def fix_randseed(seed):
9
+ r""" Set random seeds for reproducibility """
10
+ if seed is None:
11
+ seed = int(random.random() * 1e5)
12
+ np.random.seed(seed)
13
+ torch.manual_seed(seed)
14
+ torch.cuda.manual_seed(seed)
15
+ torch.cuda.manual_seed_all(seed)
16
+ torch.backends.cudnn.benchmark = False
17
+ torch.backends.cudnn.deterministic = True
18
+
19
+
20
+ def mean(x):
21
+ return sum(x) / len(x) if len(x) > 0 else 0.0
22
+
23
+
24
+ def to_cuda(batch):
25
+ for key, value in batch.items():
26
+ if isinstance(value, torch.Tensor):
27
+ batch[key] = value.cuda()
28
+ return batch
29
+
30
+
31
+ def to_cpu(tensor):
32
+ return tensor.detach().clone().cpu()
common/vis.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Visualize model predictions """
2
+ import os
3
+
4
+ from PIL import Image
5
+ import numpy as np
6
+ import torchvision.transforms as transforms
7
+
8
+ from . import utils
9
+
10
+
11
+ class Visualizer:
12
+
13
+ @classmethod
14
+ def initialize(cls, visualize, vispath='./vis/'):
15
+ cls.visualize = visualize
16
+ if not visualize:
17
+ return
18
+
19
+ cls.colors = {'red': (255, 50, 50), 'blue': (102, 140, 255)}
20
+ for key, value in cls.colors.items():
21
+ cls.colors[key] = tuple([c / 255 for c in cls.colors[key]])
22
+
23
+ cls.mean_img = [0.485, 0.456, 0.406]
24
+ cls.std_img = [0.229, 0.224, 0.225]
25
+ cls.to_pil = transforms.ToPILImage()
26
+ cls.vis_path = vispath
27
+ if not os.path.exists(cls.vis_path): os.makedirs(cls.vis_path)
28
+
29
+ @classmethod
30
+ def visualize_prediction_batch(cls, spt_img_b, spt_mask_b, qry_img_b, qry_mask_b, pred_mask_b, cls_id_b, batch_idx, iou_b=None):
31
+ spt_img_b = utils.to_cpu(spt_img_b)
32
+ spt_mask_b = utils.to_cpu(spt_mask_b)
33
+ qry_img_b = utils.to_cpu(qry_img_b)
34
+ qry_mask_b = utils.to_cpu(qry_mask_b)
35
+ pred_mask_b = utils.to_cpu(pred_mask_b)
36
+ cls_id_b = utils.to_cpu(cls_id_b)
37
+
38
+ for sample_idx, (spt_img, spt_mask, qry_img, qry_mask, pred_mask, cls_id) in \
39
+ enumerate(zip(spt_img_b, spt_mask_b, qry_img_b, qry_mask_b, pred_mask_b, cls_id_b)):
40
+ iou = iou_b[sample_idx] if iou_b is not None else None
41
+ cls.visualize_prediction(spt_img, spt_mask, qry_img, qry_mask, pred_mask, cls_id, batch_idx, sample_idx, True, iou)
42
+
43
+ @classmethod
44
+ def to_numpy(cls, tensor, type):
45
+ if type == 'img':
46
+ return np.array(cls.to_pil(cls.unnormalize(tensor))).astype(np.uint8)
47
+ elif type == 'mask':
48
+ return np.array(tensor).astype(np.uint8)
49
+ else:
50
+ raise Exception('Undefined tensor type: %s' % type)
51
+
52
+ @classmethod
53
+ def visualize_prediction(cls, spt_imgs, spt_masks, qry_img, qry_mask, pred_mask, cls_id, batch_idx, sample_idx, label, iou=None):
54
+
55
+ spt_color = cls.colors['blue']
56
+ qry_color = cls.colors['red']
57
+ pred_color = cls.colors['red']
58
+
59
+ spt_imgs = [cls.to_numpy(spt_img, 'img') for spt_img in spt_imgs]
60
+ spt_pils = [cls.to_pil(spt_img) for spt_img in spt_imgs]
61
+ spt_masks = [cls.to_numpy(spt_mask, 'mask') for spt_mask in spt_masks]
62
+ spt_masked_pils = [Image.fromarray(cls.apply_mask(spt_img, spt_mask, spt_color)) for spt_img, spt_mask in zip(spt_imgs, spt_masks)]
63
+
64
+ qry_img = cls.to_numpy(qry_img, 'img')
65
+ qry_pil = cls.to_pil(qry_img)
66
+ qry_mask = cls.to_numpy(qry_mask, 'mask')
67
+ pred_mask = cls.to_numpy(pred_mask, 'mask')
68
+ pred_masked_pil = Image.fromarray(cls.apply_mask(qry_img.astype(np.uint8), pred_mask.astype(np.uint8), pred_color))
69
+ qry_masked_pil = Image.fromarray(cls.apply_mask(qry_img.astype(np.uint8), qry_mask.astype(np.uint8), qry_color))
70
+
71
+ merged_pil = cls.merge_image_pair(spt_masked_pils + [pred_masked_pil, qry_masked_pil])
72
+
73
+ iou = iou.item() if iou else 0.0
74
+ merged_pil.save(cls.vis_path + '%d_%d_class-%d_iou-%.2f' % (batch_idx, sample_idx, cls_id, iou) + '.jpg')
75
+
76
+ @classmethod
77
+ def merge_image_pair(cls, pil_imgs):
78
+ r""" Horizontally aligns a pair of pytorch tensor images (3, H, W) and returns PIL object """
79
+
80
+ canvas_width = sum([pil.size[0] for pil in pil_imgs])
81
+ canvas_height = max([pil.size[1] for pil in pil_imgs])
82
+ canvas = Image.new('RGB', (canvas_width, canvas_height))
83
+
84
+ xpos = 0
85
+ for pil in pil_imgs:
86
+ canvas.paste(pil, (xpos, 0))
87
+ xpos += pil.size[0]
88
+
89
+ return canvas
90
+
91
+ @classmethod
92
+ def apply_mask(cls, image, mask, color, alpha=0.5):
93
+ r""" Apply mask to the given image. """
94
+ for c in range(3):
95
+ image[:, :, c] = np.where(mask == 1,
96
+ image[:, :, c] *
97
+ (1 - alpha) + alpha * color[c] * 255,
98
+ image[:, :, c])
99
+ return image
100
+
101
+ @classmethod
102
+ def unnormalize(cls, img):
103
+ img = img.clone()
104
+ for im_channel, mean, std in zip(img, cls.mean_img, cls.std_img):
105
+ im_channel.mul_(std).add_(mean)
106
+ return img
gpu_mem_track.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import datetime
3
+ import inspect
4
+
5
+ import torch
6
+ import numpy as np
7
+
8
+ dtype_memory_size_dict = {
9
+ torch.float64: 64/8,
10
+ torch.double: 64/8,
11
+ torch.float32: 32/8,
12
+ torch.float: 32/8,
13
+ torch.float16: 16/8,
14
+ torch.half: 16/8,
15
+ torch.int64: 64/8,
16
+ torch.long: 64/8,
17
+ torch.int32: 32/8,
18
+ torch.int: 32/8,
19
+ torch.int16: 16/8,
20
+ torch.short: 16/6,
21
+ torch.uint8: 8/8,
22
+ torch.int8: 8/8,
23
+ }
24
+ # compatibility of torch1.0
25
+ if getattr(torch, "bfloat16", None) is not None:
26
+ dtype_memory_size_dict[torch.bfloat16] = 16/8
27
+ if getattr(torch, "bool", None) is not None:
28
+ dtype_memory_size_dict[torch.bool] = 8/8 # pytorch use 1 byte for a bool, see https://github.com/pytorch/pytorch/issues/41571
29
+
30
+ def get_mem_space(x):
31
+ try:
32
+ ret = dtype_memory_size_dict[x]
33
+ except KeyError:
34
+ print(f"dtype {x} is not supported!")
35
+ return ret
36
+
37
+ class MemTracker(object):
38
+ """
39
+ Class used to track pytorch memory usage
40
+ Arguments:
41
+ detail(bool, default True): whether the function shows the detail gpu memory usage
42
+ path(str): where to save log file
43
+ verbose(bool, default False): whether show the trivial exception
44
+ device(int): GPU number, default is 0
45
+ """
46
+ def __init__(self, detail=True, path='', verbose=False, device=0):
47
+ self.print_detail = detail
48
+ self.last_tensor_sizes = set()
49
+ self.gpu_profile_fn = path + f'{datetime.datetime.now():%d-%b-%y-%H:%M:%S}-gpu_mem_track.txt'
50
+ self.verbose = verbose
51
+ self.begin = True
52
+ self.device = device
53
+
54
+ def get_tensors(self):
55
+ for obj in gc.get_objects():
56
+ try:
57
+ if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
58
+ tensor = obj
59
+ else:
60
+ continue
61
+ if tensor.is_cuda:
62
+ yield tensor
63
+ except Exception as e:
64
+ if self.verbose:
65
+ print('A trivial exception occured: {}'.format(e))
66
+
67
+ def get_tensor_usage(self):
68
+ sizes = [np.prod(np.array(tensor.size())) * get_mem_space(tensor.dtype) for tensor in self.get_tensors()]
69
+ return np.sum(sizes) / 1024**2
70
+
71
+ def get_allocate_usage(self):
72
+ return torch.cuda.memory_allocated() / 1024**2
73
+
74
+ def clear_cache(self):
75
+ gc.collect()
76
+ torch.cuda.empty_cache()
77
+
78
+ def print_all_gpu_tensor(self, file=None):
79
+ for x in self.get_tensors():
80
+ print(x.size(), x.dtype, np.prod(np.array(x.size()))*get_mem_space(x.dtype)/1024**2, file=file)
81
+
82
+ def track(self):
83
+ """
84
+ Track the GPU memory usage
85
+ """
86
+ frameinfo = inspect.stack()[1]
87
+ where_str = frameinfo.filename + ' line ' + str(frameinfo.lineno) + ': ' + frameinfo.function
88
+
89
+ with open(self.gpu_profile_fn, 'a+') as f:
90
+
91
+ if self.begin:
92
+ f.write(f"GPU Memory Track | {datetime.datetime.now():%d-%b-%y-%H:%M:%S} |"
93
+ f" Total Tensor Used Memory:{self.get_tensor_usage():<7.1f}Mb"
94
+ f" Total Allocated Memory:{self.get_allocate_usage():<7.1f}Mb\n\n")
95
+ self.begin = False
96
+
97
+ if self.print_detail is True:
98
+ ts_list = [(tensor.size(), tensor.dtype) for tensor in self.get_tensors()]
99
+ new_tensor_sizes = {(type(x),
100
+ tuple(x.size()),
101
+ ts_list.count((x.size(), x.dtype)),
102
+ np.prod(np.array(x.size()))*get_mem_space(x.dtype)/1024**2,
103
+ x.dtype) for x in self.get_tensors()}
104
+ for t, s, n, m, data_type in new_tensor_sizes - self.last_tensor_sizes:
105
+ f.write(f'+ | {str(n)} * Size:{str(s):<20} | Memory: {str(m*n)[:6]} M | {str(t):<20} | {data_type}\n')
106
+ for t, s, n, m, data_type in self.last_tensor_sizes - new_tensor_sizes:
107
+ f.write(f'- | {str(n)} * Size:{str(s):<20} | Memory: {str(m*n)[:6]} M | {str(t):<20} | {data_type}\n')
108
+
109
+ self.last_tensor_sizes = new_tensor_sizes
110
+
111
+ f.write(f"\nAt {where_str:<50}"
112
+ f" Total Tensor Used Memory:{self.get_tensor_usage():<7.1f}Mb"
113
+ f" Total Allocated Memory:{self.get_allocate_usage():<7.1f}Mb\n\n")
importance_analysis.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Dense Cross-Query-and-Support Attention Weighted Mask Aggregation for Few-Shot Segmentation """
2
+ import torch.nn as nn
3
+ import torch
4
+
5
+ from model.DCAMA import DCAMA
6
+ from common.logger import Logger, AverageMeter
7
+ from common.vis import Visualizer
8
+ from common.evaluation import Evaluator
9
+ from common.config import parse_opts
10
+ from common import utils
11
+ from data.dataset import FSSDataset
12
+ import cv2
13
+ import numpy as np
14
+ import os
15
+
16
+ def test(model, dataloader, nshot):
17
+ r""" Test """
18
+
19
+ # Freeze randomness during testing for reproducibility
20
+ utils.fix_randseed(0)
21
+ average_meter = AverageMeter(dataloader.dataset)
22
+
23
+ for idx, batch in enumerate(dataloader):
24
+
25
+ # 1. forward pass
26
+ nshot = batch['support_imgs'].size(1)
27
+
28
+ batch['support_imgs'][0][0] = batch['query_img'][0]
29
+ batch['support_masks'][0][0] = batch['query_mask'][0]
30
+
31
+ batch = utils.to_cuda(batch)
32
+ pred_mask, simi, simi_map = model.module.predict_mask_nshot(batch, nshot=nshot)
33
+
34
+ assert pred_mask.size() == batch['query_mask'].size()
35
+
36
+ # 2. Evaluate prediction
37
+ area_inter, area_union = Evaluator.classify_prediction(pred_mask.clone(), batch)
38
+
39
+ ## TODO:
40
+ iou = area_inter[1] / area_union[1]
41
+
42
+ '''
43
+ cv2.imwrite('debug/query.png', cv2.imread("/home/bkdongxianchi/MY_MOT/TWL/data/COCO2014/{}".format(batch['query_name'][0])))
44
+ cv2.imwrite('debug/query_mask.png', (batch['query_mask'][0] * 255).detach().cpu().numpy().astype(np.uint8))
45
+ cv2.imwrite('debug/support_{:.3}.png'.format(iou.item()), cv2.imread('/home/bkdongxianchi/MY_MOT/TWL/data/COCO2014/{}'.format(batch['support_names'][0][0])))
46
+ cv2.imwrite('debug/support_mask_{:.3}.png'.format(iou.item()), (batch['support_masks'][0][0] * 255).detach().cpu().numpy().astype(np.uint8))
47
+ simi_map = simi_map - simi_map.min()
48
+ simi_map = (simi_map / simi_map.max() * 255).detach().cpu().numpy().astype(np.uint8)
49
+ cv2.imwrite('debug/simi_map_{:.3}.png'.format(iou.item()), simi_map)
50
+
51
+ if os.path.exists('debug/stats.txt'):
52
+ with open('debug/stats.txt', "a") as f:
53
+ f.write("{} {}\n".format(simi.item(), iou.item()))
54
+ else:
55
+ with open('debug/stats.txt', 'w') as f:
56
+ f.write('{} {}\n'.format(simi.item(), iou.item()))
57
+ '''
58
+
59
+ average_meter.update(area_inter, area_union, batch['class_id'], loss=None)
60
+ average_meter.write_process(idx, len(dataloader), epoch=-1, write_batch_idx=1)
61
+
62
+ # Visualize predictions
63
+ if Visualizer.visualize:
64
+ Visualizer.visualize_prediction_batch(batch['support_imgs'], batch['support_masks'],
65
+ batch['query_img'], batch['query_mask'],
66
+ pred_mask, batch['class_id'], idx,
67
+ iou_b=area_inter[1].float() / area_union[1].float())
68
+
69
+ # Write evaluation results
70
+ average_meter.write_result('Test', 0)
71
+ miou, fb_iou = average_meter.compute_iou()
72
+
73
+ return miou, fb_iou
74
+
75
+
76
+ if __name__ == '__main__':
77
+
78
+ # Arguments parsing
79
+ args = parse_opts()
80
+
81
+ Logger.initialize(args, training=False)
82
+
83
+ # Model initialization
84
+ model = DCAMA(args.backbone, args.feature_extractor_path, args.use_original_imgsize)
85
+ model.eval()
86
+
87
+ # Device setup
88
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
89
+ Logger.info('# available GPUs: %d' % torch.cuda.device_count())
90
+ model = nn.DataParallel(model)
91
+ model.to(device)
92
+
93
+ # Load trained model
94
+ if args.load == '': raise Exception('Pretrained model not specified.')
95
+ params = model.state_dict()
96
+ state_dict = torch.load(args.load)
97
+
98
+ if 'state_dict' in state_dict.keys():
99
+ state_dict = state_dict['state_dict']
100
+
101
+ state_dict2 = {}
102
+ for k, v in state_dict.items():
103
+ if 'scorer' not in k:
104
+ state_dict2[k] = v
105
+ state_dict = state_dict2
106
+ for k1, k2 in zip(list(state_dict.keys()), params.keys()):
107
+ state_dict[k2] = state_dict.pop(k1)
108
+
109
+
110
+ try:
111
+ model.load_state_dict(state_dict, strict=True)
112
+ except:
113
+ for k in params.keys():
114
+ if k not in state_dict.keys():
115
+ state_dict[k] = params[k]
116
+ model.load_state_dict(state_dict, strict=True)
117
+
118
+ # Helper classes (for testing) initialization
119
+ Evaluator.initialize()
120
+ Visualizer.initialize(args.visualize, args.vispath)
121
+
122
+ # Dataset initialization
123
+ FSSDataset.initialize(img_size=384, datapath=args.datapath, use_original_imgsize=args.use_original_imgsize)
124
+ dataloader_test = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'test', args.nshot)
125
+
126
+ # Test
127
+ with torch.no_grad():
128
+ test_miou, test_fb_iou = test(model, dataloader_test, args.nshot)
129
+ Logger.info('Fold %d mIoU: %5.2f \t FB-IoU: %5.2f' % (args.fold, test_miou.item(), test_fb_iou.item()))
130
+ Logger.info('==================== Finished Testing ====================')
model/DCAMA.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Dense Cross-Query-and-Support Attention Weighted Mask Aggregation for Few-Shot Segmentation """
2
+ from functools import reduce
3
+ from operator import add
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torchvision.models import resnet
9
+
10
+ from .base.swin_transformer import SwinTransformer
11
+ from model.base.transformer import MultiHeadedAttention, PositionalEncoding
12
+ import copy
13
+
14
+ class Flatten(nn.Module):
15
+ def forward(self, x):
16
+ return x.view(x.size(0), x.size(1), -1).contiguous()
17
+
18
+ def reshape(x, size):
19
+ size1 = torch.tensor(x.size()).float().cuda()
20
+ # x = torch.logical_not(x.cuda())
21
+ yxs = torch.stack(torch.where(x), dim=-1)
22
+ ratio = size[0] / size1[0]
23
+ yxs2 = (yxs * ratio).long()
24
+ x2 = torch.zeros((size[0], size[1])).float().cuda()
25
+ return yxs2
26
+
27
+
28
+ class DCAMA(nn.Module):
29
+
30
+ def __init__(self, backbone, pretrained_path, use_original_imgsize):
31
+ super(DCAMA, self).__init__()
32
+
33
+ self.backbone = backbone
34
+ self.use_original_imgsize = use_original_imgsize
35
+
36
+ # feature extractor initialization
37
+ if backbone == 'resnet50':
38
+ self.feature_extractor = resnet.resnet50()
39
+ self.feature_extractor.load_state_dict(torch.load(pretrained_path))
40
+ self.feat_channels = [256, 512, 1024, 2048]
41
+ self.nlayers = [3, 4, 6, 3]
42
+ self.feat_ids = list(range(0, 17))
43
+ self.last_feat_size = [12, 12]
44
+ elif backbone == 'resnet101':
45
+ self.feature_extractor = resnet.resnet101()
46
+ self.feature_extractor.load_state_dict(torch.load(pretrained_path))
47
+ self.feat_channels = [256, 512, 1024, 2048]
48
+ self.nlayers = [3, 4, 23, 3]
49
+ self.feat_ids = list(range(0, 34))
50
+ elif backbone == 'swin':
51
+ self.feature_extractor = SwinTransformer(img_size=384, patch_size=4, window_size=12, embed_dim=128,
52
+ depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32])
53
+ self.feature_extractor.load_state_dict(torch.load(pretrained_path)['model'])
54
+ self.feat_channels = [128, 256, 512, 1024]
55
+ self.nlayers = [2, 2, 18, 2]
56
+ else:
57
+ raise Exception('Unavailable backbone: %s' % backbone)
58
+ self.feature_extractor.eval()
59
+
60
+ # define model
61
+ self.lids = reduce(add, [[i + 1] * x for i, x in enumerate(self.nlayers)])
62
+ self.stack_ids = torch.tensor(self.lids).bincount()[-4:].cumsum(dim=0)
63
+ self.model = DCAMA_model(in_channels=self.feat_channels, stack_ids=self.stack_ids)
64
+
65
+ ## TODO:
66
+
67
+ self.scorer2 = nn.ModuleList()
68
+ for layer_idx in range(len(self.nlayers)):
69
+ layer_num = self.nlayers[layer_idx]
70
+ for idx in range(layer_num):
71
+ self.scorer2.append(
72
+ nn.Sequential(
73
+ nn.Conv2d(256 * 2 ** layer_idx, 256 * 2 ** layer_idx, 1, 1),
74
+ # nn.ReLU(),
75
+ # nn.InstanceNorm2d(256 * 2 ** layer_idx),
76
+ # nn.Conv2d(256 * 2 ** layer_idx, 256 * 2 ** layer_idx, 1, 1),
77
+ )
78
+ )
79
+ self.scorer1 = nn.Sequential(
80
+ nn.Linear(sum(self.nlayers) - self.nlayers[0], 1)
81
+ )
82
+
83
+ self.cross_entropy_loss = nn.CrossEntropyLoss()
84
+
85
+ def forward(self, query_img, support_img, support_mask, nshot, predict_score=False):
86
+ n_support_feats = []
87
+ with torch.no_grad():
88
+ for k in range(nshot):
89
+ support_feats_= self.extract_feats(support_img[:, k])
90
+ support_feats = copy.deepcopy(support_feats_)
91
+ del support_feats_
92
+ torch.cuda.empty_cache()
93
+ n_support_feats.append(support_feats)
94
+ query_feats = self.extract_feats(query_img)
95
+
96
+
97
+ logit_mask = self.model(query_feats, n_support_feats, support_mask.clone(), nshot=nshot)
98
+ ## TODO:
99
+ MAX_SHOTS = 1
100
+ if len(n_support_feats) >= MAX_SHOTS:
101
+ nshot = MAX_SHOTS
102
+ n_support_query_f = []
103
+ n_simi = []
104
+ for i in range(len(n_support_feats)):
105
+ support_f = n_support_feats[i]
106
+ support_query_f = []
107
+ simi_l = []
108
+ simi_sum = []
109
+ for l in range(len(query_feats)):
110
+ if l < self.stack_ids[0]:
111
+ continue
112
+ elif l < self.stack_ids[1]:
113
+ DCAMA_blocks = self.model.DCAMA_blocks[0]
114
+ pe = self.model.pe[0]
115
+ elif l < self.stack_ids[2]:
116
+ DCAMA_blocks = self.model.DCAMA_blocks[1]
117
+ pe = self.model.pe[1]
118
+ else:
119
+ DCAMA_blocks = self.model.DCAMA_blocks[2]
120
+ pe = self.model.pe[2]
121
+ a_support_f = support_f[l].clone()
122
+ coords = reshape(support_mask[0, i], a_support_f.size()[-2:])
123
+ b, ch, w, h = a_support_f.size()
124
+ a_support_f = a_support_f.view(b, ch, -1)
125
+ a_support_f = DCAMA_blocks.linears[0](pe(a_support_f.permute(0, 2, 1))).permute(0, 2, 1)
126
+ a_support_f = a_support_f.view(b, ch, w, h)
127
+ a_support_f = self.scorer2[l](a_support_f)
128
+ a_support_f = a_support_f[:, :, coords[:, 0], coords[:, 1]].mean(-1).unsqueeze(-1).unsqueeze(-1).repeat((1, 1, a_support_f.size(-2), a_support_f.size(-1)))
129
+ # a_support_f[:, :, coords_reverse[:, 0], coords_reverse[:, 1]] *= 0.
130
+ query_feat = query_feats[l].view(b, ch, -1)
131
+ query_feat = DCAMA_blocks.linears[0](pe(query_feat.permute(0, 2, 1))).permute(0, 2, 1)
132
+ query_feat = query_feat.view(b, ch, w, h)
133
+ query_feat = self.scorer2[l](query_feat)
134
+ simi = ((query_feat * a_support_f).sum(1) / torch.norm(query_feat, dim=1) / torch.norm(a_support_f, dim=1))[0]
135
+ simi_sum.append(simi)
136
+ # simi = torch.norm(query_feats[l] - a_support_f, dim=1)[0]
137
+ if l == 6:
138
+ simi_map = simi.clone()
139
+ simi = simi.view(-1).mean()
140
+ simi_l.append(simi)
141
+ # simi_l = self.scorer1(torch.stack(simi_l, dim=0).unsqueeze(0)).squeeze(0)[0]
142
+ n_simi.append(torch.stack(simi_l, dim=0).mean())
143
+
144
+ n_simi = torch.stack(n_simi, dim=0)
145
+ args = n_simi.argsort(descending=True)[:MAX_SHOTS]
146
+ support_mask = support_mask[:, args, :, :]
147
+ # n_support_feats = [n_support_feats[arg] for arg in args]
148
+ n_simis = n_simi[args].max()
149
+ else:
150
+ n_simis = torch.tensor(0.).float().cuda()
151
+ return logit_mask, n_simis
152
+
153
+ def extract_feats(self, img):
154
+ r""" Extract input image features """
155
+ feats = []
156
+
157
+ if self.backbone == 'swin':
158
+ _ = self.feature_extractor.forward_features(img)
159
+ for feat in self.feature_extractor.feat_maps:
160
+ bsz, hw, c = feat.size()
161
+ h = int(hw ** 0.5)
162
+ feat = feat.view(bsz, h, h, c).permute(0, 3, 1, 2).contiguous()
163
+ feats.append(feat)
164
+ elif self.backbone == 'resnet50' or self.backbone == 'resnet101':
165
+ bottleneck_ids = reduce(add, list(map(lambda x: list(range(x)), self.nlayers)))
166
+ # Layer 0
167
+ feat = self.feature_extractor.conv1.forward(img)
168
+ feat = self.feature_extractor.bn1.forward(feat)
169
+ feat = self.feature_extractor.relu.forward(feat)
170
+ feat = self.feature_extractor.maxpool.forward(feat)
171
+
172
+ # Layer 1-4
173
+ for hid, (bid, lid) in enumerate(zip(bottleneck_ids, self.lids)):
174
+ res = feat
175
+ feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].conv1.forward(feat)
176
+ feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].bn1.forward(feat)
177
+ feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
178
+ feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].conv2.forward(feat)
179
+ feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].bn2.forward(feat)
180
+ feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
181
+ feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].conv3.forward(feat)
182
+ feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].bn3.forward(feat)
183
+
184
+ if bid == 0:
185
+ res = self.feature_extractor.__getattr__('layer%d' % lid)[bid].downsample.forward(res)
186
+
187
+ feat += res
188
+
189
+ if hid + 1 in self.feat_ids:
190
+ feats.append(feat.clone())
191
+
192
+ feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
193
+
194
+ return feats
195
+
196
+ def predict_mask_nshot(self, batch, nshot):
197
+ r""" n-shot inference """
198
+ query_img = batch['query_img']
199
+ support_imgs = batch['support_imgs']
200
+ support_masks = batch['support_masks']
201
+
202
+ if nshot == 1:
203
+ with torch.no_grad():
204
+ query_feats = self.extract_feats(query_img)
205
+ n_support_feats = []
206
+ for k in range(nshot):
207
+ support_feats = self.extract_feats(support_imgs[:, k])
208
+ n_support_feats.append(support_feats)
209
+
210
+ n_simis = []
211
+ simi_map = None
212
+ for i in range(len(n_support_feats)):
213
+ support_f = n_support_feats[i]
214
+ support_query_f = []
215
+ simi_l = []
216
+ for l in range(len(query_feats)):
217
+ if l < self.stack_ids[0]:
218
+ continue
219
+ elif l < self.stack_ids[1]:
220
+ DCAMA_blocks = self.model.DCAMA_blocks[0]
221
+ pe = self.model.pe[0]
222
+ elif l < self.stack_ids[2]:
223
+ DCAMA_blocks = self.model.DCAMA_blocks[1]
224
+ pe = self.model.pe[1]
225
+ else:
226
+ DCAMA_blocks = self.model.DCAMA_blocks[2]
227
+ pe = self.model.pe[2]
228
+ a_support_f = support_f[l].clone()
229
+ coords = reshape(support_masks[0, i], a_support_f.size()[-2:])
230
+ b, ch, w, h = a_support_f.size()
231
+ a_support_f = a_support_f.view(b, ch, -1)
232
+ a_support_f = DCAMA_blocks.linears[0](pe(a_support_f.permute(0, 2, 1))).permute(0, 2, 1)
233
+ a_support_f = a_support_f.view(b, ch, w, h)
234
+ a_support_f = a_support_f[:, :, coords[:, 0], coords[:, 1]].mean(-1).unsqueeze(-1).unsqueeze(-1).repeat((1, 1, a_support_f.size(-2), a_support_f.size(-1)))
235
+ # a_support_f[:, :, coords_reverse[:, 0], coords_reverse[:, 1]] *= 0.
236
+ query_feat = query_feats[l].view(b, ch, -1)
237
+ query_feat = DCAMA_blocks.linears[0](pe(query_feat.permute(0, 2, 1))).permute(0, 2, 1)
238
+ query_feat = query_feat.view(b, ch, w, h)
239
+ simi = ((query_feat * a_support_f).sum(1) / torch.norm(query_feat, dim=1) / torch.norm(a_support_f, dim=1))[0]
240
+ # simi = torch.norm(query_feats[l] - a_support_f, dim=1)[0]
241
+ if l == 13:
242
+ simi_map = simi.clone()
243
+ simi = simi.view(-1).max()
244
+ simi_l.append(simi)
245
+ simi_l = torch.stack(simi_l, dim=0).mean()
246
+ n_simis.append(simi_l)
247
+ n_simis = torch.stack(n_simis, dim=0)
248
+ logit_mask = self.model(query_feats, n_support_feats, support_masks.clone(), nshot)
249
+
250
+ else:
251
+ with torch.no_grad():
252
+ query_feats = self.extract_feats(query_img)
253
+ n_support_feats = []
254
+ for k in range(nshot):
255
+ support_feats = self.extract_feats(support_imgs[:, k])
256
+ n_support_feats.append(support_feats)
257
+
258
+ ## TODO: retrieval V1 ##
259
+ MAX_SHOTS = 200
260
+ '''
261
+ if len(n_support_feats) > MAX_SHOTS:
262
+ nshot = MAX_SHOTS
263
+ n_support_query_f = []
264
+ n_simis = []
265
+ for i in range(len(n_support_feats)):
266
+ support_f = n_support_feats[i]
267
+ support_query_f = []
268
+ simi_l = []
269
+ simi_sum = []
270
+ for l in range(len(query_feats)):
271
+ if l < self.stack_ids[0]:
272
+ continue
273
+ elif l < self.stack_ids[1]:
274
+ DCAMA_blocks = self.model.DCAMA_blocks[0]
275
+ pe = self.model.pe[0]
276
+ elif l < self.stack_ids[2]:
277
+ DCAMA_blocks = self.model.DCAMA_blocks[1]
278
+ pe = self.model.pe[1]
279
+ else:
280
+ DCAMA_blocks = self.model.DCAMA_blocks[2]
281
+ pe = self.model.pe[2]
282
+ a_support_f = support_f[l].clone()
283
+ coords = reshape(support_masks[0, i], a_support_f.size()[-2:])
284
+ b, ch, w, h = a_support_f.size()
285
+ a_support_f = a_support_f.view(b, ch, -1)
286
+ a_support_f = DCAMA_blocks.linears[0](pe(a_support_f.permute(0, 2, 1))).permute(0, 2, 1)
287
+ a_support_f = a_support_f.view(b, ch, w, h)
288
+ a_support_f = a_support_f[:, :, coords[:, 0], coords[:, 1]].mean(-1).unsqueeze(-1).unsqueeze(-1).repeat((1, 1, a_support_f.size(-2), a_support_f.size(-1)))
289
+ # a_support_f[:, :, coords_reverse[:, 0], coords_reverse[:, 1]] *= 0.
290
+ query_feat = query_feats[l].view(b, ch, -1)
291
+ query_feat = DCAMA_blocks.linears[0](pe(query_feat.permute(0, 2, 1))).permute(0, 2, 1)
292
+ query_feat = query_feat.view(b, ch, w, h)
293
+ simi = ((query_feat * a_support_f).sum(1) / torch.norm(query_feat, dim=1) / torch.norm(a_support_f, dim=1))[0]
294
+ simi_sum.append(simi)
295
+ # simi = torch.norm(query_feats[l] - a_support_f, dim=1)[0]
296
+ if l == 6:
297
+ simi_map = simi.clone().detach().cpu().numpy()
298
+ simi = simi.view(-1).max()
299
+ simi_l.append(simi)
300
+ simi_l = torch.stack(simi_l, dim=0).mean()
301
+ n_simis.append(simi_l)
302
+ n_simis = torch.stack(n_simis, dim=0)
303
+ # nshot = max((n_simis > 0.).sum(), 1)
304
+ nshot = len(n_simis)
305
+ support_masks = support_masks[:, n_simis.argsort(descending=True)[:nshot], :, :]
306
+ n_support_feats = [n_support_feats[i] for i in n_simis.argsort(descending=True)[:nshot]]
307
+ else:
308
+ n_simis = torch.tensor(0.).float().cuda()
309
+ simi_map = None
310
+ ## TODO: retriever V2
311
+ '''
312
+ '''
313
+ MAX_SHOTS = 30
314
+ if len(n_support_feats) > MAX_SHOTS:
315
+ nshot = MAX_SHOTS
316
+ n_support_query_f = []
317
+ n_simis = []
318
+ support_f_list = []
319
+ n_support_feats2 = []
320
+ query_feats2 = []
321
+ for i in range(len(n_support_feats)):
322
+ support_f = n_support_feats[i]
323
+ n_support_feats2_l = []
324
+ query_feats2_l = []
325
+ for l in range(len(query_feats)):
326
+ if l < self.stack_ids[0]:
327
+ continue
328
+ elif l < self.stack_ids[1]:
329
+ DCAMA_blocks = self.model.DCAMA_blocks[0]
330
+ pe = self.model.pe[0]
331
+ elif l < self.stack_ids[2]:
332
+ DCAMA_blocks = self.model.DCAMA_blocks[1]
333
+ pe = self.model.pe[1]
334
+ else:
335
+ DCAMA_blocks = self.model.DCAMA_blocks[2]
336
+ pe = self.model.pe[2]
337
+ a_support_f = support_f[l].clone()
338
+ coords = reshape(support_masks[0, i], a_support_f.size()[-2:])
339
+ b, ch, w, h = a_support_f.size()
340
+ a_support_f = a_support_f.view(b, ch, -1)
341
+ a_support_f = DCAMA_blocks.linears[0](pe(a_support_f.permute(0, 2, 1))).permute(0, 2, 1)
342
+ a_support_f = a_support_f.view(b, ch, w, h)
343
+ a_support_f = a_support_f[:, :, coords[:, 0], coords[:, 1]].mean(-1)
344
+ n_support_feats2_l.append(a_support_f)
345
+ query_feat = query_feats[l].view(b, ch, -1)
346
+ query_feat = DCAMA_blocks.linears[0](pe(query_feat.permute(0, 2, 1))).permute(0, 2, 1)
347
+ query_feat = query_feat.view(b, ch, w, h)
348
+ query_feats2_l.append(query_feat)
349
+ n_support_feats2.append(n_support_feats2_l)
350
+ query_feats2.append(query_feats2_l)
351
+ n_support_feats3 = [[] for _ in range(len(query_feats2[0]))]
352
+ selected = []
353
+ for i in range(MAX_SHOTS):
354
+ simi_min = -100
355
+ idx_min = -1
356
+ for idx in range(len(n_support_feats2)):
357
+ if idx in selected:
358
+ continue
359
+ support_feats2 = n_support_feats2[idx]
360
+ simi = []
361
+ for l in range(len(query_feats2[i])):
362
+ support_feats_avg = torch.stack(n_support_feats3[l] + [support_feats2[l]], dim=0).mean(0)
363
+ query_feat = query_feats2[i][l]
364
+ a_support_f = support_feats_avg.unsqueeze(-1).unsqueeze(-1).repeat(
365
+ (1, 1, query_feat.size(-2), query_feat.size(-1)))
366
+ simi_l = ((query_feat * a_support_f).sum(1) / torch.norm(query_feat, dim=1) / torch.norm(
367
+ a_support_f, dim=1))[0].view(-1).max()
368
+ simi.append(simi_l)
369
+ simi = torch.stack(simi, dim=0).mean()
370
+ if simi > simi_min:
371
+ simi_min = simi
372
+ idx_min = idx
373
+ support_feats2_argmin = n_support_feats2[idx]
374
+ for l2 in range(len(query_feats2[0])):
375
+ n_support_feats3[l2].append(n_support_feats2[idx_min][l2])
376
+ selected.append(idx_min)
377
+ n_support_feats4 = []
378
+ for idx in selected:
379
+ n_support_feats4.append(n_support_feats[idx])
380
+ support_masks = support_masks[:, torch.tensor(selected).long().cuda(), :, :]
381
+ n_support_feats = n_support_feats4
382
+ simi_map = None
383
+ else:
384
+ n_simis = torch.tensor(0.).float().cuda()
385
+ simi_map = None
386
+ '''
387
+ ## TODO: v3
388
+
389
+ MAX_SHOTS = 200
390
+ if len(n_support_feats) > MAX_SHOTS:
391
+ nshot = MAX_SHOTS
392
+ n_support_query_f = []
393
+ n_simis = []
394
+ support_f_list = []
395
+ n_support_feats2 = []
396
+ query_feats2 = []
397
+ for i in range(len(n_support_feats)):
398
+ support_f = n_support_feats[i]
399
+ n_support_feats2_l = []
400
+ query_feats2_l = []
401
+ for l in range(len(query_feats)):
402
+ if l < self.stack_ids[0]:
403
+ continue
404
+ elif l < self.stack_ids[1]:
405
+ DCAMA_blocks = self.model.DCAMA_blocks[0]
406
+ pe = self.model.pe[0]
407
+ elif l < self.stack_ids[2]:
408
+ DCAMA_blocks = self.model.DCAMA_blocks[1]
409
+ pe = self.model.pe[1]
410
+ else:
411
+ DCAMA_blocks = self.model.DCAMA_blocks[2]
412
+ pe = self.model.pe[2]
413
+ a_support_f = support_f[l].clone()
414
+ coords = reshape(support_masks[0, i], a_support_f.size()[-2:])
415
+ b, ch, w, h = a_support_f.size()
416
+ a_support_f = a_support_f.view(b, ch, -1)
417
+ a_support_f_tmp = DCAMA_blocks.linears[0](pe(a_support_f.permute(0, 2, 1))).permute(0, 2, 1)
418
+ a_support_f = a_support_f_tmp / a_support_f_tmp.norm(dim=1, keepdim=True) * DCAMA_blocks.linears[1](pe(a_support_f.permute(0, 2, 1))).permute(0, 2, 1)
419
+ a_support_f = a_support_f.view(b, ch, w, h)
420
+ a_support_f = a_support_f[:, :, coords[:, 0], coords[:, 1]].mean(-1)
421
+ n_support_feats2_l.append(a_support_f)
422
+ query_feat = query_feats[l].view(b, ch, -1)
423
+ query_feat_tmp = DCAMA_blocks.linears[0](pe(query_feat.permute(0, 2, 1))).permute(0, 2, 1)
424
+ query_feat = query_feat_tmp / query_feat_tmp.norm(dim=1, keepdim=True) * DCAMA_blocks.linears[1](pe(query_feat.permute(0, 2, 1))).permute(0, 2, 1)
425
+ query_feat = query_feat.view(b, ch, w, h)
426
+ query_feats2_l.append(query_feat)
427
+ n_support_feats2.append(n_support_feats2_l)
428
+ query_feats2.append(query_feats2_l)
429
+ n_support_feats3 = [[] for _ in range(len(query_feats2[0]))]
430
+ selected = []
431
+ for i in range(MAX_SHOTS):
432
+ simi_min = -100
433
+ idx_min = -1
434
+ for idx in range(len(n_support_feats2)):
435
+ if idx in selected:
436
+ continue
437
+ support_feats2 = n_support_feats2[idx]
438
+ simi = []
439
+ for l in range(len(query_feats2[i])):
440
+ support_feats_avg = torch.stack(n_support_feats3[l] + [support_feats2[l]], dim=0).mean(0)
441
+ query_feat = query_feats2[i][l]
442
+ a_support_f = support_feats_avg.unsqueeze(-1).unsqueeze(-1).repeat(
443
+ (1, 1, query_feat.size(-2), query_feat.size(-1)))
444
+ simi_l = ((query_feat * a_support_f).sum(1))[0].view(-1).max()
445
+ simi.append(simi_l)
446
+ simi = torch.stack(simi, dim=0).mean()
447
+ if simi > simi_min:
448
+ simi_min = simi
449
+ idx_min = idx
450
+ support_feats2_argmin = n_support_feats2[idx]
451
+ for l2 in range(len(query_feats2[0])):
452
+ n_support_feats3[l2].append(n_support_feats2[idx_min][l2])
453
+ selected.append(idx_min)
454
+ n_support_feats4 = []
455
+ for idx in selected:
456
+ n_support_feats4.append(n_support_feats[idx])
457
+ support_masks = support_masks[:, torch.tensor(selected).long().cuda(), :, :]
458
+ n_support_feats = n_support_feats4
459
+ simi_map = None
460
+ else:
461
+ n_simis = torch.tensor(0.).float().cuda()
462
+ simi_map = None
463
+
464
+ logit_mask = self.model(query_feats, n_support_feats, support_masks.clone(), nshot)
465
+
466
+ if self.use_original_imgsize:
467
+ org_qry_imsize = tuple([batch['org_query_imsize'][1].item(), batch['org_query_imsize'][0].item()])
468
+ logit_mask = F.interpolate(logit_mask, org_qry_imsize, mode='bilinear', align_corners=True)
469
+ else:
470
+ logit_mask = F.interpolate(logit_mask, support_imgs[0].size()[2:], mode='bilinear', align_corners=True)
471
+
472
+ return logit_mask.argmax(dim=1), n_simis, simi_map
473
+
474
+ def compute_objective(self, logit_mask, gt_mask):
475
+ bsz = logit_mask.size(0)
476
+ logit_mask = logit_mask.view(bsz, 2, -1)
477
+ gt_mask = gt_mask.view(bsz, -1).long()
478
+
479
+ return self.cross_entropy_loss(logit_mask, gt_mask)
480
+
481
+ def train_mode(self):
482
+ self.train()
483
+ self.feature_extractor.eval()
484
+
485
+
486
+ class DCAMA_model(nn.Module):
487
+ def __init__(self, in_channels, stack_ids):
488
+ super(DCAMA_model, self).__init__()
489
+
490
+ self.stack_ids = stack_ids
491
+
492
+ # DCAMA blocks
493
+ self.DCAMA_blocks = nn.ModuleList()
494
+ self.pe = nn.ModuleList()
495
+ for inch in in_channels[1:]:
496
+ self.DCAMA_blocks.append(MultiHeadedAttention(h=8, d_model=inch, dropout=0.5))
497
+ self.pe.append(PositionalEncoding(d_model=inch, dropout=0.5))
498
+
499
+ outch1, outch2, outch3 = 16, 64, 128
500
+
501
+ # conv blocks
502
+ self.conv1 = self.build_conv_block(stack_ids[3]-stack_ids[2], [outch1, outch2, outch3], [3, 3, 3], [1, 1, 1]) # 1/32
503
+ self.conv2 = self.build_conv_block(stack_ids[2]-stack_ids[1], [outch1, outch2, outch3], [5, 3, 3], [1, 1, 1]) # 1/16
504
+ self.conv3 = self.build_conv_block(stack_ids[1]-stack_ids[0], [outch1, outch2, outch3], [5, 5, 3], [1, 1, 1]) # 1/8
505
+
506
+ self.conv4 = self.build_conv_block(outch3, [outch3, outch3, outch3], [3, 3, 3], [1, 1, 1]) # 1/32 + 1/16
507
+ self.conv5 = self.build_conv_block(outch3, [outch3, outch3, outch3], [3, 3, 3], [1, 1, 1]) # 1/16 + 1/8
508
+
509
+ # mixer blocks
510
+ self.mixer1 = nn.Sequential(nn.Conv2d(outch3+2*in_channels[1]+2*in_channels[0], outch3, (3, 3), padding=(1, 1), bias=True),
511
+ nn.ReLU(),
512
+ nn.Conv2d(outch3, outch2, (3, 3), padding=(1, 1), bias=True),
513
+ nn.ReLU())
514
+
515
+ self.mixer2 = nn.Sequential(nn.Conv2d(outch2, outch2, (3, 3), padding=(1, 1), bias=True),
516
+ nn.ReLU(),
517
+ nn.Conv2d(outch2, outch1, (3, 3), padding=(1, 1), bias=True),
518
+ nn.ReLU())
519
+
520
+ self.mixer3 = nn.Sequential(nn.Conv2d(outch1, outch1, (3, 3), padding=(1, 1), bias=True),
521
+ nn.ReLU(),
522
+ nn.Conv2d(outch1, 2, (3, 3), padding=(1, 1), bias=True))
523
+
524
+
525
+
526
+ def forward(self, query_feats, support_feats, support_mask, nshot=1):
527
+ coarse_masks = []
528
+ for idx, query_feat in enumerate(query_feats):
529
+ # 1/4 scale feature only used in skip connect
530
+ if idx < self.stack_ids[0]: continue
531
+
532
+ bsz, ch, ha, wa = query_feat.size()
533
+
534
+ # reshape the input feature and mask
535
+ query = query_feat.view(bsz, ch, -1).permute(0, 2, 1).contiguous()
536
+ # if nshot == 1:
537
+ # support_feat = support_feats[idx]
538
+ # mask = F.interpolate(support_mask.unsqueeze(1).float(), support_feat.size()[2:], mode='bilinear',
539
+ # align_corners=True).view(support_feat.size()[0], -1)
540
+ # support_feat = support_feat.view(support_feat.size()[0], support_feat.size()[1], -1).permute(0, 2, 1).contiguous()
541
+ # else:
542
+ support_feat = torch.stack([support_feats[k][idx] for k in range(nshot)])
543
+ support_feat = support_feat.view(-1, ch, ha * wa).permute(0, 2, 1).contiguous()
544
+ mask = torch.stack([F.interpolate(k.unsqueeze(1).float(), (ha, wa), mode='bilinear', align_corners=True)
545
+ for k in support_mask])
546
+ mask = mask.view(bsz, -1)
547
+
548
+ # DCAMA blocks forward
549
+ DCAMA_blocks = None
550
+ pe = None
551
+ if idx < self.stack_ids[1]:
552
+ DCAMA_blocks = self.DCAMA_blocks[0]
553
+ pe = self.pe[0]
554
+ elif idx < self.stack_ids[2]:
555
+ DCAMA_blocks = self.DCAMA_blocks[1]
556
+ pe = self.pe[1]
557
+ else:
558
+ DCAMA_blocks = self.DCAMA_blocks[2]
559
+ pe = self.pe[2]
560
+ coarse_mask = DCAMA_blocks(pe(query), pe(support_feat), mask)
561
+ coarse_masks.append(coarse_mask.permute(0, 2, 1).contiguous().view(bsz, 1, ha, wa))
562
+
563
+
564
+ # multi-scale conv blocks forward
565
+ bsz, ch, ha, wa = coarse_masks[self.stack_ids[3]-1-self.stack_ids[0]].size()
566
+ coarse_masks1 = torch.stack(coarse_masks[self.stack_ids[2]-self.stack_ids[0]:self.stack_ids[3]-self.stack_ids[0]]).transpose(0, 1).contiguous().view(bsz, -1, ha, wa)
567
+ bsz, ch, ha, wa = coarse_masks[self.stack_ids[2]-1-self.stack_ids[0]].size()
568
+ coarse_masks2 = torch.stack(coarse_masks[self.stack_ids[1]-self.stack_ids[0]:self.stack_ids[2]-self.stack_ids[0]]).transpose(0, 1).contiguous().view(bsz, -1, ha, wa)
569
+ bsz, ch, ha, wa = coarse_masks[self.stack_ids[1]-1-self.stack_ids[0]].size()
570
+ coarse_masks3 = torch.stack(coarse_masks[0:self.stack_ids[1]-self.stack_ids[0]]).transpose(0, 1).contiguous().view(bsz, -1, ha, wa)
571
+
572
+ coarse_masks1 = self.conv1(coarse_masks1)
573
+ coarse_masks2 = self.conv2(coarse_masks2)
574
+ coarse_masks3 = self.conv3(coarse_masks3)
575
+
576
+ # multi-scale cascade (pixel-wise addition)
577
+ coarse_masks1 = F.interpolate(coarse_masks1, coarse_masks2.size()[-2:], mode='bilinear', align_corners=True)
578
+ mix = coarse_masks1 + coarse_masks2
579
+ mix = self.conv4(mix)
580
+
581
+ mix = F.interpolate(mix, coarse_masks3.size()[-2:], mode='bilinear', align_corners=True)
582
+ mix = mix + coarse_masks3
583
+ mix = self.conv5(mix)
584
+
585
+ # skip connect 1/8 and 1/4 features (concatenation)
586
+ # if nshot == 1:
587
+ # support_feat = support_feats[self.stack_ids[1] - 1]
588
+ # else:
589
+ support_feat = torch.stack([support_feats[k][self.stack_ids[1] - 1] for k in range(nshot)]).max(dim=0).values
590
+ mix = torch.cat((mix, query_feats[self.stack_ids[1] - 1], support_feat), 1)
591
+
592
+ upsample_size = (mix.size(-1) * 2,) * 2
593
+ mix = F.interpolate(mix, upsample_size, mode='bilinear', align_corners=True)
594
+ # if nshot == 1:
595
+ # support_feat = support_feats[self.stack_ids[0] - 1]
596
+ # else:
597
+ support_feat = torch.stack([support_feats[k][self.stack_ids[0] - 1] for k in range(nshot)]).max(dim=0).values
598
+ mix = torch.cat((mix, query_feats[self.stack_ids[0] - 1], support_feat), 1)
599
+
600
+ # mixer blocks forward
601
+ out = self.mixer1(mix)
602
+ upsample_size = (out.size(-1) * 2,) * 2
603
+ out = F.interpolate(out, upsample_size, mode='bilinear', align_corners=True)
604
+ out = self.mixer2(out)
605
+ upsample_size = (out.size(-1) * 2,) * 2
606
+ out = F.interpolate(out, upsample_size, mode='bilinear', align_corners=True)
607
+ logit_mask = self.mixer3(out)
608
+
609
+ return logit_mask
610
+
611
+ def build_conv_block(self, in_channel, out_channels, kernel_sizes, spt_strides, group=4):
612
+ r""" bulid conv blocks """
613
+ assert len(out_channels) == len(kernel_sizes) == len(spt_strides)
614
+
615
+ building_block_layers = []
616
+ for idx, (outch, ksz, stride) in enumerate(zip(out_channels, kernel_sizes, spt_strides)):
617
+ inch = in_channel if idx == 0 else out_channels[idx - 1]
618
+ pad = ksz // 2
619
+
620
+ building_block_layers.append(nn.Conv2d(in_channels=inch, out_channels=outch,
621
+ kernel_size=ksz, stride=stride, padding=pad))
622
+ building_block_layers.append(nn.GroupNorm(group, outch))
623
+ building_block_layers.append(nn.ReLU(inplace=True))
624
+
625
+ return nn.Sequential(*building_block_layers)
model/__pycache__/DCAMA.cpython-38.pyc ADDED
Binary file (13.8 kB). View file
 
model/__pycache__/DCAMA.cpython-39.pyc ADDED
Binary file (13.3 kB). View file
 
model/base/__pycache__/swin_transformer.cpython-38.pyc ADDED
Binary file (20.6 kB). View file
 
model/base/__pycache__/swin_transformer.cpython-39.pyc ADDED
Binary file (20.5 kB). View file
 
model/base/__pycache__/transformer.cpython-38.pyc ADDED
Binary file (3.61 kB). View file
 
model/base/__pycache__/transformer.cpython-39.pyc ADDED
Binary file (3.68 kB). View file
 
model/base/swin_transformer.py ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.utils.checkpoint as checkpoint
11
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
12
+
13
+
14
+ class Mlp(nn.Module):
15
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
16
+ super().__init__()
17
+ out_features = out_features or in_features
18
+ hidden_features = hidden_features or in_features
19
+ self.fc1 = nn.Linear(in_features, hidden_features)
20
+ self.act = act_layer()
21
+ self.fc2 = nn.Linear(hidden_features, out_features)
22
+ self.drop = nn.Dropout(drop)
23
+
24
+ def forward(self, x):
25
+ x = self.fc1(x)
26
+ x = self.act(x)
27
+ x = self.drop(x)
28
+ x = self.fc2(x)
29
+ x = self.drop(x)
30
+ return x
31
+
32
+
33
+ def window_partition(x, window_size):
34
+ """
35
+ Args:
36
+ x: (B, H, W, C)
37
+ window_size (int): window size
38
+
39
+ Returns:
40
+ windows: (num_windows*B, window_size, window_size, C)
41
+ """
42
+ B, H, W, C = x.shape
43
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
44
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
45
+ return windows
46
+
47
+
48
+ def window_reverse(windows, window_size, H, W):
49
+ """
50
+ Args:
51
+ windows: (num_windows*B, window_size, window_size, C)
52
+ window_size (int): Window size
53
+ H (int): Height of image
54
+ W (int): Width of image
55
+
56
+ Returns:
57
+ x: (B, H, W, C)
58
+ """
59
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
60
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
61
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
62
+ return x
63
+
64
+
65
+ class WindowAttention(nn.Module):
66
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
67
+ It supports both of shifted and non-shifted window.
68
+
69
+ Args:
70
+ dim (int): Number of input channels.
71
+ window_size (tuple[int]): The height and width of the window.
72
+ num_heads (int): Number of attention heads.
73
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
74
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
75
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
76
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
77
+ """
78
+
79
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
80
+
81
+ super().__init__()
82
+ self.dim = dim
83
+ self.window_size = window_size # Wh, Ww
84
+ self.num_heads = num_heads
85
+ head_dim = dim // num_heads
86
+ self.scale = qk_scale or head_dim ** -0.5
87
+
88
+ # define a parameter table of relative position bias
89
+ self.relative_position_bias_table = nn.Parameter(
90
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
91
+
92
+ # get pair-wise relative position index for each token inside the window
93
+ coords_h = torch.arange(self.window_size[0])
94
+ coords_w = torch.arange(self.window_size[1])
95
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
96
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
97
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
98
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
99
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
100
+ relative_coords[:, :, 1] += self.window_size[1] - 1
101
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
102
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
103
+ self.register_buffer("relative_position_index", relative_position_index)
104
+
105
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
106
+ self.attn_drop = nn.Dropout(attn_drop)
107
+ self.proj = nn.Linear(dim, dim)
108
+ self.proj_drop = nn.Dropout(proj_drop)
109
+
110
+ trunc_normal_(self.relative_position_bias_table, std=.02)
111
+ self.softmax = nn.Softmax(dim=-1)
112
+
113
+ def forward(self, x, mask=None):
114
+ """
115
+ Args:
116
+ x: input features with shape of (num_windows*B, N, C)
117
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
118
+ """
119
+ B_, N, C = x.shape
120
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
121
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
122
+
123
+ q = q * self.scale
124
+ attn = (q @ k.transpose(-2, -1))
125
+
126
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
127
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
128
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
129
+ attn = attn + relative_position_bias.unsqueeze(0)
130
+
131
+ if mask is not None:
132
+ nW = mask.shape[0]
133
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
134
+ attn = attn.view(-1, self.num_heads, N, N)
135
+ attn = self.softmax(attn)
136
+ else:
137
+ attn = self.softmax(attn)
138
+
139
+ attn = self.attn_drop(attn)
140
+
141
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
142
+ x = self.proj(x)
143
+ x = self.proj_drop(x)
144
+ return x
145
+
146
+ def extra_repr(self) -> str:
147
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
148
+
149
+ def flops(self, N):
150
+ # calculate flops for 1 window with token length of N
151
+ flops = 0
152
+ # qkv = self.qkv(x)
153
+ flops += N * self.dim * 3 * self.dim
154
+ # attn = (q @ k.transpose(-2, -1))
155
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
156
+ # x = (attn @ v)
157
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
158
+ # x = self.proj(x)
159
+ flops += N * self.dim * self.dim
160
+ return flops
161
+
162
+
163
+ class SwinTransformerBlock(nn.Module):
164
+ r""" Swin Transformer Block.
165
+
166
+ Args:
167
+ dim (int): Number of input channels.
168
+ input_resolution (tuple[int]): Input resulotion.
169
+ num_heads (int): Number of attention heads.
170
+ window_size (int): Window size.
171
+ shift_size (int): Shift size for SW-MSA.
172
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
173
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
174
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
175
+ drop (float, optional): Dropout rate. Default: 0.0
176
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
177
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
178
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
179
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
180
+ """
181
+
182
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
183
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
184
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
185
+ super().__init__()
186
+ self.dim = dim
187
+ self.input_resolution = input_resolution
188
+ self.num_heads = num_heads
189
+ self.window_size = window_size
190
+ self.shift_size = shift_size
191
+ self.mlp_ratio = mlp_ratio
192
+ if min(self.input_resolution) <= self.window_size:
193
+ # if window size is larger than input resolution, we don't partition windows
194
+ self.shift_size = 0
195
+ self.window_size = min(self.input_resolution)
196
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
197
+
198
+ self.norm1 = norm_layer(dim)
199
+ self.attn = WindowAttention(
200
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
201
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
202
+
203
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
204
+ self.norm2 = norm_layer(dim)
205
+ mlp_hidden_dim = int(dim * mlp_ratio)
206
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
207
+
208
+ if self.shift_size > 0:
209
+ # calculate attention mask for SW-MSA
210
+ H, W = self.input_resolution
211
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
212
+ h_slices = (slice(0, -self.window_size),
213
+ slice(-self.window_size, -self.shift_size),
214
+ slice(-self.shift_size, None))
215
+ w_slices = (slice(0, -self.window_size),
216
+ slice(-self.window_size, -self.shift_size),
217
+ slice(-self.shift_size, None))
218
+ cnt = 0
219
+ for h in h_slices:
220
+ for w in w_slices:
221
+ img_mask[:, h, w, :] = cnt
222
+ cnt += 1
223
+
224
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
225
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
226
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
227
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
228
+ else:
229
+ attn_mask = None
230
+
231
+ self.register_buffer("attn_mask", attn_mask)
232
+
233
+ def forward(self, x):
234
+ H, W = self.input_resolution
235
+ B, L, C = x.shape
236
+ assert L == H * W, "input feature has wrong size"
237
+
238
+ shortcut = x
239
+ x = self.norm1(x)
240
+ x = x.view(B, H, W, C)
241
+
242
+ # cyclic shift
243
+ if self.shift_size > 0:
244
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
245
+ else:
246
+ shifted_x = x
247
+
248
+ # partition windows
249
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
250
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
251
+
252
+ # W-MSA/SW-MSA
253
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
254
+
255
+ # merge windows
256
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
257
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
258
+
259
+ # reverse cyclic shift
260
+ if self.shift_size > 0:
261
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
262
+ else:
263
+ x = shifted_x
264
+ x = x.view(B, H * W, C)
265
+
266
+ # FFN
267
+ x = shortcut + self.drop_path(x)
268
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
269
+
270
+ return x
271
+
272
+ def extra_repr(self) -> str:
273
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
274
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
275
+
276
+ def flops(self):
277
+ flops = 0
278
+ H, W = self.input_resolution
279
+ # norm1
280
+ flops += self.dim * H * W
281
+ # W-MSA/SW-MSA
282
+ nW = H * W / self.window_size / self.window_size
283
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
284
+ # mlp
285
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
286
+ # norm2
287
+ flops += self.dim * H * W
288
+ return flops
289
+
290
+
291
+ class PatchMerging(nn.Module):
292
+ r""" Patch Merging Layer.
293
+
294
+ Args:
295
+ input_resolution (tuple[int]): Resolution of input feature.
296
+ dim (int): Number of input channels.
297
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
298
+ """
299
+
300
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
301
+ super().__init__()
302
+ self.input_resolution = input_resolution
303
+ self.dim = dim
304
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
305
+ self.norm = norm_layer(4 * dim)
306
+
307
+ def forward(self, x):
308
+ """
309
+ x: B, H*W, C
310
+ """
311
+ H, W = self.input_resolution
312
+ B, L, C = x.shape
313
+ assert L == H * W, "input feature has wrong size"
314
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
315
+
316
+ x = x.view(B, H, W, C)
317
+
318
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
319
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
320
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
321
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
322
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
323
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
324
+
325
+ x = self.norm(x)
326
+ x = self.reduction(x)
327
+
328
+ return x
329
+
330
+ def extra_repr(self) -> str:
331
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
332
+
333
+ def flops(self):
334
+ H, W = self.input_resolution
335
+ flops = H * W * self.dim
336
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
337
+ return flops
338
+
339
+
340
+ class BasicLayer(nn.Module):
341
+ """ A basic Swin Transformer layer for one stage.
342
+
343
+ Args:
344
+ dim (int): Number of input channels.
345
+ input_resolution (tuple[int]): Input resolution.
346
+ depth (int): Number of blocks.
347
+ num_heads (int): Number of attention heads.
348
+ window_size (int): Local window size.
349
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
350
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
351
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
352
+ drop (float, optional): Dropout rate. Default: 0.0
353
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
354
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
355
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
356
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
357
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
358
+ """
359
+
360
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
361
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
362
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
363
+
364
+ super().__init__()
365
+ self.dim = dim
366
+ self.input_resolution = input_resolution
367
+ self.depth = depth
368
+ self.use_checkpoint = use_checkpoint
369
+
370
+ # build blocks
371
+ self.blocks = nn.ModuleList([
372
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
373
+ num_heads=num_heads, window_size=window_size,
374
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
375
+ mlp_ratio=mlp_ratio,
376
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
377
+ drop=drop, attn_drop=attn_drop,
378
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
379
+ norm_layer=norm_layer)
380
+ for i in range(depth)])
381
+
382
+ # patch merging layer
383
+ if downsample is not None:
384
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
385
+ else:
386
+ self.downsample = None
387
+
388
+ def forward(self, x):
389
+ feats = []
390
+ for blk in self.blocks:
391
+ if self.use_checkpoint:
392
+ x = checkpoint.checkpoint(blk, x)
393
+ else:
394
+ x = blk(x)
395
+ feats.append(x.clone().detach())
396
+ if self.downsample is not None:
397
+ x = self.downsample(x)
398
+ return feats, x
399
+
400
+ def extra_repr(self) -> str:
401
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
402
+
403
+ def flops(self):
404
+ flops = 0
405
+ for blk in self.blocks:
406
+ flops += blk.flops()
407
+ if self.downsample is not None:
408
+ flops += self.downsample.flops()
409
+ return flops
410
+
411
+
412
+ class PatchEmbed(nn.Module):
413
+ r""" Image to Patch Embedding
414
+
415
+ Args:
416
+ img_size (int): Image size. Default: 224.
417
+ patch_size (int): Patch token size. Default: 4.
418
+ in_chans (int): Number of input image channels. Default: 3.
419
+ embed_dim (int): Number of linear projection output channels. Default: 96.
420
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
421
+ """
422
+
423
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
424
+ super().__init__()
425
+ img_size = to_2tuple(img_size)
426
+ patch_size = to_2tuple(patch_size)
427
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
428
+ self.img_size = img_size
429
+ self.patch_size = patch_size
430
+ self.patches_resolution = patches_resolution
431
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
432
+
433
+ self.in_chans = in_chans
434
+ self.embed_dim = embed_dim
435
+
436
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
437
+ if norm_layer is not None:
438
+ self.norm = norm_layer(embed_dim)
439
+ else:
440
+ self.norm = None
441
+
442
+ def forward(self, x):
443
+ B, C, H, W = x.shape
444
+ # FIXME look at relaxing size constraints
445
+ assert H == self.img_size[0] and W == self.img_size[1], \
446
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
447
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
448
+ if self.norm is not None:
449
+ x = self.norm(x)
450
+ return x
451
+
452
+ def flops(self):
453
+ Ho, Wo = self.patches_resolution
454
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
455
+ if self.norm is not None:
456
+ flops += Ho * Wo * self.embed_dim
457
+ return flops
458
+
459
+
460
+ class SwinTransformer(nn.Module):
461
+ r""" Swin Transformer
462
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
463
+ https://arxiv.org/pdf/2103.14030
464
+
465
+ Args:
466
+ img_size (int | tuple(int)): Input image size. Default 224
467
+ patch_size (int | tuple(int)): Patch size. Default: 4
468
+ in_chans (int): Number of input image channels. Default: 3
469
+ num_classes (int): Number of classes for classification head. Default: 1000
470
+ embed_dim (int): Patch embedding dimension. Default: 96
471
+ depths (tuple(int)): Depth of each Swin Transformer layer.
472
+ num_heads (tuple(int)): Number of attention heads in different layers.
473
+ window_size (int): Window size. Default: 7
474
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
475
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
476
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
477
+ drop_rate (float): Dropout rate. Default: 0
478
+ attn_drop_rate (float): Attention dropout rate. Default: 0
479
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
480
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
481
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
482
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
483
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
484
+ """
485
+
486
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
487
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
488
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
489
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
490
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
491
+ use_checkpoint=False, feat_ids=[1, 2, 3, 4], **kwargs):
492
+ super().__init__()
493
+
494
+ self.num_classes = num_classes
495
+ self.num_layers = len(depths)
496
+ self.embed_dim = embed_dim
497
+ self.ape = ape
498
+ self.patch_norm = patch_norm
499
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
500
+ self.mlp_ratio = mlp_ratio
501
+
502
+ # split image into non-overlapping patches
503
+ self.patch_embed = PatchEmbed(
504
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
505
+ norm_layer=norm_layer if self.patch_norm else None)
506
+ num_patches = self.patch_embed.num_patches
507
+ patches_resolution = self.patch_embed.patches_resolution
508
+ self.patches_resolution = patches_resolution
509
+
510
+ # absolute position embedding
511
+ if self.ape:
512
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
513
+ trunc_normal_(self.absolute_pos_embed, std=.02)
514
+
515
+ self.pos_drop = nn.Dropout(p=drop_rate)
516
+
517
+ # stochastic depth
518
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
519
+
520
+ # build layers
521
+ self.layers = nn.ModuleList()
522
+ for i_layer in range(self.num_layers):
523
+ layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
524
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
525
+ patches_resolution[1] // (2 ** i_layer)),
526
+ depth=depths[i_layer],
527
+ num_heads=num_heads[i_layer],
528
+ window_size=window_size,
529
+ mlp_ratio=self.mlp_ratio,
530
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
531
+ drop=drop_rate, attn_drop=attn_drop_rate,
532
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
533
+ norm_layer=norm_layer,
534
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
535
+ use_checkpoint=use_checkpoint)
536
+ self.layers.append(layer)
537
+
538
+ self.norm = norm_layer(self.num_features)
539
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
540
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
541
+ self.feat_ids = feat_ids
542
+
543
+ self.apply(self._init_weights)
544
+
545
+ def _init_weights(self, m):
546
+ if isinstance(m, nn.Linear):
547
+ trunc_normal_(m.weight, std=.02)
548
+ if isinstance(m, nn.Linear) and m.bias is not None:
549
+ nn.init.constant_(m.bias, 0)
550
+ elif isinstance(m, nn.LayerNorm):
551
+ nn.init.constant_(m.bias, 0)
552
+ nn.init.constant_(m.weight, 1.0)
553
+
554
+ @torch.jit.ignore
555
+ def no_weight_decay(self):
556
+ return {'absolute_pos_embed'}
557
+
558
+ @torch.jit.ignore
559
+ def no_weight_decay_keywords(self):
560
+ return {'relative_position_bias_table'}
561
+
562
+ def forward_features(self, x):
563
+ x = self.patch_embed(x)
564
+ if self.ape:
565
+ x = x + self.absolute_pos_embed
566
+ x = self.pos_drop(x)
567
+
568
+ self.feat_maps = []
569
+ for i, layer in enumerate(self.layers):
570
+ feats, x = layer(x)
571
+ if i+1 in self.feat_ids:
572
+ self.feat_maps += feats
573
+
574
+ x = self.norm(x) # B L C
575
+ x = self.avgpool(x.transpose(1, 2)) # B C 1
576
+ x = torch.flatten(x, 1)
577
+ return x
578
+
579
+ def forward(self, x):
580
+ x = self.forward_features(x)
581
+ x = self.head(x)
582
+ return x
583
+
584
+ def flops(self):
585
+ flops = 0
586
+ flops += self.patch_embed.flops()
587
+ for i, layer in enumerate(self.layers):
588
+ flops += layer.flops()
589
+ flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
590
+ flops += self.num_features * self.num_classes
591
+ return flops
592
+
593
+
594
+ if __name__ == '__main__':
595
+ input = torch.randn(2, 3, 384, 384).cuda()
596
+
597
+ net = SwinTransformer(img_size=384, patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32))
598
+ net.load_state_dict(torch.load("/apdcephfs/share_1290796/shixinyu/checkpoints/swin_base_patch4_window12_384_22kto1k.pth")['model'])
599
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
600
+ net.to(device)
601
+
602
+ out = net.forward_features(input)
603
+ feat = net.feat_maps
604
+ for x in feat:
605
+ print(x.shape)
model/base/transformer.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+ import math, copy
6
+ from torch.autograd import Variable
7
+
8
+ TRAIN = False
9
+
10
+ class MultiHeadedAttention(nn.Module):
11
+ def __init__(self, h, d_model, dropout=0.1):
12
+ "Take in model size and number of heads."
13
+ super(MultiHeadedAttention, self).__init__()
14
+ assert d_model % h == 0
15
+ # We assume d_v always equals d_k
16
+ self.d_k = d_model // h
17
+ self.h = h
18
+ self.linears = clones(nn.Linear(d_model, d_model), 2)
19
+ self.attn = None
20
+ self.dropout = nn.Dropout(p=dropout)
21
+
22
+ def forward(self, query, key, value, mask=None):
23
+ if mask is not None:
24
+ # Same mask applied to all h heads.
25
+ mask = mask.unsqueeze(1)
26
+ nbatches = query.size(0)
27
+
28
+ # 1) Do all the linear projections in batch from d_model => h x d_k
29
+ query, key = \
30
+ [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
31
+ for l, x in zip(self.linears, (query, key))]
32
+
33
+ value = value.repeat(self.h, 1, 1).transpose(0, 1).contiguous().unsqueeze(-1)
34
+
35
+ # query_dir, key_dir = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) for l, x in zip([self.linears[0], self.linears[0]], (query, key))]
36
+ # query_norm = self.linears[1](query)[:, :, :self.h].view(nbatches, -1, self.h).transpose(1, 2)
37
+ # key_norm = self.linears[1](key)[:, :, :self.h].view(nbatches, -1, self.h).transpose(1, 2)
38
+ # query = query_dir / query_dir.norm(dim=-1).unsqueeze(-1) * 10 * query_norm.unsqueeze(-1)
39
+ # key = key_dir / key_dir.norm(dim=-1).unsqueeze(-1) * 10 * key_norm.unsqueeze(-1)
40
+
41
+ if not TRAIN:
42
+ query = query.detach().cpu()
43
+ key = key.detach().cpu()
44
+ value = value.detach().cpu()
45
+ # 2) Apply attention on all the projected vectors in batch.
46
+ x, self.attn = attention(query, key, value, mask=mask,
47
+ dropout=self.dropout)
48
+ if not TRAIN:
49
+ x = x.cuda()
50
+
51
+ # 3) "Concat" using a view and apply a final linear.
52
+ return torch.mean(x, -3)
53
+
54
+
55
+ class PositionalEncoding(nn.Module):
56
+ "Implement the PE function."
57
+
58
+ def __init__(self, d_model, dropout, max_len=10000):
59
+ super(PositionalEncoding, self).__init__()
60
+ self.dropout = nn.Dropout(p=dropout)
61
+
62
+ # Compute the positional encodings once in log space.
63
+ pe = torch.zeros(max_len, d_model)
64
+ position = torch.arange(0, max_len).unsqueeze(1)
65
+ div_term = torch.exp(torch.arange(0, d_model, 2) *
66
+ -(math.log(10000.0) / d_model))
67
+ pe[:, 0::2] = torch.sin(position * div_term)
68
+ pe[:, 1::2] = torch.cos(position * div_term)
69
+ pe = pe.unsqueeze(0)
70
+ self.register_buffer('pe', pe)
71
+
72
+ def forward(self, x):
73
+ x = x + Variable(self.pe[:, :x.size(1)],
74
+ requires_grad=False)
75
+ return self.dropout(x)
76
+
77
+ importance = torch.tensor(0.).float().cuda()
78
+ cnt = 0
79
+
80
+ def attention(query, key, value, mask=None, dropout=None):
81
+ "Compute 'Scaled Dot Product Attention'"
82
+ d_k = query.size(-1)
83
+ scores = torch.matmul(query, key.transpose(-2, -1)) \
84
+ / math.sqrt(d_k)
85
+ if mask is not None:
86
+ scores = scores.masked_fill(mask == 0, -1e9)
87
+ p_attn = F.softmax(scores, dim=-1)
88
+ # global importance, cnt
89
+ # im = p_attn[:, :, :, :query.size(2)].max(2)[0].mean()
90
+ # importance += im
91
+ # cnt += 1
92
+ if dropout is not None:
93
+ p_attn = dropout(p_attn)
94
+ return torch.matmul(p_attn, value), p_attn
95
+
96
+
97
+ def clones(module, N):
98
+ "Produce N identical layers."
99
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
modelsize_estimate.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+
6
+ def modelsize(model, input, type_size=4):
7
+ para = sum([np.prod(list(p.size())) for p in model.parameters()])
8
+ # print('Model {} : Number of params: {}'.format(model._get_name(), para))
9
+ print('Model {} : params: {:4f}M'.format(model._get_name(), para * type_size / 1000 / 1000))
10
+
11
+ input_ = input.clone()
12
+ input_.requires_grad_(requires_grad=False)
13
+
14
+ mods = list(model.modules())
15
+ out_sizes = []
16
+
17
+ for i in range(1, len(mods)):
18
+ m = mods[i]
19
+ if isinstance(m, nn.ReLU):
20
+ if m.inplace:
21
+ continue
22
+ out = m(input_)
23
+ out_sizes.append(np.array(out.size()))
24
+ input_ = out
25
+
26
+ total_nums = 0
27
+ for i in range(len(out_sizes)):
28
+ s = out_sizes[i]
29
+ nums = np.prod(np.array(s))
30
+ total_nums += nums
31
+
32
+ # print('Model {} : Number of intermedite variables without backward: {}'.format(model._get_name(), total_nums))
33
+ # print('Model {} : Number of intermedite variables with backward: {}'.format(model._get_name(), total_nums*2))
34
+ print('Model {} : intermedite variables: {:3f} M (without backward)'
35
+ .format(model._get_name(), total_nums * type_size / 1000 / 1000))
36
+ print('Model {} : intermedite variables: {:3f} M (with backward)'
37
+ .format(model._get_name(), total_nums * type_size*2 / 1000 / 1000))
38
+
scripts/importance_analysis.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python ./importance_analysis.py --datapath "/research/d4/gds/wltang21/data" \
2
+ --benchmark coco \
3
+ --fold 0 \
4
+ --bsz 1 \
5
+ --nworker 1 \
6
+ --backbone resnet50 \
7
+ --feature_extractor_path "/research/d4/gds/wltang21/logistic_project/DCAMA/backbones/resnet50_a1h-35c100f8.pth" \
8
+ --logpath "./logs" \
9
+ --load "/research/d4/gds/wltang21/logistic_project/DCAMA/checkpoint/coco-20i/resnet50_fold0.pt" \
10
+ --nshot 10
11
+
12
+
13
+ # --load "/research/d6/rshr/xjgao/twl/logistic_project/DCAMA/checkpoint/coco-20i/resnet50_fold0.pt" \
14
+ # --visualize
15
+ #checkpoint/coco-20i/resnet50_fold0.pt
16
+ # log/train/fold_0_ft_v0/best_model.pt
scripts/test.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python ./test.py --datapath "/research/d4/gds/wltang21/data" \
2
+ --benchmark coco \
3
+ --fold 0 \
4
+ --bsz 1 \
5
+ --nworker 1 \
6
+ --backbone resnet50 \
7
+ --feature_extractor_path "/research/d4/gds/wltang21/logistic_project/DCAMA/backbones/resnet50_a1h-35c100f8.pth" \
8
+ --logpath "./logs" \
9
+ --load "/research/d4/gds/wltang21/logistic_project/DCAMA/checkpoint/coco-20i/resnet50_fold0.pt" \
10
+ --nshot 30
11
+
12
+ # --load "/research/d6/rshr/xjgao/twl/logistic_project/DCAMA/checkpoint/coco-20i/resnet50_fold0.pt" \
13
+ # --visualize
14
+ #checkpoint/coco-20i/resnet50_fold0.pt
15
+ # log/train/fold_0_ft_v0/best_model.pt
scripts/train.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python -u -m torch.distributed.launch --nnodes=1 --nproc_per_node=4 --node_rank=0 --master_port=16006 \
2
+ ./train.py --datapath "../datasets" \
3
+ --benchmark coco \
4
+ --fold 0 \
5
+ --bsz 12 \
6
+ --nworker 8 \
7
+ --backbone swin \
8
+ --feature_extractor_path "../backbones/swin_base_patch4_window12_384.pth" \
9
+ --logpath "./logs" \
10
+ --lr 1e-3 \
11
+ --nepoch 500
scripts/train_1gpu.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python ./train_1gpu.py --datapath "/home/bkdongxianchi/MY_MOT/TWL/data" \
2
+ --benchmark coco \
3
+ --fold 0 \
4
+ --bsz 1 \
5
+ --nworker 0 \
6
+ --backbone resnet50 \
7
+ --feature_extractor_path "/home/bkdongxianchi/MY_MOT/TWL/DCAMA/backbones/resnet50_a1h-35c100f8.pth" \
8
+ --logpath "/home/bkdongxianchi/MY_MOT/TWL/DCAMA/log" \
9
+ --lr 1e-4 \
10
+ --nepoch 50 \
11
+ --load "/home/bkdongxianchi/MY_MOT/TWL/DCAMA/log/resnet50_fold0.pt" \
12
+ --nshot 3
scripts/train_1gpu_retriver.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python ./train_1gpu_retriever.py --datapath "/home/bkdongxianchi/MY_MOT/TWL/data" \
2
+ --benchmark coco \
3
+ --fold 1 \
4
+ --bsz 1 \
5
+ --nworker 0 \
6
+ --backbone resnet50 \
7
+ --feature_extractor_path "/home/bkdongxianchi/MY_MOT/TWL/DCAMA/backbones/resnet50_a1h-35c100f8.pth" \
8
+ --logpath "/home/bkdongxianchi/MY_MOT/TWL/DCAMA/log" \
9
+ --lr 1e-4 \
10
+ --nepoch 50 \
11
+ --load "/home/bkdongxianchi/MY_MOT/TWL/DCAMA/log/fold_1_ft_v0/model_45.pt" \
12
+ --nshot 1
scripts/train_2gpu.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python -u -m torch.distributed.launch --nproc_per_node=2 --master_port=18024 \
2
+ ./train.py --datapath "/home/bkdongxianchi/MY_MOT/TWL/data" \
3
+ --benchmark coco \
4
+ --fold 0 \
5
+ --bsz 1 \
6
+ --nworker 8 \
7
+ --backbone resnet50 \
8
+ --feature_extractor_path "/home/bkdongxianchi/MY_MOT/TWL/DCAMA/backbones/resnet50_a1h-35c100f8.pth" \
9
+ --logpath "/home/bkdongxianchi/MY_MOT/TWL/DCAMA/log" \
10
+ --lr 1e-4 \
11
+ --nepoch 50 \
12
+ --load "/home/bkdongxianchi/MY_MOT/TWL/DCAMA/checkpoint/coco-20i/resnet50_fold0.pt" \
13
+ --nshot 10
14
+ # --load "/research/d6/rshr/xjgao/twl/logistic_project/DCAMA/checkpoint/coco-20i/resnet50_fold0.pt" \
scripts/train_2gpu_retriever.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python -u -m torch.distributed.launch --nproc_per_node=2 --master_port=18024 \
2
+ ./train_retriever.py --datapath "/home/bkdongxianchi/MY_MOT/TWL/data" \
3
+ --benchmark coco \
4
+ --fold 0 \
5
+ --bsz 1 \
6
+ --nworker 8 \
7
+ --backbone resnet50 \
8
+ --feature_extractor_path "/home/bkdongxianchi/MY_MOT/TWL/DCAMA/backbones/resnet50_a1h-35c100f8.pth" \
9
+ --logpath "/home/bkdongxianchi/MY_MOT/TWL/DCAMA/log" \
10
+ --lr 1e-4 \
11
+ --nepoch 50 \
12
+ --load "/home/bkdongxianchi/MY_MOT/TWL/DCAMA/checkpoint/coco-20i/resnet50_fold0.pt" \
13
+ --nshot 1
14
+ # --load "/research/d6/rshr/xjgao/twl/logistic_project/DCAMA/checkpoint/coco-20i/resnet50_fold0.pt" \
scripts/train_4gpu.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python -u -m torch.distributed.launch --nproc_per_node=4 --master_port=18024 \
2
+ ./train.py --datapath "~/MY_MOT/TWL/data" \
3
+ --benchmark coco \
4
+ --fold 0 \
5
+ --bsz 1 \
6
+ --nworker 8 \
7
+ --backbone resnet101 \
8
+ --feature_extractor_path "~/MY_MOT/TWL/logistic_project/DCAMA/backbones/swin_base_patch4_window12_384_22kto1k.pth" \
9
+ --logpath "~/MY_MOT/TWL/logistic_project/DCAMA/log" \
10
+ --lr 1e-4 \
11
+ --nepoch 50 \
12
+ --load "~/MY_MOT/TWL/logistic_project/DCAMA/checkpoint/coco-20i/swin_fold2.pt" \
13
+ --nshot 3
14
+ # --load "/research/d6/rshr/xjgao/twl/logistic_project/DCAMA/checkpoint/coco-20i/resnet50_fold0.pt" \
test.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Dense Cross-Query-and-Support Attention Weighted Mask Aggregation for Few-Shot Segmentation """
2
+ import torch.nn as nn
3
+ import torch
4
+
5
+ from model.DCAMA import DCAMA
6
+ from common.logger import Logger, AverageMeter
7
+ from common.vis import Visualizer
8
+ from common.evaluation import Evaluator
9
+ from common.config import parse_opts
10
+ from common import utils
11
+ from data.dataset import FSSDataset
12
+ import cv2
13
+ import numpy as np
14
+ import os
15
+ # from gpu_mem_track import MemTracker
16
+
17
+ # gpu_tracker = MemTracker()
18
+ def test(model, dataloader, nshot):
19
+ r""" Test """
20
+
21
+ # Freeze randomness during testing for reproducibility
22
+ utils.fix_randseed(0)
23
+ average_meter = AverageMeter(dataloader.dataset)
24
+
25
+ for idx, batch in enumerate(dataloader):
26
+
27
+ # 1. forward pass
28
+ nshot = batch['support_imgs'].size(1)
29
+ ## TODO:
30
+
31
+ batch = utils.to_cuda(batch)
32
+ # gpu_tracker.track()
33
+ pred_mask, simi, simi_map = model.module.predict_mask_nshot(batch, nshot=nshot)
34
+ # gpu_tracker.track()
35
+ torch.cuda.synchronize()
36
+ assert pred_mask.size() == batch['query_mask'].size()
37
+
38
+ # 2. Evaluate prediction
39
+ area_inter, area_union = Evaluator.classify_prediction(pred_mask.clone(), batch)
40
+
41
+ ## TODO:
42
+ iou = area_inter[1] / area_union[1]
43
+
44
+ '''
45
+ cv2.imwrite('debug/query.png', cv2.imread("/home/bkdongxianchi/MY_MOT/TWL/data/COCO2014/{}".format(batch['query_name'][0])))
46
+ cv2.imwrite('debug/query_mask.png', (batch['query_mask'][0] * 255).detach().cpu().numpy().astype(np.uint8))
47
+ cv2.imwrite('debug/support_{:.3}.png'.format(iou.item()), cv2.imread('/home/bkdongxianchi/MY_MOT/TWL/data/COCO2014/{}'.format(batch['support_names'][0][0])))
48
+ cv2.imwrite('debug/support_mask_{:.3}.png'.format(iou.item()), (batch['support_masks'][0][0] * 255).detach().cpu().numpy().astype(np.uint8))
49
+ simi_map = simi_map - simi_map.min()
50
+ simi_map = (simi_map / simi_map.max() * 255).detach().cpu().numpy().astype(np.uint8)
51
+ cv2.imwrite('debug/simi_map_{:.3}.png'.format(iou.item()), simi_map)
52
+
53
+ if os.path.exists('debug/stats.txt'):
54
+ with open('debug/stats.txt', "a") as f:
55
+ f.write("{} {}\n".format(simi.item(), iou.item()))
56
+ else:
57
+ with open('debug/stats.txt', 'w') as f:
58
+ f.write('{} {}\n'.format(simi.item(), iou.item()))
59
+ '''
60
+
61
+ average_meter.update(area_inter, area_union, batch['class_id'], loss=None)
62
+ average_meter.write_process(idx, len(dataloader), epoch=-1, write_batch_idx=1)
63
+
64
+ # Visualize predictions
65
+ if Visualizer.visualize:
66
+ Visualizer.visualize_prediction_batch(batch['support_imgs'], batch['support_masks'],
67
+ batch['query_img'], batch['query_mask'],
68
+ pred_mask, batch['class_id'], idx,
69
+ iou_b=area_inter[1].float() / area_union[1].float())
70
+
71
+ # Write evaluation results
72
+ average_meter.write_result('Test', 0)
73
+ miou, fb_iou = average_meter.compute_iou()
74
+
75
+ return miou, fb_iou
76
+
77
+
78
+ if __name__ == '__main__':
79
+
80
+ # Arguments parsing
81
+ args = parse_opts()
82
+
83
+ Logger.initialize(args, training=False)
84
+
85
+ # Model initialization
86
+ model = DCAMA(args.backbone, args.feature_extractor_path, args.use_original_imgsize)
87
+ model.eval()
88
+
89
+ # Device setup
90
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
91
+ Logger.info('# available GPUs: %d' % torch.cuda.device_count())
92
+ model = nn.DataParallel(model)
93
+ model.to(device)
94
+
95
+ # Load trained model
96
+ if args.load == '': raise Exception('Pretrained model not specified.')
97
+ params = model.state_dict()
98
+ state_dict = torch.load(args.load)
99
+
100
+ if 'state_dict' in state_dict.keys():
101
+ state_dict = state_dict['state_dict']
102
+ state_dict2 = {}
103
+ for k, v in state_dict.items():
104
+ if 'scorer' not in k:
105
+ state_dict2[k] = v
106
+ state_dict = state_dict2
107
+
108
+ for k1, k2 in zip(list(state_dict.keys()), params.keys()):
109
+ state_dict[k2] = state_dict.pop(k1)
110
+
111
+
112
+ try:
113
+ model.load_state_dict(state_dict, strict=True)
114
+ except:
115
+ for k in params.keys():
116
+ if k not in state_dict.keys():
117
+ state_dict[k] = params[k]
118
+ model.load_state_dict(state_dict, strict=True)
119
+
120
+ # Helper classes (for testing) initialization
121
+ Evaluator.initialize()
122
+ Visualizer.initialize(args.visualize, args.vispath)
123
+
124
+ # Dataset initialization
125
+ FSSDataset.initialize(img_size=384, datapath=args.datapath, use_original_imgsize=args.use_original_imgsize)
126
+ dataloader_test = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'test', args.nshot)
127
+
128
+ # Test
129
+ with torch.no_grad():
130
+ test_miou, test_fb_iou = test(model, dataloader_test, args.nshot)
131
+ Logger.info('Fold %d mIoU: %5.2f \t FB-IoU: %5.2f' % (args.fold, test_miou.item(), test_fb_iou.item()))
132
+ Logger.info('==================== Finished Testing ====================')
train.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" training (validation) code """
2
+ import torch.optim as optim
3
+ import torch.nn as nn
4
+ import torch
5
+
6
+ from model.DCAMA import DCAMA
7
+ from common.logger import Logger, AverageMeter
8
+ from common.evaluation import Evaluator
9
+ from common.config import parse_opts
10
+ from common import utils
11
+ from data.dataset import FSSDataset
12
+ import torch.nn.functional as F
13
+
14
+ average_loss = torch.tensor(0.).float().cuda()
15
+ global_idx = 0
16
+ def train(epoch, model, dataloader, optimizer, training):
17
+ r""" Train """
18
+
19
+ # Force randomness during training / freeze randomness during testing
20
+ utils.fix_randseed(None) if training else utils.fix_randseed(0)
21
+ model.module.train_mode() if training else model.module.eval()
22
+ average_meter = AverageMeter(dataloader.dataset)
23
+
24
+ global average_loss, global_idx
25
+ average_loss = average_loss.to("cuda:{}".format(torch.cuda.current_device()))
26
+ stats = [[], []]
27
+ for idx, batch in enumerate(dataloader):
28
+
29
+ # 1. forward pass
30
+ batch = utils.to_cuda(batch)
31
+ logit_mask, score_preds = model(batch['query_img'], batch['support_imgs'], batch['support_masks'], nshot=batch['support_imgs'].size(1))
32
+ pred_mask = logit_mask.argmax(dim=1)
33
+
34
+ # 2. Compute loss & update model parameters
35
+ loss = model.module.compute_objective(logit_mask, batch['query_mask'])
36
+ # loss_obj = loss.detach()
37
+ area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch)
38
+ iou = area_inter[1] / area_union[1]
39
+ loss_obj = iou.detach()
40
+ score_loss = F.l1_loss(score_preds, loss_obj)
41
+ stats[0].append(score_preds.detach().cpu().numpy())
42
+ stats[1].append(loss_obj.detach().cpu().numpy()[0])
43
+ if global_idx == 0:
44
+ average_loss = loss_obj.detach()
45
+ global_idx += 1
46
+ else:
47
+ average_loss = loss_obj.detach() * 0.05 + 0.95 * average_loss
48
+ print(loss_obj.item(), " ", score_preds.item(), " ", score_loss.item())
49
+ loss += score_loss
50
+
51
+ if training:
52
+ optimizer.zero_grad()
53
+ loss.backward()
54
+ optimizer.step()
55
+
56
+ # 3. Evaluate prediction
57
+ area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch)
58
+ average_meter.update(area_inter, area_union, batch['class_id'], loss.detach().clone())
59
+ average_meter.write_process(idx, len(dataloader), epoch, write_batch_idx=50)
60
+
61
+ # Write evaluation results
62
+ average_meter.write_result('Training' if training else 'Validation', epoch)
63
+ avg_loss = utils.mean(average_meter.loss_buf)
64
+ miou, fb_iou = average_meter.compute_iou()
65
+ import matplotlib.pyplot as plt
66
+ idx = 0
67
+ plt.scatter(stats[0], stats[1], c="red", s=2, alpha=0.1)
68
+ plt.savefig('stat.png')
69
+ plt.close()
70
+
71
+ return avg_loss, miou, fb_iou
72
+
73
+
74
+ if __name__ == '__main__':
75
+
76
+ # Arguments parsing
77
+ args = parse_opts()
78
+
79
+ # ddp backend initialization
80
+ torch.distributed.init_process_group(backend='nccl')
81
+ torch.cuda.set_device(args.local_rank)
82
+
83
+ # Model initialization
84
+ model = DCAMA(args.backbone, args.feature_extractor_path, False)
85
+ device = torch.device("cuda", args.local_rank)
86
+ model.to(device)
87
+ model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank,
88
+ find_unused_parameters=True)
89
+
90
+
91
+ params = model.state_dict()
92
+ state_dict = torch.load(args.load)
93
+ state_dict2 = {}
94
+ for k in state_dict.keys():
95
+ if "scorer" in k:
96
+ continue
97
+ state_dict2[k] = state_dict[k]
98
+ state_dict = state_dict2
99
+ for k1, k2 in zip(list(state_dict.keys()), params.keys()):
100
+ state_dict[k2] = state_dict.pop(k1)
101
+
102
+ model.load_state_dict(state_dict, strict=False)
103
+
104
+ ## TODO:
105
+ for i in range(len(model.module.model.DCAMA_blocks)):
106
+ torch.nn.init.constant_(model.module.model.DCAMA_blocks[i].linears[1].weight, 0.)
107
+ torch.nn.init.constant_(model.module.model.DCAMA_blocks[i].linears[1].bias, 1.)
108
+
109
+
110
+ # Helper classes (for training) initialization
111
+ optimizer = optim.SGD([{"params": model.module.model.parameters(), "lr": args.lr,
112
+ "momentum": 0.9, "weight_decay": args.lr/10, "nesterov": True}])
113
+ Evaluator.initialize()
114
+ if args.local_rank == 0:
115
+ Logger.initialize(args, training=True)
116
+ Logger.info('# available GPUs: %d' % torch.cuda.device_count())
117
+
118
+ # Dataset initialization
119
+ FSSDataset.initialize(img_size=384, datapath=args.datapath, use_original_imgsize=False)
120
+ dataloader_trn = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'trn', args.nshot)
121
+ if args.local_rank == 0:
122
+ dataloader_val = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'val', args.nshot)
123
+
124
+ # Train
125
+ best_val_miou = float('-inf')
126
+ best_val_loss = float('inf')
127
+ for epoch in range(args.nepoch):
128
+ dataloader_trn.sampler.set_epoch(epoch)
129
+ trn_loss, trn_miou, trn_fb_iou = train(epoch, model, dataloader_trn, optimizer, training=True)
130
+
131
+ # evaluation
132
+ if args.local_rank == 0:
133
+ # with torch.no_grad():
134
+ # val_loss, val_miou, val_fb_iou = train(epoch, model, dataloader_val, optimizer, training=False)
135
+
136
+ # Save the best model
137
+ # if val_miou > best_val_miou:
138
+ # best_val_miou = val_miou
139
+ # Logger.save_model_miou(model, epoch, val_miou)
140
+ Logger.save_model_miou(model, epoch , 1.)
141
+
142
+ # Logger.tbd_writer.add_scalars('data/loss', {'trn_loss': trn_loss, 'val_loss': val_loss}, epoch)
143
+ # Logger.tbd_writer.add_scalars('data/miou', {'trn_miou': trn_miou, 'val_miou': val_miou}, epoch)
144
+ # Logger.tbd_writer.add_scalars('data/fb_iou', {'trn_fb_iou': trn_fb_iou, 'val_fb_iou': val_fb_iou}, epoch)
145
+ # Logger.tbd_writer.flush()
146
+
147
+ if args.local_rank == 0:
148
+ Logger.tbd_writer.close()
149
+ Logger.info('==================== Finished Training ====================')
train_1gpu.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" training (validation) code """
2
+ import torch.optim as optim
3
+ import torch.nn as nn
4
+ import torch
5
+
6
+ from model.DCAMA import DCAMA
7
+ from common.logger import Logger, AverageMeter
8
+ from common.evaluation import Evaluator
9
+ from common.config import parse_opts
10
+ from common import utils
11
+ from data.dataset import FSSDataset # FSDataset4SAM
12
+ # from transformers import SamProcessor
13
+ from PIL import Image
14
+ import numpy as np
15
+ import torch.nn.functional as F
16
+ from torchvision import transforms
17
+ import pickle
18
+ import pycocotools.coco as COCO
19
+ import cv2
20
+
21
+ def train(epoch, model, dataloader, optimizer, training, shot=1):
22
+ r""" Train """
23
+
24
+ # Force randomness during training / freeze randomness during testing
25
+ utils.fix_randseed(None) if training else utils.fix_randseed(0)
26
+
27
+ if hasattr(model, "module"):
28
+ model.module.train_mode() if training else model.module.eval()
29
+ else:
30
+ model.train_mode() if training else model.module.eval()
31
+ average_meter = AverageMeter(dataloader.dataset)
32
+ average_loss = torch.tensor(0.).float().cuda()
33
+ stats = [[], []]
34
+ criterion_score = nn.BCEWithLogitsLoss()
35
+ for idx, batch in enumerate(dataloader):
36
+
37
+ # batch = process_batch4SAM(batch)
38
+ shot = batch['support_imgs'].size(1)
39
+ # 1. forward pass
40
+ batch = utils.to_cuda(batch)
41
+ logit_mask, score_preds = model(batch['query_img'], batch['support_imgs'], batch['support_masks'], nshot=shot)
42
+ pred_mask = logit_mask.argmax(dim=1)
43
+ # 2. Compute loss & update model parameters
44
+ loss = model.compute_objective(logit_mask, batch['query_mask'])
45
+ # loss_obj = loss.detach()
46
+
47
+ area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch)
48
+
49
+ iou = (area_inter[1] / area_union[1]).float()
50
+ if iou > 0.7 or iou < 0.1:
51
+ '''
52
+ if iou < 0.1:
53
+ img = batch['query_img'][0].permute(1, 2, 0).detach().cpu().numpy()
54
+ img = img - img.min()
55
+ img = img / img.max()
56
+ cv2.imwrite('query_image.png', (img * 255).astype(np.uint8))
57
+ img = batch['support_imgs'][0][0].permute(1, 2, 0).detach().cpu().numpy()
58
+ img = img - img.min()
59
+ img = img / img.max()
60
+ cv2.imwrite('support_image.png', (img * 255).astype(np.uint8))
61
+ cv2.imwrite('query_mask.png', (batch['query_mask'][0] * 255).detach().cpu().numpy().astype(np.uint8))
62
+ cv2.imwrite('pred_mask.png', (pred_mask[0] * 255).detach().cpu().numpy().astype(np.uint8))
63
+ cv2.imwrite('support_mask.png', (batch['support_masks'][0][0] * 255).detach().cpu().numpy().astype(np.uint8))
64
+ '''
65
+ if iou > 0.7:
66
+ iou = torch.tensor(1.).float().cuda()
67
+ else:
68
+ iou = torch.tensor(0.).float().cuda()
69
+ score_loss = criterion_score(score_preds, iou)
70
+ stats[0].append(score_preds.detach().cpu().numpy())
71
+ stats[1].append((area_inter[1] / area_union[1]).detach().cpu().numpy())
72
+ print(score_preds, (area_inter[1] / area_union[1]))
73
+
74
+
75
+ if training:
76
+ optimizer.zero_grad()
77
+ loss.backward()
78
+ optimizer.step()
79
+ # 3. Evaluate prediction
80
+
81
+ # img = batch['support_imgs'][0][0].permute(1, 2, 0)
82
+ # img = img - img.min()
83
+ # img /= img.max()
84
+ # import cv2
85
+ # cv2.imwrite("debug.png", (img * 255).detach().cpu().numpy())
86
+ # cv2.imwrite("debug2.png", (batch['support_masks'][0][0] * 255).detach().cpu().numpy())
87
+ # import ipdb;ipdb.set_trace()
88
+
89
+ area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch)
90
+ average_meter.update(area_inter, area_union, batch['class_id'], loss.detach().clone())
91
+ average_meter.write_process(idx, len(dataloader), epoch, write_batch_idx=50)
92
+
93
+ # Write evaluation results
94
+ average_meter.write_result('Training' if training else 'Validation', epoch)
95
+ avg_loss = utils.mean(average_meter.loss_buf)
96
+ miou, fb_iou = average_meter.compute_iou()
97
+
98
+ import matplotlib.pyplot as plt
99
+ plt.scatter(stats[0], stats[1], c="red", s=2, alpha=0.02)
100
+ plt.savefig("stats.png")
101
+ return avg_loss, miou, fb_iou
102
+
103
+
104
+ if __name__ == '__main__':
105
+
106
+ # Arguments parsing
107
+ args = parse_opts()
108
+
109
+ # Model initialization
110
+ model = DCAMA(args.backbone, args.feature_extractor_path, False)
111
+ device = torch.device("cuda", args.local_rank)
112
+ model.to(device)
113
+
114
+ params = model.state_dict()
115
+ state_dict = torch.load(args.load)
116
+ if 'state_dict' in state_dict.keys():
117
+ state_dict = state_dict['state_dict']
118
+ state_dict2 = {}
119
+ for k in state_dict.keys():
120
+ if "scorer" in k:
121
+ continue
122
+ state_dict2[k] = state_dict[k]
123
+ state_dict = state_dict2
124
+ for k1, k2 in zip(list(state_dict.keys()), params.keys()):
125
+ state_dict[k2] = state_dict.pop(k1)
126
+ model.load_state_dict(state_dict, strict=False)
127
+
128
+ ## TODO:
129
+ for i in range(len(model.model.DCAMA_blocks)):
130
+ torch.nn.init.constant_(model.model.DCAMA_blocks[i].linears[1].weight, 0.)
131
+ torch.nn.init.constant_(model.model.DCAMA_blocks[i].linears[1].bias, 1.)
132
+ # Helper classes (for training) initialization
133
+ optimizer = optim.SGD([{"params": model.parameters(), "lr": args.lr,
134
+ "momentum": 0.9, "weight_decay": args.lr/10, "nesterov": True}])
135
+ Evaluator.initialize()
136
+ if args.local_rank == 0:
137
+ Logger.initialize(args, training=True)
138
+ Logger.info('# available GPUs: %d' % torch.cuda.device_count())
139
+
140
+ # Dataset initialization
141
+ FSSDataset.initialize(img_size=384, datapath=args.datapath, use_original_imgsize=False)
142
+ dataloader_trn = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'trn', shot=args.nshot)
143
+ if args.local_rank == 0:
144
+ dataloader_val = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'val', shot=args.nshot)
145
+
146
+ # Train
147
+ best_val_miou = float('-inf')
148
+ best_val_loss = float('inf')
149
+
150
+ for epoch in range(args.nepoch):
151
+ trn_loss, trn_miou, trn_fb_iou = train(epoch, model, dataloader_trn, optimizer, training=True, shot=args.nshot)
152
+
153
+ # evaluation
154
+ if args.local_rank == 0:
155
+ # with torch.no_grad():
156
+ # val_loss, val_miou, val_fb_iou = train(epoch, model, dataloader_val, optimizer, training=False)
157
+
158
+ # Save the best model
159
+ # if val_miou > best_val_miou:
160
+ # best_val_miou = val_miou
161
+ Logger.save_model_miou(model, epoch, 1.)
162
+
163
+ # Logger.tbd_writer.add_scalars('data/loss', {'trn_loss': trn_loss, 'val_loss': val_loss}, epoch)
164
+ # Logger.tbd_writer.add_scalars('data/miou', {'trn_miou': trn_miou, 'val_miou': val_miou}, epoch)
165
+ # Logger.tbd_writer.add_scalars('data/fb_iou', {'trn_fb_iou': trn_fb_iou, 'val_fb_iou': val_fb_iou}, epoch)
166
+ # Logger.tbd_writer.flush()
167
+
168
+ if args.local_rank == 0:
169
+ Logger.tbd_writer.close()
170
+ Logger.info('==================== Finished Training ====================')
train_1gpu_retriever.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" training (validation) code """
2
+ import torch.optim as optim
3
+ import torch.nn as nn
4
+ import torch
5
+
6
+ from model.DCAMA import DCAMA
7
+ from common.logger import Logger, AverageMeter
8
+ from common.evaluation import Evaluator
9
+ from common.config import parse_opts
10
+ from common import utils
11
+ from data.dataset import FSSDataset # FSDataset4SAM
12
+ # from transformers import SamProcessor
13
+ from PIL import Image
14
+ import numpy as np
15
+ import torch.nn.functional as F
16
+ from torchvision import transforms
17
+ import pickle
18
+ import pycocotools.coco as COCO
19
+ import cv2
20
+ import torchvision
21
+
22
+ def train(epoch, model, dataloader, optimizer, training, shot=1):
23
+ r""" Train """
24
+
25
+ # Force randomness during training / freeze randomness during testing
26
+ utils.fix_randseed(None) if training else utils.fix_randseed(0)
27
+
28
+ if hasattr(model, "module"):
29
+ model.module.train_mode() if training else model.module.eval()
30
+ else:
31
+ model.train_mode() if training else model.module.eval()
32
+ average_meter = AverageMeter(dataloader.dataset)
33
+ average_loss = torch.tensor(0.).float().cuda()
34
+ stats = [[], []]
35
+ criterion_score = nn.BCEWithLogitsLoss()
36
+ for idx, batch in enumerate(dataloader):
37
+
38
+ # batch = process_batch4SAM(batch)
39
+ shot = batch['support_imgs'].size(1)
40
+ # 1. forward pass
41
+ batch = utils.to_cuda(batch)
42
+ logit_mask, score_preds = model(batch['query_img'], batch['support_imgs'], batch['support_masks'], nshot=shot, predict_score=True)
43
+ pred_mask = logit_mask.argmax(dim=1)
44
+ # 2. Compute loss & update model parameters
45
+ loss = model.compute_objective(logit_mask, batch['query_mask'])
46
+ # loss_obj = loss.detach()
47
+
48
+ area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch)
49
+
50
+ iou = (area_inter[1] / area_union[1]).float()
51
+ if iou > 0.7 or iou < 0.05:
52
+ '''
53
+ if iou < 0.1:
54
+ img = batch['query_img'][0].permute(1, 2, 0).detach().cpu().numpy()
55
+ img = img - img.min()
56
+ img = img / img.max()
57
+ cv2.imwrite('query_image.png', (img * 255).astype(np.uint8))
58
+ img = batch['support_imgs'][0][0].permute(1, 2, 0).detach().cpu().numpy()
59
+ img = img - img.min()
60
+ img = img / img.max()
61
+ cv2.imwrite('support_image.png', (img * 255).astype(np.uint8))
62
+ cv2.imwrite('query_mask.png', (batch['query_mask'][0] * 255).detach().cpu().numpy().astype(np.uint8))
63
+ cv2.imwrite('pred_mask.png', (pred_mask[0] * 255).detach().cpu().numpy().astype(np.uint8))
64
+ cv2.imwrite('support_mask.png', (batch['support_masks'][0][0] * 255).detach().cpu().numpy().astype(np.uint8))
65
+ '''
66
+ if iou > 0.7:
67
+ iou = torch.tensor(1.).float().cuda()
68
+ else:
69
+ iou = torch.tensor(0.).float().cuda()
70
+ score_loss = torchvision.ops.sigmoid_focal_loss(score_preds, iou)
71
+ # score_loss = F.l1_loss(score_preds, iou)
72
+ stats[0].append(score_preds.detach().cpu().numpy())
73
+ stats[1].append((area_inter[1] / area_union[1]).detach().cpu().numpy())
74
+ print(score_preds, (area_inter[1] / area_union[1]))
75
+ loss = score_loss
76
+
77
+ if training:
78
+ optimizer.zero_grad()
79
+ loss.backward()
80
+ optimizer.step()
81
+ # 3. Evaluate prediction
82
+
83
+ # img = batch['support_imgs'][0][0].permute(1, 2, 0)
84
+ # img = img - img.min()
85
+ # img /= img.max()
86
+ # import cv2
87
+ # cv2.imwrite("debug.png", (img * 255).detach().cpu().numpy())
88
+ # cv2.imwrite("debug2.png", (batch['support_masks'][0][0] * 255).detach().cpu().numpy())
89
+ # import ipdb;ipdb.set_trace()
90
+
91
+ area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch)
92
+ average_meter.update(area_inter, area_union, batch['class_id'], loss.detach().clone())
93
+ average_meter.write_process(idx, len(dataloader), epoch, write_batch_idx=50)
94
+
95
+ # Write evaluation results
96
+ average_meter.write_result('Training' if training else 'Validation', epoch)
97
+ avg_loss = utils.mean(average_meter.loss_buf)
98
+ miou, fb_iou = average_meter.compute_iou()
99
+
100
+ import matplotlib.pyplot as plt
101
+ plt.scatter(stats[0], stats[1], c="red", s=2, alpha=0.02)
102
+ plt.savefig("stats.png")
103
+ return avg_loss, miou, fb_iou
104
+
105
+
106
+ if __name__ == '__main__':
107
+
108
+ # Arguments parsing
109
+ args = parse_opts()
110
+
111
+ # Model initialization
112
+ model = DCAMA(args.backbone, args.feature_extractor_path, False)
113
+ device = torch.device("cuda", args.local_rank)
114
+ model.to(device)
115
+
116
+ params = model.state_dict()
117
+ state_dict = torch.load(args.load)
118
+ if 'state_dict' in state_dict.keys():
119
+ state_dict = state_dict['state_dict']
120
+ state_dict2 = {}
121
+ for k in state_dict.keys():
122
+ if "scorer" in k:
123
+ continue
124
+ state_dict2[k] = state_dict[k]
125
+ state_dict = state_dict2
126
+ for k1, k2 in zip(list(state_dict.keys()), params.keys()):
127
+ state_dict[k2] = state_dict.pop(k1)
128
+ model.load_state_dict(state_dict, strict=False)
129
+
130
+ ## TODO:
131
+ # for i in range(len(model.model.DCAMA_blocks)):
132
+ # torch.nn.init.constant_(model.model.DCAMA_blocks[i].linears[1].weight, 0.)
133
+ # torch.nn.init.constant_(model.model.DCAMA_blocks[i].linears[1].bias, 1.)
134
+ # Helper classes (for training) initialization
135
+ optimizer = optim.SGD([{"params": model.parameters(), "lr": args.lr,
136
+ "momentum": 0.9, "weight_decay": args.lr/10, "nesterov": True}])
137
+ Evaluator.initialize()
138
+ if args.local_rank == 0:
139
+ Logger.initialize(args, training=True)
140
+ Logger.info('# available GPUs: %d' % torch.cuda.device_count())
141
+
142
+ # Dataset initialization
143
+ FSSDataset.initialize(img_size=384, datapath=args.datapath, use_original_imgsize=False)
144
+ dataloader_trn = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'trn', shot=args.nshot)
145
+ if args.local_rank == 0:
146
+ dataloader_val = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'val', shot=args.nshot)
147
+
148
+ # Train
149
+ best_val_miou = float('-inf')
150
+ best_val_loss = float('inf')
151
+
152
+ for epoch in range(args.nepoch):
153
+ trn_loss, trn_miou, trn_fb_iou = train(epoch, model, dataloader_trn, optimizer, training=True, shot=args.nshot)
154
+
155
+ # evaluation
156
+ if args.local_rank == 0:
157
+ # with torch.no_grad():
158
+ # val_loss, val_miou, val_fb_iou = train(epoch, model, dataloader_val, optimizer, training=False)
159
+
160
+ # Save the best model
161
+ # if val_miou > best_val_miou:
162
+ # best_val_miou = val_miou
163
+ Logger.save_model_miou(model, epoch, 1.)
164
+
165
+ # Logger.tbd_writer.add_scalars('data/loss', {'trn_loss': trn_loss, 'val_loss': val_loss}, epoch)
166
+ # Logger.tbd_writer.add_scalars('data/miou', {'trn_miou': trn_miou, 'val_miou': val_miou}, epoch)
167
+ # Logger.tbd_writer.add_scalars('data/fb_iou', {'trn_fb_iou': trn_fb_iou, 'val_fb_iou': val_fb_iou}, epoch)
168
+ # Logger.tbd_writer.flush()
169
+
170
+ if args.local_rank == 0:
171
+ Logger.tbd_writer.close()
172
+ Logger.info('==================== Finished Training ====================')
train_retriever.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" training (validation) code """
2
+ import torch.optim as optim
3
+ import torch.nn as nn
4
+ import torch
5
+
6
+ from model.DCAMA import DCAMA
7
+ from common.logger import Logger, AverageMeter
8
+ from common.evaluation import Evaluator
9
+ from common.config import parse_opts
10
+ from common import utils
11
+ from data.dataset import FSSDataset
12
+ import torch.nn.functional as F
13
+
14
+ average_loss = torch.tensor(0.).float().cuda()
15
+ global_idx = 0
16
+ def train(epoch, model, dataloader, optimizer, training):
17
+ r""" Train """
18
+
19
+ # Force randomness during training / freeze randomness during testing
20
+ utils.fix_randseed(None) if training else utils.fix_randseed(0)
21
+ model.module.train_mode() if training else model.module.eval()
22
+ average_meter = AverageMeter(dataloader.dataset)
23
+
24
+ global average_loss, global_idx
25
+ average_loss = average_loss.to("cuda:{}".format(torch.cuda.current_device()))
26
+ stats = [[], []]
27
+ criterion_score = nn.BCEWithLogitsLoss()
28
+ for idx, batch in enumerate(dataloader):
29
+
30
+ # 1. forward pass
31
+ batch = utils.to_cuda(batch)
32
+ logit_mask, score_preds = model(batch['query_img'], batch['support_imgs'], batch['support_masks'], nshot=batch['support_imgs'].size(1))
33
+ pred_mask = logit_mask.argmax(dim=1)
34
+
35
+ # 2. Compute loss & update model parameters
36
+ loss = model.module.compute_objective(logit_mask, batch['query_mask'])
37
+ # loss_obj = loss.detach()
38
+ area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch)
39
+ iou = area_inter[1] / area_union[1]
40
+
41
+ if iou > 0.7 or iou < 0.1:
42
+ '''
43
+ if iou < 0.1:
44
+ img = batch['query_img'][0].permute(1, 2, 0).detach().cpu().numpy()
45
+ img = img - img.min()
46
+ img = img / img.max()
47
+ cv2.imwrite('query_image.png', (img * 255).astype(np.uint8))
48
+ img = batch['support_imgs'][0][0].permute(1, 2, 0).detach().cpu().numpy()
49
+ img = img - img.min()
50
+ img = img / img.max()
51
+ cv2.imwrite('support_image.png', (img * 255).astype(np.uint8))
52
+ cv2.imwrite('query_mask.png', (batch['query_mask'][0] * 255).detach().cpu().numpy().astype(np.uint8))
53
+ cv2.imwrite('pred_mask.png', (pred_mask[0] * 255).detach().cpu().numpy().astype(np.uint8))
54
+ cv2.imwrite('support_mask.png', (batch['support_masks'][0][0] * 255).detach().cpu().numpy().astype(np.uint8))
55
+ '''
56
+ if iou > 0.7:
57
+ iou = torch.tensor(1.).float().cuda()
58
+ else:
59
+ iou = torch.tensor(0.).float().cuda()
60
+ score_loss = criterion_score(score_preds, iou)
61
+ stats[0].append(score_preds.detach().cpu().numpy())
62
+ stats[1].append((area_inter[1] / area_union[1]).detach().cpu().numpy())
63
+ print(score_preds, (area_inter[1] / area_union[1]))
64
+ loss = score_loss
65
+
66
+ if training:
67
+ optimizer.zero_grad()
68
+ loss.backward()
69
+ optimizer.step()
70
+
71
+ # 3. Evaluate prediction
72
+ area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch)
73
+ average_meter.update(area_inter, area_union, batch['class_id'], loss.detach().clone())
74
+ average_meter.write_process(idx, len(dataloader), epoch, write_batch_idx=50)
75
+
76
+ # Write evaluation results
77
+ average_meter.write_result('Training' if training else 'Validation', epoch)
78
+ avg_loss = utils.mean(average_meter.loss_buf)
79
+ miou, fb_iou = average_meter.compute_iou()
80
+ import matplotlib.pyplot as plt
81
+ idx = 0
82
+ plt.scatter(stats[0], stats[1], c="red", s=2, alpha=0.1)
83
+ plt.savefig('stat.png')
84
+ plt.close()
85
+
86
+ return avg_loss, miou, fb_iou
87
+
88
+
89
+ if __name__ == '__main__':
90
+
91
+ # Arguments parsing
92
+ args = parse_opts()
93
+
94
+ # ddp backend initialization
95
+ torch.distributed.init_process_group(backend='nccl')
96
+ torch.cuda.set_device(args.local_rank)
97
+
98
+ # Model initialization
99
+ model = DCAMA(args.backbone, args.feature_extractor_path, False)
100
+ device = torch.device("cuda", args.local_rank)
101
+ model.to(device)
102
+ model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank,
103
+ find_unused_parameters=True)
104
+
105
+
106
+ params = model.state_dict()
107
+ state_dict = torch.load(args.load)
108
+ state_dict2 = {}
109
+ for k in state_dict.keys():
110
+ if "scorer" in k:
111
+ continue
112
+ state_dict2[k] = state_dict[k]
113
+ state_dict = state_dict2
114
+ for k1, k2 in zip(list(state_dict.keys()), params.keys()):
115
+ state_dict[k2] = state_dict.pop(k1)
116
+
117
+ model.load_state_dict(state_dict, strict=False)
118
+
119
+ ## TODO:
120
+ for i in range(len(model.module.model.DCAMA_blocks)):
121
+ torch.nn.init.constant_(model.module.model.DCAMA_blocks[i].linears[1].weight, 0.)
122
+ torch.nn.init.constant_(model.module.model.DCAMA_blocks[i].linears[1].bias, 1.)
123
+
124
+
125
+ # Helper classes (for training) initialization
126
+ optimizer = optim.SGD([{"params": model.module.model.parameters(), "lr": args.lr,
127
+ "momentum": 0.9, "weight_decay": args.lr/10, "nesterov": True}])
128
+ Evaluator.initialize()
129
+ if args.local_rank == 0:
130
+ Logger.initialize(args, training=True)
131
+ Logger.info('# available GPUs: %d' % torch.cuda.device_count())
132
+
133
+ # Dataset initialization
134
+ FSSDataset.initialize(img_size=384, datapath=args.datapath, use_original_imgsize=False)
135
+ dataloader_trn = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'trn', args.nshot)
136
+ if args.local_rank == 0:
137
+ dataloader_val = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'val', args.nshot)
138
+
139
+ # Train
140
+ best_val_miou = float('-inf')
141
+ best_val_loss = float('inf')
142
+ for epoch in range(args.nepoch):
143
+ dataloader_trn.sampler.set_epoch(epoch)
144
+ trn_loss, trn_miou, trn_fb_iou = train(epoch, model, dataloader_trn, optimizer, training=True)
145
+
146
+ # evaluation
147
+ if args.local_rank == 0:
148
+ # with torch.no_grad():
149
+ # val_loss, val_miou, val_fb_iou = train(epoch, model, dataloader_val, optimizer, training=False)
150
+
151
+ # Save the best model
152
+ # if val_miou > best_val_miou:
153
+ # best_val_miou = val_miou
154
+ # Logger.save_model_miou(model, epoch, val_miou)
155
+ Logger.save_model_miou(model, epoch , 1.)
156
+
157
+ # Logger.tbd_writer.add_scalars('data/loss', {'trn_loss': trn_loss, 'val_loss': val_loss}, epoch)
158
+ # Logger.tbd_writer.add_scalars('data/miou', {'trn_miou': trn_miou, 'val_miou': val_miou}, epoch)
159
+ # Logger.tbd_writer.add_scalars('data/fb_iou', {'trn_fb_iou': trn_fb_iou, 'val_fb_iou': val_fb_iou}, epoch)
160
+ # Logger.tbd_writer.flush()
161
+
162
+ if args.local_rank == 0:
163
+ Logger.tbd_writer.close()
164
+ Logger.info('==================== Finished Training ====================')