Spaces:
Sleeping
Sleeping
| r""" training (validation) code """ | |
| import torch.optim as optim | |
| import torch.nn as nn | |
| import torch | |
| from model.DCAMA import DCAMA | |
| from common.logger import Logger, AverageMeter | |
| from common.evaluation import Evaluator | |
| from common.config import parse_opts | |
| from common import utils | |
| from data.dataset import FSSDataset | |
| import torch.nn.functional as F | |
| average_loss = torch.tensor(0.).float().cuda() | |
| global_idx = 0 | |
| def train(epoch, model, dataloader, optimizer, training): | |
| r""" Train """ | |
| # Force randomness during training / freeze randomness during testing | |
| utils.fix_randseed(None) if training else utils.fix_randseed(0) | |
| model.module.train_mode() if training else model.module.eval() | |
| average_meter = AverageMeter(dataloader.dataset) | |
| global average_loss, global_idx | |
| average_loss = average_loss.to("cuda:{}".format(torch.cuda.current_device())) | |
| stats = [[], []] | |
| criterion_score = nn.BCEWithLogitsLoss() | |
| for idx, batch in enumerate(dataloader): | |
| # 1. forward pass | |
| batch = utils.to_cuda(batch) | |
| logit_mask, score_preds = model(batch['query_img'], batch['support_imgs'], batch['support_masks'], nshot=batch['support_imgs'].size(1)) | |
| pred_mask = logit_mask.argmax(dim=1) | |
| # 2. Compute loss & update model parameters | |
| loss = model.module.compute_objective(logit_mask, batch['query_mask']) | |
| # loss_obj = loss.detach() | |
| area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch) | |
| iou = area_inter[1] / area_union[1] | |
| if iou > 0.7 or iou < 0.1: | |
| ''' | |
| if iou < 0.1: | |
| img = batch['query_img'][0].permute(1, 2, 0).detach().cpu().numpy() | |
| img = img - img.min() | |
| img = img / img.max() | |
| cv2.imwrite('query_image.png', (img * 255).astype(np.uint8)) | |
| img = batch['support_imgs'][0][0].permute(1, 2, 0).detach().cpu().numpy() | |
| img = img - img.min() | |
| img = img / img.max() | |
| cv2.imwrite('support_image.png', (img * 255).astype(np.uint8)) | |
| cv2.imwrite('query_mask.png', (batch['query_mask'][0] * 255).detach().cpu().numpy().astype(np.uint8)) | |
| cv2.imwrite('pred_mask.png', (pred_mask[0] * 255).detach().cpu().numpy().astype(np.uint8)) | |
| cv2.imwrite('support_mask.png', (batch['support_masks'][0][0] * 255).detach().cpu().numpy().astype(np.uint8)) | |
| ''' | |
| if iou > 0.7: | |
| iou = torch.tensor(1.).float().cuda() | |
| else: | |
| iou = torch.tensor(0.).float().cuda() | |
| score_loss = criterion_score(score_preds, iou) | |
| stats[0].append(score_preds.detach().cpu().numpy()) | |
| stats[1].append((area_inter[1] / area_union[1]).detach().cpu().numpy()) | |
| print(score_preds, (area_inter[1] / area_union[1])) | |
| loss = score_loss | |
| if training: | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| # 3. Evaluate prediction | |
| area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch) | |
| average_meter.update(area_inter, area_union, batch['class_id'], loss.detach().clone()) | |
| average_meter.write_process(idx, len(dataloader), epoch, write_batch_idx=50) | |
| # Write evaluation results | |
| average_meter.write_result('Training' if training else 'Validation', epoch) | |
| avg_loss = utils.mean(average_meter.loss_buf) | |
| miou, fb_iou = average_meter.compute_iou() | |
| import matplotlib.pyplot as plt | |
| idx = 0 | |
| plt.scatter(stats[0], stats[1], c="red", s=2, alpha=0.1) | |
| plt.savefig('stat.png') | |
| plt.close() | |
| return avg_loss, miou, fb_iou | |
| if __name__ == '__main__': | |
| # Arguments parsing | |
| args = parse_opts() | |
| # ddp backend initialization | |
| torch.distributed.init_process_group(backend='nccl') | |
| torch.cuda.set_device(args.local_rank) | |
| # Model initialization | |
| model = DCAMA(args.backbone, args.feature_extractor_path, False) | |
| device = torch.device("cuda", args.local_rank) | |
| model.to(device) | |
| model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, | |
| find_unused_parameters=True) | |
| params = model.state_dict() | |
| state_dict = torch.load(args.load) | |
| state_dict2 = {} | |
| for k in state_dict.keys(): | |
| if "scorer" in k: | |
| continue | |
| state_dict2[k] = state_dict[k] | |
| state_dict = state_dict2 | |
| for k1, k2 in zip(list(state_dict.keys()), params.keys()): | |
| state_dict[k2] = state_dict.pop(k1) | |
| model.load_state_dict(state_dict, strict=False) | |
| ## TODO: | |
| for i in range(len(model.module.model.DCAMA_blocks)): | |
| torch.nn.init.constant_(model.module.model.DCAMA_blocks[i].linears[1].weight, 0.) | |
| torch.nn.init.constant_(model.module.model.DCAMA_blocks[i].linears[1].bias, 1.) | |
| # Helper classes (for training) initialization | |
| optimizer = optim.SGD([{"params": model.module.model.parameters(), "lr": args.lr, | |
| "momentum": 0.9, "weight_decay": args.lr/10, "nesterov": True}]) | |
| Evaluator.initialize() | |
| if args.local_rank == 0: | |
| Logger.initialize(args, training=True) | |
| Logger.info('# available GPUs: %d' % torch.cuda.device_count()) | |
| # Dataset initialization | |
| FSSDataset.initialize(img_size=384, datapath=args.datapath, use_original_imgsize=False) | |
| dataloader_trn = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'trn', args.nshot) | |
| if args.local_rank == 0: | |
| dataloader_val = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'val', args.nshot) | |
| # Train | |
| best_val_miou = float('-inf') | |
| best_val_loss = float('inf') | |
| for epoch in range(args.nepoch): | |
| dataloader_trn.sampler.set_epoch(epoch) | |
| trn_loss, trn_miou, trn_fb_iou = train(epoch, model, dataloader_trn, optimizer, training=True) | |
| # evaluation | |
| if args.local_rank == 0: | |
| # with torch.no_grad(): | |
| # val_loss, val_miou, val_fb_iou = train(epoch, model, dataloader_val, optimizer, training=False) | |
| # Save the best model | |
| # if val_miou > best_val_miou: | |
| # best_val_miou = val_miou | |
| # Logger.save_model_miou(model, epoch, val_miou) | |
| Logger.save_model_miou(model, epoch , 1.) | |
| # Logger.tbd_writer.add_scalars('data/loss', {'trn_loss': trn_loss, 'val_loss': val_loss}, epoch) | |
| # Logger.tbd_writer.add_scalars('data/miou', {'trn_miou': trn_miou, 'val_miou': val_miou}, epoch) | |
| # Logger.tbd_writer.add_scalars('data/fb_iou', {'trn_fb_iou': trn_fb_iou, 'val_fb_iou': val_fb_iou}, epoch) | |
| # Logger.tbd_writer.flush() | |
| if args.local_rank == 0: | |
| Logger.tbd_writer.close() | |
| Logger.info('==================== Finished Training ====================') | |