Spaces:
Sleeping
Sleeping
| r""" Dense Cross-Query-and-Support Attention Weighted Mask Aggregation for Few-Shot Segmentation """ | |
| import torch.nn as nn | |
| import torch | |
| from model.DCAMA import DCAMA | |
| from common.logger import Logger, AverageMeter | |
| from common.vis import Visualizer | |
| from common.evaluation import Evaluator | |
| from common.config import parse_opts | |
| from common import utils | |
| from data.dataset import FSSDataset | |
| import cv2 | |
| import numpy as np | |
| import os | |
| def test(model, dataloader, nshot): | |
| r""" Test """ | |
| # Freeze randomness during testing for reproducibility | |
| utils.fix_randseed(0) | |
| average_meter = AverageMeter(dataloader.dataset) | |
| for idx, batch in enumerate(dataloader): | |
| # 1. forward pass | |
| nshot = batch['support_imgs'].size(1) | |
| batch['support_imgs'][0][0] = batch['query_img'][0] | |
| batch['support_masks'][0][0] = batch['query_mask'][0] | |
| batch = utils.to_cuda(batch) | |
| pred_mask, simi, simi_map = model.module.predict_mask_nshot(batch, nshot=nshot) | |
| assert pred_mask.size() == batch['query_mask'].size() | |
| # 2. Evaluate prediction | |
| area_inter, area_union = Evaluator.classify_prediction(pred_mask.clone(), batch) | |
| ## TODO: | |
| iou = area_inter[1] / area_union[1] | |
| ''' | |
| cv2.imwrite('debug/query.png', cv2.imread("/home/bkdongxianchi/MY_MOT/TWL/data/COCO2014/{}".format(batch['query_name'][0]))) | |
| cv2.imwrite('debug/query_mask.png', (batch['query_mask'][0] * 255).detach().cpu().numpy().astype(np.uint8)) | |
| cv2.imwrite('debug/support_{:.3}.png'.format(iou.item()), cv2.imread('/home/bkdongxianchi/MY_MOT/TWL/data/COCO2014/{}'.format(batch['support_names'][0][0]))) | |
| cv2.imwrite('debug/support_mask_{:.3}.png'.format(iou.item()), (batch['support_masks'][0][0] * 255).detach().cpu().numpy().astype(np.uint8)) | |
| simi_map = simi_map - simi_map.min() | |
| simi_map = (simi_map / simi_map.max() * 255).detach().cpu().numpy().astype(np.uint8) | |
| cv2.imwrite('debug/simi_map_{:.3}.png'.format(iou.item()), simi_map) | |
| if os.path.exists('debug/stats.txt'): | |
| with open('debug/stats.txt', "a") as f: | |
| f.write("{} {}\n".format(simi.item(), iou.item())) | |
| else: | |
| with open('debug/stats.txt', 'w') as f: | |
| f.write('{} {}\n'.format(simi.item(), iou.item())) | |
| ''' | |
| average_meter.update(area_inter, area_union, batch['class_id'], loss=None) | |
| average_meter.write_process(idx, len(dataloader), epoch=-1, write_batch_idx=1) | |
| # Visualize predictions | |
| if Visualizer.visualize: | |
| Visualizer.visualize_prediction_batch(batch['support_imgs'], batch['support_masks'], | |
| batch['query_img'], batch['query_mask'], | |
| pred_mask, batch['class_id'], idx, | |
| iou_b=area_inter[1].float() / area_union[1].float()) | |
| # Write evaluation results | |
| average_meter.write_result('Test', 0) | |
| miou, fb_iou = average_meter.compute_iou() | |
| return miou, fb_iou | |
| if __name__ == '__main__': | |
| # Arguments parsing | |
| args = parse_opts() | |
| Logger.initialize(args, training=False) | |
| # Model initialization | |
| model = DCAMA(args.backbone, args.feature_extractor_path, args.use_original_imgsize) | |
| model.eval() | |
| # Device setup | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| Logger.info('# available GPUs: %d' % torch.cuda.device_count()) | |
| model = nn.DataParallel(model) | |
| model.to(device) | |
| # Load trained model | |
| if args.load == '': raise Exception('Pretrained model not specified.') | |
| params = model.state_dict() | |
| state_dict = torch.load(args.load) | |
| if 'state_dict' in state_dict.keys(): | |
| state_dict = state_dict['state_dict'] | |
| state_dict2 = {} | |
| for k, v in state_dict.items(): | |
| if 'scorer' not in k: | |
| state_dict2[k] = v | |
| state_dict = state_dict2 | |
| for k1, k2 in zip(list(state_dict.keys()), params.keys()): | |
| state_dict[k2] = state_dict.pop(k1) | |
| try: | |
| model.load_state_dict(state_dict, strict=True) | |
| except: | |
| for k in params.keys(): | |
| if k not in state_dict.keys(): | |
| state_dict[k] = params[k] | |
| model.load_state_dict(state_dict, strict=True) | |
| # Helper classes (for testing) initialization | |
| Evaluator.initialize() | |
| Visualizer.initialize(args.visualize, args.vispath) | |
| # Dataset initialization | |
| FSSDataset.initialize(img_size=384, datapath=args.datapath, use_original_imgsize=args.use_original_imgsize) | |
| dataloader_test = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'test', args.nshot) | |
| # Test | |
| with torch.no_grad(): | |
| test_miou, test_fb_iou = test(model, dataloader_test, args.nshot) | |
| Logger.info('Fold %d mIoU: %5.2f \t FB-IoU: %5.2f' % (args.fold, test_miou.item(), test_fb_iou.item())) | |
| Logger.info('==================== Finished Testing ====================') | |