Spaces:
Sleeping
Sleeping
| r""" Dense Cross-Query-and-Support Attention Weighted Mask Aggregation for Few-Shot Segmentation """ | |
| from functools import reduce | |
| from operator import add | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision.models import resnet | |
| from .base.swin_transformer import SwinTransformer | |
| from model.base.transformer import MultiHeadedAttention, PositionalEncoding | |
| import copy | |
| class Flatten(nn.Module): | |
| def forward(self, x): | |
| return x.view(x.size(0), x.size(1), -1).contiguous() | |
| def reshape(x, size): | |
| size1 = torch.tensor(x.size()).float() | |
| # x = torch.logical_not(x.cuda()) | |
| yxs = torch.stack(torch.where(x), dim=-1) | |
| ratio = size[0] / size1[0] | |
| yxs2 = (yxs * ratio).long() | |
| x2 = torch.zeros((size[0], size[1])).float() | |
| return yxs2 | |
| class DCAMA(nn.Module): | |
| def __init__(self, backbone, pretrained_path, use_original_imgsize): | |
| super(DCAMA, self).__init__() | |
| self.backbone = backbone | |
| self.use_original_imgsize = use_original_imgsize | |
| # feature extractor initialization | |
| if backbone == 'resnet50': | |
| self.feature_extractor = resnet.resnet50() | |
| self.feature_extractor.load_state_dict(torch.load(pretrained_path)) | |
| self.feat_channels = [256, 512, 1024, 2048] | |
| self.nlayers = [3, 4, 6, 3] | |
| self.feat_ids = list(range(0, 17)) | |
| self.last_feat_size = [12, 12] | |
| elif backbone == 'resnet101': | |
| self.feature_extractor = resnet.resnet101() | |
| self.feature_extractor.load_state_dict(torch.load(pretrained_path)) | |
| self.feat_channels = [256, 512, 1024, 2048] | |
| self.nlayers = [3, 4, 23, 3] | |
| self.feat_ids = list(range(0, 34)) | |
| elif backbone == 'swin': | |
| self.feature_extractor = SwinTransformer(img_size=384, patch_size=4, window_size=12, embed_dim=128, | |
| depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32]) | |
| self.feature_extractor.load_state_dict(torch.load(pretrained_path)['model']) | |
| self.feat_channels = [128, 256, 512, 1024] | |
| self.nlayers = [2, 2, 18, 2] | |
| else: | |
| raise Exception('Unavailable backbone: %s' % backbone) | |
| self.feature_extractor.eval() | |
| # define model | |
| self.lids = reduce(add, [[i + 1] * x for i, x in enumerate(self.nlayers)]) | |
| self.stack_ids = torch.tensor(self.lids).bincount()[-4:].cumsum(dim=0) | |
| self.model = DCAMA_model(in_channels=self.feat_channels, stack_ids=self.stack_ids) | |
| ## TODO: | |
| self.scorer2 = nn.ModuleList() | |
| for layer_idx in range(len(self.nlayers)): | |
| layer_num = self.nlayers[layer_idx] | |
| for idx in range(layer_num): | |
| self.scorer2.append( | |
| nn.Sequential( | |
| nn.Conv2d(256 * 2 ** layer_idx, 256 * 2 ** layer_idx, 1, 1), | |
| # nn.ReLU(), | |
| # nn.InstanceNorm2d(256 * 2 ** layer_idx), | |
| # nn.Conv2d(256 * 2 ** layer_idx, 256 * 2 ** layer_idx, 1, 1), | |
| ) | |
| ) | |
| self.scorer1 = nn.Sequential( | |
| nn.Linear(sum(self.nlayers) - self.nlayers[0], 1) | |
| ) | |
| self.cross_entropy_loss = nn.CrossEntropyLoss() | |
| def forward(self, query_img, support_img, support_mask, nshot, predict_score=False): | |
| n_support_feats = [] | |
| with torch.no_grad(): | |
| for k in range(nshot): | |
| support_feats_= self.extract_feats(support_img[:, k]) | |
| support_feats = copy.deepcopy(support_feats_) | |
| del support_feats_ | |
| n_support_feats.append(support_feats) | |
| query_feats = self.extract_feats(query_img) | |
| logit_mask = self.model(query_feats, n_support_feats, support_mask.clone(), nshot=nshot) | |
| ## TODO: | |
| MAX_SHOTS = 1 | |
| if len(n_support_feats) >= MAX_SHOTS: | |
| nshot = MAX_SHOTS | |
| n_support_query_f = [] | |
| n_simi = [] | |
| for i in range(len(n_support_feats)): | |
| support_f = n_support_feats[i] | |
| support_query_f = [] | |
| simi_l = [] | |
| simi_sum = [] | |
| for l in range(len(query_feats)): | |
| if l < self.stack_ids[0]: | |
| continue | |
| elif l < self.stack_ids[1]: | |
| DCAMA_blocks = self.model.DCAMA_blocks[0] | |
| pe = self.model.pe[0] | |
| elif l < self.stack_ids[2]: | |
| DCAMA_blocks = self.model.DCAMA_blocks[1] | |
| pe = self.model.pe[1] | |
| else: | |
| DCAMA_blocks = self.model.DCAMA_blocks[2] | |
| pe = self.model.pe[2] | |
| a_support_f = support_f[l].clone() | |
| coords = reshape(support_mask[0, i], a_support_f.size()[-2:]) | |
| b, ch, w, h = a_support_f.size() | |
| a_support_f = a_support_f.view(b, ch, -1) | |
| a_support_f = DCAMA_blocks.linears[0](pe(a_support_f.permute(0, 2, 1))).permute(0, 2, 1) | |
| a_support_f = a_support_f.view(b, ch, w, h) | |
| a_support_f = self.scorer2[l](a_support_f) | |
| a_support_f = a_support_f[:, :, coords[:, 0], coords[:, 1]].mean(-1).unsqueeze(-1).unsqueeze(-1).repeat((1, 1, a_support_f.size(-2), a_support_f.size(-1))) | |
| # a_support_f[:, :, coords_reverse[:, 0], coords_reverse[:, 1]] *= 0. | |
| query_feat = query_feats[l].view(b, ch, -1) | |
| query_feat = DCAMA_blocks.linears[0](pe(query_feat.permute(0, 2, 1))).permute(0, 2, 1) | |
| query_feat = query_feat.view(b, ch, w, h) | |
| query_feat = self.scorer2[l](query_feat) | |
| simi = ((query_feat * a_support_f).sum(1) / torch.norm(query_feat, dim=1) / torch.norm(a_support_f, dim=1))[0] | |
| simi_sum.append(simi) | |
| # simi = torch.norm(query_feats[l] - a_support_f, dim=1)[0] | |
| if l == 6: | |
| simi_map = simi.clone() | |
| simi = simi.view(-1).mean() | |
| simi_l.append(simi) | |
| # simi_l = self.scorer1(torch.stack(simi_l, dim=0).unsqueeze(0)).squeeze(0)[0] | |
| n_simi.append(torch.stack(simi_l, dim=0).mean()) | |
| n_simi = torch.stack(n_simi, dim=0) | |
| args = n_simi.argsort(descending=True)[:MAX_SHOTS] | |
| support_mask = support_mask[:, args, :, :] | |
| # n_support_feats = [n_support_feats[arg] for arg in args] | |
| n_simis = n_simi[args].max() | |
| else: | |
| n_simis = torch.tensor(0.).float() | |
| return logit_mask, n_simis | |
| def extract_feats(self, img): | |
| r""" Extract input image features """ | |
| feats = [] | |
| if self.backbone == 'swin': | |
| _ = self.feature_extractor.forward_features(img) | |
| for feat in self.feature_extractor.feat_maps: | |
| bsz, hw, c = feat.size() | |
| h = int(hw ** 0.5) | |
| feat = feat.view(bsz, h, h, c).permute(0, 3, 1, 2).contiguous() | |
| feats.append(feat) | |
| elif self.backbone == 'resnet50' or self.backbone == 'resnet101': | |
| bottleneck_ids = reduce(add, list(map(lambda x: list(range(x)), self.nlayers))) | |
| # Layer 0 | |
| feat = self.feature_extractor.conv1.forward(img) | |
| feat = self.feature_extractor.bn1.forward(feat) | |
| feat = self.feature_extractor.relu.forward(feat) | |
| feat = self.feature_extractor.maxpool.forward(feat) | |
| # Layer 1-4 | |
| for hid, (bid, lid) in enumerate(zip(bottleneck_ids, self.lids)): | |
| res = feat | |
| feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].conv1.forward(feat) | |
| feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].bn1.forward(feat) | |
| feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].relu.forward(feat) | |
| feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].conv2.forward(feat) | |
| feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].bn2.forward(feat) | |
| feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].relu.forward(feat) | |
| feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].conv3.forward(feat) | |
| feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].bn3.forward(feat) | |
| if bid == 0: | |
| res = self.feature_extractor.__getattr__('layer%d' % lid)[bid].downsample.forward(res) | |
| feat += res | |
| if hid + 1 in self.feat_ids: | |
| feats.append(feat.clone()) | |
| feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].relu.forward(feat) | |
| return feats | |
| def predict_mask_nshot(self, batch, nshot): | |
| r""" n-shot inference """ | |
| query_img = batch['query_img'] | |
| support_imgs = batch['support_imgs'] | |
| support_masks = batch['support_masks'] | |
| if nshot == 1: | |
| with torch.no_grad(): | |
| query_feats = self.extract_feats(query_img) | |
| n_support_feats = [] | |
| for k in range(nshot): | |
| support_feats = self.extract_feats(support_imgs[:, k]) | |
| n_support_feats.append(support_feats) | |
| n_simis = [] | |
| simi_map = None | |
| for i in range(len(n_support_feats)): | |
| support_f = n_support_feats[i] | |
| support_query_f = [] | |
| simi_l = [] | |
| for l in range(len(query_feats)): | |
| if l < self.stack_ids[0]: | |
| continue | |
| elif l < self.stack_ids[1]: | |
| DCAMA_blocks = self.model.DCAMA_blocks[0] | |
| pe = self.model.pe[0] | |
| elif l < self.stack_ids[2]: | |
| DCAMA_blocks = self.model.DCAMA_blocks[1] | |
| pe = self.model.pe[1] | |
| else: | |
| DCAMA_blocks = self.model.DCAMA_blocks[2] | |
| pe = self.model.pe[2] | |
| a_support_f = support_f[l].clone() | |
| coords = reshape(support_masks[0, i], a_support_f.size()[-2:]) | |
| b, ch, w, h = a_support_f.size() | |
| a_support_f = a_support_f.view(b, ch, -1) | |
| a_support_f = DCAMA_blocks.linears[0](pe(a_support_f.permute(0, 2, 1))).permute(0, 2, 1) | |
| a_support_f = a_support_f.view(b, ch, w, h) | |
| a_support_f = a_support_f[:, :, coords[:, 0], coords[:, 1]].mean(-1).unsqueeze(-1).unsqueeze(-1).repeat((1, 1, a_support_f.size(-2), a_support_f.size(-1))) | |
| # a_support_f[:, :, coords_reverse[:, 0], coords_reverse[:, 1]] *= 0. | |
| query_feat = query_feats[l].view(b, ch, -1) | |
| query_feat = DCAMA_blocks.linears[0](pe(query_feat.permute(0, 2, 1))).permute(0, 2, 1) | |
| query_feat = query_feat.view(b, ch, w, h) | |
| simi = ((query_feat * a_support_f).sum(1) / torch.norm(query_feat, dim=1) / torch.norm(a_support_f, dim=1))[0] | |
| # simi = torch.norm(query_feats[l] - a_support_f, dim=1)[0] | |
| if l == 13: | |
| simi_map = simi.clone() | |
| simi = simi.view(-1).max() | |
| simi_l.append(simi) | |
| simi_l = torch.stack(simi_l, dim=0).mean() | |
| n_simis.append(simi_l) | |
| n_simis = torch.stack(n_simis, dim=0) | |
| logit_mask = self.model(query_feats, n_support_feats, support_masks.clone(), nshot) | |
| else: | |
| with torch.no_grad(): | |
| query_feats = self.extract_feats(query_img) | |
| n_support_feats = [] | |
| for k in range(nshot): | |
| support_feats = self.extract_feats(support_imgs[:, k]) | |
| n_support_feats.append(support_feats) | |
| ## TODO: retrieval V1 ## | |
| MAX_SHOTS = 200 | |
| ''' | |
| if len(n_support_feats) > MAX_SHOTS: | |
| nshot = MAX_SHOTS | |
| n_support_query_f = [] | |
| n_simis = [] | |
| for i in range(len(n_support_feats)): | |
| support_f = n_support_feats[i] | |
| support_query_f = [] | |
| simi_l = [] | |
| simi_sum = [] | |
| for l in range(len(query_feats)): | |
| if l < self.stack_ids[0]: | |
| continue | |
| elif l < self.stack_ids[1]: | |
| DCAMA_blocks = self.model.DCAMA_blocks[0] | |
| pe = self.model.pe[0] | |
| elif l < self.stack_ids[2]: | |
| DCAMA_blocks = self.model.DCAMA_blocks[1] | |
| pe = self.model.pe[1] | |
| else: | |
| DCAMA_blocks = self.model.DCAMA_blocks[2] | |
| pe = self.model.pe[2] | |
| a_support_f = support_f[l].clone() | |
| coords = reshape(support_masks[0, i], a_support_f.size()[-2:]) | |
| b, ch, w, h = a_support_f.size() | |
| a_support_f = a_support_f.view(b, ch, -1) | |
| a_support_f = DCAMA_blocks.linears[0](pe(a_support_f.permute(0, 2, 1))).permute(0, 2, 1) | |
| a_support_f = a_support_f.view(b, ch, w, h) | |
| a_support_f = a_support_f[:, :, coords[:, 0], coords[:, 1]].mean(-1).unsqueeze(-1).unsqueeze(-1).repeat((1, 1, a_support_f.size(-2), a_support_f.size(-1))) | |
| # a_support_f[:, :, coords_reverse[:, 0], coords_reverse[:, 1]] *= 0. | |
| query_feat = query_feats[l].view(b, ch, -1) | |
| query_feat = DCAMA_blocks.linears[0](pe(query_feat.permute(0, 2, 1))).permute(0, 2, 1) | |
| query_feat = query_feat.view(b, ch, w, h) | |
| simi = ((query_feat * a_support_f).sum(1) / torch.norm(query_feat, dim=1) / torch.norm(a_support_f, dim=1))[0] | |
| simi_sum.append(simi) | |
| # simi = torch.norm(query_feats[l] - a_support_f, dim=1)[0] | |
| if l == 6: | |
| simi_map = simi.clone().detach().cpu().numpy() | |
| simi = simi.view(-1).max() | |
| simi_l.append(simi) | |
| simi_l = torch.stack(simi_l, dim=0).mean() | |
| n_simis.append(simi_l) | |
| n_simis = torch.stack(n_simis, dim=0) | |
| # nshot = max((n_simis > 0.).sum(), 1) | |
| nshot = len(n_simis) | |
| support_masks = support_masks[:, n_simis.argsort(descending=True)[:nshot], :, :] | |
| n_support_feats = [n_support_feats[i] for i in n_simis.argsort(descending=True)[:nshot]] | |
| else: | |
| n_simis = torch.tensor(0.).float().cuda() | |
| simi_map = None | |
| ## TODO: retriever V2 | |
| ''' | |
| ''' | |
| MAX_SHOTS = 30 | |
| if len(n_support_feats) > MAX_SHOTS: | |
| nshot = MAX_SHOTS | |
| n_support_query_f = [] | |
| n_simis = [] | |
| support_f_list = [] | |
| n_support_feats2 = [] | |
| query_feats2 = [] | |
| for i in range(len(n_support_feats)): | |
| support_f = n_support_feats[i] | |
| n_support_feats2_l = [] | |
| query_feats2_l = [] | |
| for l in range(len(query_feats)): | |
| if l < self.stack_ids[0]: | |
| continue | |
| elif l < self.stack_ids[1]: | |
| DCAMA_blocks = self.model.DCAMA_blocks[0] | |
| pe = self.model.pe[0] | |
| elif l < self.stack_ids[2]: | |
| DCAMA_blocks = self.model.DCAMA_blocks[1] | |
| pe = self.model.pe[1] | |
| else: | |
| DCAMA_blocks = self.model.DCAMA_blocks[2] | |
| pe = self.model.pe[2] | |
| a_support_f = support_f[l].clone() | |
| coords = reshape(support_masks[0, i], a_support_f.size()[-2:]) | |
| b, ch, w, h = a_support_f.size() | |
| a_support_f = a_support_f.view(b, ch, -1) | |
| a_support_f = DCAMA_blocks.linears[0](pe(a_support_f.permute(0, 2, 1))).permute(0, 2, 1) | |
| a_support_f = a_support_f.view(b, ch, w, h) | |
| a_support_f = a_support_f[:, :, coords[:, 0], coords[:, 1]].mean(-1) | |
| n_support_feats2_l.append(a_support_f) | |
| query_feat = query_feats[l].view(b, ch, -1) | |
| query_feat = DCAMA_blocks.linears[0](pe(query_feat.permute(0, 2, 1))).permute(0, 2, 1) | |
| query_feat = query_feat.view(b, ch, w, h) | |
| query_feats2_l.append(query_feat) | |
| n_support_feats2.append(n_support_feats2_l) | |
| query_feats2.append(query_feats2_l) | |
| n_support_feats3 = [[] for _ in range(len(query_feats2[0]))] | |
| selected = [] | |
| for i in range(MAX_SHOTS): | |
| simi_min = -100 | |
| idx_min = -1 | |
| for idx in range(len(n_support_feats2)): | |
| if idx in selected: | |
| continue | |
| support_feats2 = n_support_feats2[idx] | |
| simi = [] | |
| for l in range(len(query_feats2[i])): | |
| support_feats_avg = torch.stack(n_support_feats3[l] + [support_feats2[l]], dim=0).mean(0) | |
| query_feat = query_feats2[i][l] | |
| a_support_f = support_feats_avg.unsqueeze(-1).unsqueeze(-1).repeat( | |
| (1, 1, query_feat.size(-2), query_feat.size(-1))) | |
| simi_l = ((query_feat * a_support_f).sum(1) / torch.norm(query_feat, dim=1) / torch.norm( | |
| a_support_f, dim=1))[0].view(-1).max() | |
| simi.append(simi_l) | |
| simi = torch.stack(simi, dim=0).mean() | |
| if simi > simi_min: | |
| simi_min = simi | |
| idx_min = idx | |
| support_feats2_argmin = n_support_feats2[idx] | |
| for l2 in range(len(query_feats2[0])): | |
| n_support_feats3[l2].append(n_support_feats2[idx_min][l2]) | |
| selected.append(idx_min) | |
| n_support_feats4 = [] | |
| for idx in selected: | |
| n_support_feats4.append(n_support_feats[idx]) | |
| support_masks = support_masks[:, torch.tensor(selected).long().cuda(), :, :] | |
| n_support_feats = n_support_feats4 | |
| simi_map = None | |
| else: | |
| n_simis = torch.tensor(0.).float().cuda() | |
| simi_map = None | |
| ''' | |
| ## TODO: v3 | |
| MAX_SHOTS = 200 | |
| if len(n_support_feats) > MAX_SHOTS: | |
| nshot = MAX_SHOTS | |
| n_support_query_f = [] | |
| n_simis = [] | |
| support_f_list = [] | |
| n_support_feats2 = [] | |
| query_feats2 = [] | |
| for i in range(len(n_support_feats)): | |
| support_f = n_support_feats[i] | |
| n_support_feats2_l = [] | |
| query_feats2_l = [] | |
| for l in range(len(query_feats)): | |
| if l < self.stack_ids[0]: | |
| continue | |
| elif l < self.stack_ids[1]: | |
| DCAMA_blocks = self.model.DCAMA_blocks[0] | |
| pe = self.model.pe[0] | |
| elif l < self.stack_ids[2]: | |
| DCAMA_blocks = self.model.DCAMA_blocks[1] | |
| pe = self.model.pe[1] | |
| else: | |
| DCAMA_blocks = self.model.DCAMA_blocks[2] | |
| pe = self.model.pe[2] | |
| a_support_f = support_f[l].clone() | |
| coords = reshape(support_masks[0, i], a_support_f.size()[-2:]) | |
| b, ch, w, h = a_support_f.size() | |
| a_support_f = a_support_f.view(b, ch, -1) | |
| a_support_f_tmp = DCAMA_blocks.linears[0](pe(a_support_f.permute(0, 2, 1))).permute(0, 2, 1) | |
| a_support_f = a_support_f_tmp / a_support_f_tmp.norm(dim=1, keepdim=True) * DCAMA_blocks.linears[1](pe(a_support_f.permute(0, 2, 1))).permute(0, 2, 1) | |
| a_support_f = a_support_f.view(b, ch, w, h) | |
| a_support_f = a_support_f[:, :, coords[:, 0], coords[:, 1]].mean(-1) | |
| n_support_feats2_l.append(a_support_f) | |
| query_feat = query_feats[l].view(b, ch, -1) | |
| query_feat_tmp = DCAMA_blocks.linears[0](pe(query_feat.permute(0, 2, 1))).permute(0, 2, 1) | |
| query_feat = query_feat_tmp / query_feat_tmp.norm(dim=1, keepdim=True) * DCAMA_blocks.linears[1](pe(query_feat.permute(0, 2, 1))).permute(0, 2, 1) | |
| query_feat = query_feat.view(b, ch, w, h) | |
| query_feats2_l.append(query_feat) | |
| n_support_feats2.append(n_support_feats2_l) | |
| query_feats2.append(query_feats2_l) | |
| n_support_feats3 = [[] for _ in range(len(query_feats2[0]))] | |
| selected = [] | |
| for i in range(MAX_SHOTS): | |
| simi_min = -100 | |
| idx_min = -1 | |
| for idx in range(len(n_support_feats2)): | |
| if idx in selected: | |
| continue | |
| support_feats2 = n_support_feats2[idx] | |
| simi = [] | |
| for l in range(len(query_feats2[i])): | |
| support_feats_avg = torch.stack(n_support_feats3[l] + [support_feats2[l]], dim=0).mean(0) | |
| query_feat = query_feats2[i][l] | |
| a_support_f = support_feats_avg.unsqueeze(-1).unsqueeze(-1).repeat( | |
| (1, 1, query_feat.size(-2), query_feat.size(-1))) | |
| simi_l = ((query_feat * a_support_f).sum(1))[0].view(-1).max() | |
| simi.append(simi_l) | |
| simi = torch.stack(simi, dim=0).mean() | |
| if simi > simi_min: | |
| simi_min = simi | |
| idx_min = idx | |
| support_feats2_argmin = n_support_feats2[idx] | |
| for l2 in range(len(query_feats2[0])): | |
| n_support_feats3[l2].append(n_support_feats2[idx_min][l2]) | |
| selected.append(idx_min) | |
| n_support_feats4 = [] | |
| for idx in selected: | |
| n_support_feats4.append(n_support_feats[idx]) | |
| support_masks = support_masks[:, torch.tensor(selected).long(), :, :] | |
| n_support_feats = n_support_feats4 | |
| simi_map = None | |
| else: | |
| n_simis = torch.tensor(0.).float() | |
| simi_map = None | |
| logit_mask = self.model(query_feats, n_support_feats, support_masks.clone(), nshot) | |
| if self.use_original_imgsize: | |
| org_qry_imsize = tuple([batch['org_query_imsize'][1].item(), batch['org_query_imsize'][0].item()]) | |
| logit_mask = F.interpolate(logit_mask, org_qry_imsize, mode='bilinear', align_corners=True) | |
| else: | |
| logit_mask = F.interpolate(logit_mask, support_imgs[0].size()[2:], mode='bilinear', align_corners=True) | |
| return logit_mask.argmax(dim=1), n_simis, simi_map | |
| def compute_objective(self, logit_mask, gt_mask): | |
| bsz = logit_mask.size(0) | |
| logit_mask = logit_mask.view(bsz, 2, -1) | |
| gt_mask = gt_mask.view(bsz, -1).long() | |
| return self.cross_entropy_loss(logit_mask, gt_mask) | |
| def train_mode(self): | |
| self.train() | |
| self.feature_extractor.eval() | |
| class DCAMA_model(nn.Module): | |
| def __init__(self, in_channels, stack_ids): | |
| super(DCAMA_model, self).__init__() | |
| self.stack_ids = stack_ids | |
| # DCAMA blocks | |
| self.DCAMA_blocks = nn.ModuleList() | |
| self.pe = nn.ModuleList() | |
| for inch in in_channels[1:]: | |
| self.DCAMA_blocks.append(MultiHeadedAttention(h=8, d_model=inch, dropout=0.5)) | |
| self.pe.append(PositionalEncoding(d_model=inch, dropout=0.5)) | |
| outch1, outch2, outch3 = 16, 64, 128 | |
| # conv blocks | |
| self.conv1 = self.build_conv_block(stack_ids[3]-stack_ids[2], [outch1, outch2, outch3], [3, 3, 3], [1, 1, 1]) # 1/32 | |
| self.conv2 = self.build_conv_block(stack_ids[2]-stack_ids[1], [outch1, outch2, outch3], [5, 3, 3], [1, 1, 1]) # 1/16 | |
| self.conv3 = self.build_conv_block(stack_ids[1]-stack_ids[0], [outch1, outch2, outch3], [5, 5, 3], [1, 1, 1]) # 1/8 | |
| self.conv4 = self.build_conv_block(outch3, [outch3, outch3, outch3], [3, 3, 3], [1, 1, 1]) # 1/32 + 1/16 | |
| self.conv5 = self.build_conv_block(outch3, [outch3, outch3, outch3], [3, 3, 3], [1, 1, 1]) # 1/16 + 1/8 | |
| # mixer blocks | |
| self.mixer1 = nn.Sequential(nn.Conv2d(outch3+2*in_channels[1]+2*in_channels[0], outch3, (3, 3), padding=(1, 1), bias=True), | |
| nn.ReLU(), | |
| nn.Conv2d(outch3, outch2, (3, 3), padding=(1, 1), bias=True), | |
| nn.ReLU()) | |
| self.mixer2 = nn.Sequential(nn.Conv2d(outch2, outch2, (3, 3), padding=(1, 1), bias=True), | |
| nn.ReLU(), | |
| nn.Conv2d(outch2, outch1, (3, 3), padding=(1, 1), bias=True), | |
| nn.ReLU()) | |
| self.mixer3 = nn.Sequential(nn.Conv2d(outch1, outch1, (3, 3), padding=(1, 1), bias=True), | |
| nn.ReLU(), | |
| nn.Conv2d(outch1, 2, (3, 3), padding=(1, 1), bias=True)) | |
| def forward(self, query_feats, support_feats, support_mask, nshot=1): | |
| coarse_masks = [] | |
| for idx, query_feat in enumerate(query_feats): | |
| # 1/4 scale feature only used in skip connect | |
| if idx < self.stack_ids[0]: continue | |
| bsz, ch, ha, wa = query_feat.size() | |
| # reshape the input feature and mask | |
| query = query_feat.view(bsz, ch, -1).permute(0, 2, 1).contiguous() | |
| # if nshot == 1: | |
| # support_feat = support_feats[idx] | |
| # mask = F.interpolate(support_mask.unsqueeze(1).float(), support_feat.size()[2:], mode='bilinear', | |
| # align_corners=True).view(support_feat.size()[0], -1) | |
| # support_feat = support_feat.view(support_feat.size()[0], support_feat.size()[1], -1).permute(0, 2, 1).contiguous() | |
| # else: | |
| support_feat = torch.stack([support_feats[k][idx] for k in range(nshot)]) | |
| support_feat = support_feat.view(-1, ch, ha * wa).permute(0, 2, 1).contiguous() | |
| mask = torch.stack([F.interpolate(k.unsqueeze(1).float(), (ha, wa), mode='bilinear', align_corners=True) | |
| for k in support_mask]) | |
| mask = mask.view(bsz, -1) | |
| # DCAMA blocks forward | |
| DCAMA_blocks = None | |
| pe = None | |
| if idx < self.stack_ids[1]: | |
| DCAMA_blocks = self.DCAMA_blocks[0] | |
| pe = self.pe[0] | |
| elif idx < self.stack_ids[2]: | |
| DCAMA_blocks = self.DCAMA_blocks[1] | |
| pe = self.pe[1] | |
| else: | |
| DCAMA_blocks = self.DCAMA_blocks[2] | |
| pe = self.pe[2] | |
| coarse_mask = DCAMA_blocks(pe(query), pe(support_feat), mask) | |
| coarse_masks.append(coarse_mask.permute(0, 2, 1).contiguous().view(bsz, 1, ha, wa)) | |
| # multi-scale conv blocks forward | |
| bsz, ch, ha, wa = coarse_masks[self.stack_ids[3]-1-self.stack_ids[0]].size() | |
| coarse_masks1 = torch.stack(coarse_masks[self.stack_ids[2]-self.stack_ids[0]:self.stack_ids[3]-self.stack_ids[0]]).transpose(0, 1).contiguous().view(bsz, -1, ha, wa) | |
| bsz, ch, ha, wa = coarse_masks[self.stack_ids[2]-1-self.stack_ids[0]].size() | |
| coarse_masks2 = torch.stack(coarse_masks[self.stack_ids[1]-self.stack_ids[0]:self.stack_ids[2]-self.stack_ids[0]]).transpose(0, 1).contiguous().view(bsz, -1, ha, wa) | |
| bsz, ch, ha, wa = coarse_masks[self.stack_ids[1]-1-self.stack_ids[0]].size() | |
| coarse_masks3 = torch.stack(coarse_masks[0:self.stack_ids[1]-self.stack_ids[0]]).transpose(0, 1).contiguous().view(bsz, -1, ha, wa) | |
| coarse_masks1 = self.conv1(coarse_masks1) | |
| coarse_masks2 = self.conv2(coarse_masks2) | |
| coarse_masks3 = self.conv3(coarse_masks3) | |
| # multi-scale cascade (pixel-wise addition) | |
| coarse_masks1 = F.interpolate(coarse_masks1, coarse_masks2.size()[-2:], mode='bilinear', align_corners=True) | |
| mix = coarse_masks1 + coarse_masks2 | |
| mix = self.conv4(mix) | |
| mix = F.interpolate(mix, coarse_masks3.size()[-2:], mode='bilinear', align_corners=True) | |
| mix = mix + coarse_masks3 | |
| mix = self.conv5(mix) | |
| # skip connect 1/8 and 1/4 features (concatenation) | |
| # if nshot == 1: | |
| # support_feat = support_feats[self.stack_ids[1] - 1] | |
| # else: | |
| support_feat = torch.stack([support_feats[k][self.stack_ids[1] - 1] for k in range(nshot)]).max(dim=0).values | |
| mix = torch.cat((mix, query_feats[self.stack_ids[1] - 1], support_feat), 1) | |
| upsample_size = (mix.size(-1) * 2,) * 2 | |
| mix = F.interpolate(mix, upsample_size, mode='bilinear', align_corners=True) | |
| # if nshot == 1: | |
| # support_feat = support_feats[self.stack_ids[0] - 1] | |
| # else: | |
| support_feat = torch.stack([support_feats[k][self.stack_ids[0] - 1] for k in range(nshot)]).max(dim=0).values | |
| mix = torch.cat((mix, query_feats[self.stack_ids[0] - 1], support_feat), 1) | |
| # mixer blocks forward | |
| out = self.mixer1(mix) | |
| upsample_size = (out.size(-1) * 2,) * 2 | |
| out = F.interpolate(out, upsample_size, mode='bilinear', align_corners=True) | |
| out = self.mixer2(out) | |
| upsample_size = (out.size(-1) * 2,) * 2 | |
| out = F.interpolate(out, upsample_size, mode='bilinear', align_corners=True) | |
| logit_mask = self.mixer3(out) | |
| return logit_mask | |
| def build_conv_block(self, in_channel, out_channels, kernel_sizes, spt_strides, group=4): | |
| r""" bulid conv blocks """ | |
| assert len(out_channels) == len(kernel_sizes) == len(spt_strides) | |
| building_block_layers = [] | |
| for idx, (outch, ksz, stride) in enumerate(zip(out_channels, kernel_sizes, spt_strides)): | |
| inch = in_channel if idx == 0 else out_channels[idx - 1] | |
| pad = ksz // 2 | |
| building_block_layers.append(nn.Conv2d(in_channels=inch, out_channels=outch, | |
| kernel_size=ksz, stride=stride, padding=pad)) | |
| building_block_layers.append(nn.GroupNorm(group, outch)) | |
| building_block_layers.append(nn.ReLU(inplace=True)) | |
| return nn.Sequential(*building_block_layers) | |