Spaces:
Sleeping
Sleeping
| """ | |
| MIT License | |
| Copyright (c) 2019 Ildoo Kim | |
| Permission is hereby granted, free of charge, to any person obtaining a copy | |
| of this software and associated documentation files (the "Software"), to deal | |
| in the Software without restriction, including without limitation the rights | |
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| copies of the Software, and to permit persons to whom the Software is | |
| furnished to do so, subject to the following conditions: | |
| The above copyright notice and this permission notice shall be included in all | |
| copies or substantial portions of the Software. | |
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| SOFTWARE. | |
| """ | |
| from torch.optim.lr_scheduler import _LRScheduler | |
| from torch.optim.lr_scheduler import ReduceLROnPlateau | |
| class GradualWarmupScheduler(_LRScheduler): | |
| """ Gradually warm-up(increasing) learning rate in optimizer. | |
| Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. | |
| Args: | |
| optimizer (Optimizer): Wrapped optimizer. | |
| multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. | |
| total_epoch: target learning rate is reached at total_epoch, gradually | |
| after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) | |
| """ | |
| def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): | |
| self.multiplier = multiplier | |
| if self.multiplier < 1.: | |
| raise ValueError('multiplier should be greater thant or equal to 1.') | |
| self.total_epoch = total_epoch | |
| self.after_scheduler = after_scheduler | |
| self.finished = False | |
| super(GradualWarmupScheduler, self).__init__(optimizer) | |
| def get_lr(self): | |
| if self.last_epoch > self.total_epoch: | |
| if self.after_scheduler: | |
| if not self.finished: | |
| self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] | |
| self.finished = True | |
| return self.after_scheduler.get_last_lr() | |
| return [base_lr * self.multiplier for base_lr in self.base_lrs] | |
| if self.multiplier == 1.0: | |
| return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] | |
| else: | |
| return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] | |
| def step_ReduceLROnPlateau(self, metrics, epoch=None): | |
| if epoch is None: | |
| epoch = self.last_epoch + 1 | |
| self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning | |
| if self.last_epoch <= self.total_epoch: | |
| warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] | |
| for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): | |
| param_group['lr'] = lr | |
| else: | |
| if epoch is None: | |
| self.after_scheduler.step(metrics, None) | |
| else: | |
| self.after_scheduler.step(metrics, epoch - self.total_epoch) | |
| def step(self, epoch=None, metrics=None): | |
| if type(self.after_scheduler) != ReduceLROnPlateau: | |
| if self.finished and self.after_scheduler: | |
| if epoch is None: | |
| self.after_scheduler.step(None) | |
| else: | |
| self.after_scheduler.step(epoch - self.total_epoch) | |
| self._last_lr = self.after_scheduler.get_last_lr() | |
| else: | |
| return super(GradualWarmupScheduler, self).step(epoch) | |
| else: | |
| self.step_ReduceLROnPlateau(metrics, epoch) | |