Multiple-shots / importance_analysis.py
CUHKWilliam's picture
5
c70812a
raw
history blame
5.08 kB
r""" Dense Cross-Query-and-Support Attention Weighted Mask Aggregation for Few-Shot Segmentation """
import torch.nn as nn
import torch
from model.DCAMA import DCAMA
from common.logger import Logger, AverageMeter
from common.vis import Visualizer
from common.evaluation import Evaluator
from common.config import parse_opts
from common import utils
from data.dataset import FSSDataset
import cv2
import numpy as np
import os
def test(model, dataloader, nshot):
r""" Test """
# Freeze randomness during testing for reproducibility
utils.fix_randseed(0)
average_meter = AverageMeter(dataloader.dataset)
for idx, batch in enumerate(dataloader):
# 1. forward pass
nshot = batch['support_imgs'].size(1)
batch['support_imgs'][0][0] = batch['query_img'][0]
batch['support_masks'][0][0] = batch['query_mask'][0]
batch = utils.to_cuda(batch)
pred_mask, simi, simi_map = model.module.predict_mask_nshot(batch, nshot=nshot)
assert pred_mask.size() == batch['query_mask'].size()
# 2. Evaluate prediction
area_inter, area_union = Evaluator.classify_prediction(pred_mask.clone(), batch)
## TODO:
iou = area_inter[1] / area_union[1]
'''
cv2.imwrite('debug/query.png', cv2.imread("/home/bkdongxianchi/MY_MOT/TWL/data/COCO2014/{}".format(batch['query_name'][0])))
cv2.imwrite('debug/query_mask.png', (batch['query_mask'][0] * 255).detach().cpu().numpy().astype(np.uint8))
cv2.imwrite('debug/support_{:.3}.png'.format(iou.item()), cv2.imread('/home/bkdongxianchi/MY_MOT/TWL/data/COCO2014/{}'.format(batch['support_names'][0][0])))
cv2.imwrite('debug/support_mask_{:.3}.png'.format(iou.item()), (batch['support_masks'][0][0] * 255).detach().cpu().numpy().astype(np.uint8))
simi_map = simi_map - simi_map.min()
simi_map = (simi_map / simi_map.max() * 255).detach().cpu().numpy().astype(np.uint8)
cv2.imwrite('debug/simi_map_{:.3}.png'.format(iou.item()), simi_map)
if os.path.exists('debug/stats.txt'):
with open('debug/stats.txt', "a") as f:
f.write("{} {}\n".format(simi.item(), iou.item()))
else:
with open('debug/stats.txt', 'w') as f:
f.write('{} {}\n'.format(simi.item(), iou.item()))
'''
average_meter.update(area_inter, area_union, batch['class_id'], loss=None)
average_meter.write_process(idx, len(dataloader), epoch=-1, write_batch_idx=1)
# Visualize predictions
if Visualizer.visualize:
Visualizer.visualize_prediction_batch(batch['support_imgs'], batch['support_masks'],
batch['query_img'], batch['query_mask'],
pred_mask, batch['class_id'], idx,
iou_b=area_inter[1].float() / area_union[1].float())
# Write evaluation results
average_meter.write_result('Test', 0)
miou, fb_iou = average_meter.compute_iou()
return miou, fb_iou
if __name__ == '__main__':
# Arguments parsing
args = parse_opts()
Logger.initialize(args, training=False)
# Model initialization
model = DCAMA(args.backbone, args.feature_extractor_path, args.use_original_imgsize)
model.eval()
# Device setup
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Logger.info('# available GPUs: %d' % torch.cuda.device_count())
model = nn.DataParallel(model)
model.to(device)
# Load trained model
if args.load == '': raise Exception('Pretrained model not specified.')
params = model.state_dict()
state_dict = torch.load(args.load)
if 'state_dict' in state_dict.keys():
state_dict = state_dict['state_dict']
state_dict2 = {}
for k, v in state_dict.items():
if 'scorer' not in k:
state_dict2[k] = v
state_dict = state_dict2
for k1, k2 in zip(list(state_dict.keys()), params.keys()):
state_dict[k2] = state_dict.pop(k1)
try:
model.load_state_dict(state_dict, strict=True)
except:
for k in params.keys():
if k not in state_dict.keys():
state_dict[k] = params[k]
model.load_state_dict(state_dict, strict=True)
# Helper classes (for testing) initialization
Evaluator.initialize()
Visualizer.initialize(args.visualize, args.vispath)
# Dataset initialization
FSSDataset.initialize(img_size=384, datapath=args.datapath, use_original_imgsize=args.use_original_imgsize)
dataloader_test = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'test', args.nshot)
# Test
with torch.no_grad():
test_miou, test_fb_iou = test(model, dataloader_test, args.nshot)
Logger.info('Fold %d mIoU: %5.2f \t FB-IoU: %5.2f' % (args.fold, test_miou.item(), test_fb_iou.item()))
Logger.info('==================== Finished Testing ====================')