Spaces:
Sleeping
Sleeping
Commit
·
c70812a
1
Parent(s):
214c299
- analyze.py +11 -0
- common/__pycache__/config.cpython-38.pyc +0 -0
- common/__pycache__/config.cpython-39.pyc +0 -0
- common/__pycache__/evaluation.cpython-38.pyc +0 -0
- common/__pycache__/evaluation.cpython-39.pyc +0 -0
- common/__pycache__/logger.cpython-38.pyc +0 -0
- common/__pycache__/logger.cpython-39.pyc +0 -0
- common/__pycache__/utils.cpython-38.pyc +0 -0
- common/__pycache__/utils.cpython-39.pyc +0 -0
- common/__pycache__/vis.cpython-38.pyc +0 -0
- common/__pycache__/vis.cpython-39.pyc +0 -0
- common/config.py +31 -0
- common/evaluation.py +39 -0
- common/logger.py +117 -0
- common/utils.py +32 -0
- common/vis.py +106 -0
- gpu_mem_track.py +113 -0
- importance_analysis.py +130 -0
- model/DCAMA.py +625 -0
- model/__pycache__/DCAMA.cpython-38.pyc +0 -0
- model/__pycache__/DCAMA.cpython-39.pyc +0 -0
- model/base/__pycache__/swin_transformer.cpython-38.pyc +0 -0
- model/base/__pycache__/swin_transformer.cpython-39.pyc +0 -0
- model/base/__pycache__/transformer.cpython-38.pyc +0 -0
- model/base/__pycache__/transformer.cpython-39.pyc +0 -0
- model/base/swin_transformer.py +605 -0
- model/base/transformer.py +99 -0
- modelsize_estimate.py +38 -0
- scripts/importance_analysis.sh +16 -0
- scripts/test.sh +15 -0
- scripts/train.sh +11 -0
- scripts/train_1gpu.sh +12 -0
- scripts/train_1gpu_retriver.sh +12 -0
- scripts/train_2gpu.sh +14 -0
- scripts/train_2gpu_retriever.sh +14 -0
- scripts/train_4gpu.sh +14 -0
- test.py +132 -0
- train.py +149 -0
- train_1gpu.py +170 -0
- train_1gpu_retriever.py +172 -0
- train_retriever.py +164 -0
analyze.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
with open('debug/stats.txt', 'r') as f:
|
| 7 |
+
stats = f.readlines()
|
| 8 |
+
for stat in stats:
|
| 9 |
+
plt.scatter(float(stat.split(" ")[0]), float(stat.split(' ')[1]), alpha=0.1, s=10, c='red')
|
| 10 |
+
print(stat)
|
| 11 |
+
plt.savefig('stats.png')
|
common/__pycache__/config.cpython-38.pyc
ADDED
|
Binary file (1.28 kB). View file
|
|
|
common/__pycache__/config.cpython-39.pyc
ADDED
|
Binary file (1.28 kB). View file
|
|
|
common/__pycache__/evaluation.cpython-38.pyc
ADDED
|
Binary file (1.42 kB). View file
|
|
|
common/__pycache__/evaluation.cpython-39.pyc
ADDED
|
Binary file (1.39 kB). View file
|
|
|
common/__pycache__/logger.cpython-38.pyc
ADDED
|
Binary file (4.34 kB). View file
|
|
|
common/__pycache__/logger.cpython-39.pyc
ADDED
|
Binary file (4.32 kB). View file
|
|
|
common/__pycache__/utils.cpython-38.pyc
ADDED
|
Binary file (1.12 kB). View file
|
|
|
common/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (1.1 kB). View file
|
|
|
common/__pycache__/vis.cpython-38.pyc
ADDED
|
Binary file (4.72 kB). View file
|
|
|
common/__pycache__/vis.cpython-39.pyc
ADDED
|
Binary file (4.68 kB). View file
|
|
|
common/config.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r"""config"""
|
| 2 |
+
import argparse
|
| 3 |
+
|
| 4 |
+
def parse_opts():
|
| 5 |
+
r"""arguments"""
|
| 6 |
+
parser = argparse.ArgumentParser(description='Dense Cross-Query-and-Support Attention Weighted Mask Aggregation for Few-Shot Segmentation')
|
| 7 |
+
|
| 8 |
+
# common
|
| 9 |
+
parser.add_argument('--datapath', type=str, default='./datasets')
|
| 10 |
+
parser.add_argument('--benchmark', type=str, default='pascal', choices=['pascal', 'coco', 'fss'])
|
| 11 |
+
parser.add_argument('--fold', type=int, default=0, choices=[0, 1, 2, 3])
|
| 12 |
+
parser.add_argument('--bsz', type=int, default=20)
|
| 13 |
+
parser.add_argument('--nworker', type=int, default=8)
|
| 14 |
+
parser.add_argument('--backbone', type=str, default='swin', choices=['resnet50', 'resnet101', 'swin'])
|
| 15 |
+
parser.add_argument('--feature_extractor_path', type=str, default='')
|
| 16 |
+
parser.add_argument('--logpath', type=str, default='./logs')
|
| 17 |
+
|
| 18 |
+
# for train
|
| 19 |
+
parser.add_argument('--lr', type=float, default=1e-3)
|
| 20 |
+
parser.add_argument('--nepoch', type=int, default=1000)
|
| 21 |
+
parser.add_argument('--local-rank', default=0, type=int, help='node rank for distributed training')
|
| 22 |
+
|
| 23 |
+
# for test
|
| 24 |
+
parser.add_argument('--load', type=str, default='')
|
| 25 |
+
parser.add_argument('--nshot', type=int, default=1)
|
| 26 |
+
parser.add_argument('--visualize', action='store_true')
|
| 27 |
+
parser.add_argument('--vispath', type=str, default='./vis')
|
| 28 |
+
parser.add_argument('--use_original_imgsize', action='store_true')
|
| 29 |
+
|
| 30 |
+
args = parser.parse_args()
|
| 31 |
+
return args
|
common/evaluation.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r""" Evaluate mask prediction """
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Evaluator:
|
| 6 |
+
r""" Computes intersection and union between prediction and ground-truth """
|
| 7 |
+
@classmethod
|
| 8 |
+
def initialize(cls):
|
| 9 |
+
cls.ignore_index = 255
|
| 10 |
+
|
| 11 |
+
@classmethod
|
| 12 |
+
def classify_prediction(cls, pred_mask, batch):
|
| 13 |
+
gt_mask = batch.get('query_mask')
|
| 14 |
+
|
| 15 |
+
# Apply ignore_index in PASCAL-5i masks (following evaluation scheme in PFE-Net (TPAMI 2020))
|
| 16 |
+
query_ignore_idx = batch.get('query_ignore_idx')
|
| 17 |
+
if query_ignore_idx is not None:
|
| 18 |
+
assert torch.logical_and(query_ignore_idx, gt_mask).sum() == 0
|
| 19 |
+
query_ignore_idx *= cls.ignore_index
|
| 20 |
+
gt_mask = gt_mask + query_ignore_idx
|
| 21 |
+
pred_mask[gt_mask == cls.ignore_index] = cls.ignore_index
|
| 22 |
+
|
| 23 |
+
# compute intersection and union of each episode in a batch
|
| 24 |
+
area_inter, area_pred, area_gt = [], [], []
|
| 25 |
+
for _pred_mask, _gt_mask in zip(pred_mask, gt_mask):
|
| 26 |
+
_inter = _pred_mask[_pred_mask == _gt_mask]
|
| 27 |
+
if _inter.size(0) == 0: # as torch.histc returns error if it gets empty tensor (pytorch 1.5.1)
|
| 28 |
+
_area_inter = torch.tensor([0, 0], device=_pred_mask.device)
|
| 29 |
+
else:
|
| 30 |
+
_area_inter = torch.histc(_inter, bins=2, min=0, max=1)
|
| 31 |
+
area_inter.append(_area_inter)
|
| 32 |
+
area_pred.append(torch.histc(_pred_mask, bins=2, min=0, max=1))
|
| 33 |
+
area_gt.append(torch.histc(_gt_mask, bins=2, min=0, max=1))
|
| 34 |
+
area_inter = torch.stack(area_inter).t()
|
| 35 |
+
area_pred = torch.stack(area_pred).t()
|
| 36 |
+
area_gt = torch.stack(area_gt).t()
|
| 37 |
+
area_union = area_pred + area_gt - area_inter
|
| 38 |
+
|
| 39 |
+
return area_inter, area_union
|
common/logger.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r""" Logging during training/testing """
|
| 2 |
+
import datetime
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from tensorboardX import SummaryWriter
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class AverageMeter:
|
| 11 |
+
r""" Stores loss, evaluation results """
|
| 12 |
+
def __init__(self, dataset):
|
| 13 |
+
self.benchmark = dataset.benchmark
|
| 14 |
+
self.class_ids_interest = dataset.class_ids
|
| 15 |
+
self.class_ids_interest = torch.tensor(self.class_ids_interest).cuda()
|
| 16 |
+
|
| 17 |
+
if self.benchmark == 'pascal':
|
| 18 |
+
self.nclass = 20
|
| 19 |
+
elif self.benchmark == 'coco':
|
| 20 |
+
self.nclass = 80
|
| 21 |
+
elif self.benchmark == 'fss':
|
| 22 |
+
self.nclass = 1000
|
| 23 |
+
|
| 24 |
+
self.intersection_buf = torch.zeros([2, self.nclass]).float().cuda()
|
| 25 |
+
self.union_buf = torch.zeros([2, self.nclass]).float().cuda()
|
| 26 |
+
self.ones = torch.ones_like(self.union_buf)
|
| 27 |
+
self.loss_buf = []
|
| 28 |
+
|
| 29 |
+
def update(self, inter_b, union_b, class_id, loss):
|
| 30 |
+
self.intersection_buf.index_add_(1, class_id, inter_b.float())
|
| 31 |
+
self.union_buf.index_add_(1, class_id, union_b.float())
|
| 32 |
+
if loss is None:
|
| 33 |
+
loss = torch.tensor(0.0)
|
| 34 |
+
self.loss_buf.append(loss)
|
| 35 |
+
|
| 36 |
+
def compute_iou(self):
|
| 37 |
+
iou = self.intersection_buf.float() / \
|
| 38 |
+
torch.max(torch.stack([self.union_buf, self.ones]), dim=0)[0]
|
| 39 |
+
iou = iou.index_select(1, self.class_ids_interest)
|
| 40 |
+
miou = iou[1].mean() * 100
|
| 41 |
+
|
| 42 |
+
fb_iou = (self.intersection_buf.index_select(1, self.class_ids_interest).sum(dim=1) /
|
| 43 |
+
self.union_buf.index_select(1, self.class_ids_interest).sum(dim=1)).mean() * 100
|
| 44 |
+
|
| 45 |
+
return miou, fb_iou
|
| 46 |
+
|
| 47 |
+
def write_result(self, split, epoch):
|
| 48 |
+
iou, fb_iou = self.compute_iou()
|
| 49 |
+
|
| 50 |
+
loss_buf = torch.stack(self.loss_buf)
|
| 51 |
+
msg = '\n*** %s ' % split
|
| 52 |
+
msg += '[@Epoch %02d] ' % epoch
|
| 53 |
+
msg += 'Avg L: %6.5f ' % loss_buf.mean()
|
| 54 |
+
msg += 'mIoU: %5.2f ' % iou
|
| 55 |
+
msg += 'FB-IoU: %5.2f ' % fb_iou
|
| 56 |
+
|
| 57 |
+
msg += '***\n'
|
| 58 |
+
Logger.info(msg)
|
| 59 |
+
|
| 60 |
+
def write_process(self, batch_idx, datalen, epoch, write_batch_idx=20):
|
| 61 |
+
if batch_idx % write_batch_idx == 0:
|
| 62 |
+
msg = '[Epoch: %02d] ' % epoch if epoch != -1 else ''
|
| 63 |
+
msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen)
|
| 64 |
+
iou, fb_iou = self.compute_iou()
|
| 65 |
+
if epoch != -1:
|
| 66 |
+
loss_buf = torch.stack(self.loss_buf)
|
| 67 |
+
msg += 'L: %6.5f ' % loss_buf[-1]
|
| 68 |
+
msg += 'Avg L: %6.5f ' % loss_buf.mean()
|
| 69 |
+
msg += 'mIoU: %5.2f | ' % iou
|
| 70 |
+
msg += 'FB-IoU: %5.2f' % fb_iou
|
| 71 |
+
Logger.info(msg)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class Logger:
|
| 75 |
+
r""" Writes evaluation results of training/testing """
|
| 76 |
+
@classmethod
|
| 77 |
+
def initialize(cls, args, training):
|
| 78 |
+
logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S')
|
| 79 |
+
logpath = os.path.join(args.logpath, 'train/fold_' + str(args.fold) + logtime) if training \
|
| 80 |
+
else os.path.join(args.logpath, 'test/fold_' + args.load.split('/')[-2].split('.')[0] + logtime)
|
| 81 |
+
if logpath == '': logpath = logtime
|
| 82 |
+
|
| 83 |
+
cls.logpath = logpath
|
| 84 |
+
cls.benchmark = args.benchmark
|
| 85 |
+
if not os.path.exists(cls.logpath): os.makedirs(cls.logpath)
|
| 86 |
+
|
| 87 |
+
logging.basicConfig(filemode='w',
|
| 88 |
+
filename=os.path.join(cls.logpath, 'log.txt'),
|
| 89 |
+
level=logging.INFO,
|
| 90 |
+
format='%(message)s',
|
| 91 |
+
datefmt='%m-%d %H:%M:%S')
|
| 92 |
+
|
| 93 |
+
# Console log config
|
| 94 |
+
console = logging.StreamHandler()
|
| 95 |
+
console.setLevel(logging.INFO)
|
| 96 |
+
formatter = logging.Formatter('%(message)s')
|
| 97 |
+
console.setFormatter(formatter)
|
| 98 |
+
logging.getLogger('').addHandler(console)
|
| 99 |
+
|
| 100 |
+
# Tensorboard writer
|
| 101 |
+
cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs'))
|
| 102 |
+
|
| 103 |
+
# Log arguments
|
| 104 |
+
logging.info('\n:==================== Start =====================')
|
| 105 |
+
for arg_key in args.__dict__:
|
| 106 |
+
logging.info('| %20s: %-24s' % (arg_key, str(args.__dict__[arg_key])))
|
| 107 |
+
logging.info(':================================================\n')
|
| 108 |
+
|
| 109 |
+
@classmethod
|
| 110 |
+
def info(cls, msg):
|
| 111 |
+
r""" Writes log message to log.txt """
|
| 112 |
+
logging.info(msg)
|
| 113 |
+
|
| 114 |
+
@classmethod
|
| 115 |
+
def save_model_miou(cls, model, epoch, val_miou):
|
| 116 |
+
torch.save(model.state_dict(), os.path.join(cls.logpath, "model_{}.pt".format(epoch)))
|
| 117 |
+
cls.info('Model saved @%d w/ val. mIoU: %5.2f.\n' % (epoch, val_miou))
|
common/utils.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r""" Helper functions """
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def fix_randseed(seed):
|
| 9 |
+
r""" Set random seeds for reproducibility """
|
| 10 |
+
if seed is None:
|
| 11 |
+
seed = int(random.random() * 1e5)
|
| 12 |
+
np.random.seed(seed)
|
| 13 |
+
torch.manual_seed(seed)
|
| 14 |
+
torch.cuda.manual_seed(seed)
|
| 15 |
+
torch.cuda.manual_seed_all(seed)
|
| 16 |
+
torch.backends.cudnn.benchmark = False
|
| 17 |
+
torch.backends.cudnn.deterministic = True
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def mean(x):
|
| 21 |
+
return sum(x) / len(x) if len(x) > 0 else 0.0
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def to_cuda(batch):
|
| 25 |
+
for key, value in batch.items():
|
| 26 |
+
if isinstance(value, torch.Tensor):
|
| 27 |
+
batch[key] = value.cuda()
|
| 28 |
+
return batch
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def to_cpu(tensor):
|
| 32 |
+
return tensor.detach().clone().cpu()
|
common/vis.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r""" Visualize model predictions """
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torchvision.transforms as transforms
|
| 7 |
+
|
| 8 |
+
from . import utils
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Visualizer:
|
| 12 |
+
|
| 13 |
+
@classmethod
|
| 14 |
+
def initialize(cls, visualize, vispath='./vis/'):
|
| 15 |
+
cls.visualize = visualize
|
| 16 |
+
if not visualize:
|
| 17 |
+
return
|
| 18 |
+
|
| 19 |
+
cls.colors = {'red': (255, 50, 50), 'blue': (102, 140, 255)}
|
| 20 |
+
for key, value in cls.colors.items():
|
| 21 |
+
cls.colors[key] = tuple([c / 255 for c in cls.colors[key]])
|
| 22 |
+
|
| 23 |
+
cls.mean_img = [0.485, 0.456, 0.406]
|
| 24 |
+
cls.std_img = [0.229, 0.224, 0.225]
|
| 25 |
+
cls.to_pil = transforms.ToPILImage()
|
| 26 |
+
cls.vis_path = vispath
|
| 27 |
+
if not os.path.exists(cls.vis_path): os.makedirs(cls.vis_path)
|
| 28 |
+
|
| 29 |
+
@classmethod
|
| 30 |
+
def visualize_prediction_batch(cls, spt_img_b, spt_mask_b, qry_img_b, qry_mask_b, pred_mask_b, cls_id_b, batch_idx, iou_b=None):
|
| 31 |
+
spt_img_b = utils.to_cpu(spt_img_b)
|
| 32 |
+
spt_mask_b = utils.to_cpu(spt_mask_b)
|
| 33 |
+
qry_img_b = utils.to_cpu(qry_img_b)
|
| 34 |
+
qry_mask_b = utils.to_cpu(qry_mask_b)
|
| 35 |
+
pred_mask_b = utils.to_cpu(pred_mask_b)
|
| 36 |
+
cls_id_b = utils.to_cpu(cls_id_b)
|
| 37 |
+
|
| 38 |
+
for sample_idx, (spt_img, spt_mask, qry_img, qry_mask, pred_mask, cls_id) in \
|
| 39 |
+
enumerate(zip(spt_img_b, spt_mask_b, qry_img_b, qry_mask_b, pred_mask_b, cls_id_b)):
|
| 40 |
+
iou = iou_b[sample_idx] if iou_b is not None else None
|
| 41 |
+
cls.visualize_prediction(spt_img, spt_mask, qry_img, qry_mask, pred_mask, cls_id, batch_idx, sample_idx, True, iou)
|
| 42 |
+
|
| 43 |
+
@classmethod
|
| 44 |
+
def to_numpy(cls, tensor, type):
|
| 45 |
+
if type == 'img':
|
| 46 |
+
return np.array(cls.to_pil(cls.unnormalize(tensor))).astype(np.uint8)
|
| 47 |
+
elif type == 'mask':
|
| 48 |
+
return np.array(tensor).astype(np.uint8)
|
| 49 |
+
else:
|
| 50 |
+
raise Exception('Undefined tensor type: %s' % type)
|
| 51 |
+
|
| 52 |
+
@classmethod
|
| 53 |
+
def visualize_prediction(cls, spt_imgs, spt_masks, qry_img, qry_mask, pred_mask, cls_id, batch_idx, sample_idx, label, iou=None):
|
| 54 |
+
|
| 55 |
+
spt_color = cls.colors['blue']
|
| 56 |
+
qry_color = cls.colors['red']
|
| 57 |
+
pred_color = cls.colors['red']
|
| 58 |
+
|
| 59 |
+
spt_imgs = [cls.to_numpy(spt_img, 'img') for spt_img in spt_imgs]
|
| 60 |
+
spt_pils = [cls.to_pil(spt_img) for spt_img in spt_imgs]
|
| 61 |
+
spt_masks = [cls.to_numpy(spt_mask, 'mask') for spt_mask in spt_masks]
|
| 62 |
+
spt_masked_pils = [Image.fromarray(cls.apply_mask(spt_img, spt_mask, spt_color)) for spt_img, spt_mask in zip(spt_imgs, spt_masks)]
|
| 63 |
+
|
| 64 |
+
qry_img = cls.to_numpy(qry_img, 'img')
|
| 65 |
+
qry_pil = cls.to_pil(qry_img)
|
| 66 |
+
qry_mask = cls.to_numpy(qry_mask, 'mask')
|
| 67 |
+
pred_mask = cls.to_numpy(pred_mask, 'mask')
|
| 68 |
+
pred_masked_pil = Image.fromarray(cls.apply_mask(qry_img.astype(np.uint8), pred_mask.astype(np.uint8), pred_color))
|
| 69 |
+
qry_masked_pil = Image.fromarray(cls.apply_mask(qry_img.astype(np.uint8), qry_mask.astype(np.uint8), qry_color))
|
| 70 |
+
|
| 71 |
+
merged_pil = cls.merge_image_pair(spt_masked_pils + [pred_masked_pil, qry_masked_pil])
|
| 72 |
+
|
| 73 |
+
iou = iou.item() if iou else 0.0
|
| 74 |
+
merged_pil.save(cls.vis_path + '%d_%d_class-%d_iou-%.2f' % (batch_idx, sample_idx, cls_id, iou) + '.jpg')
|
| 75 |
+
|
| 76 |
+
@classmethod
|
| 77 |
+
def merge_image_pair(cls, pil_imgs):
|
| 78 |
+
r""" Horizontally aligns a pair of pytorch tensor images (3, H, W) and returns PIL object """
|
| 79 |
+
|
| 80 |
+
canvas_width = sum([pil.size[0] for pil in pil_imgs])
|
| 81 |
+
canvas_height = max([pil.size[1] for pil in pil_imgs])
|
| 82 |
+
canvas = Image.new('RGB', (canvas_width, canvas_height))
|
| 83 |
+
|
| 84 |
+
xpos = 0
|
| 85 |
+
for pil in pil_imgs:
|
| 86 |
+
canvas.paste(pil, (xpos, 0))
|
| 87 |
+
xpos += pil.size[0]
|
| 88 |
+
|
| 89 |
+
return canvas
|
| 90 |
+
|
| 91 |
+
@classmethod
|
| 92 |
+
def apply_mask(cls, image, mask, color, alpha=0.5):
|
| 93 |
+
r""" Apply mask to the given image. """
|
| 94 |
+
for c in range(3):
|
| 95 |
+
image[:, :, c] = np.where(mask == 1,
|
| 96 |
+
image[:, :, c] *
|
| 97 |
+
(1 - alpha) + alpha * color[c] * 255,
|
| 98 |
+
image[:, :, c])
|
| 99 |
+
return image
|
| 100 |
+
|
| 101 |
+
@classmethod
|
| 102 |
+
def unnormalize(cls, img):
|
| 103 |
+
img = img.clone()
|
| 104 |
+
for im_channel, mean, std in zip(img, cls.mean_img, cls.std_img):
|
| 105 |
+
im_channel.mul_(std).add_(mean)
|
| 106 |
+
return img
|
gpu_mem_track.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import datetime
|
| 3 |
+
import inspect
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
dtype_memory_size_dict = {
|
| 9 |
+
torch.float64: 64/8,
|
| 10 |
+
torch.double: 64/8,
|
| 11 |
+
torch.float32: 32/8,
|
| 12 |
+
torch.float: 32/8,
|
| 13 |
+
torch.float16: 16/8,
|
| 14 |
+
torch.half: 16/8,
|
| 15 |
+
torch.int64: 64/8,
|
| 16 |
+
torch.long: 64/8,
|
| 17 |
+
torch.int32: 32/8,
|
| 18 |
+
torch.int: 32/8,
|
| 19 |
+
torch.int16: 16/8,
|
| 20 |
+
torch.short: 16/6,
|
| 21 |
+
torch.uint8: 8/8,
|
| 22 |
+
torch.int8: 8/8,
|
| 23 |
+
}
|
| 24 |
+
# compatibility of torch1.0
|
| 25 |
+
if getattr(torch, "bfloat16", None) is not None:
|
| 26 |
+
dtype_memory_size_dict[torch.bfloat16] = 16/8
|
| 27 |
+
if getattr(torch, "bool", None) is not None:
|
| 28 |
+
dtype_memory_size_dict[torch.bool] = 8/8 # pytorch use 1 byte for a bool, see https://github.com/pytorch/pytorch/issues/41571
|
| 29 |
+
|
| 30 |
+
def get_mem_space(x):
|
| 31 |
+
try:
|
| 32 |
+
ret = dtype_memory_size_dict[x]
|
| 33 |
+
except KeyError:
|
| 34 |
+
print(f"dtype {x} is not supported!")
|
| 35 |
+
return ret
|
| 36 |
+
|
| 37 |
+
class MemTracker(object):
|
| 38 |
+
"""
|
| 39 |
+
Class used to track pytorch memory usage
|
| 40 |
+
Arguments:
|
| 41 |
+
detail(bool, default True): whether the function shows the detail gpu memory usage
|
| 42 |
+
path(str): where to save log file
|
| 43 |
+
verbose(bool, default False): whether show the trivial exception
|
| 44 |
+
device(int): GPU number, default is 0
|
| 45 |
+
"""
|
| 46 |
+
def __init__(self, detail=True, path='', verbose=False, device=0):
|
| 47 |
+
self.print_detail = detail
|
| 48 |
+
self.last_tensor_sizes = set()
|
| 49 |
+
self.gpu_profile_fn = path + f'{datetime.datetime.now():%d-%b-%y-%H:%M:%S}-gpu_mem_track.txt'
|
| 50 |
+
self.verbose = verbose
|
| 51 |
+
self.begin = True
|
| 52 |
+
self.device = device
|
| 53 |
+
|
| 54 |
+
def get_tensors(self):
|
| 55 |
+
for obj in gc.get_objects():
|
| 56 |
+
try:
|
| 57 |
+
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
|
| 58 |
+
tensor = obj
|
| 59 |
+
else:
|
| 60 |
+
continue
|
| 61 |
+
if tensor.is_cuda:
|
| 62 |
+
yield tensor
|
| 63 |
+
except Exception as e:
|
| 64 |
+
if self.verbose:
|
| 65 |
+
print('A trivial exception occured: {}'.format(e))
|
| 66 |
+
|
| 67 |
+
def get_tensor_usage(self):
|
| 68 |
+
sizes = [np.prod(np.array(tensor.size())) * get_mem_space(tensor.dtype) for tensor in self.get_tensors()]
|
| 69 |
+
return np.sum(sizes) / 1024**2
|
| 70 |
+
|
| 71 |
+
def get_allocate_usage(self):
|
| 72 |
+
return torch.cuda.memory_allocated() / 1024**2
|
| 73 |
+
|
| 74 |
+
def clear_cache(self):
|
| 75 |
+
gc.collect()
|
| 76 |
+
torch.cuda.empty_cache()
|
| 77 |
+
|
| 78 |
+
def print_all_gpu_tensor(self, file=None):
|
| 79 |
+
for x in self.get_tensors():
|
| 80 |
+
print(x.size(), x.dtype, np.prod(np.array(x.size()))*get_mem_space(x.dtype)/1024**2, file=file)
|
| 81 |
+
|
| 82 |
+
def track(self):
|
| 83 |
+
"""
|
| 84 |
+
Track the GPU memory usage
|
| 85 |
+
"""
|
| 86 |
+
frameinfo = inspect.stack()[1]
|
| 87 |
+
where_str = frameinfo.filename + ' line ' + str(frameinfo.lineno) + ': ' + frameinfo.function
|
| 88 |
+
|
| 89 |
+
with open(self.gpu_profile_fn, 'a+') as f:
|
| 90 |
+
|
| 91 |
+
if self.begin:
|
| 92 |
+
f.write(f"GPU Memory Track | {datetime.datetime.now():%d-%b-%y-%H:%M:%S} |"
|
| 93 |
+
f" Total Tensor Used Memory:{self.get_tensor_usage():<7.1f}Mb"
|
| 94 |
+
f" Total Allocated Memory:{self.get_allocate_usage():<7.1f}Mb\n\n")
|
| 95 |
+
self.begin = False
|
| 96 |
+
|
| 97 |
+
if self.print_detail is True:
|
| 98 |
+
ts_list = [(tensor.size(), tensor.dtype) for tensor in self.get_tensors()]
|
| 99 |
+
new_tensor_sizes = {(type(x),
|
| 100 |
+
tuple(x.size()),
|
| 101 |
+
ts_list.count((x.size(), x.dtype)),
|
| 102 |
+
np.prod(np.array(x.size()))*get_mem_space(x.dtype)/1024**2,
|
| 103 |
+
x.dtype) for x in self.get_tensors()}
|
| 104 |
+
for t, s, n, m, data_type in new_tensor_sizes - self.last_tensor_sizes:
|
| 105 |
+
f.write(f'+ | {str(n)} * Size:{str(s):<20} | Memory: {str(m*n)[:6]} M | {str(t):<20} | {data_type}\n')
|
| 106 |
+
for t, s, n, m, data_type in self.last_tensor_sizes - new_tensor_sizes:
|
| 107 |
+
f.write(f'- | {str(n)} * Size:{str(s):<20} | Memory: {str(m*n)[:6]} M | {str(t):<20} | {data_type}\n')
|
| 108 |
+
|
| 109 |
+
self.last_tensor_sizes = new_tensor_sizes
|
| 110 |
+
|
| 111 |
+
f.write(f"\nAt {where_str:<50}"
|
| 112 |
+
f" Total Tensor Used Memory:{self.get_tensor_usage():<7.1f}Mb"
|
| 113 |
+
f" Total Allocated Memory:{self.get_allocate_usage():<7.1f}Mb\n\n")
|
importance_analysis.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r""" Dense Cross-Query-and-Support Attention Weighted Mask Aggregation for Few-Shot Segmentation """
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from model.DCAMA import DCAMA
|
| 6 |
+
from common.logger import Logger, AverageMeter
|
| 7 |
+
from common.vis import Visualizer
|
| 8 |
+
from common.evaluation import Evaluator
|
| 9 |
+
from common.config import parse_opts
|
| 10 |
+
from common import utils
|
| 11 |
+
from data.dataset import FSSDataset
|
| 12 |
+
import cv2
|
| 13 |
+
import numpy as np
|
| 14 |
+
import os
|
| 15 |
+
|
| 16 |
+
def test(model, dataloader, nshot):
|
| 17 |
+
r""" Test """
|
| 18 |
+
|
| 19 |
+
# Freeze randomness during testing for reproducibility
|
| 20 |
+
utils.fix_randseed(0)
|
| 21 |
+
average_meter = AverageMeter(dataloader.dataset)
|
| 22 |
+
|
| 23 |
+
for idx, batch in enumerate(dataloader):
|
| 24 |
+
|
| 25 |
+
# 1. forward pass
|
| 26 |
+
nshot = batch['support_imgs'].size(1)
|
| 27 |
+
|
| 28 |
+
batch['support_imgs'][0][0] = batch['query_img'][0]
|
| 29 |
+
batch['support_masks'][0][0] = batch['query_mask'][0]
|
| 30 |
+
|
| 31 |
+
batch = utils.to_cuda(batch)
|
| 32 |
+
pred_mask, simi, simi_map = model.module.predict_mask_nshot(batch, nshot=nshot)
|
| 33 |
+
|
| 34 |
+
assert pred_mask.size() == batch['query_mask'].size()
|
| 35 |
+
|
| 36 |
+
# 2. Evaluate prediction
|
| 37 |
+
area_inter, area_union = Evaluator.classify_prediction(pred_mask.clone(), batch)
|
| 38 |
+
|
| 39 |
+
## TODO:
|
| 40 |
+
iou = area_inter[1] / area_union[1]
|
| 41 |
+
|
| 42 |
+
'''
|
| 43 |
+
cv2.imwrite('debug/query.png', cv2.imread("/home/bkdongxianchi/MY_MOT/TWL/data/COCO2014/{}".format(batch['query_name'][0])))
|
| 44 |
+
cv2.imwrite('debug/query_mask.png', (batch['query_mask'][0] * 255).detach().cpu().numpy().astype(np.uint8))
|
| 45 |
+
cv2.imwrite('debug/support_{:.3}.png'.format(iou.item()), cv2.imread('/home/bkdongxianchi/MY_MOT/TWL/data/COCO2014/{}'.format(batch['support_names'][0][0])))
|
| 46 |
+
cv2.imwrite('debug/support_mask_{:.3}.png'.format(iou.item()), (batch['support_masks'][0][0] * 255).detach().cpu().numpy().astype(np.uint8))
|
| 47 |
+
simi_map = simi_map - simi_map.min()
|
| 48 |
+
simi_map = (simi_map / simi_map.max() * 255).detach().cpu().numpy().astype(np.uint8)
|
| 49 |
+
cv2.imwrite('debug/simi_map_{:.3}.png'.format(iou.item()), simi_map)
|
| 50 |
+
|
| 51 |
+
if os.path.exists('debug/stats.txt'):
|
| 52 |
+
with open('debug/stats.txt', "a") as f:
|
| 53 |
+
f.write("{} {}\n".format(simi.item(), iou.item()))
|
| 54 |
+
else:
|
| 55 |
+
with open('debug/stats.txt', 'w') as f:
|
| 56 |
+
f.write('{} {}\n'.format(simi.item(), iou.item()))
|
| 57 |
+
'''
|
| 58 |
+
|
| 59 |
+
average_meter.update(area_inter, area_union, batch['class_id'], loss=None)
|
| 60 |
+
average_meter.write_process(idx, len(dataloader), epoch=-1, write_batch_idx=1)
|
| 61 |
+
|
| 62 |
+
# Visualize predictions
|
| 63 |
+
if Visualizer.visualize:
|
| 64 |
+
Visualizer.visualize_prediction_batch(batch['support_imgs'], batch['support_masks'],
|
| 65 |
+
batch['query_img'], batch['query_mask'],
|
| 66 |
+
pred_mask, batch['class_id'], idx,
|
| 67 |
+
iou_b=area_inter[1].float() / area_union[1].float())
|
| 68 |
+
|
| 69 |
+
# Write evaluation results
|
| 70 |
+
average_meter.write_result('Test', 0)
|
| 71 |
+
miou, fb_iou = average_meter.compute_iou()
|
| 72 |
+
|
| 73 |
+
return miou, fb_iou
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if __name__ == '__main__':
|
| 77 |
+
|
| 78 |
+
# Arguments parsing
|
| 79 |
+
args = parse_opts()
|
| 80 |
+
|
| 81 |
+
Logger.initialize(args, training=False)
|
| 82 |
+
|
| 83 |
+
# Model initialization
|
| 84 |
+
model = DCAMA(args.backbone, args.feature_extractor_path, args.use_original_imgsize)
|
| 85 |
+
model.eval()
|
| 86 |
+
|
| 87 |
+
# Device setup
|
| 88 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 89 |
+
Logger.info('# available GPUs: %d' % torch.cuda.device_count())
|
| 90 |
+
model = nn.DataParallel(model)
|
| 91 |
+
model.to(device)
|
| 92 |
+
|
| 93 |
+
# Load trained model
|
| 94 |
+
if args.load == '': raise Exception('Pretrained model not specified.')
|
| 95 |
+
params = model.state_dict()
|
| 96 |
+
state_dict = torch.load(args.load)
|
| 97 |
+
|
| 98 |
+
if 'state_dict' in state_dict.keys():
|
| 99 |
+
state_dict = state_dict['state_dict']
|
| 100 |
+
|
| 101 |
+
state_dict2 = {}
|
| 102 |
+
for k, v in state_dict.items():
|
| 103 |
+
if 'scorer' not in k:
|
| 104 |
+
state_dict2[k] = v
|
| 105 |
+
state_dict = state_dict2
|
| 106 |
+
for k1, k2 in zip(list(state_dict.keys()), params.keys()):
|
| 107 |
+
state_dict[k2] = state_dict.pop(k1)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
try:
|
| 111 |
+
model.load_state_dict(state_dict, strict=True)
|
| 112 |
+
except:
|
| 113 |
+
for k in params.keys():
|
| 114 |
+
if k not in state_dict.keys():
|
| 115 |
+
state_dict[k] = params[k]
|
| 116 |
+
model.load_state_dict(state_dict, strict=True)
|
| 117 |
+
|
| 118 |
+
# Helper classes (for testing) initialization
|
| 119 |
+
Evaluator.initialize()
|
| 120 |
+
Visualizer.initialize(args.visualize, args.vispath)
|
| 121 |
+
|
| 122 |
+
# Dataset initialization
|
| 123 |
+
FSSDataset.initialize(img_size=384, datapath=args.datapath, use_original_imgsize=args.use_original_imgsize)
|
| 124 |
+
dataloader_test = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'test', args.nshot)
|
| 125 |
+
|
| 126 |
+
# Test
|
| 127 |
+
with torch.no_grad():
|
| 128 |
+
test_miou, test_fb_iou = test(model, dataloader_test, args.nshot)
|
| 129 |
+
Logger.info('Fold %d mIoU: %5.2f \t FB-IoU: %5.2f' % (args.fold, test_miou.item(), test_fb_iou.item()))
|
| 130 |
+
Logger.info('==================== Finished Testing ====================')
|
model/DCAMA.py
ADDED
|
@@ -0,0 +1,625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r""" Dense Cross-Query-and-Support Attention Weighted Mask Aggregation for Few-Shot Segmentation """
|
| 2 |
+
from functools import reduce
|
| 3 |
+
from operator import add
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torchvision.models import resnet
|
| 9 |
+
|
| 10 |
+
from .base.swin_transformer import SwinTransformer
|
| 11 |
+
from model.base.transformer import MultiHeadedAttention, PositionalEncoding
|
| 12 |
+
import copy
|
| 13 |
+
|
| 14 |
+
class Flatten(nn.Module):
|
| 15 |
+
def forward(self, x):
|
| 16 |
+
return x.view(x.size(0), x.size(1), -1).contiguous()
|
| 17 |
+
|
| 18 |
+
def reshape(x, size):
|
| 19 |
+
size1 = torch.tensor(x.size()).float().cuda()
|
| 20 |
+
# x = torch.logical_not(x.cuda())
|
| 21 |
+
yxs = torch.stack(torch.where(x), dim=-1)
|
| 22 |
+
ratio = size[0] / size1[0]
|
| 23 |
+
yxs2 = (yxs * ratio).long()
|
| 24 |
+
x2 = torch.zeros((size[0], size[1])).float().cuda()
|
| 25 |
+
return yxs2
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DCAMA(nn.Module):
|
| 29 |
+
|
| 30 |
+
def __init__(self, backbone, pretrained_path, use_original_imgsize):
|
| 31 |
+
super(DCAMA, self).__init__()
|
| 32 |
+
|
| 33 |
+
self.backbone = backbone
|
| 34 |
+
self.use_original_imgsize = use_original_imgsize
|
| 35 |
+
|
| 36 |
+
# feature extractor initialization
|
| 37 |
+
if backbone == 'resnet50':
|
| 38 |
+
self.feature_extractor = resnet.resnet50()
|
| 39 |
+
self.feature_extractor.load_state_dict(torch.load(pretrained_path))
|
| 40 |
+
self.feat_channels = [256, 512, 1024, 2048]
|
| 41 |
+
self.nlayers = [3, 4, 6, 3]
|
| 42 |
+
self.feat_ids = list(range(0, 17))
|
| 43 |
+
self.last_feat_size = [12, 12]
|
| 44 |
+
elif backbone == 'resnet101':
|
| 45 |
+
self.feature_extractor = resnet.resnet101()
|
| 46 |
+
self.feature_extractor.load_state_dict(torch.load(pretrained_path))
|
| 47 |
+
self.feat_channels = [256, 512, 1024, 2048]
|
| 48 |
+
self.nlayers = [3, 4, 23, 3]
|
| 49 |
+
self.feat_ids = list(range(0, 34))
|
| 50 |
+
elif backbone == 'swin':
|
| 51 |
+
self.feature_extractor = SwinTransformer(img_size=384, patch_size=4, window_size=12, embed_dim=128,
|
| 52 |
+
depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32])
|
| 53 |
+
self.feature_extractor.load_state_dict(torch.load(pretrained_path)['model'])
|
| 54 |
+
self.feat_channels = [128, 256, 512, 1024]
|
| 55 |
+
self.nlayers = [2, 2, 18, 2]
|
| 56 |
+
else:
|
| 57 |
+
raise Exception('Unavailable backbone: %s' % backbone)
|
| 58 |
+
self.feature_extractor.eval()
|
| 59 |
+
|
| 60 |
+
# define model
|
| 61 |
+
self.lids = reduce(add, [[i + 1] * x for i, x in enumerate(self.nlayers)])
|
| 62 |
+
self.stack_ids = torch.tensor(self.lids).bincount()[-4:].cumsum(dim=0)
|
| 63 |
+
self.model = DCAMA_model(in_channels=self.feat_channels, stack_ids=self.stack_ids)
|
| 64 |
+
|
| 65 |
+
## TODO:
|
| 66 |
+
|
| 67 |
+
self.scorer2 = nn.ModuleList()
|
| 68 |
+
for layer_idx in range(len(self.nlayers)):
|
| 69 |
+
layer_num = self.nlayers[layer_idx]
|
| 70 |
+
for idx in range(layer_num):
|
| 71 |
+
self.scorer2.append(
|
| 72 |
+
nn.Sequential(
|
| 73 |
+
nn.Conv2d(256 * 2 ** layer_idx, 256 * 2 ** layer_idx, 1, 1),
|
| 74 |
+
# nn.ReLU(),
|
| 75 |
+
# nn.InstanceNorm2d(256 * 2 ** layer_idx),
|
| 76 |
+
# nn.Conv2d(256 * 2 ** layer_idx, 256 * 2 ** layer_idx, 1, 1),
|
| 77 |
+
)
|
| 78 |
+
)
|
| 79 |
+
self.scorer1 = nn.Sequential(
|
| 80 |
+
nn.Linear(sum(self.nlayers) - self.nlayers[0], 1)
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
self.cross_entropy_loss = nn.CrossEntropyLoss()
|
| 84 |
+
|
| 85 |
+
def forward(self, query_img, support_img, support_mask, nshot, predict_score=False):
|
| 86 |
+
n_support_feats = []
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
for k in range(nshot):
|
| 89 |
+
support_feats_= self.extract_feats(support_img[:, k])
|
| 90 |
+
support_feats = copy.deepcopy(support_feats_)
|
| 91 |
+
del support_feats_
|
| 92 |
+
torch.cuda.empty_cache()
|
| 93 |
+
n_support_feats.append(support_feats)
|
| 94 |
+
query_feats = self.extract_feats(query_img)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
logit_mask = self.model(query_feats, n_support_feats, support_mask.clone(), nshot=nshot)
|
| 98 |
+
## TODO:
|
| 99 |
+
MAX_SHOTS = 1
|
| 100 |
+
if len(n_support_feats) >= MAX_SHOTS:
|
| 101 |
+
nshot = MAX_SHOTS
|
| 102 |
+
n_support_query_f = []
|
| 103 |
+
n_simi = []
|
| 104 |
+
for i in range(len(n_support_feats)):
|
| 105 |
+
support_f = n_support_feats[i]
|
| 106 |
+
support_query_f = []
|
| 107 |
+
simi_l = []
|
| 108 |
+
simi_sum = []
|
| 109 |
+
for l in range(len(query_feats)):
|
| 110 |
+
if l < self.stack_ids[0]:
|
| 111 |
+
continue
|
| 112 |
+
elif l < self.stack_ids[1]:
|
| 113 |
+
DCAMA_blocks = self.model.DCAMA_blocks[0]
|
| 114 |
+
pe = self.model.pe[0]
|
| 115 |
+
elif l < self.stack_ids[2]:
|
| 116 |
+
DCAMA_blocks = self.model.DCAMA_blocks[1]
|
| 117 |
+
pe = self.model.pe[1]
|
| 118 |
+
else:
|
| 119 |
+
DCAMA_blocks = self.model.DCAMA_blocks[2]
|
| 120 |
+
pe = self.model.pe[2]
|
| 121 |
+
a_support_f = support_f[l].clone()
|
| 122 |
+
coords = reshape(support_mask[0, i], a_support_f.size()[-2:])
|
| 123 |
+
b, ch, w, h = a_support_f.size()
|
| 124 |
+
a_support_f = a_support_f.view(b, ch, -1)
|
| 125 |
+
a_support_f = DCAMA_blocks.linears[0](pe(a_support_f.permute(0, 2, 1))).permute(0, 2, 1)
|
| 126 |
+
a_support_f = a_support_f.view(b, ch, w, h)
|
| 127 |
+
a_support_f = self.scorer2[l](a_support_f)
|
| 128 |
+
a_support_f = a_support_f[:, :, coords[:, 0], coords[:, 1]].mean(-1).unsqueeze(-1).unsqueeze(-1).repeat((1, 1, a_support_f.size(-2), a_support_f.size(-1)))
|
| 129 |
+
# a_support_f[:, :, coords_reverse[:, 0], coords_reverse[:, 1]] *= 0.
|
| 130 |
+
query_feat = query_feats[l].view(b, ch, -1)
|
| 131 |
+
query_feat = DCAMA_blocks.linears[0](pe(query_feat.permute(0, 2, 1))).permute(0, 2, 1)
|
| 132 |
+
query_feat = query_feat.view(b, ch, w, h)
|
| 133 |
+
query_feat = self.scorer2[l](query_feat)
|
| 134 |
+
simi = ((query_feat * a_support_f).sum(1) / torch.norm(query_feat, dim=1) / torch.norm(a_support_f, dim=1))[0]
|
| 135 |
+
simi_sum.append(simi)
|
| 136 |
+
# simi = torch.norm(query_feats[l] - a_support_f, dim=1)[0]
|
| 137 |
+
if l == 6:
|
| 138 |
+
simi_map = simi.clone()
|
| 139 |
+
simi = simi.view(-1).mean()
|
| 140 |
+
simi_l.append(simi)
|
| 141 |
+
# simi_l = self.scorer1(torch.stack(simi_l, dim=0).unsqueeze(0)).squeeze(0)[0]
|
| 142 |
+
n_simi.append(torch.stack(simi_l, dim=0).mean())
|
| 143 |
+
|
| 144 |
+
n_simi = torch.stack(n_simi, dim=0)
|
| 145 |
+
args = n_simi.argsort(descending=True)[:MAX_SHOTS]
|
| 146 |
+
support_mask = support_mask[:, args, :, :]
|
| 147 |
+
# n_support_feats = [n_support_feats[arg] for arg in args]
|
| 148 |
+
n_simis = n_simi[args].max()
|
| 149 |
+
else:
|
| 150 |
+
n_simis = torch.tensor(0.).float().cuda()
|
| 151 |
+
return logit_mask, n_simis
|
| 152 |
+
|
| 153 |
+
def extract_feats(self, img):
|
| 154 |
+
r""" Extract input image features """
|
| 155 |
+
feats = []
|
| 156 |
+
|
| 157 |
+
if self.backbone == 'swin':
|
| 158 |
+
_ = self.feature_extractor.forward_features(img)
|
| 159 |
+
for feat in self.feature_extractor.feat_maps:
|
| 160 |
+
bsz, hw, c = feat.size()
|
| 161 |
+
h = int(hw ** 0.5)
|
| 162 |
+
feat = feat.view(bsz, h, h, c).permute(0, 3, 1, 2).contiguous()
|
| 163 |
+
feats.append(feat)
|
| 164 |
+
elif self.backbone == 'resnet50' or self.backbone == 'resnet101':
|
| 165 |
+
bottleneck_ids = reduce(add, list(map(lambda x: list(range(x)), self.nlayers)))
|
| 166 |
+
# Layer 0
|
| 167 |
+
feat = self.feature_extractor.conv1.forward(img)
|
| 168 |
+
feat = self.feature_extractor.bn1.forward(feat)
|
| 169 |
+
feat = self.feature_extractor.relu.forward(feat)
|
| 170 |
+
feat = self.feature_extractor.maxpool.forward(feat)
|
| 171 |
+
|
| 172 |
+
# Layer 1-4
|
| 173 |
+
for hid, (bid, lid) in enumerate(zip(bottleneck_ids, self.lids)):
|
| 174 |
+
res = feat
|
| 175 |
+
feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].conv1.forward(feat)
|
| 176 |
+
feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].bn1.forward(feat)
|
| 177 |
+
feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
|
| 178 |
+
feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].conv2.forward(feat)
|
| 179 |
+
feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].bn2.forward(feat)
|
| 180 |
+
feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
|
| 181 |
+
feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].conv3.forward(feat)
|
| 182 |
+
feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].bn3.forward(feat)
|
| 183 |
+
|
| 184 |
+
if bid == 0:
|
| 185 |
+
res = self.feature_extractor.__getattr__('layer%d' % lid)[bid].downsample.forward(res)
|
| 186 |
+
|
| 187 |
+
feat += res
|
| 188 |
+
|
| 189 |
+
if hid + 1 in self.feat_ids:
|
| 190 |
+
feats.append(feat.clone())
|
| 191 |
+
|
| 192 |
+
feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
|
| 193 |
+
|
| 194 |
+
return feats
|
| 195 |
+
|
| 196 |
+
def predict_mask_nshot(self, batch, nshot):
|
| 197 |
+
r""" n-shot inference """
|
| 198 |
+
query_img = batch['query_img']
|
| 199 |
+
support_imgs = batch['support_imgs']
|
| 200 |
+
support_masks = batch['support_masks']
|
| 201 |
+
|
| 202 |
+
if nshot == 1:
|
| 203 |
+
with torch.no_grad():
|
| 204 |
+
query_feats = self.extract_feats(query_img)
|
| 205 |
+
n_support_feats = []
|
| 206 |
+
for k in range(nshot):
|
| 207 |
+
support_feats = self.extract_feats(support_imgs[:, k])
|
| 208 |
+
n_support_feats.append(support_feats)
|
| 209 |
+
|
| 210 |
+
n_simis = []
|
| 211 |
+
simi_map = None
|
| 212 |
+
for i in range(len(n_support_feats)):
|
| 213 |
+
support_f = n_support_feats[i]
|
| 214 |
+
support_query_f = []
|
| 215 |
+
simi_l = []
|
| 216 |
+
for l in range(len(query_feats)):
|
| 217 |
+
if l < self.stack_ids[0]:
|
| 218 |
+
continue
|
| 219 |
+
elif l < self.stack_ids[1]:
|
| 220 |
+
DCAMA_blocks = self.model.DCAMA_blocks[0]
|
| 221 |
+
pe = self.model.pe[0]
|
| 222 |
+
elif l < self.stack_ids[2]:
|
| 223 |
+
DCAMA_blocks = self.model.DCAMA_blocks[1]
|
| 224 |
+
pe = self.model.pe[1]
|
| 225 |
+
else:
|
| 226 |
+
DCAMA_blocks = self.model.DCAMA_blocks[2]
|
| 227 |
+
pe = self.model.pe[2]
|
| 228 |
+
a_support_f = support_f[l].clone()
|
| 229 |
+
coords = reshape(support_masks[0, i], a_support_f.size()[-2:])
|
| 230 |
+
b, ch, w, h = a_support_f.size()
|
| 231 |
+
a_support_f = a_support_f.view(b, ch, -1)
|
| 232 |
+
a_support_f = DCAMA_blocks.linears[0](pe(a_support_f.permute(0, 2, 1))).permute(0, 2, 1)
|
| 233 |
+
a_support_f = a_support_f.view(b, ch, w, h)
|
| 234 |
+
a_support_f = a_support_f[:, :, coords[:, 0], coords[:, 1]].mean(-1).unsqueeze(-1).unsqueeze(-1).repeat((1, 1, a_support_f.size(-2), a_support_f.size(-1)))
|
| 235 |
+
# a_support_f[:, :, coords_reverse[:, 0], coords_reverse[:, 1]] *= 0.
|
| 236 |
+
query_feat = query_feats[l].view(b, ch, -1)
|
| 237 |
+
query_feat = DCAMA_blocks.linears[0](pe(query_feat.permute(0, 2, 1))).permute(0, 2, 1)
|
| 238 |
+
query_feat = query_feat.view(b, ch, w, h)
|
| 239 |
+
simi = ((query_feat * a_support_f).sum(1) / torch.norm(query_feat, dim=1) / torch.norm(a_support_f, dim=1))[0]
|
| 240 |
+
# simi = torch.norm(query_feats[l] - a_support_f, dim=1)[0]
|
| 241 |
+
if l == 13:
|
| 242 |
+
simi_map = simi.clone()
|
| 243 |
+
simi = simi.view(-1).max()
|
| 244 |
+
simi_l.append(simi)
|
| 245 |
+
simi_l = torch.stack(simi_l, dim=0).mean()
|
| 246 |
+
n_simis.append(simi_l)
|
| 247 |
+
n_simis = torch.stack(n_simis, dim=0)
|
| 248 |
+
logit_mask = self.model(query_feats, n_support_feats, support_masks.clone(), nshot)
|
| 249 |
+
|
| 250 |
+
else:
|
| 251 |
+
with torch.no_grad():
|
| 252 |
+
query_feats = self.extract_feats(query_img)
|
| 253 |
+
n_support_feats = []
|
| 254 |
+
for k in range(nshot):
|
| 255 |
+
support_feats = self.extract_feats(support_imgs[:, k])
|
| 256 |
+
n_support_feats.append(support_feats)
|
| 257 |
+
|
| 258 |
+
## TODO: retrieval V1 ##
|
| 259 |
+
MAX_SHOTS = 200
|
| 260 |
+
'''
|
| 261 |
+
if len(n_support_feats) > MAX_SHOTS:
|
| 262 |
+
nshot = MAX_SHOTS
|
| 263 |
+
n_support_query_f = []
|
| 264 |
+
n_simis = []
|
| 265 |
+
for i in range(len(n_support_feats)):
|
| 266 |
+
support_f = n_support_feats[i]
|
| 267 |
+
support_query_f = []
|
| 268 |
+
simi_l = []
|
| 269 |
+
simi_sum = []
|
| 270 |
+
for l in range(len(query_feats)):
|
| 271 |
+
if l < self.stack_ids[0]:
|
| 272 |
+
continue
|
| 273 |
+
elif l < self.stack_ids[1]:
|
| 274 |
+
DCAMA_blocks = self.model.DCAMA_blocks[0]
|
| 275 |
+
pe = self.model.pe[0]
|
| 276 |
+
elif l < self.stack_ids[2]:
|
| 277 |
+
DCAMA_blocks = self.model.DCAMA_blocks[1]
|
| 278 |
+
pe = self.model.pe[1]
|
| 279 |
+
else:
|
| 280 |
+
DCAMA_blocks = self.model.DCAMA_blocks[2]
|
| 281 |
+
pe = self.model.pe[2]
|
| 282 |
+
a_support_f = support_f[l].clone()
|
| 283 |
+
coords = reshape(support_masks[0, i], a_support_f.size()[-2:])
|
| 284 |
+
b, ch, w, h = a_support_f.size()
|
| 285 |
+
a_support_f = a_support_f.view(b, ch, -1)
|
| 286 |
+
a_support_f = DCAMA_blocks.linears[0](pe(a_support_f.permute(0, 2, 1))).permute(0, 2, 1)
|
| 287 |
+
a_support_f = a_support_f.view(b, ch, w, h)
|
| 288 |
+
a_support_f = a_support_f[:, :, coords[:, 0], coords[:, 1]].mean(-1).unsqueeze(-1).unsqueeze(-1).repeat((1, 1, a_support_f.size(-2), a_support_f.size(-1)))
|
| 289 |
+
# a_support_f[:, :, coords_reverse[:, 0], coords_reverse[:, 1]] *= 0.
|
| 290 |
+
query_feat = query_feats[l].view(b, ch, -1)
|
| 291 |
+
query_feat = DCAMA_blocks.linears[0](pe(query_feat.permute(0, 2, 1))).permute(0, 2, 1)
|
| 292 |
+
query_feat = query_feat.view(b, ch, w, h)
|
| 293 |
+
simi = ((query_feat * a_support_f).sum(1) / torch.norm(query_feat, dim=1) / torch.norm(a_support_f, dim=1))[0]
|
| 294 |
+
simi_sum.append(simi)
|
| 295 |
+
# simi = torch.norm(query_feats[l] - a_support_f, dim=1)[0]
|
| 296 |
+
if l == 6:
|
| 297 |
+
simi_map = simi.clone().detach().cpu().numpy()
|
| 298 |
+
simi = simi.view(-1).max()
|
| 299 |
+
simi_l.append(simi)
|
| 300 |
+
simi_l = torch.stack(simi_l, dim=0).mean()
|
| 301 |
+
n_simis.append(simi_l)
|
| 302 |
+
n_simis = torch.stack(n_simis, dim=0)
|
| 303 |
+
# nshot = max((n_simis > 0.).sum(), 1)
|
| 304 |
+
nshot = len(n_simis)
|
| 305 |
+
support_masks = support_masks[:, n_simis.argsort(descending=True)[:nshot], :, :]
|
| 306 |
+
n_support_feats = [n_support_feats[i] for i in n_simis.argsort(descending=True)[:nshot]]
|
| 307 |
+
else:
|
| 308 |
+
n_simis = torch.tensor(0.).float().cuda()
|
| 309 |
+
simi_map = None
|
| 310 |
+
## TODO: retriever V2
|
| 311 |
+
'''
|
| 312 |
+
'''
|
| 313 |
+
MAX_SHOTS = 30
|
| 314 |
+
if len(n_support_feats) > MAX_SHOTS:
|
| 315 |
+
nshot = MAX_SHOTS
|
| 316 |
+
n_support_query_f = []
|
| 317 |
+
n_simis = []
|
| 318 |
+
support_f_list = []
|
| 319 |
+
n_support_feats2 = []
|
| 320 |
+
query_feats2 = []
|
| 321 |
+
for i in range(len(n_support_feats)):
|
| 322 |
+
support_f = n_support_feats[i]
|
| 323 |
+
n_support_feats2_l = []
|
| 324 |
+
query_feats2_l = []
|
| 325 |
+
for l in range(len(query_feats)):
|
| 326 |
+
if l < self.stack_ids[0]:
|
| 327 |
+
continue
|
| 328 |
+
elif l < self.stack_ids[1]:
|
| 329 |
+
DCAMA_blocks = self.model.DCAMA_blocks[0]
|
| 330 |
+
pe = self.model.pe[0]
|
| 331 |
+
elif l < self.stack_ids[2]:
|
| 332 |
+
DCAMA_blocks = self.model.DCAMA_blocks[1]
|
| 333 |
+
pe = self.model.pe[1]
|
| 334 |
+
else:
|
| 335 |
+
DCAMA_blocks = self.model.DCAMA_blocks[2]
|
| 336 |
+
pe = self.model.pe[2]
|
| 337 |
+
a_support_f = support_f[l].clone()
|
| 338 |
+
coords = reshape(support_masks[0, i], a_support_f.size()[-2:])
|
| 339 |
+
b, ch, w, h = a_support_f.size()
|
| 340 |
+
a_support_f = a_support_f.view(b, ch, -1)
|
| 341 |
+
a_support_f = DCAMA_blocks.linears[0](pe(a_support_f.permute(0, 2, 1))).permute(0, 2, 1)
|
| 342 |
+
a_support_f = a_support_f.view(b, ch, w, h)
|
| 343 |
+
a_support_f = a_support_f[:, :, coords[:, 0], coords[:, 1]].mean(-1)
|
| 344 |
+
n_support_feats2_l.append(a_support_f)
|
| 345 |
+
query_feat = query_feats[l].view(b, ch, -1)
|
| 346 |
+
query_feat = DCAMA_blocks.linears[0](pe(query_feat.permute(0, 2, 1))).permute(0, 2, 1)
|
| 347 |
+
query_feat = query_feat.view(b, ch, w, h)
|
| 348 |
+
query_feats2_l.append(query_feat)
|
| 349 |
+
n_support_feats2.append(n_support_feats2_l)
|
| 350 |
+
query_feats2.append(query_feats2_l)
|
| 351 |
+
n_support_feats3 = [[] for _ in range(len(query_feats2[0]))]
|
| 352 |
+
selected = []
|
| 353 |
+
for i in range(MAX_SHOTS):
|
| 354 |
+
simi_min = -100
|
| 355 |
+
idx_min = -1
|
| 356 |
+
for idx in range(len(n_support_feats2)):
|
| 357 |
+
if idx in selected:
|
| 358 |
+
continue
|
| 359 |
+
support_feats2 = n_support_feats2[idx]
|
| 360 |
+
simi = []
|
| 361 |
+
for l in range(len(query_feats2[i])):
|
| 362 |
+
support_feats_avg = torch.stack(n_support_feats3[l] + [support_feats2[l]], dim=0).mean(0)
|
| 363 |
+
query_feat = query_feats2[i][l]
|
| 364 |
+
a_support_f = support_feats_avg.unsqueeze(-1).unsqueeze(-1).repeat(
|
| 365 |
+
(1, 1, query_feat.size(-2), query_feat.size(-1)))
|
| 366 |
+
simi_l = ((query_feat * a_support_f).sum(1) / torch.norm(query_feat, dim=1) / torch.norm(
|
| 367 |
+
a_support_f, dim=1))[0].view(-1).max()
|
| 368 |
+
simi.append(simi_l)
|
| 369 |
+
simi = torch.stack(simi, dim=0).mean()
|
| 370 |
+
if simi > simi_min:
|
| 371 |
+
simi_min = simi
|
| 372 |
+
idx_min = idx
|
| 373 |
+
support_feats2_argmin = n_support_feats2[idx]
|
| 374 |
+
for l2 in range(len(query_feats2[0])):
|
| 375 |
+
n_support_feats3[l2].append(n_support_feats2[idx_min][l2])
|
| 376 |
+
selected.append(idx_min)
|
| 377 |
+
n_support_feats4 = []
|
| 378 |
+
for idx in selected:
|
| 379 |
+
n_support_feats4.append(n_support_feats[idx])
|
| 380 |
+
support_masks = support_masks[:, torch.tensor(selected).long().cuda(), :, :]
|
| 381 |
+
n_support_feats = n_support_feats4
|
| 382 |
+
simi_map = None
|
| 383 |
+
else:
|
| 384 |
+
n_simis = torch.tensor(0.).float().cuda()
|
| 385 |
+
simi_map = None
|
| 386 |
+
'''
|
| 387 |
+
## TODO: v3
|
| 388 |
+
|
| 389 |
+
MAX_SHOTS = 200
|
| 390 |
+
if len(n_support_feats) > MAX_SHOTS:
|
| 391 |
+
nshot = MAX_SHOTS
|
| 392 |
+
n_support_query_f = []
|
| 393 |
+
n_simis = []
|
| 394 |
+
support_f_list = []
|
| 395 |
+
n_support_feats2 = []
|
| 396 |
+
query_feats2 = []
|
| 397 |
+
for i in range(len(n_support_feats)):
|
| 398 |
+
support_f = n_support_feats[i]
|
| 399 |
+
n_support_feats2_l = []
|
| 400 |
+
query_feats2_l = []
|
| 401 |
+
for l in range(len(query_feats)):
|
| 402 |
+
if l < self.stack_ids[0]:
|
| 403 |
+
continue
|
| 404 |
+
elif l < self.stack_ids[1]:
|
| 405 |
+
DCAMA_blocks = self.model.DCAMA_blocks[0]
|
| 406 |
+
pe = self.model.pe[0]
|
| 407 |
+
elif l < self.stack_ids[2]:
|
| 408 |
+
DCAMA_blocks = self.model.DCAMA_blocks[1]
|
| 409 |
+
pe = self.model.pe[1]
|
| 410 |
+
else:
|
| 411 |
+
DCAMA_blocks = self.model.DCAMA_blocks[2]
|
| 412 |
+
pe = self.model.pe[2]
|
| 413 |
+
a_support_f = support_f[l].clone()
|
| 414 |
+
coords = reshape(support_masks[0, i], a_support_f.size()[-2:])
|
| 415 |
+
b, ch, w, h = a_support_f.size()
|
| 416 |
+
a_support_f = a_support_f.view(b, ch, -1)
|
| 417 |
+
a_support_f_tmp = DCAMA_blocks.linears[0](pe(a_support_f.permute(0, 2, 1))).permute(0, 2, 1)
|
| 418 |
+
a_support_f = a_support_f_tmp / a_support_f_tmp.norm(dim=1, keepdim=True) * DCAMA_blocks.linears[1](pe(a_support_f.permute(0, 2, 1))).permute(0, 2, 1)
|
| 419 |
+
a_support_f = a_support_f.view(b, ch, w, h)
|
| 420 |
+
a_support_f = a_support_f[:, :, coords[:, 0], coords[:, 1]].mean(-1)
|
| 421 |
+
n_support_feats2_l.append(a_support_f)
|
| 422 |
+
query_feat = query_feats[l].view(b, ch, -1)
|
| 423 |
+
query_feat_tmp = DCAMA_blocks.linears[0](pe(query_feat.permute(0, 2, 1))).permute(0, 2, 1)
|
| 424 |
+
query_feat = query_feat_tmp / query_feat_tmp.norm(dim=1, keepdim=True) * DCAMA_blocks.linears[1](pe(query_feat.permute(0, 2, 1))).permute(0, 2, 1)
|
| 425 |
+
query_feat = query_feat.view(b, ch, w, h)
|
| 426 |
+
query_feats2_l.append(query_feat)
|
| 427 |
+
n_support_feats2.append(n_support_feats2_l)
|
| 428 |
+
query_feats2.append(query_feats2_l)
|
| 429 |
+
n_support_feats3 = [[] for _ in range(len(query_feats2[0]))]
|
| 430 |
+
selected = []
|
| 431 |
+
for i in range(MAX_SHOTS):
|
| 432 |
+
simi_min = -100
|
| 433 |
+
idx_min = -1
|
| 434 |
+
for idx in range(len(n_support_feats2)):
|
| 435 |
+
if idx in selected:
|
| 436 |
+
continue
|
| 437 |
+
support_feats2 = n_support_feats2[idx]
|
| 438 |
+
simi = []
|
| 439 |
+
for l in range(len(query_feats2[i])):
|
| 440 |
+
support_feats_avg = torch.stack(n_support_feats3[l] + [support_feats2[l]], dim=0).mean(0)
|
| 441 |
+
query_feat = query_feats2[i][l]
|
| 442 |
+
a_support_f = support_feats_avg.unsqueeze(-1).unsqueeze(-1).repeat(
|
| 443 |
+
(1, 1, query_feat.size(-2), query_feat.size(-1)))
|
| 444 |
+
simi_l = ((query_feat * a_support_f).sum(1))[0].view(-1).max()
|
| 445 |
+
simi.append(simi_l)
|
| 446 |
+
simi = torch.stack(simi, dim=0).mean()
|
| 447 |
+
if simi > simi_min:
|
| 448 |
+
simi_min = simi
|
| 449 |
+
idx_min = idx
|
| 450 |
+
support_feats2_argmin = n_support_feats2[idx]
|
| 451 |
+
for l2 in range(len(query_feats2[0])):
|
| 452 |
+
n_support_feats3[l2].append(n_support_feats2[idx_min][l2])
|
| 453 |
+
selected.append(idx_min)
|
| 454 |
+
n_support_feats4 = []
|
| 455 |
+
for idx in selected:
|
| 456 |
+
n_support_feats4.append(n_support_feats[idx])
|
| 457 |
+
support_masks = support_masks[:, torch.tensor(selected).long().cuda(), :, :]
|
| 458 |
+
n_support_feats = n_support_feats4
|
| 459 |
+
simi_map = None
|
| 460 |
+
else:
|
| 461 |
+
n_simis = torch.tensor(0.).float().cuda()
|
| 462 |
+
simi_map = None
|
| 463 |
+
|
| 464 |
+
logit_mask = self.model(query_feats, n_support_feats, support_masks.clone(), nshot)
|
| 465 |
+
|
| 466 |
+
if self.use_original_imgsize:
|
| 467 |
+
org_qry_imsize = tuple([batch['org_query_imsize'][1].item(), batch['org_query_imsize'][0].item()])
|
| 468 |
+
logit_mask = F.interpolate(logit_mask, org_qry_imsize, mode='bilinear', align_corners=True)
|
| 469 |
+
else:
|
| 470 |
+
logit_mask = F.interpolate(logit_mask, support_imgs[0].size()[2:], mode='bilinear', align_corners=True)
|
| 471 |
+
|
| 472 |
+
return logit_mask.argmax(dim=1), n_simis, simi_map
|
| 473 |
+
|
| 474 |
+
def compute_objective(self, logit_mask, gt_mask):
|
| 475 |
+
bsz = logit_mask.size(0)
|
| 476 |
+
logit_mask = logit_mask.view(bsz, 2, -1)
|
| 477 |
+
gt_mask = gt_mask.view(bsz, -1).long()
|
| 478 |
+
|
| 479 |
+
return self.cross_entropy_loss(logit_mask, gt_mask)
|
| 480 |
+
|
| 481 |
+
def train_mode(self):
|
| 482 |
+
self.train()
|
| 483 |
+
self.feature_extractor.eval()
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
class DCAMA_model(nn.Module):
|
| 487 |
+
def __init__(self, in_channels, stack_ids):
|
| 488 |
+
super(DCAMA_model, self).__init__()
|
| 489 |
+
|
| 490 |
+
self.stack_ids = stack_ids
|
| 491 |
+
|
| 492 |
+
# DCAMA blocks
|
| 493 |
+
self.DCAMA_blocks = nn.ModuleList()
|
| 494 |
+
self.pe = nn.ModuleList()
|
| 495 |
+
for inch in in_channels[1:]:
|
| 496 |
+
self.DCAMA_blocks.append(MultiHeadedAttention(h=8, d_model=inch, dropout=0.5))
|
| 497 |
+
self.pe.append(PositionalEncoding(d_model=inch, dropout=0.5))
|
| 498 |
+
|
| 499 |
+
outch1, outch2, outch3 = 16, 64, 128
|
| 500 |
+
|
| 501 |
+
# conv blocks
|
| 502 |
+
self.conv1 = self.build_conv_block(stack_ids[3]-stack_ids[2], [outch1, outch2, outch3], [3, 3, 3], [1, 1, 1]) # 1/32
|
| 503 |
+
self.conv2 = self.build_conv_block(stack_ids[2]-stack_ids[1], [outch1, outch2, outch3], [5, 3, 3], [1, 1, 1]) # 1/16
|
| 504 |
+
self.conv3 = self.build_conv_block(stack_ids[1]-stack_ids[0], [outch1, outch2, outch3], [5, 5, 3], [1, 1, 1]) # 1/8
|
| 505 |
+
|
| 506 |
+
self.conv4 = self.build_conv_block(outch3, [outch3, outch3, outch3], [3, 3, 3], [1, 1, 1]) # 1/32 + 1/16
|
| 507 |
+
self.conv5 = self.build_conv_block(outch3, [outch3, outch3, outch3], [3, 3, 3], [1, 1, 1]) # 1/16 + 1/8
|
| 508 |
+
|
| 509 |
+
# mixer blocks
|
| 510 |
+
self.mixer1 = nn.Sequential(nn.Conv2d(outch3+2*in_channels[1]+2*in_channels[0], outch3, (3, 3), padding=(1, 1), bias=True),
|
| 511 |
+
nn.ReLU(),
|
| 512 |
+
nn.Conv2d(outch3, outch2, (3, 3), padding=(1, 1), bias=True),
|
| 513 |
+
nn.ReLU())
|
| 514 |
+
|
| 515 |
+
self.mixer2 = nn.Sequential(nn.Conv2d(outch2, outch2, (3, 3), padding=(1, 1), bias=True),
|
| 516 |
+
nn.ReLU(),
|
| 517 |
+
nn.Conv2d(outch2, outch1, (3, 3), padding=(1, 1), bias=True),
|
| 518 |
+
nn.ReLU())
|
| 519 |
+
|
| 520 |
+
self.mixer3 = nn.Sequential(nn.Conv2d(outch1, outch1, (3, 3), padding=(1, 1), bias=True),
|
| 521 |
+
nn.ReLU(),
|
| 522 |
+
nn.Conv2d(outch1, 2, (3, 3), padding=(1, 1), bias=True))
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
def forward(self, query_feats, support_feats, support_mask, nshot=1):
|
| 527 |
+
coarse_masks = []
|
| 528 |
+
for idx, query_feat in enumerate(query_feats):
|
| 529 |
+
# 1/4 scale feature only used in skip connect
|
| 530 |
+
if idx < self.stack_ids[0]: continue
|
| 531 |
+
|
| 532 |
+
bsz, ch, ha, wa = query_feat.size()
|
| 533 |
+
|
| 534 |
+
# reshape the input feature and mask
|
| 535 |
+
query = query_feat.view(bsz, ch, -1).permute(0, 2, 1).contiguous()
|
| 536 |
+
# if nshot == 1:
|
| 537 |
+
# support_feat = support_feats[idx]
|
| 538 |
+
# mask = F.interpolate(support_mask.unsqueeze(1).float(), support_feat.size()[2:], mode='bilinear',
|
| 539 |
+
# align_corners=True).view(support_feat.size()[0], -1)
|
| 540 |
+
# support_feat = support_feat.view(support_feat.size()[0], support_feat.size()[1], -1).permute(0, 2, 1).contiguous()
|
| 541 |
+
# else:
|
| 542 |
+
support_feat = torch.stack([support_feats[k][idx] for k in range(nshot)])
|
| 543 |
+
support_feat = support_feat.view(-1, ch, ha * wa).permute(0, 2, 1).contiguous()
|
| 544 |
+
mask = torch.stack([F.interpolate(k.unsqueeze(1).float(), (ha, wa), mode='bilinear', align_corners=True)
|
| 545 |
+
for k in support_mask])
|
| 546 |
+
mask = mask.view(bsz, -1)
|
| 547 |
+
|
| 548 |
+
# DCAMA blocks forward
|
| 549 |
+
DCAMA_blocks = None
|
| 550 |
+
pe = None
|
| 551 |
+
if idx < self.stack_ids[1]:
|
| 552 |
+
DCAMA_blocks = self.DCAMA_blocks[0]
|
| 553 |
+
pe = self.pe[0]
|
| 554 |
+
elif idx < self.stack_ids[2]:
|
| 555 |
+
DCAMA_blocks = self.DCAMA_blocks[1]
|
| 556 |
+
pe = self.pe[1]
|
| 557 |
+
else:
|
| 558 |
+
DCAMA_blocks = self.DCAMA_blocks[2]
|
| 559 |
+
pe = self.pe[2]
|
| 560 |
+
coarse_mask = DCAMA_blocks(pe(query), pe(support_feat), mask)
|
| 561 |
+
coarse_masks.append(coarse_mask.permute(0, 2, 1).contiguous().view(bsz, 1, ha, wa))
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
# multi-scale conv blocks forward
|
| 565 |
+
bsz, ch, ha, wa = coarse_masks[self.stack_ids[3]-1-self.stack_ids[0]].size()
|
| 566 |
+
coarse_masks1 = torch.stack(coarse_masks[self.stack_ids[2]-self.stack_ids[0]:self.stack_ids[3]-self.stack_ids[0]]).transpose(0, 1).contiguous().view(bsz, -1, ha, wa)
|
| 567 |
+
bsz, ch, ha, wa = coarse_masks[self.stack_ids[2]-1-self.stack_ids[0]].size()
|
| 568 |
+
coarse_masks2 = torch.stack(coarse_masks[self.stack_ids[1]-self.stack_ids[0]:self.stack_ids[2]-self.stack_ids[0]]).transpose(0, 1).contiguous().view(bsz, -1, ha, wa)
|
| 569 |
+
bsz, ch, ha, wa = coarse_masks[self.stack_ids[1]-1-self.stack_ids[0]].size()
|
| 570 |
+
coarse_masks3 = torch.stack(coarse_masks[0:self.stack_ids[1]-self.stack_ids[0]]).transpose(0, 1).contiguous().view(bsz, -1, ha, wa)
|
| 571 |
+
|
| 572 |
+
coarse_masks1 = self.conv1(coarse_masks1)
|
| 573 |
+
coarse_masks2 = self.conv2(coarse_masks2)
|
| 574 |
+
coarse_masks3 = self.conv3(coarse_masks3)
|
| 575 |
+
|
| 576 |
+
# multi-scale cascade (pixel-wise addition)
|
| 577 |
+
coarse_masks1 = F.interpolate(coarse_masks1, coarse_masks2.size()[-2:], mode='bilinear', align_corners=True)
|
| 578 |
+
mix = coarse_masks1 + coarse_masks2
|
| 579 |
+
mix = self.conv4(mix)
|
| 580 |
+
|
| 581 |
+
mix = F.interpolate(mix, coarse_masks3.size()[-2:], mode='bilinear', align_corners=True)
|
| 582 |
+
mix = mix + coarse_masks3
|
| 583 |
+
mix = self.conv5(mix)
|
| 584 |
+
|
| 585 |
+
# skip connect 1/8 and 1/4 features (concatenation)
|
| 586 |
+
# if nshot == 1:
|
| 587 |
+
# support_feat = support_feats[self.stack_ids[1] - 1]
|
| 588 |
+
# else:
|
| 589 |
+
support_feat = torch.stack([support_feats[k][self.stack_ids[1] - 1] for k in range(nshot)]).max(dim=0).values
|
| 590 |
+
mix = torch.cat((mix, query_feats[self.stack_ids[1] - 1], support_feat), 1)
|
| 591 |
+
|
| 592 |
+
upsample_size = (mix.size(-1) * 2,) * 2
|
| 593 |
+
mix = F.interpolate(mix, upsample_size, mode='bilinear', align_corners=True)
|
| 594 |
+
# if nshot == 1:
|
| 595 |
+
# support_feat = support_feats[self.stack_ids[0] - 1]
|
| 596 |
+
# else:
|
| 597 |
+
support_feat = torch.stack([support_feats[k][self.stack_ids[0] - 1] for k in range(nshot)]).max(dim=0).values
|
| 598 |
+
mix = torch.cat((mix, query_feats[self.stack_ids[0] - 1], support_feat), 1)
|
| 599 |
+
|
| 600 |
+
# mixer blocks forward
|
| 601 |
+
out = self.mixer1(mix)
|
| 602 |
+
upsample_size = (out.size(-1) * 2,) * 2
|
| 603 |
+
out = F.interpolate(out, upsample_size, mode='bilinear', align_corners=True)
|
| 604 |
+
out = self.mixer2(out)
|
| 605 |
+
upsample_size = (out.size(-1) * 2,) * 2
|
| 606 |
+
out = F.interpolate(out, upsample_size, mode='bilinear', align_corners=True)
|
| 607 |
+
logit_mask = self.mixer3(out)
|
| 608 |
+
|
| 609 |
+
return logit_mask
|
| 610 |
+
|
| 611 |
+
def build_conv_block(self, in_channel, out_channels, kernel_sizes, spt_strides, group=4):
|
| 612 |
+
r""" bulid conv blocks """
|
| 613 |
+
assert len(out_channels) == len(kernel_sizes) == len(spt_strides)
|
| 614 |
+
|
| 615 |
+
building_block_layers = []
|
| 616 |
+
for idx, (outch, ksz, stride) in enumerate(zip(out_channels, kernel_sizes, spt_strides)):
|
| 617 |
+
inch = in_channel if idx == 0 else out_channels[idx - 1]
|
| 618 |
+
pad = ksz // 2
|
| 619 |
+
|
| 620 |
+
building_block_layers.append(nn.Conv2d(in_channels=inch, out_channels=outch,
|
| 621 |
+
kernel_size=ksz, stride=stride, padding=pad))
|
| 622 |
+
building_block_layers.append(nn.GroupNorm(group, outch))
|
| 623 |
+
building_block_layers.append(nn.ReLU(inplace=True))
|
| 624 |
+
|
| 625 |
+
return nn.Sequential(*building_block_layers)
|
model/__pycache__/DCAMA.cpython-38.pyc
ADDED
|
Binary file (13.8 kB). View file
|
|
|
model/__pycache__/DCAMA.cpython-39.pyc
ADDED
|
Binary file (13.3 kB). View file
|
|
|
model/base/__pycache__/swin_transformer.cpython-38.pyc
ADDED
|
Binary file (20.6 kB). View file
|
|
|
model/base/__pycache__/swin_transformer.cpython-39.pyc
ADDED
|
Binary file (20.5 kB). View file
|
|
|
model/base/__pycache__/transformer.cpython-38.pyc
ADDED
|
Binary file (3.61 kB). View file
|
|
|
model/base/__pycache__/transformer.cpython-39.pyc
ADDED
|
Binary file (3.68 kB). View file
|
|
|
model/base/swin_transformer.py
ADDED
|
@@ -0,0 +1,605 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# Swin Transformer
|
| 3 |
+
# Copyright (c) 2021 Microsoft
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# Written by Ze Liu
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.utils.checkpoint as checkpoint
|
| 11 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Mlp(nn.Module):
|
| 15 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 16 |
+
super().__init__()
|
| 17 |
+
out_features = out_features or in_features
|
| 18 |
+
hidden_features = hidden_features or in_features
|
| 19 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 20 |
+
self.act = act_layer()
|
| 21 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 22 |
+
self.drop = nn.Dropout(drop)
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
x = self.fc1(x)
|
| 26 |
+
x = self.act(x)
|
| 27 |
+
x = self.drop(x)
|
| 28 |
+
x = self.fc2(x)
|
| 29 |
+
x = self.drop(x)
|
| 30 |
+
return x
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def window_partition(x, window_size):
|
| 34 |
+
"""
|
| 35 |
+
Args:
|
| 36 |
+
x: (B, H, W, C)
|
| 37 |
+
window_size (int): window size
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 41 |
+
"""
|
| 42 |
+
B, H, W, C = x.shape
|
| 43 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
| 44 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
| 45 |
+
return windows
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def window_reverse(windows, window_size, H, W):
|
| 49 |
+
"""
|
| 50 |
+
Args:
|
| 51 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 52 |
+
window_size (int): Window size
|
| 53 |
+
H (int): Height of image
|
| 54 |
+
W (int): Width of image
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
x: (B, H, W, C)
|
| 58 |
+
"""
|
| 59 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
| 60 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
| 61 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class WindowAttention(nn.Module):
|
| 66 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
| 67 |
+
It supports both of shifted and non-shifted window.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
dim (int): Number of input channels.
|
| 71 |
+
window_size (tuple[int]): The height and width of the window.
|
| 72 |
+
num_heads (int): Number of attention heads.
|
| 73 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 74 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
| 75 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
| 76 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 80 |
+
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.dim = dim
|
| 83 |
+
self.window_size = window_size # Wh, Ww
|
| 84 |
+
self.num_heads = num_heads
|
| 85 |
+
head_dim = dim // num_heads
|
| 86 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 87 |
+
|
| 88 |
+
# define a parameter table of relative position bias
|
| 89 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 90 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
| 91 |
+
|
| 92 |
+
# get pair-wise relative position index for each token inside the window
|
| 93 |
+
coords_h = torch.arange(self.window_size[0])
|
| 94 |
+
coords_w = torch.arange(self.window_size[1])
|
| 95 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 96 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 97 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
| 98 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 99 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
| 100 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 101 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
| 102 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 103 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
| 104 |
+
|
| 105 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 106 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 107 |
+
self.proj = nn.Linear(dim, dim)
|
| 108 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 109 |
+
|
| 110 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
| 111 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 112 |
+
|
| 113 |
+
def forward(self, x, mask=None):
|
| 114 |
+
"""
|
| 115 |
+
Args:
|
| 116 |
+
x: input features with shape of (num_windows*B, N, C)
|
| 117 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
| 118 |
+
"""
|
| 119 |
+
B_, N, C = x.shape
|
| 120 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 121 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 122 |
+
|
| 123 |
+
q = q * self.scale
|
| 124 |
+
attn = (q @ k.transpose(-2, -1))
|
| 125 |
+
|
| 126 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
| 127 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
| 128 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 129 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
| 130 |
+
|
| 131 |
+
if mask is not None:
|
| 132 |
+
nW = mask.shape[0]
|
| 133 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
| 134 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
| 135 |
+
attn = self.softmax(attn)
|
| 136 |
+
else:
|
| 137 |
+
attn = self.softmax(attn)
|
| 138 |
+
|
| 139 |
+
attn = self.attn_drop(attn)
|
| 140 |
+
|
| 141 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
| 142 |
+
x = self.proj(x)
|
| 143 |
+
x = self.proj_drop(x)
|
| 144 |
+
return x
|
| 145 |
+
|
| 146 |
+
def extra_repr(self) -> str:
|
| 147 |
+
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
| 148 |
+
|
| 149 |
+
def flops(self, N):
|
| 150 |
+
# calculate flops for 1 window with token length of N
|
| 151 |
+
flops = 0
|
| 152 |
+
# qkv = self.qkv(x)
|
| 153 |
+
flops += N * self.dim * 3 * self.dim
|
| 154 |
+
# attn = (q @ k.transpose(-2, -1))
|
| 155 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
| 156 |
+
# x = (attn @ v)
|
| 157 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
| 158 |
+
# x = self.proj(x)
|
| 159 |
+
flops += N * self.dim * self.dim
|
| 160 |
+
return flops
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class SwinTransformerBlock(nn.Module):
|
| 164 |
+
r""" Swin Transformer Block.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
dim (int): Number of input channels.
|
| 168 |
+
input_resolution (tuple[int]): Input resulotion.
|
| 169 |
+
num_heads (int): Number of attention heads.
|
| 170 |
+
window_size (int): Window size.
|
| 171 |
+
shift_size (int): Shift size for SW-MSA.
|
| 172 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 173 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 174 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 175 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 176 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 177 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
| 178 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
| 179 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
| 183 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
| 184 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
| 185 |
+
super().__init__()
|
| 186 |
+
self.dim = dim
|
| 187 |
+
self.input_resolution = input_resolution
|
| 188 |
+
self.num_heads = num_heads
|
| 189 |
+
self.window_size = window_size
|
| 190 |
+
self.shift_size = shift_size
|
| 191 |
+
self.mlp_ratio = mlp_ratio
|
| 192 |
+
if min(self.input_resolution) <= self.window_size:
|
| 193 |
+
# if window size is larger than input resolution, we don't partition windows
|
| 194 |
+
self.shift_size = 0
|
| 195 |
+
self.window_size = min(self.input_resolution)
|
| 196 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
| 197 |
+
|
| 198 |
+
self.norm1 = norm_layer(dim)
|
| 199 |
+
self.attn = WindowAttention(
|
| 200 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
| 201 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 202 |
+
|
| 203 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 204 |
+
self.norm2 = norm_layer(dim)
|
| 205 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 206 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 207 |
+
|
| 208 |
+
if self.shift_size > 0:
|
| 209 |
+
# calculate attention mask for SW-MSA
|
| 210 |
+
H, W = self.input_resolution
|
| 211 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
| 212 |
+
h_slices = (slice(0, -self.window_size),
|
| 213 |
+
slice(-self.window_size, -self.shift_size),
|
| 214 |
+
slice(-self.shift_size, None))
|
| 215 |
+
w_slices = (slice(0, -self.window_size),
|
| 216 |
+
slice(-self.window_size, -self.shift_size),
|
| 217 |
+
slice(-self.shift_size, None))
|
| 218 |
+
cnt = 0
|
| 219 |
+
for h in h_slices:
|
| 220 |
+
for w in w_slices:
|
| 221 |
+
img_mask[:, h, w, :] = cnt
|
| 222 |
+
cnt += 1
|
| 223 |
+
|
| 224 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
| 225 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
| 226 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 227 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
| 228 |
+
else:
|
| 229 |
+
attn_mask = None
|
| 230 |
+
|
| 231 |
+
self.register_buffer("attn_mask", attn_mask)
|
| 232 |
+
|
| 233 |
+
def forward(self, x):
|
| 234 |
+
H, W = self.input_resolution
|
| 235 |
+
B, L, C = x.shape
|
| 236 |
+
assert L == H * W, "input feature has wrong size"
|
| 237 |
+
|
| 238 |
+
shortcut = x
|
| 239 |
+
x = self.norm1(x)
|
| 240 |
+
x = x.view(B, H, W, C)
|
| 241 |
+
|
| 242 |
+
# cyclic shift
|
| 243 |
+
if self.shift_size > 0:
|
| 244 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
| 245 |
+
else:
|
| 246 |
+
shifted_x = x
|
| 247 |
+
|
| 248 |
+
# partition windows
|
| 249 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
| 250 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
| 251 |
+
|
| 252 |
+
# W-MSA/SW-MSA
|
| 253 |
+
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
| 254 |
+
|
| 255 |
+
# merge windows
|
| 256 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
| 257 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
| 258 |
+
|
| 259 |
+
# reverse cyclic shift
|
| 260 |
+
if self.shift_size > 0:
|
| 261 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
| 262 |
+
else:
|
| 263 |
+
x = shifted_x
|
| 264 |
+
x = x.view(B, H * W, C)
|
| 265 |
+
|
| 266 |
+
# FFN
|
| 267 |
+
x = shortcut + self.drop_path(x)
|
| 268 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 269 |
+
|
| 270 |
+
return x
|
| 271 |
+
|
| 272 |
+
def extra_repr(self) -> str:
|
| 273 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
| 274 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
| 275 |
+
|
| 276 |
+
def flops(self):
|
| 277 |
+
flops = 0
|
| 278 |
+
H, W = self.input_resolution
|
| 279 |
+
# norm1
|
| 280 |
+
flops += self.dim * H * W
|
| 281 |
+
# W-MSA/SW-MSA
|
| 282 |
+
nW = H * W / self.window_size / self.window_size
|
| 283 |
+
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
| 284 |
+
# mlp
|
| 285 |
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
| 286 |
+
# norm2
|
| 287 |
+
flops += self.dim * H * W
|
| 288 |
+
return flops
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class PatchMerging(nn.Module):
|
| 292 |
+
r""" Patch Merging Layer.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
| 296 |
+
dim (int): Number of input channels.
|
| 297 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 298 |
+
"""
|
| 299 |
+
|
| 300 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
| 301 |
+
super().__init__()
|
| 302 |
+
self.input_resolution = input_resolution
|
| 303 |
+
self.dim = dim
|
| 304 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
| 305 |
+
self.norm = norm_layer(4 * dim)
|
| 306 |
+
|
| 307 |
+
def forward(self, x):
|
| 308 |
+
"""
|
| 309 |
+
x: B, H*W, C
|
| 310 |
+
"""
|
| 311 |
+
H, W = self.input_resolution
|
| 312 |
+
B, L, C = x.shape
|
| 313 |
+
assert L == H * W, "input feature has wrong size"
|
| 314 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
| 315 |
+
|
| 316 |
+
x = x.view(B, H, W, C)
|
| 317 |
+
|
| 318 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
| 319 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
| 320 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
| 321 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
| 322 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
| 323 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
| 324 |
+
|
| 325 |
+
x = self.norm(x)
|
| 326 |
+
x = self.reduction(x)
|
| 327 |
+
|
| 328 |
+
return x
|
| 329 |
+
|
| 330 |
+
def extra_repr(self) -> str:
|
| 331 |
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
| 332 |
+
|
| 333 |
+
def flops(self):
|
| 334 |
+
H, W = self.input_resolution
|
| 335 |
+
flops = H * W * self.dim
|
| 336 |
+
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
| 337 |
+
return flops
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
class BasicLayer(nn.Module):
|
| 341 |
+
""" A basic Swin Transformer layer for one stage.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
dim (int): Number of input channels.
|
| 345 |
+
input_resolution (tuple[int]): Input resolution.
|
| 346 |
+
depth (int): Number of blocks.
|
| 347 |
+
num_heads (int): Number of attention heads.
|
| 348 |
+
window_size (int): Local window size.
|
| 349 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 350 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 351 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 352 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 353 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 354 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
| 355 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 356 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
| 357 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 358 |
+
"""
|
| 359 |
+
|
| 360 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
| 361 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
| 362 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
|
| 363 |
+
|
| 364 |
+
super().__init__()
|
| 365 |
+
self.dim = dim
|
| 366 |
+
self.input_resolution = input_resolution
|
| 367 |
+
self.depth = depth
|
| 368 |
+
self.use_checkpoint = use_checkpoint
|
| 369 |
+
|
| 370 |
+
# build blocks
|
| 371 |
+
self.blocks = nn.ModuleList([
|
| 372 |
+
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
| 373 |
+
num_heads=num_heads, window_size=window_size,
|
| 374 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
| 375 |
+
mlp_ratio=mlp_ratio,
|
| 376 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 377 |
+
drop=drop, attn_drop=attn_drop,
|
| 378 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
| 379 |
+
norm_layer=norm_layer)
|
| 380 |
+
for i in range(depth)])
|
| 381 |
+
|
| 382 |
+
# patch merging layer
|
| 383 |
+
if downsample is not None:
|
| 384 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
| 385 |
+
else:
|
| 386 |
+
self.downsample = None
|
| 387 |
+
|
| 388 |
+
def forward(self, x):
|
| 389 |
+
feats = []
|
| 390 |
+
for blk in self.blocks:
|
| 391 |
+
if self.use_checkpoint:
|
| 392 |
+
x = checkpoint.checkpoint(blk, x)
|
| 393 |
+
else:
|
| 394 |
+
x = blk(x)
|
| 395 |
+
feats.append(x.clone().detach())
|
| 396 |
+
if self.downsample is not None:
|
| 397 |
+
x = self.downsample(x)
|
| 398 |
+
return feats, x
|
| 399 |
+
|
| 400 |
+
def extra_repr(self) -> str:
|
| 401 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
| 402 |
+
|
| 403 |
+
def flops(self):
|
| 404 |
+
flops = 0
|
| 405 |
+
for blk in self.blocks:
|
| 406 |
+
flops += blk.flops()
|
| 407 |
+
if self.downsample is not None:
|
| 408 |
+
flops += self.downsample.flops()
|
| 409 |
+
return flops
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class PatchEmbed(nn.Module):
|
| 413 |
+
r""" Image to Patch Embedding
|
| 414 |
+
|
| 415 |
+
Args:
|
| 416 |
+
img_size (int): Image size. Default: 224.
|
| 417 |
+
patch_size (int): Patch token size. Default: 4.
|
| 418 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 419 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 420 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
| 421 |
+
"""
|
| 422 |
+
|
| 423 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
| 424 |
+
super().__init__()
|
| 425 |
+
img_size = to_2tuple(img_size)
|
| 426 |
+
patch_size = to_2tuple(patch_size)
|
| 427 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
| 428 |
+
self.img_size = img_size
|
| 429 |
+
self.patch_size = patch_size
|
| 430 |
+
self.patches_resolution = patches_resolution
|
| 431 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
| 432 |
+
|
| 433 |
+
self.in_chans = in_chans
|
| 434 |
+
self.embed_dim = embed_dim
|
| 435 |
+
|
| 436 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 437 |
+
if norm_layer is not None:
|
| 438 |
+
self.norm = norm_layer(embed_dim)
|
| 439 |
+
else:
|
| 440 |
+
self.norm = None
|
| 441 |
+
|
| 442 |
+
def forward(self, x):
|
| 443 |
+
B, C, H, W = x.shape
|
| 444 |
+
# FIXME look at relaxing size constraints
|
| 445 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
| 446 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
| 447 |
+
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
|
| 448 |
+
if self.norm is not None:
|
| 449 |
+
x = self.norm(x)
|
| 450 |
+
return x
|
| 451 |
+
|
| 452 |
+
def flops(self):
|
| 453 |
+
Ho, Wo = self.patches_resolution
|
| 454 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
| 455 |
+
if self.norm is not None:
|
| 456 |
+
flops += Ho * Wo * self.embed_dim
|
| 457 |
+
return flops
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
class SwinTransformer(nn.Module):
|
| 461 |
+
r""" Swin Transformer
|
| 462 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
| 463 |
+
https://arxiv.org/pdf/2103.14030
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
img_size (int | tuple(int)): Input image size. Default 224
|
| 467 |
+
patch_size (int | tuple(int)): Patch size. Default: 4
|
| 468 |
+
in_chans (int): Number of input image channels. Default: 3
|
| 469 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
| 470 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
| 471 |
+
depths (tuple(int)): Depth of each Swin Transformer layer.
|
| 472 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
| 473 |
+
window_size (int): Window size. Default: 7
|
| 474 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
| 475 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
| 476 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
| 477 |
+
drop_rate (float): Dropout rate. Default: 0
|
| 478 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
| 479 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
| 480 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
| 481 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
| 482 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
| 483 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
| 484 |
+
"""
|
| 485 |
+
|
| 486 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
|
| 487 |
+
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
|
| 488 |
+
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
| 489 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
| 490 |
+
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
| 491 |
+
use_checkpoint=False, feat_ids=[1, 2, 3, 4], **kwargs):
|
| 492 |
+
super().__init__()
|
| 493 |
+
|
| 494 |
+
self.num_classes = num_classes
|
| 495 |
+
self.num_layers = len(depths)
|
| 496 |
+
self.embed_dim = embed_dim
|
| 497 |
+
self.ape = ape
|
| 498 |
+
self.patch_norm = patch_norm
|
| 499 |
+
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
| 500 |
+
self.mlp_ratio = mlp_ratio
|
| 501 |
+
|
| 502 |
+
# split image into non-overlapping patches
|
| 503 |
+
self.patch_embed = PatchEmbed(
|
| 504 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
| 505 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
| 506 |
+
num_patches = self.patch_embed.num_patches
|
| 507 |
+
patches_resolution = self.patch_embed.patches_resolution
|
| 508 |
+
self.patches_resolution = patches_resolution
|
| 509 |
+
|
| 510 |
+
# absolute position embedding
|
| 511 |
+
if self.ape:
|
| 512 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
| 513 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
| 514 |
+
|
| 515 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 516 |
+
|
| 517 |
+
# stochastic depth
|
| 518 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
| 519 |
+
|
| 520 |
+
# build layers
|
| 521 |
+
self.layers = nn.ModuleList()
|
| 522 |
+
for i_layer in range(self.num_layers):
|
| 523 |
+
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
|
| 524 |
+
input_resolution=(patches_resolution[0] // (2 ** i_layer),
|
| 525 |
+
patches_resolution[1] // (2 ** i_layer)),
|
| 526 |
+
depth=depths[i_layer],
|
| 527 |
+
num_heads=num_heads[i_layer],
|
| 528 |
+
window_size=window_size,
|
| 529 |
+
mlp_ratio=self.mlp_ratio,
|
| 530 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 531 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
| 532 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
| 533 |
+
norm_layer=norm_layer,
|
| 534 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
| 535 |
+
use_checkpoint=use_checkpoint)
|
| 536 |
+
self.layers.append(layer)
|
| 537 |
+
|
| 538 |
+
self.norm = norm_layer(self.num_features)
|
| 539 |
+
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
| 540 |
+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
| 541 |
+
self.feat_ids = feat_ids
|
| 542 |
+
|
| 543 |
+
self.apply(self._init_weights)
|
| 544 |
+
|
| 545 |
+
def _init_weights(self, m):
|
| 546 |
+
if isinstance(m, nn.Linear):
|
| 547 |
+
trunc_normal_(m.weight, std=.02)
|
| 548 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 549 |
+
nn.init.constant_(m.bias, 0)
|
| 550 |
+
elif isinstance(m, nn.LayerNorm):
|
| 551 |
+
nn.init.constant_(m.bias, 0)
|
| 552 |
+
nn.init.constant_(m.weight, 1.0)
|
| 553 |
+
|
| 554 |
+
@torch.jit.ignore
|
| 555 |
+
def no_weight_decay(self):
|
| 556 |
+
return {'absolute_pos_embed'}
|
| 557 |
+
|
| 558 |
+
@torch.jit.ignore
|
| 559 |
+
def no_weight_decay_keywords(self):
|
| 560 |
+
return {'relative_position_bias_table'}
|
| 561 |
+
|
| 562 |
+
def forward_features(self, x):
|
| 563 |
+
x = self.patch_embed(x)
|
| 564 |
+
if self.ape:
|
| 565 |
+
x = x + self.absolute_pos_embed
|
| 566 |
+
x = self.pos_drop(x)
|
| 567 |
+
|
| 568 |
+
self.feat_maps = []
|
| 569 |
+
for i, layer in enumerate(self.layers):
|
| 570 |
+
feats, x = layer(x)
|
| 571 |
+
if i+1 in self.feat_ids:
|
| 572 |
+
self.feat_maps += feats
|
| 573 |
+
|
| 574 |
+
x = self.norm(x) # B L C
|
| 575 |
+
x = self.avgpool(x.transpose(1, 2)) # B C 1
|
| 576 |
+
x = torch.flatten(x, 1)
|
| 577 |
+
return x
|
| 578 |
+
|
| 579 |
+
def forward(self, x):
|
| 580 |
+
x = self.forward_features(x)
|
| 581 |
+
x = self.head(x)
|
| 582 |
+
return x
|
| 583 |
+
|
| 584 |
+
def flops(self):
|
| 585 |
+
flops = 0
|
| 586 |
+
flops += self.patch_embed.flops()
|
| 587 |
+
for i, layer in enumerate(self.layers):
|
| 588 |
+
flops += layer.flops()
|
| 589 |
+
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
|
| 590 |
+
flops += self.num_features * self.num_classes
|
| 591 |
+
return flops
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
if __name__ == '__main__':
|
| 595 |
+
input = torch.randn(2, 3, 384, 384).cuda()
|
| 596 |
+
|
| 597 |
+
net = SwinTransformer(img_size=384, patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32))
|
| 598 |
+
net.load_state_dict(torch.load("/apdcephfs/share_1290796/shixinyu/checkpoints/swin_base_patch4_window12_384_22kto1k.pth")['model'])
|
| 599 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 600 |
+
net.to(device)
|
| 601 |
+
|
| 602 |
+
out = net.forward_features(input)
|
| 603 |
+
feat = net.feat_maps
|
| 604 |
+
for x in feat:
|
| 605 |
+
print(x.shape)
|
model/base/transformer.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import math, copy
|
| 6 |
+
from torch.autograd import Variable
|
| 7 |
+
|
| 8 |
+
TRAIN = False
|
| 9 |
+
|
| 10 |
+
class MultiHeadedAttention(nn.Module):
|
| 11 |
+
def __init__(self, h, d_model, dropout=0.1):
|
| 12 |
+
"Take in model size and number of heads."
|
| 13 |
+
super(MultiHeadedAttention, self).__init__()
|
| 14 |
+
assert d_model % h == 0
|
| 15 |
+
# We assume d_v always equals d_k
|
| 16 |
+
self.d_k = d_model // h
|
| 17 |
+
self.h = h
|
| 18 |
+
self.linears = clones(nn.Linear(d_model, d_model), 2)
|
| 19 |
+
self.attn = None
|
| 20 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 21 |
+
|
| 22 |
+
def forward(self, query, key, value, mask=None):
|
| 23 |
+
if mask is not None:
|
| 24 |
+
# Same mask applied to all h heads.
|
| 25 |
+
mask = mask.unsqueeze(1)
|
| 26 |
+
nbatches = query.size(0)
|
| 27 |
+
|
| 28 |
+
# 1) Do all the linear projections in batch from d_model => h x d_k
|
| 29 |
+
query, key = \
|
| 30 |
+
[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
| 31 |
+
for l, x in zip(self.linears, (query, key))]
|
| 32 |
+
|
| 33 |
+
value = value.repeat(self.h, 1, 1).transpose(0, 1).contiguous().unsqueeze(-1)
|
| 34 |
+
|
| 35 |
+
# query_dir, key_dir = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) for l, x in zip([self.linears[0], self.linears[0]], (query, key))]
|
| 36 |
+
# query_norm = self.linears[1](query)[:, :, :self.h].view(nbatches, -1, self.h).transpose(1, 2)
|
| 37 |
+
# key_norm = self.linears[1](key)[:, :, :self.h].view(nbatches, -1, self.h).transpose(1, 2)
|
| 38 |
+
# query = query_dir / query_dir.norm(dim=-1).unsqueeze(-1) * 10 * query_norm.unsqueeze(-1)
|
| 39 |
+
# key = key_dir / key_dir.norm(dim=-1).unsqueeze(-1) * 10 * key_norm.unsqueeze(-1)
|
| 40 |
+
|
| 41 |
+
if not TRAIN:
|
| 42 |
+
query = query.detach().cpu()
|
| 43 |
+
key = key.detach().cpu()
|
| 44 |
+
value = value.detach().cpu()
|
| 45 |
+
# 2) Apply attention on all the projected vectors in batch.
|
| 46 |
+
x, self.attn = attention(query, key, value, mask=mask,
|
| 47 |
+
dropout=self.dropout)
|
| 48 |
+
if not TRAIN:
|
| 49 |
+
x = x.cuda()
|
| 50 |
+
|
| 51 |
+
# 3) "Concat" using a view and apply a final linear.
|
| 52 |
+
return torch.mean(x, -3)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class PositionalEncoding(nn.Module):
|
| 56 |
+
"Implement the PE function."
|
| 57 |
+
|
| 58 |
+
def __init__(self, d_model, dropout, max_len=10000):
|
| 59 |
+
super(PositionalEncoding, self).__init__()
|
| 60 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 61 |
+
|
| 62 |
+
# Compute the positional encodings once in log space.
|
| 63 |
+
pe = torch.zeros(max_len, d_model)
|
| 64 |
+
position = torch.arange(0, max_len).unsqueeze(1)
|
| 65 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) *
|
| 66 |
+
-(math.log(10000.0) / d_model))
|
| 67 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 68 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 69 |
+
pe = pe.unsqueeze(0)
|
| 70 |
+
self.register_buffer('pe', pe)
|
| 71 |
+
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
x = x + Variable(self.pe[:, :x.size(1)],
|
| 74 |
+
requires_grad=False)
|
| 75 |
+
return self.dropout(x)
|
| 76 |
+
|
| 77 |
+
importance = torch.tensor(0.).float().cuda()
|
| 78 |
+
cnt = 0
|
| 79 |
+
|
| 80 |
+
def attention(query, key, value, mask=None, dropout=None):
|
| 81 |
+
"Compute 'Scaled Dot Product Attention'"
|
| 82 |
+
d_k = query.size(-1)
|
| 83 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) \
|
| 84 |
+
/ math.sqrt(d_k)
|
| 85 |
+
if mask is not None:
|
| 86 |
+
scores = scores.masked_fill(mask == 0, -1e9)
|
| 87 |
+
p_attn = F.softmax(scores, dim=-1)
|
| 88 |
+
# global importance, cnt
|
| 89 |
+
# im = p_attn[:, :, :, :query.size(2)].max(2)[0].mean()
|
| 90 |
+
# importance += im
|
| 91 |
+
# cnt += 1
|
| 92 |
+
if dropout is not None:
|
| 93 |
+
p_attn = dropout(p_attn)
|
| 94 |
+
return torch.matmul(p_attn, value), p_attn
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def clones(module, N):
|
| 98 |
+
"Produce N identical layers."
|
| 99 |
+
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
modelsize_estimate.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def modelsize(model, input, type_size=4):
|
| 7 |
+
para = sum([np.prod(list(p.size())) for p in model.parameters()])
|
| 8 |
+
# print('Model {} : Number of params: {}'.format(model._get_name(), para))
|
| 9 |
+
print('Model {} : params: {:4f}M'.format(model._get_name(), para * type_size / 1000 / 1000))
|
| 10 |
+
|
| 11 |
+
input_ = input.clone()
|
| 12 |
+
input_.requires_grad_(requires_grad=False)
|
| 13 |
+
|
| 14 |
+
mods = list(model.modules())
|
| 15 |
+
out_sizes = []
|
| 16 |
+
|
| 17 |
+
for i in range(1, len(mods)):
|
| 18 |
+
m = mods[i]
|
| 19 |
+
if isinstance(m, nn.ReLU):
|
| 20 |
+
if m.inplace:
|
| 21 |
+
continue
|
| 22 |
+
out = m(input_)
|
| 23 |
+
out_sizes.append(np.array(out.size()))
|
| 24 |
+
input_ = out
|
| 25 |
+
|
| 26 |
+
total_nums = 0
|
| 27 |
+
for i in range(len(out_sizes)):
|
| 28 |
+
s = out_sizes[i]
|
| 29 |
+
nums = np.prod(np.array(s))
|
| 30 |
+
total_nums += nums
|
| 31 |
+
|
| 32 |
+
# print('Model {} : Number of intermedite variables without backward: {}'.format(model._get_name(), total_nums))
|
| 33 |
+
# print('Model {} : Number of intermedite variables with backward: {}'.format(model._get_name(), total_nums*2))
|
| 34 |
+
print('Model {} : intermedite variables: {:3f} M (without backward)'
|
| 35 |
+
.format(model._get_name(), total_nums * type_size / 1000 / 1000))
|
| 36 |
+
print('Model {} : intermedite variables: {:3f} M (with backward)'
|
| 37 |
+
.format(model._get_name(), total_nums * type_size*2 / 1000 / 1000))
|
| 38 |
+
|
scripts/importance_analysis.sh
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python ./importance_analysis.py --datapath "/research/d4/gds/wltang21/data" \
|
| 2 |
+
--benchmark coco \
|
| 3 |
+
--fold 0 \
|
| 4 |
+
--bsz 1 \
|
| 5 |
+
--nworker 1 \
|
| 6 |
+
--backbone resnet50 \
|
| 7 |
+
--feature_extractor_path "/research/d4/gds/wltang21/logistic_project/DCAMA/backbones/resnet50_a1h-35c100f8.pth" \
|
| 8 |
+
--logpath "./logs" \
|
| 9 |
+
--load "/research/d4/gds/wltang21/logistic_project/DCAMA/checkpoint/coco-20i/resnet50_fold0.pt" \
|
| 10 |
+
--nshot 10
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# --load "/research/d6/rshr/xjgao/twl/logistic_project/DCAMA/checkpoint/coco-20i/resnet50_fold0.pt" \
|
| 14 |
+
# --visualize
|
| 15 |
+
#checkpoint/coco-20i/resnet50_fold0.pt
|
| 16 |
+
# log/train/fold_0_ft_v0/best_model.pt
|
scripts/test.sh
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python ./test.py --datapath "/research/d4/gds/wltang21/data" \
|
| 2 |
+
--benchmark coco \
|
| 3 |
+
--fold 0 \
|
| 4 |
+
--bsz 1 \
|
| 5 |
+
--nworker 1 \
|
| 6 |
+
--backbone resnet50 \
|
| 7 |
+
--feature_extractor_path "/research/d4/gds/wltang21/logistic_project/DCAMA/backbones/resnet50_a1h-35c100f8.pth" \
|
| 8 |
+
--logpath "./logs" \
|
| 9 |
+
--load "/research/d4/gds/wltang21/logistic_project/DCAMA/checkpoint/coco-20i/resnet50_fold0.pt" \
|
| 10 |
+
--nshot 30
|
| 11 |
+
|
| 12 |
+
# --load "/research/d6/rshr/xjgao/twl/logistic_project/DCAMA/checkpoint/coco-20i/resnet50_fold0.pt" \
|
| 13 |
+
# --visualize
|
| 14 |
+
#checkpoint/coco-20i/resnet50_fold0.pt
|
| 15 |
+
# log/train/fold_0_ft_v0/best_model.pt
|
scripts/train.sh
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python -u -m torch.distributed.launch --nnodes=1 --nproc_per_node=4 --node_rank=0 --master_port=16006 \
|
| 2 |
+
./train.py --datapath "../datasets" \
|
| 3 |
+
--benchmark coco \
|
| 4 |
+
--fold 0 \
|
| 5 |
+
--bsz 12 \
|
| 6 |
+
--nworker 8 \
|
| 7 |
+
--backbone swin \
|
| 8 |
+
--feature_extractor_path "../backbones/swin_base_patch4_window12_384.pth" \
|
| 9 |
+
--logpath "./logs" \
|
| 10 |
+
--lr 1e-3 \
|
| 11 |
+
--nepoch 500
|
scripts/train_1gpu.sh
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python ./train_1gpu.py --datapath "/home/bkdongxianchi/MY_MOT/TWL/data" \
|
| 2 |
+
--benchmark coco \
|
| 3 |
+
--fold 0 \
|
| 4 |
+
--bsz 1 \
|
| 5 |
+
--nworker 0 \
|
| 6 |
+
--backbone resnet50 \
|
| 7 |
+
--feature_extractor_path "/home/bkdongxianchi/MY_MOT/TWL/DCAMA/backbones/resnet50_a1h-35c100f8.pth" \
|
| 8 |
+
--logpath "/home/bkdongxianchi/MY_MOT/TWL/DCAMA/log" \
|
| 9 |
+
--lr 1e-4 \
|
| 10 |
+
--nepoch 50 \
|
| 11 |
+
--load "/home/bkdongxianchi/MY_MOT/TWL/DCAMA/log/resnet50_fold0.pt" \
|
| 12 |
+
--nshot 3
|
scripts/train_1gpu_retriver.sh
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python ./train_1gpu_retriever.py --datapath "/home/bkdongxianchi/MY_MOT/TWL/data" \
|
| 2 |
+
--benchmark coco \
|
| 3 |
+
--fold 1 \
|
| 4 |
+
--bsz 1 \
|
| 5 |
+
--nworker 0 \
|
| 6 |
+
--backbone resnet50 \
|
| 7 |
+
--feature_extractor_path "/home/bkdongxianchi/MY_MOT/TWL/DCAMA/backbones/resnet50_a1h-35c100f8.pth" \
|
| 8 |
+
--logpath "/home/bkdongxianchi/MY_MOT/TWL/DCAMA/log" \
|
| 9 |
+
--lr 1e-4 \
|
| 10 |
+
--nepoch 50 \
|
| 11 |
+
--load "/home/bkdongxianchi/MY_MOT/TWL/DCAMA/log/fold_1_ft_v0/model_45.pt" \
|
| 12 |
+
--nshot 1
|
scripts/train_2gpu.sh
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python -u -m torch.distributed.launch --nproc_per_node=2 --master_port=18024 \
|
| 2 |
+
./train.py --datapath "/home/bkdongxianchi/MY_MOT/TWL/data" \
|
| 3 |
+
--benchmark coco \
|
| 4 |
+
--fold 0 \
|
| 5 |
+
--bsz 1 \
|
| 6 |
+
--nworker 8 \
|
| 7 |
+
--backbone resnet50 \
|
| 8 |
+
--feature_extractor_path "/home/bkdongxianchi/MY_MOT/TWL/DCAMA/backbones/resnet50_a1h-35c100f8.pth" \
|
| 9 |
+
--logpath "/home/bkdongxianchi/MY_MOT/TWL/DCAMA/log" \
|
| 10 |
+
--lr 1e-4 \
|
| 11 |
+
--nepoch 50 \
|
| 12 |
+
--load "/home/bkdongxianchi/MY_MOT/TWL/DCAMA/checkpoint/coco-20i/resnet50_fold0.pt" \
|
| 13 |
+
--nshot 10
|
| 14 |
+
# --load "/research/d6/rshr/xjgao/twl/logistic_project/DCAMA/checkpoint/coco-20i/resnet50_fold0.pt" \
|
scripts/train_2gpu_retriever.sh
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python -u -m torch.distributed.launch --nproc_per_node=2 --master_port=18024 \
|
| 2 |
+
./train_retriever.py --datapath "/home/bkdongxianchi/MY_MOT/TWL/data" \
|
| 3 |
+
--benchmark coco \
|
| 4 |
+
--fold 0 \
|
| 5 |
+
--bsz 1 \
|
| 6 |
+
--nworker 8 \
|
| 7 |
+
--backbone resnet50 \
|
| 8 |
+
--feature_extractor_path "/home/bkdongxianchi/MY_MOT/TWL/DCAMA/backbones/resnet50_a1h-35c100f8.pth" \
|
| 9 |
+
--logpath "/home/bkdongxianchi/MY_MOT/TWL/DCAMA/log" \
|
| 10 |
+
--lr 1e-4 \
|
| 11 |
+
--nepoch 50 \
|
| 12 |
+
--load "/home/bkdongxianchi/MY_MOT/TWL/DCAMA/checkpoint/coco-20i/resnet50_fold0.pt" \
|
| 13 |
+
--nshot 1
|
| 14 |
+
# --load "/research/d6/rshr/xjgao/twl/logistic_project/DCAMA/checkpoint/coco-20i/resnet50_fold0.pt" \
|
scripts/train_4gpu.sh
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python -u -m torch.distributed.launch --nproc_per_node=4 --master_port=18024 \
|
| 2 |
+
./train.py --datapath "~/MY_MOT/TWL/data" \
|
| 3 |
+
--benchmark coco \
|
| 4 |
+
--fold 0 \
|
| 5 |
+
--bsz 1 \
|
| 6 |
+
--nworker 8 \
|
| 7 |
+
--backbone resnet101 \
|
| 8 |
+
--feature_extractor_path "~/MY_MOT/TWL/logistic_project/DCAMA/backbones/swin_base_patch4_window12_384_22kto1k.pth" \
|
| 9 |
+
--logpath "~/MY_MOT/TWL/logistic_project/DCAMA/log" \
|
| 10 |
+
--lr 1e-4 \
|
| 11 |
+
--nepoch 50 \
|
| 12 |
+
--load "~/MY_MOT/TWL/logistic_project/DCAMA/checkpoint/coco-20i/swin_fold2.pt" \
|
| 13 |
+
--nshot 3
|
| 14 |
+
# --load "/research/d6/rshr/xjgao/twl/logistic_project/DCAMA/checkpoint/coco-20i/resnet50_fold0.pt" \
|
test.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r""" Dense Cross-Query-and-Support Attention Weighted Mask Aggregation for Few-Shot Segmentation """
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from model.DCAMA import DCAMA
|
| 6 |
+
from common.logger import Logger, AverageMeter
|
| 7 |
+
from common.vis import Visualizer
|
| 8 |
+
from common.evaluation import Evaluator
|
| 9 |
+
from common.config import parse_opts
|
| 10 |
+
from common import utils
|
| 11 |
+
from data.dataset import FSSDataset
|
| 12 |
+
import cv2
|
| 13 |
+
import numpy as np
|
| 14 |
+
import os
|
| 15 |
+
# from gpu_mem_track import MemTracker
|
| 16 |
+
|
| 17 |
+
# gpu_tracker = MemTracker()
|
| 18 |
+
def test(model, dataloader, nshot):
|
| 19 |
+
r""" Test """
|
| 20 |
+
|
| 21 |
+
# Freeze randomness during testing for reproducibility
|
| 22 |
+
utils.fix_randseed(0)
|
| 23 |
+
average_meter = AverageMeter(dataloader.dataset)
|
| 24 |
+
|
| 25 |
+
for idx, batch in enumerate(dataloader):
|
| 26 |
+
|
| 27 |
+
# 1. forward pass
|
| 28 |
+
nshot = batch['support_imgs'].size(1)
|
| 29 |
+
## TODO:
|
| 30 |
+
|
| 31 |
+
batch = utils.to_cuda(batch)
|
| 32 |
+
# gpu_tracker.track()
|
| 33 |
+
pred_mask, simi, simi_map = model.module.predict_mask_nshot(batch, nshot=nshot)
|
| 34 |
+
# gpu_tracker.track()
|
| 35 |
+
torch.cuda.synchronize()
|
| 36 |
+
assert pred_mask.size() == batch['query_mask'].size()
|
| 37 |
+
|
| 38 |
+
# 2. Evaluate prediction
|
| 39 |
+
area_inter, area_union = Evaluator.classify_prediction(pred_mask.clone(), batch)
|
| 40 |
+
|
| 41 |
+
## TODO:
|
| 42 |
+
iou = area_inter[1] / area_union[1]
|
| 43 |
+
|
| 44 |
+
'''
|
| 45 |
+
cv2.imwrite('debug/query.png', cv2.imread("/home/bkdongxianchi/MY_MOT/TWL/data/COCO2014/{}".format(batch['query_name'][0])))
|
| 46 |
+
cv2.imwrite('debug/query_mask.png', (batch['query_mask'][0] * 255).detach().cpu().numpy().astype(np.uint8))
|
| 47 |
+
cv2.imwrite('debug/support_{:.3}.png'.format(iou.item()), cv2.imread('/home/bkdongxianchi/MY_MOT/TWL/data/COCO2014/{}'.format(batch['support_names'][0][0])))
|
| 48 |
+
cv2.imwrite('debug/support_mask_{:.3}.png'.format(iou.item()), (batch['support_masks'][0][0] * 255).detach().cpu().numpy().astype(np.uint8))
|
| 49 |
+
simi_map = simi_map - simi_map.min()
|
| 50 |
+
simi_map = (simi_map / simi_map.max() * 255).detach().cpu().numpy().astype(np.uint8)
|
| 51 |
+
cv2.imwrite('debug/simi_map_{:.3}.png'.format(iou.item()), simi_map)
|
| 52 |
+
|
| 53 |
+
if os.path.exists('debug/stats.txt'):
|
| 54 |
+
with open('debug/stats.txt', "a") as f:
|
| 55 |
+
f.write("{} {}\n".format(simi.item(), iou.item()))
|
| 56 |
+
else:
|
| 57 |
+
with open('debug/stats.txt', 'w') as f:
|
| 58 |
+
f.write('{} {}\n'.format(simi.item(), iou.item()))
|
| 59 |
+
'''
|
| 60 |
+
|
| 61 |
+
average_meter.update(area_inter, area_union, batch['class_id'], loss=None)
|
| 62 |
+
average_meter.write_process(idx, len(dataloader), epoch=-1, write_batch_idx=1)
|
| 63 |
+
|
| 64 |
+
# Visualize predictions
|
| 65 |
+
if Visualizer.visualize:
|
| 66 |
+
Visualizer.visualize_prediction_batch(batch['support_imgs'], batch['support_masks'],
|
| 67 |
+
batch['query_img'], batch['query_mask'],
|
| 68 |
+
pred_mask, batch['class_id'], idx,
|
| 69 |
+
iou_b=area_inter[1].float() / area_union[1].float())
|
| 70 |
+
|
| 71 |
+
# Write evaluation results
|
| 72 |
+
average_meter.write_result('Test', 0)
|
| 73 |
+
miou, fb_iou = average_meter.compute_iou()
|
| 74 |
+
|
| 75 |
+
return miou, fb_iou
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if __name__ == '__main__':
|
| 79 |
+
|
| 80 |
+
# Arguments parsing
|
| 81 |
+
args = parse_opts()
|
| 82 |
+
|
| 83 |
+
Logger.initialize(args, training=False)
|
| 84 |
+
|
| 85 |
+
# Model initialization
|
| 86 |
+
model = DCAMA(args.backbone, args.feature_extractor_path, args.use_original_imgsize)
|
| 87 |
+
model.eval()
|
| 88 |
+
|
| 89 |
+
# Device setup
|
| 90 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 91 |
+
Logger.info('# available GPUs: %d' % torch.cuda.device_count())
|
| 92 |
+
model = nn.DataParallel(model)
|
| 93 |
+
model.to(device)
|
| 94 |
+
|
| 95 |
+
# Load trained model
|
| 96 |
+
if args.load == '': raise Exception('Pretrained model not specified.')
|
| 97 |
+
params = model.state_dict()
|
| 98 |
+
state_dict = torch.load(args.load)
|
| 99 |
+
|
| 100 |
+
if 'state_dict' in state_dict.keys():
|
| 101 |
+
state_dict = state_dict['state_dict']
|
| 102 |
+
state_dict2 = {}
|
| 103 |
+
for k, v in state_dict.items():
|
| 104 |
+
if 'scorer' not in k:
|
| 105 |
+
state_dict2[k] = v
|
| 106 |
+
state_dict = state_dict2
|
| 107 |
+
|
| 108 |
+
for k1, k2 in zip(list(state_dict.keys()), params.keys()):
|
| 109 |
+
state_dict[k2] = state_dict.pop(k1)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
try:
|
| 113 |
+
model.load_state_dict(state_dict, strict=True)
|
| 114 |
+
except:
|
| 115 |
+
for k in params.keys():
|
| 116 |
+
if k not in state_dict.keys():
|
| 117 |
+
state_dict[k] = params[k]
|
| 118 |
+
model.load_state_dict(state_dict, strict=True)
|
| 119 |
+
|
| 120 |
+
# Helper classes (for testing) initialization
|
| 121 |
+
Evaluator.initialize()
|
| 122 |
+
Visualizer.initialize(args.visualize, args.vispath)
|
| 123 |
+
|
| 124 |
+
# Dataset initialization
|
| 125 |
+
FSSDataset.initialize(img_size=384, datapath=args.datapath, use_original_imgsize=args.use_original_imgsize)
|
| 126 |
+
dataloader_test = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'test', args.nshot)
|
| 127 |
+
|
| 128 |
+
# Test
|
| 129 |
+
with torch.no_grad():
|
| 130 |
+
test_miou, test_fb_iou = test(model, dataloader_test, args.nshot)
|
| 131 |
+
Logger.info('Fold %d mIoU: %5.2f \t FB-IoU: %5.2f' % (args.fold, test_miou.item(), test_fb_iou.item()))
|
| 132 |
+
Logger.info('==================== Finished Testing ====================')
|
train.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r""" training (validation) code """
|
| 2 |
+
import torch.optim as optim
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from model.DCAMA import DCAMA
|
| 7 |
+
from common.logger import Logger, AverageMeter
|
| 8 |
+
from common.evaluation import Evaluator
|
| 9 |
+
from common.config import parse_opts
|
| 10 |
+
from common import utils
|
| 11 |
+
from data.dataset import FSSDataset
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
average_loss = torch.tensor(0.).float().cuda()
|
| 15 |
+
global_idx = 0
|
| 16 |
+
def train(epoch, model, dataloader, optimizer, training):
|
| 17 |
+
r""" Train """
|
| 18 |
+
|
| 19 |
+
# Force randomness during training / freeze randomness during testing
|
| 20 |
+
utils.fix_randseed(None) if training else utils.fix_randseed(0)
|
| 21 |
+
model.module.train_mode() if training else model.module.eval()
|
| 22 |
+
average_meter = AverageMeter(dataloader.dataset)
|
| 23 |
+
|
| 24 |
+
global average_loss, global_idx
|
| 25 |
+
average_loss = average_loss.to("cuda:{}".format(torch.cuda.current_device()))
|
| 26 |
+
stats = [[], []]
|
| 27 |
+
for idx, batch in enumerate(dataloader):
|
| 28 |
+
|
| 29 |
+
# 1. forward pass
|
| 30 |
+
batch = utils.to_cuda(batch)
|
| 31 |
+
logit_mask, score_preds = model(batch['query_img'], batch['support_imgs'], batch['support_masks'], nshot=batch['support_imgs'].size(1))
|
| 32 |
+
pred_mask = logit_mask.argmax(dim=1)
|
| 33 |
+
|
| 34 |
+
# 2. Compute loss & update model parameters
|
| 35 |
+
loss = model.module.compute_objective(logit_mask, batch['query_mask'])
|
| 36 |
+
# loss_obj = loss.detach()
|
| 37 |
+
area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch)
|
| 38 |
+
iou = area_inter[1] / area_union[1]
|
| 39 |
+
loss_obj = iou.detach()
|
| 40 |
+
score_loss = F.l1_loss(score_preds, loss_obj)
|
| 41 |
+
stats[0].append(score_preds.detach().cpu().numpy())
|
| 42 |
+
stats[1].append(loss_obj.detach().cpu().numpy()[0])
|
| 43 |
+
if global_idx == 0:
|
| 44 |
+
average_loss = loss_obj.detach()
|
| 45 |
+
global_idx += 1
|
| 46 |
+
else:
|
| 47 |
+
average_loss = loss_obj.detach() * 0.05 + 0.95 * average_loss
|
| 48 |
+
print(loss_obj.item(), " ", score_preds.item(), " ", score_loss.item())
|
| 49 |
+
loss += score_loss
|
| 50 |
+
|
| 51 |
+
if training:
|
| 52 |
+
optimizer.zero_grad()
|
| 53 |
+
loss.backward()
|
| 54 |
+
optimizer.step()
|
| 55 |
+
|
| 56 |
+
# 3. Evaluate prediction
|
| 57 |
+
area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch)
|
| 58 |
+
average_meter.update(area_inter, area_union, batch['class_id'], loss.detach().clone())
|
| 59 |
+
average_meter.write_process(idx, len(dataloader), epoch, write_batch_idx=50)
|
| 60 |
+
|
| 61 |
+
# Write evaluation results
|
| 62 |
+
average_meter.write_result('Training' if training else 'Validation', epoch)
|
| 63 |
+
avg_loss = utils.mean(average_meter.loss_buf)
|
| 64 |
+
miou, fb_iou = average_meter.compute_iou()
|
| 65 |
+
import matplotlib.pyplot as plt
|
| 66 |
+
idx = 0
|
| 67 |
+
plt.scatter(stats[0], stats[1], c="red", s=2, alpha=0.1)
|
| 68 |
+
plt.savefig('stat.png')
|
| 69 |
+
plt.close()
|
| 70 |
+
|
| 71 |
+
return avg_loss, miou, fb_iou
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if __name__ == '__main__':
|
| 75 |
+
|
| 76 |
+
# Arguments parsing
|
| 77 |
+
args = parse_opts()
|
| 78 |
+
|
| 79 |
+
# ddp backend initialization
|
| 80 |
+
torch.distributed.init_process_group(backend='nccl')
|
| 81 |
+
torch.cuda.set_device(args.local_rank)
|
| 82 |
+
|
| 83 |
+
# Model initialization
|
| 84 |
+
model = DCAMA(args.backbone, args.feature_extractor_path, False)
|
| 85 |
+
device = torch.device("cuda", args.local_rank)
|
| 86 |
+
model.to(device)
|
| 87 |
+
model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank,
|
| 88 |
+
find_unused_parameters=True)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
params = model.state_dict()
|
| 92 |
+
state_dict = torch.load(args.load)
|
| 93 |
+
state_dict2 = {}
|
| 94 |
+
for k in state_dict.keys():
|
| 95 |
+
if "scorer" in k:
|
| 96 |
+
continue
|
| 97 |
+
state_dict2[k] = state_dict[k]
|
| 98 |
+
state_dict = state_dict2
|
| 99 |
+
for k1, k2 in zip(list(state_dict.keys()), params.keys()):
|
| 100 |
+
state_dict[k2] = state_dict.pop(k1)
|
| 101 |
+
|
| 102 |
+
model.load_state_dict(state_dict, strict=False)
|
| 103 |
+
|
| 104 |
+
## TODO:
|
| 105 |
+
for i in range(len(model.module.model.DCAMA_blocks)):
|
| 106 |
+
torch.nn.init.constant_(model.module.model.DCAMA_blocks[i].linears[1].weight, 0.)
|
| 107 |
+
torch.nn.init.constant_(model.module.model.DCAMA_blocks[i].linears[1].bias, 1.)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# Helper classes (for training) initialization
|
| 111 |
+
optimizer = optim.SGD([{"params": model.module.model.parameters(), "lr": args.lr,
|
| 112 |
+
"momentum": 0.9, "weight_decay": args.lr/10, "nesterov": True}])
|
| 113 |
+
Evaluator.initialize()
|
| 114 |
+
if args.local_rank == 0:
|
| 115 |
+
Logger.initialize(args, training=True)
|
| 116 |
+
Logger.info('# available GPUs: %d' % torch.cuda.device_count())
|
| 117 |
+
|
| 118 |
+
# Dataset initialization
|
| 119 |
+
FSSDataset.initialize(img_size=384, datapath=args.datapath, use_original_imgsize=False)
|
| 120 |
+
dataloader_trn = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'trn', args.nshot)
|
| 121 |
+
if args.local_rank == 0:
|
| 122 |
+
dataloader_val = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'val', args.nshot)
|
| 123 |
+
|
| 124 |
+
# Train
|
| 125 |
+
best_val_miou = float('-inf')
|
| 126 |
+
best_val_loss = float('inf')
|
| 127 |
+
for epoch in range(args.nepoch):
|
| 128 |
+
dataloader_trn.sampler.set_epoch(epoch)
|
| 129 |
+
trn_loss, trn_miou, trn_fb_iou = train(epoch, model, dataloader_trn, optimizer, training=True)
|
| 130 |
+
|
| 131 |
+
# evaluation
|
| 132 |
+
if args.local_rank == 0:
|
| 133 |
+
# with torch.no_grad():
|
| 134 |
+
# val_loss, val_miou, val_fb_iou = train(epoch, model, dataloader_val, optimizer, training=False)
|
| 135 |
+
|
| 136 |
+
# Save the best model
|
| 137 |
+
# if val_miou > best_val_miou:
|
| 138 |
+
# best_val_miou = val_miou
|
| 139 |
+
# Logger.save_model_miou(model, epoch, val_miou)
|
| 140 |
+
Logger.save_model_miou(model, epoch , 1.)
|
| 141 |
+
|
| 142 |
+
# Logger.tbd_writer.add_scalars('data/loss', {'trn_loss': trn_loss, 'val_loss': val_loss}, epoch)
|
| 143 |
+
# Logger.tbd_writer.add_scalars('data/miou', {'trn_miou': trn_miou, 'val_miou': val_miou}, epoch)
|
| 144 |
+
# Logger.tbd_writer.add_scalars('data/fb_iou', {'trn_fb_iou': trn_fb_iou, 'val_fb_iou': val_fb_iou}, epoch)
|
| 145 |
+
# Logger.tbd_writer.flush()
|
| 146 |
+
|
| 147 |
+
if args.local_rank == 0:
|
| 148 |
+
Logger.tbd_writer.close()
|
| 149 |
+
Logger.info('==================== Finished Training ====================')
|
train_1gpu.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r""" training (validation) code """
|
| 2 |
+
import torch.optim as optim
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from model.DCAMA import DCAMA
|
| 7 |
+
from common.logger import Logger, AverageMeter
|
| 8 |
+
from common.evaluation import Evaluator
|
| 9 |
+
from common.config import parse_opts
|
| 10 |
+
from common import utils
|
| 11 |
+
from data.dataset import FSSDataset # FSDataset4SAM
|
| 12 |
+
# from transformers import SamProcessor
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from torchvision import transforms
|
| 17 |
+
import pickle
|
| 18 |
+
import pycocotools.coco as COCO
|
| 19 |
+
import cv2
|
| 20 |
+
|
| 21 |
+
def train(epoch, model, dataloader, optimizer, training, shot=1):
|
| 22 |
+
r""" Train """
|
| 23 |
+
|
| 24 |
+
# Force randomness during training / freeze randomness during testing
|
| 25 |
+
utils.fix_randseed(None) if training else utils.fix_randseed(0)
|
| 26 |
+
|
| 27 |
+
if hasattr(model, "module"):
|
| 28 |
+
model.module.train_mode() if training else model.module.eval()
|
| 29 |
+
else:
|
| 30 |
+
model.train_mode() if training else model.module.eval()
|
| 31 |
+
average_meter = AverageMeter(dataloader.dataset)
|
| 32 |
+
average_loss = torch.tensor(0.).float().cuda()
|
| 33 |
+
stats = [[], []]
|
| 34 |
+
criterion_score = nn.BCEWithLogitsLoss()
|
| 35 |
+
for idx, batch in enumerate(dataloader):
|
| 36 |
+
|
| 37 |
+
# batch = process_batch4SAM(batch)
|
| 38 |
+
shot = batch['support_imgs'].size(1)
|
| 39 |
+
# 1. forward pass
|
| 40 |
+
batch = utils.to_cuda(batch)
|
| 41 |
+
logit_mask, score_preds = model(batch['query_img'], batch['support_imgs'], batch['support_masks'], nshot=shot)
|
| 42 |
+
pred_mask = logit_mask.argmax(dim=1)
|
| 43 |
+
# 2. Compute loss & update model parameters
|
| 44 |
+
loss = model.compute_objective(logit_mask, batch['query_mask'])
|
| 45 |
+
# loss_obj = loss.detach()
|
| 46 |
+
|
| 47 |
+
area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch)
|
| 48 |
+
|
| 49 |
+
iou = (area_inter[1] / area_union[1]).float()
|
| 50 |
+
if iou > 0.7 or iou < 0.1:
|
| 51 |
+
'''
|
| 52 |
+
if iou < 0.1:
|
| 53 |
+
img = batch['query_img'][0].permute(1, 2, 0).detach().cpu().numpy()
|
| 54 |
+
img = img - img.min()
|
| 55 |
+
img = img / img.max()
|
| 56 |
+
cv2.imwrite('query_image.png', (img * 255).astype(np.uint8))
|
| 57 |
+
img = batch['support_imgs'][0][0].permute(1, 2, 0).detach().cpu().numpy()
|
| 58 |
+
img = img - img.min()
|
| 59 |
+
img = img / img.max()
|
| 60 |
+
cv2.imwrite('support_image.png', (img * 255).astype(np.uint8))
|
| 61 |
+
cv2.imwrite('query_mask.png', (batch['query_mask'][0] * 255).detach().cpu().numpy().astype(np.uint8))
|
| 62 |
+
cv2.imwrite('pred_mask.png', (pred_mask[0] * 255).detach().cpu().numpy().astype(np.uint8))
|
| 63 |
+
cv2.imwrite('support_mask.png', (batch['support_masks'][0][0] * 255).detach().cpu().numpy().astype(np.uint8))
|
| 64 |
+
'''
|
| 65 |
+
if iou > 0.7:
|
| 66 |
+
iou = torch.tensor(1.).float().cuda()
|
| 67 |
+
else:
|
| 68 |
+
iou = torch.tensor(0.).float().cuda()
|
| 69 |
+
score_loss = criterion_score(score_preds, iou)
|
| 70 |
+
stats[0].append(score_preds.detach().cpu().numpy())
|
| 71 |
+
stats[1].append((area_inter[1] / area_union[1]).detach().cpu().numpy())
|
| 72 |
+
print(score_preds, (area_inter[1] / area_union[1]))
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if training:
|
| 76 |
+
optimizer.zero_grad()
|
| 77 |
+
loss.backward()
|
| 78 |
+
optimizer.step()
|
| 79 |
+
# 3. Evaluate prediction
|
| 80 |
+
|
| 81 |
+
# img = batch['support_imgs'][0][0].permute(1, 2, 0)
|
| 82 |
+
# img = img - img.min()
|
| 83 |
+
# img /= img.max()
|
| 84 |
+
# import cv2
|
| 85 |
+
# cv2.imwrite("debug.png", (img * 255).detach().cpu().numpy())
|
| 86 |
+
# cv2.imwrite("debug2.png", (batch['support_masks'][0][0] * 255).detach().cpu().numpy())
|
| 87 |
+
# import ipdb;ipdb.set_trace()
|
| 88 |
+
|
| 89 |
+
area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch)
|
| 90 |
+
average_meter.update(area_inter, area_union, batch['class_id'], loss.detach().clone())
|
| 91 |
+
average_meter.write_process(idx, len(dataloader), epoch, write_batch_idx=50)
|
| 92 |
+
|
| 93 |
+
# Write evaluation results
|
| 94 |
+
average_meter.write_result('Training' if training else 'Validation', epoch)
|
| 95 |
+
avg_loss = utils.mean(average_meter.loss_buf)
|
| 96 |
+
miou, fb_iou = average_meter.compute_iou()
|
| 97 |
+
|
| 98 |
+
import matplotlib.pyplot as plt
|
| 99 |
+
plt.scatter(stats[0], stats[1], c="red", s=2, alpha=0.02)
|
| 100 |
+
plt.savefig("stats.png")
|
| 101 |
+
return avg_loss, miou, fb_iou
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
if __name__ == '__main__':
|
| 105 |
+
|
| 106 |
+
# Arguments parsing
|
| 107 |
+
args = parse_opts()
|
| 108 |
+
|
| 109 |
+
# Model initialization
|
| 110 |
+
model = DCAMA(args.backbone, args.feature_extractor_path, False)
|
| 111 |
+
device = torch.device("cuda", args.local_rank)
|
| 112 |
+
model.to(device)
|
| 113 |
+
|
| 114 |
+
params = model.state_dict()
|
| 115 |
+
state_dict = torch.load(args.load)
|
| 116 |
+
if 'state_dict' in state_dict.keys():
|
| 117 |
+
state_dict = state_dict['state_dict']
|
| 118 |
+
state_dict2 = {}
|
| 119 |
+
for k in state_dict.keys():
|
| 120 |
+
if "scorer" in k:
|
| 121 |
+
continue
|
| 122 |
+
state_dict2[k] = state_dict[k]
|
| 123 |
+
state_dict = state_dict2
|
| 124 |
+
for k1, k2 in zip(list(state_dict.keys()), params.keys()):
|
| 125 |
+
state_dict[k2] = state_dict.pop(k1)
|
| 126 |
+
model.load_state_dict(state_dict, strict=False)
|
| 127 |
+
|
| 128 |
+
## TODO:
|
| 129 |
+
for i in range(len(model.model.DCAMA_blocks)):
|
| 130 |
+
torch.nn.init.constant_(model.model.DCAMA_blocks[i].linears[1].weight, 0.)
|
| 131 |
+
torch.nn.init.constant_(model.model.DCAMA_blocks[i].linears[1].bias, 1.)
|
| 132 |
+
# Helper classes (for training) initialization
|
| 133 |
+
optimizer = optim.SGD([{"params": model.parameters(), "lr": args.lr,
|
| 134 |
+
"momentum": 0.9, "weight_decay": args.lr/10, "nesterov": True}])
|
| 135 |
+
Evaluator.initialize()
|
| 136 |
+
if args.local_rank == 0:
|
| 137 |
+
Logger.initialize(args, training=True)
|
| 138 |
+
Logger.info('# available GPUs: %d' % torch.cuda.device_count())
|
| 139 |
+
|
| 140 |
+
# Dataset initialization
|
| 141 |
+
FSSDataset.initialize(img_size=384, datapath=args.datapath, use_original_imgsize=False)
|
| 142 |
+
dataloader_trn = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'trn', shot=args.nshot)
|
| 143 |
+
if args.local_rank == 0:
|
| 144 |
+
dataloader_val = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'val', shot=args.nshot)
|
| 145 |
+
|
| 146 |
+
# Train
|
| 147 |
+
best_val_miou = float('-inf')
|
| 148 |
+
best_val_loss = float('inf')
|
| 149 |
+
|
| 150 |
+
for epoch in range(args.nepoch):
|
| 151 |
+
trn_loss, trn_miou, trn_fb_iou = train(epoch, model, dataloader_trn, optimizer, training=True, shot=args.nshot)
|
| 152 |
+
|
| 153 |
+
# evaluation
|
| 154 |
+
if args.local_rank == 0:
|
| 155 |
+
# with torch.no_grad():
|
| 156 |
+
# val_loss, val_miou, val_fb_iou = train(epoch, model, dataloader_val, optimizer, training=False)
|
| 157 |
+
|
| 158 |
+
# Save the best model
|
| 159 |
+
# if val_miou > best_val_miou:
|
| 160 |
+
# best_val_miou = val_miou
|
| 161 |
+
Logger.save_model_miou(model, epoch, 1.)
|
| 162 |
+
|
| 163 |
+
# Logger.tbd_writer.add_scalars('data/loss', {'trn_loss': trn_loss, 'val_loss': val_loss}, epoch)
|
| 164 |
+
# Logger.tbd_writer.add_scalars('data/miou', {'trn_miou': trn_miou, 'val_miou': val_miou}, epoch)
|
| 165 |
+
# Logger.tbd_writer.add_scalars('data/fb_iou', {'trn_fb_iou': trn_fb_iou, 'val_fb_iou': val_fb_iou}, epoch)
|
| 166 |
+
# Logger.tbd_writer.flush()
|
| 167 |
+
|
| 168 |
+
if args.local_rank == 0:
|
| 169 |
+
Logger.tbd_writer.close()
|
| 170 |
+
Logger.info('==================== Finished Training ====================')
|
train_1gpu_retriever.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r""" training (validation) code """
|
| 2 |
+
import torch.optim as optim
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from model.DCAMA import DCAMA
|
| 7 |
+
from common.logger import Logger, AverageMeter
|
| 8 |
+
from common.evaluation import Evaluator
|
| 9 |
+
from common.config import parse_opts
|
| 10 |
+
from common import utils
|
| 11 |
+
from data.dataset import FSSDataset # FSDataset4SAM
|
| 12 |
+
# from transformers import SamProcessor
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from torchvision import transforms
|
| 17 |
+
import pickle
|
| 18 |
+
import pycocotools.coco as COCO
|
| 19 |
+
import cv2
|
| 20 |
+
import torchvision
|
| 21 |
+
|
| 22 |
+
def train(epoch, model, dataloader, optimizer, training, shot=1):
|
| 23 |
+
r""" Train """
|
| 24 |
+
|
| 25 |
+
# Force randomness during training / freeze randomness during testing
|
| 26 |
+
utils.fix_randseed(None) if training else utils.fix_randseed(0)
|
| 27 |
+
|
| 28 |
+
if hasattr(model, "module"):
|
| 29 |
+
model.module.train_mode() if training else model.module.eval()
|
| 30 |
+
else:
|
| 31 |
+
model.train_mode() if training else model.module.eval()
|
| 32 |
+
average_meter = AverageMeter(dataloader.dataset)
|
| 33 |
+
average_loss = torch.tensor(0.).float().cuda()
|
| 34 |
+
stats = [[], []]
|
| 35 |
+
criterion_score = nn.BCEWithLogitsLoss()
|
| 36 |
+
for idx, batch in enumerate(dataloader):
|
| 37 |
+
|
| 38 |
+
# batch = process_batch4SAM(batch)
|
| 39 |
+
shot = batch['support_imgs'].size(1)
|
| 40 |
+
# 1. forward pass
|
| 41 |
+
batch = utils.to_cuda(batch)
|
| 42 |
+
logit_mask, score_preds = model(batch['query_img'], batch['support_imgs'], batch['support_masks'], nshot=shot, predict_score=True)
|
| 43 |
+
pred_mask = logit_mask.argmax(dim=1)
|
| 44 |
+
# 2. Compute loss & update model parameters
|
| 45 |
+
loss = model.compute_objective(logit_mask, batch['query_mask'])
|
| 46 |
+
# loss_obj = loss.detach()
|
| 47 |
+
|
| 48 |
+
area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch)
|
| 49 |
+
|
| 50 |
+
iou = (area_inter[1] / area_union[1]).float()
|
| 51 |
+
if iou > 0.7 or iou < 0.05:
|
| 52 |
+
'''
|
| 53 |
+
if iou < 0.1:
|
| 54 |
+
img = batch['query_img'][0].permute(1, 2, 0).detach().cpu().numpy()
|
| 55 |
+
img = img - img.min()
|
| 56 |
+
img = img / img.max()
|
| 57 |
+
cv2.imwrite('query_image.png', (img * 255).astype(np.uint8))
|
| 58 |
+
img = batch['support_imgs'][0][0].permute(1, 2, 0).detach().cpu().numpy()
|
| 59 |
+
img = img - img.min()
|
| 60 |
+
img = img / img.max()
|
| 61 |
+
cv2.imwrite('support_image.png', (img * 255).astype(np.uint8))
|
| 62 |
+
cv2.imwrite('query_mask.png', (batch['query_mask'][0] * 255).detach().cpu().numpy().astype(np.uint8))
|
| 63 |
+
cv2.imwrite('pred_mask.png', (pred_mask[0] * 255).detach().cpu().numpy().astype(np.uint8))
|
| 64 |
+
cv2.imwrite('support_mask.png', (batch['support_masks'][0][0] * 255).detach().cpu().numpy().astype(np.uint8))
|
| 65 |
+
'''
|
| 66 |
+
if iou > 0.7:
|
| 67 |
+
iou = torch.tensor(1.).float().cuda()
|
| 68 |
+
else:
|
| 69 |
+
iou = torch.tensor(0.).float().cuda()
|
| 70 |
+
score_loss = torchvision.ops.sigmoid_focal_loss(score_preds, iou)
|
| 71 |
+
# score_loss = F.l1_loss(score_preds, iou)
|
| 72 |
+
stats[0].append(score_preds.detach().cpu().numpy())
|
| 73 |
+
stats[1].append((area_inter[1] / area_union[1]).detach().cpu().numpy())
|
| 74 |
+
print(score_preds, (area_inter[1] / area_union[1]))
|
| 75 |
+
loss = score_loss
|
| 76 |
+
|
| 77 |
+
if training:
|
| 78 |
+
optimizer.zero_grad()
|
| 79 |
+
loss.backward()
|
| 80 |
+
optimizer.step()
|
| 81 |
+
# 3. Evaluate prediction
|
| 82 |
+
|
| 83 |
+
# img = batch['support_imgs'][0][0].permute(1, 2, 0)
|
| 84 |
+
# img = img - img.min()
|
| 85 |
+
# img /= img.max()
|
| 86 |
+
# import cv2
|
| 87 |
+
# cv2.imwrite("debug.png", (img * 255).detach().cpu().numpy())
|
| 88 |
+
# cv2.imwrite("debug2.png", (batch['support_masks'][0][0] * 255).detach().cpu().numpy())
|
| 89 |
+
# import ipdb;ipdb.set_trace()
|
| 90 |
+
|
| 91 |
+
area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch)
|
| 92 |
+
average_meter.update(area_inter, area_union, batch['class_id'], loss.detach().clone())
|
| 93 |
+
average_meter.write_process(idx, len(dataloader), epoch, write_batch_idx=50)
|
| 94 |
+
|
| 95 |
+
# Write evaluation results
|
| 96 |
+
average_meter.write_result('Training' if training else 'Validation', epoch)
|
| 97 |
+
avg_loss = utils.mean(average_meter.loss_buf)
|
| 98 |
+
miou, fb_iou = average_meter.compute_iou()
|
| 99 |
+
|
| 100 |
+
import matplotlib.pyplot as plt
|
| 101 |
+
plt.scatter(stats[0], stats[1], c="red", s=2, alpha=0.02)
|
| 102 |
+
plt.savefig("stats.png")
|
| 103 |
+
return avg_loss, miou, fb_iou
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
if __name__ == '__main__':
|
| 107 |
+
|
| 108 |
+
# Arguments parsing
|
| 109 |
+
args = parse_opts()
|
| 110 |
+
|
| 111 |
+
# Model initialization
|
| 112 |
+
model = DCAMA(args.backbone, args.feature_extractor_path, False)
|
| 113 |
+
device = torch.device("cuda", args.local_rank)
|
| 114 |
+
model.to(device)
|
| 115 |
+
|
| 116 |
+
params = model.state_dict()
|
| 117 |
+
state_dict = torch.load(args.load)
|
| 118 |
+
if 'state_dict' in state_dict.keys():
|
| 119 |
+
state_dict = state_dict['state_dict']
|
| 120 |
+
state_dict2 = {}
|
| 121 |
+
for k in state_dict.keys():
|
| 122 |
+
if "scorer" in k:
|
| 123 |
+
continue
|
| 124 |
+
state_dict2[k] = state_dict[k]
|
| 125 |
+
state_dict = state_dict2
|
| 126 |
+
for k1, k2 in zip(list(state_dict.keys()), params.keys()):
|
| 127 |
+
state_dict[k2] = state_dict.pop(k1)
|
| 128 |
+
model.load_state_dict(state_dict, strict=False)
|
| 129 |
+
|
| 130 |
+
## TODO:
|
| 131 |
+
# for i in range(len(model.model.DCAMA_blocks)):
|
| 132 |
+
# torch.nn.init.constant_(model.model.DCAMA_blocks[i].linears[1].weight, 0.)
|
| 133 |
+
# torch.nn.init.constant_(model.model.DCAMA_blocks[i].linears[1].bias, 1.)
|
| 134 |
+
# Helper classes (for training) initialization
|
| 135 |
+
optimizer = optim.SGD([{"params": model.parameters(), "lr": args.lr,
|
| 136 |
+
"momentum": 0.9, "weight_decay": args.lr/10, "nesterov": True}])
|
| 137 |
+
Evaluator.initialize()
|
| 138 |
+
if args.local_rank == 0:
|
| 139 |
+
Logger.initialize(args, training=True)
|
| 140 |
+
Logger.info('# available GPUs: %d' % torch.cuda.device_count())
|
| 141 |
+
|
| 142 |
+
# Dataset initialization
|
| 143 |
+
FSSDataset.initialize(img_size=384, datapath=args.datapath, use_original_imgsize=False)
|
| 144 |
+
dataloader_trn = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'trn', shot=args.nshot)
|
| 145 |
+
if args.local_rank == 0:
|
| 146 |
+
dataloader_val = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'val', shot=args.nshot)
|
| 147 |
+
|
| 148 |
+
# Train
|
| 149 |
+
best_val_miou = float('-inf')
|
| 150 |
+
best_val_loss = float('inf')
|
| 151 |
+
|
| 152 |
+
for epoch in range(args.nepoch):
|
| 153 |
+
trn_loss, trn_miou, trn_fb_iou = train(epoch, model, dataloader_trn, optimizer, training=True, shot=args.nshot)
|
| 154 |
+
|
| 155 |
+
# evaluation
|
| 156 |
+
if args.local_rank == 0:
|
| 157 |
+
# with torch.no_grad():
|
| 158 |
+
# val_loss, val_miou, val_fb_iou = train(epoch, model, dataloader_val, optimizer, training=False)
|
| 159 |
+
|
| 160 |
+
# Save the best model
|
| 161 |
+
# if val_miou > best_val_miou:
|
| 162 |
+
# best_val_miou = val_miou
|
| 163 |
+
Logger.save_model_miou(model, epoch, 1.)
|
| 164 |
+
|
| 165 |
+
# Logger.tbd_writer.add_scalars('data/loss', {'trn_loss': trn_loss, 'val_loss': val_loss}, epoch)
|
| 166 |
+
# Logger.tbd_writer.add_scalars('data/miou', {'trn_miou': trn_miou, 'val_miou': val_miou}, epoch)
|
| 167 |
+
# Logger.tbd_writer.add_scalars('data/fb_iou', {'trn_fb_iou': trn_fb_iou, 'val_fb_iou': val_fb_iou}, epoch)
|
| 168 |
+
# Logger.tbd_writer.flush()
|
| 169 |
+
|
| 170 |
+
if args.local_rank == 0:
|
| 171 |
+
Logger.tbd_writer.close()
|
| 172 |
+
Logger.info('==================== Finished Training ====================')
|
train_retriever.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r""" training (validation) code """
|
| 2 |
+
import torch.optim as optim
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from model.DCAMA import DCAMA
|
| 7 |
+
from common.logger import Logger, AverageMeter
|
| 8 |
+
from common.evaluation import Evaluator
|
| 9 |
+
from common.config import parse_opts
|
| 10 |
+
from common import utils
|
| 11 |
+
from data.dataset import FSSDataset
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
average_loss = torch.tensor(0.).float().cuda()
|
| 15 |
+
global_idx = 0
|
| 16 |
+
def train(epoch, model, dataloader, optimizer, training):
|
| 17 |
+
r""" Train """
|
| 18 |
+
|
| 19 |
+
# Force randomness during training / freeze randomness during testing
|
| 20 |
+
utils.fix_randseed(None) if training else utils.fix_randseed(0)
|
| 21 |
+
model.module.train_mode() if training else model.module.eval()
|
| 22 |
+
average_meter = AverageMeter(dataloader.dataset)
|
| 23 |
+
|
| 24 |
+
global average_loss, global_idx
|
| 25 |
+
average_loss = average_loss.to("cuda:{}".format(torch.cuda.current_device()))
|
| 26 |
+
stats = [[], []]
|
| 27 |
+
criterion_score = nn.BCEWithLogitsLoss()
|
| 28 |
+
for idx, batch in enumerate(dataloader):
|
| 29 |
+
|
| 30 |
+
# 1. forward pass
|
| 31 |
+
batch = utils.to_cuda(batch)
|
| 32 |
+
logit_mask, score_preds = model(batch['query_img'], batch['support_imgs'], batch['support_masks'], nshot=batch['support_imgs'].size(1))
|
| 33 |
+
pred_mask = logit_mask.argmax(dim=1)
|
| 34 |
+
|
| 35 |
+
# 2. Compute loss & update model parameters
|
| 36 |
+
loss = model.module.compute_objective(logit_mask, batch['query_mask'])
|
| 37 |
+
# loss_obj = loss.detach()
|
| 38 |
+
area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch)
|
| 39 |
+
iou = area_inter[1] / area_union[1]
|
| 40 |
+
|
| 41 |
+
if iou > 0.7 or iou < 0.1:
|
| 42 |
+
'''
|
| 43 |
+
if iou < 0.1:
|
| 44 |
+
img = batch['query_img'][0].permute(1, 2, 0).detach().cpu().numpy()
|
| 45 |
+
img = img - img.min()
|
| 46 |
+
img = img / img.max()
|
| 47 |
+
cv2.imwrite('query_image.png', (img * 255).astype(np.uint8))
|
| 48 |
+
img = batch['support_imgs'][0][0].permute(1, 2, 0).detach().cpu().numpy()
|
| 49 |
+
img = img - img.min()
|
| 50 |
+
img = img / img.max()
|
| 51 |
+
cv2.imwrite('support_image.png', (img * 255).astype(np.uint8))
|
| 52 |
+
cv2.imwrite('query_mask.png', (batch['query_mask'][0] * 255).detach().cpu().numpy().astype(np.uint8))
|
| 53 |
+
cv2.imwrite('pred_mask.png', (pred_mask[0] * 255).detach().cpu().numpy().astype(np.uint8))
|
| 54 |
+
cv2.imwrite('support_mask.png', (batch['support_masks'][0][0] * 255).detach().cpu().numpy().astype(np.uint8))
|
| 55 |
+
'''
|
| 56 |
+
if iou > 0.7:
|
| 57 |
+
iou = torch.tensor(1.).float().cuda()
|
| 58 |
+
else:
|
| 59 |
+
iou = torch.tensor(0.).float().cuda()
|
| 60 |
+
score_loss = criterion_score(score_preds, iou)
|
| 61 |
+
stats[0].append(score_preds.detach().cpu().numpy())
|
| 62 |
+
stats[1].append((area_inter[1] / area_union[1]).detach().cpu().numpy())
|
| 63 |
+
print(score_preds, (area_inter[1] / area_union[1]))
|
| 64 |
+
loss = score_loss
|
| 65 |
+
|
| 66 |
+
if training:
|
| 67 |
+
optimizer.zero_grad()
|
| 68 |
+
loss.backward()
|
| 69 |
+
optimizer.step()
|
| 70 |
+
|
| 71 |
+
# 3. Evaluate prediction
|
| 72 |
+
area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch)
|
| 73 |
+
average_meter.update(area_inter, area_union, batch['class_id'], loss.detach().clone())
|
| 74 |
+
average_meter.write_process(idx, len(dataloader), epoch, write_batch_idx=50)
|
| 75 |
+
|
| 76 |
+
# Write evaluation results
|
| 77 |
+
average_meter.write_result('Training' if training else 'Validation', epoch)
|
| 78 |
+
avg_loss = utils.mean(average_meter.loss_buf)
|
| 79 |
+
miou, fb_iou = average_meter.compute_iou()
|
| 80 |
+
import matplotlib.pyplot as plt
|
| 81 |
+
idx = 0
|
| 82 |
+
plt.scatter(stats[0], stats[1], c="red", s=2, alpha=0.1)
|
| 83 |
+
plt.savefig('stat.png')
|
| 84 |
+
plt.close()
|
| 85 |
+
|
| 86 |
+
return avg_loss, miou, fb_iou
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
if __name__ == '__main__':
|
| 90 |
+
|
| 91 |
+
# Arguments parsing
|
| 92 |
+
args = parse_opts()
|
| 93 |
+
|
| 94 |
+
# ddp backend initialization
|
| 95 |
+
torch.distributed.init_process_group(backend='nccl')
|
| 96 |
+
torch.cuda.set_device(args.local_rank)
|
| 97 |
+
|
| 98 |
+
# Model initialization
|
| 99 |
+
model = DCAMA(args.backbone, args.feature_extractor_path, False)
|
| 100 |
+
device = torch.device("cuda", args.local_rank)
|
| 101 |
+
model.to(device)
|
| 102 |
+
model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank,
|
| 103 |
+
find_unused_parameters=True)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
params = model.state_dict()
|
| 107 |
+
state_dict = torch.load(args.load)
|
| 108 |
+
state_dict2 = {}
|
| 109 |
+
for k in state_dict.keys():
|
| 110 |
+
if "scorer" in k:
|
| 111 |
+
continue
|
| 112 |
+
state_dict2[k] = state_dict[k]
|
| 113 |
+
state_dict = state_dict2
|
| 114 |
+
for k1, k2 in zip(list(state_dict.keys()), params.keys()):
|
| 115 |
+
state_dict[k2] = state_dict.pop(k1)
|
| 116 |
+
|
| 117 |
+
model.load_state_dict(state_dict, strict=False)
|
| 118 |
+
|
| 119 |
+
## TODO:
|
| 120 |
+
for i in range(len(model.module.model.DCAMA_blocks)):
|
| 121 |
+
torch.nn.init.constant_(model.module.model.DCAMA_blocks[i].linears[1].weight, 0.)
|
| 122 |
+
torch.nn.init.constant_(model.module.model.DCAMA_blocks[i].linears[1].bias, 1.)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# Helper classes (for training) initialization
|
| 126 |
+
optimizer = optim.SGD([{"params": model.module.model.parameters(), "lr": args.lr,
|
| 127 |
+
"momentum": 0.9, "weight_decay": args.lr/10, "nesterov": True}])
|
| 128 |
+
Evaluator.initialize()
|
| 129 |
+
if args.local_rank == 0:
|
| 130 |
+
Logger.initialize(args, training=True)
|
| 131 |
+
Logger.info('# available GPUs: %d' % torch.cuda.device_count())
|
| 132 |
+
|
| 133 |
+
# Dataset initialization
|
| 134 |
+
FSSDataset.initialize(img_size=384, datapath=args.datapath, use_original_imgsize=False)
|
| 135 |
+
dataloader_trn = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'trn', args.nshot)
|
| 136 |
+
if args.local_rank == 0:
|
| 137 |
+
dataloader_val = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'val', args.nshot)
|
| 138 |
+
|
| 139 |
+
# Train
|
| 140 |
+
best_val_miou = float('-inf')
|
| 141 |
+
best_val_loss = float('inf')
|
| 142 |
+
for epoch in range(args.nepoch):
|
| 143 |
+
dataloader_trn.sampler.set_epoch(epoch)
|
| 144 |
+
trn_loss, trn_miou, trn_fb_iou = train(epoch, model, dataloader_trn, optimizer, training=True)
|
| 145 |
+
|
| 146 |
+
# evaluation
|
| 147 |
+
if args.local_rank == 0:
|
| 148 |
+
# with torch.no_grad():
|
| 149 |
+
# val_loss, val_miou, val_fb_iou = train(epoch, model, dataloader_val, optimizer, training=False)
|
| 150 |
+
|
| 151 |
+
# Save the best model
|
| 152 |
+
# if val_miou > best_val_miou:
|
| 153 |
+
# best_val_miou = val_miou
|
| 154 |
+
# Logger.save_model_miou(model, epoch, val_miou)
|
| 155 |
+
Logger.save_model_miou(model, epoch , 1.)
|
| 156 |
+
|
| 157 |
+
# Logger.tbd_writer.add_scalars('data/loss', {'trn_loss': trn_loss, 'val_loss': val_loss}, epoch)
|
| 158 |
+
# Logger.tbd_writer.add_scalars('data/miou', {'trn_miou': trn_miou, 'val_miou': val_miou}, epoch)
|
| 159 |
+
# Logger.tbd_writer.add_scalars('data/fb_iou', {'trn_fb_iou': trn_fb_iou, 'val_fb_iou': val_fb_iou}, epoch)
|
| 160 |
+
# Logger.tbd_writer.flush()
|
| 161 |
+
|
| 162 |
+
if args.local_rank == 0:
|
| 163 |
+
Logger.tbd_writer.close()
|
| 164 |
+
Logger.info('==================== Finished Training ====================')
|