Spaces:
Sleeping
Sleeping
| r""" Logging during training/testing """ | |
| import datetime | |
| import logging | |
| import os | |
| from tensorboardX import SummaryWriter | |
| import torch | |
| class AverageMeter: | |
| r""" Stores loss, evaluation results """ | |
| def __init__(self, dataset): | |
| self.benchmark = dataset.benchmark | |
| self.class_ids_interest = dataset.class_ids | |
| self.class_ids_interest = torch.tensor(self.class_ids_interest).cuda() | |
| if self.benchmark == 'pascal': | |
| self.nclass = 20 | |
| elif self.benchmark == 'coco': | |
| self.nclass = 80 | |
| elif self.benchmark == 'fss': | |
| self.nclass = 1000 | |
| self.intersection_buf = torch.zeros([2, self.nclass]).float().cuda() | |
| self.union_buf = torch.zeros([2, self.nclass]).float().cuda() | |
| self.ones = torch.ones_like(self.union_buf) | |
| self.loss_buf = [] | |
| def update(self, inter_b, union_b, class_id, loss): | |
| self.intersection_buf.index_add_(1, class_id, inter_b.float()) | |
| self.union_buf.index_add_(1, class_id, union_b.float()) | |
| if loss is None: | |
| loss = torch.tensor(0.0) | |
| self.loss_buf.append(loss) | |
| def compute_iou(self): | |
| iou = self.intersection_buf.float() / \ | |
| torch.max(torch.stack([self.union_buf, self.ones]), dim=0)[0] | |
| iou = iou.index_select(1, self.class_ids_interest) | |
| miou = iou[1].mean() * 100 | |
| fb_iou = (self.intersection_buf.index_select(1, self.class_ids_interest).sum(dim=1) / | |
| self.union_buf.index_select(1, self.class_ids_interest).sum(dim=1)).mean() * 100 | |
| return miou, fb_iou | |
| def write_result(self, split, epoch): | |
| iou, fb_iou = self.compute_iou() | |
| loss_buf = torch.stack(self.loss_buf) | |
| msg = '\n*** %s ' % split | |
| msg += '[@Epoch %02d] ' % epoch | |
| msg += 'Avg L: %6.5f ' % loss_buf.mean() | |
| msg += 'mIoU: %5.2f ' % iou | |
| msg += 'FB-IoU: %5.2f ' % fb_iou | |
| msg += '***\n' | |
| Logger.info(msg) | |
| def write_process(self, batch_idx, datalen, epoch, write_batch_idx=20): | |
| if batch_idx % write_batch_idx == 0: | |
| msg = '[Epoch: %02d] ' % epoch if epoch != -1 else '' | |
| msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen) | |
| iou, fb_iou = self.compute_iou() | |
| if epoch != -1: | |
| loss_buf = torch.stack(self.loss_buf) | |
| msg += 'L: %6.5f ' % loss_buf[-1] | |
| msg += 'Avg L: %6.5f ' % loss_buf.mean() | |
| msg += 'mIoU: %5.2f | ' % iou | |
| msg += 'FB-IoU: %5.2f' % fb_iou | |
| Logger.info(msg) | |
| class Logger: | |
| r""" Writes evaluation results of training/testing """ | |
| def initialize(cls, args, training): | |
| logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S') | |
| logpath = os.path.join(args.logpath, 'train/fold_' + str(args.fold) + logtime) if training \ | |
| else os.path.join(args.logpath, 'test/fold_' + args.load.split('/')[-2].split('.')[0] + logtime) | |
| if logpath == '': logpath = logtime | |
| cls.logpath = logpath | |
| cls.benchmark = args.benchmark | |
| if not os.path.exists(cls.logpath): os.makedirs(cls.logpath) | |
| logging.basicConfig(filemode='w', | |
| filename=os.path.join(cls.logpath, 'log.txt'), | |
| level=logging.INFO, | |
| format='%(message)s', | |
| datefmt='%m-%d %H:%M:%S') | |
| # Console log config | |
| console = logging.StreamHandler() | |
| console.setLevel(logging.INFO) | |
| formatter = logging.Formatter('%(message)s') | |
| console.setFormatter(formatter) | |
| logging.getLogger('').addHandler(console) | |
| # Tensorboard writer | |
| cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs')) | |
| # Log arguments | |
| logging.info('\n:==================== Start =====================') | |
| for arg_key in args.__dict__: | |
| logging.info('| %20s: %-24s' % (arg_key, str(args.__dict__[arg_key]))) | |
| logging.info(':================================================\n') | |
| def info(cls, msg): | |
| r""" Writes log message to log.txt """ | |
| logging.info(msg) | |
| def save_model_miou(cls, model, epoch, val_miou): | |
| torch.save(model.state_dict(), os.path.join(cls.logpath, "model_{}.pt".format(epoch))) | |
| cls.info('Model saved @%d w/ val. mIoU: %5.2f.\n' % (epoch, val_miou)) | |