|
|
|
|
|
|
|
|
import math |
|
|
from enum import Enum |
|
|
from functools import partial |
|
|
|
|
|
from torch.optim import Optimizer |
|
|
from torch.optim.lr_scheduler import LambdaLR |
|
|
|
|
|
|
|
|
class SchedulerType(Enum): |
|
|
"""Enumeration of available learning rate schedulers.""" |
|
|
|
|
|
COSINE = "cosine" |
|
|
COSINE_WITH_WARMUP = "cosine_with_warmup" |
|
|
COSINE_WITH_RESTARTS = "cosine_with_restarts" |
|
|
WARMUP_STABLE_DECAY = "warmup_stable_decay" |
|
|
POLYNOMIAL_WITH_WARMUP = "polynomial_with_warmup" |
|
|
LINEAR_WITH_WARMUP = "linear_with_warmup" |
|
|
CONSTANT_WITH_WARMUP = "constant_with_warmup" |
|
|
INVERSE_SQRT = "inverse_sqrt" |
|
|
|
|
|
|
|
|
def _get_warmup_stable_decay_lr_lambda( |
|
|
current_step: int, |
|
|
*, |
|
|
num_warmup_steps: int, |
|
|
num_stable_steps: int, |
|
|
num_training_steps: int, |
|
|
min_lr_ratio: float = 0.001, |
|
|
decay_type: str = "cosine", |
|
|
): |
|
|
""" |
|
|
Learning rate lambda function for Warmup-Stable-Decay (WSD) schedule. |
|
|
|
|
|
This scheduler implements three phases: |
|
|
1. Warmup: Linear increase from 0 to peak learning rate |
|
|
2. Stable: Constant learning rate for majority of training |
|
|
3. Decay: Gradual decrease using cosine or linear decay |
|
|
|
|
|
Args: |
|
|
current_step: Current training step |
|
|
num_warmup_steps: Number of warmup steps |
|
|
num_stable_steps: Number of stable learning rate steps |
|
|
num_training_steps: Total number of training steps |
|
|
min_lr_ratio: Minimum learning rate as ratio of peak learning rate |
|
|
decay_type: Type of decay schedule ("cosine" or "linear") |
|
|
""" |
|
|
if current_step < num_warmup_steps: |
|
|
|
|
|
return float(current_step) / float(max(1, num_warmup_steps)) |
|
|
|
|
|
elif current_step < num_warmup_steps + num_stable_steps: |
|
|
|
|
|
return 1.0 |
|
|
|
|
|
else: |
|
|
|
|
|
decay_steps = num_training_steps - num_warmup_steps - num_stable_steps |
|
|
if decay_steps <= 0: |
|
|
return max(min_lr_ratio, 1.0) |
|
|
|
|
|
progress = (current_step - num_warmup_steps - num_stable_steps) / decay_steps |
|
|
progress = min(progress, 1.0) |
|
|
|
|
|
if decay_type == "cosine": |
|
|
|
|
|
decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress)) |
|
|
return max(min_lr_ratio, decay_factor) |
|
|
elif decay_type == "linear": |
|
|
|
|
|
decay_factor = 1.0 - progress |
|
|
return max(min_lr_ratio, decay_factor) |
|
|
else: |
|
|
raise ValueError(f"Unknown decay_type: {decay_type}") |
|
|
|
|
|
|
|
|
def get_warmup_stable_decay_schedule( |
|
|
optimizer: Optimizer, |
|
|
num_warmup_steps: int, |
|
|
num_stable_steps: int, |
|
|
num_training_steps: int, |
|
|
min_lr_ratio: float = 0.01, |
|
|
decay_type: str = "cosine", |
|
|
last_epoch: int = -1, |
|
|
): |
|
|
""" |
|
|
Create a Warmup-Stable-Decay learning rate schedule. |
|
|
|
|
|
This scheduler is particularly well-suited for foundation model training as it: |
|
|
- Provides stable learning during the majority of training |
|
|
- Doesn't require pre-committing to exact training duration |
|
|
- Allows for extended training without aggressive decay |
|
|
|
|
|
Args: |
|
|
optimizer: The optimizer for which to schedule the learning rate |
|
|
num_warmup_steps: Number of steps for warmup phase |
|
|
num_stable_steps: Number of steps for stable learning rate phase |
|
|
num_training_steps: Total number of training steps |
|
|
min_lr_ratio: Minimum learning rate as fraction of peak learning rate |
|
|
decay_type: Type of decay ("cosine" or "linear") |
|
|
last_epoch: The index of the last epoch when resuming training |
|
|
|
|
|
Returns: |
|
|
torch.optim.lr_scheduler.LambdaLR with the WSD schedule |
|
|
""" |
|
|
lr_lambda = partial( |
|
|
_get_warmup_stable_decay_lr_lambda, |
|
|
num_warmup_steps=num_warmup_steps, |
|
|
num_stable_steps=num_stable_steps, |
|
|
num_training_steps=num_training_steps, |
|
|
min_lr_ratio=min_lr_ratio, |
|
|
decay_type=decay_type, |
|
|
) |
|
|
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) |
|
|
|
|
|
|
|
|
def _get_cosine_schedule_with_warmup_lr_lambda( |
|
|
current_step: int, |
|
|
*, |
|
|
num_warmup_steps: int, |
|
|
num_training_steps: int, |
|
|
num_cycles: float = 0.5, |
|
|
min_lr_ratio: float = 0.0, |
|
|
): |
|
|
"""Enhanced cosine schedule with configurable minimum learning rate.""" |
|
|
if current_step < num_warmup_steps: |
|
|
return float(current_step) / float(max(1, num_warmup_steps)) |
|
|
|
|
|
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) |
|
|
cosine_factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) |
|
|
return max(min_lr_ratio, cosine_factor) |
|
|
|
|
|
|
|
|
def get_enhanced_cosine_schedule_with_warmup( |
|
|
optimizer: Optimizer, |
|
|
num_warmup_steps: int, |
|
|
num_training_steps: int, |
|
|
num_cycles: float = 0.5, |
|
|
min_lr_ratio: float = 0.01, |
|
|
last_epoch: int = -1, |
|
|
): |
|
|
""" |
|
|
Enhanced cosine schedule with warmup and configurable minimum learning rate. |
|
|
|
|
|
Args: |
|
|
optimizer: The optimizer for which to schedule the learning rate |
|
|
num_warmup_steps: Number of steps for warmup phase |
|
|
num_training_steps: Total number of training steps |
|
|
num_cycles: Number of cosine cycles (0.5 = half cosine) |
|
|
min_lr_ratio: Minimum learning rate as fraction of peak learning rate |
|
|
last_epoch: The index of the last epoch when resuming training |
|
|
""" |
|
|
lr_lambda = partial( |
|
|
_get_cosine_schedule_with_warmup_lr_lambda, |
|
|
num_warmup_steps=num_warmup_steps, |
|
|
num_training_steps=num_training_steps, |
|
|
num_cycles=num_cycles, |
|
|
min_lr_ratio=min_lr_ratio, |
|
|
) |
|
|
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) |
|
|
|
|
|
|
|
|
def _get_cosine_with_restarts_lr_lambda( |
|
|
current_step: int, |
|
|
*, |
|
|
num_warmup_steps: int, |
|
|
num_training_steps: int, |
|
|
num_cycles: int = 1, |
|
|
min_lr_ratio: float = 0.0, |
|
|
): |
|
|
"""Cosine schedule with hard restarts and configurable minimum learning rate.""" |
|
|
if current_step < num_warmup_steps: |
|
|
return float(current_step) / float(max(1, num_warmup_steps)) |
|
|
|
|
|
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) |
|
|
if progress >= 1.0: |
|
|
return min_lr_ratio |
|
|
|
|
|
cosine_factor = 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))) |
|
|
return max(min_lr_ratio, cosine_factor) |
|
|
|
|
|
|
|
|
def get_cosine_with_restarts_schedule( |
|
|
optimizer: Optimizer, |
|
|
num_warmup_steps: int, |
|
|
num_training_steps: int, |
|
|
num_cycles: int = 4, |
|
|
min_lr_ratio: float = 0.01, |
|
|
last_epoch: int = -1, |
|
|
): |
|
|
""" |
|
|
Cosine schedule with hard restarts. |
|
|
|
|
|
Args: |
|
|
optimizer: The optimizer for which to schedule the learning rate |
|
|
num_warmup_steps: Number of steps for warmup phase |
|
|
num_training_steps: Total number of training steps |
|
|
num_cycles: Number of restart cycles |
|
|
min_lr_ratio: Minimum learning rate as fraction of peak learning rate |
|
|
last_epoch: The index of the last epoch when resuming training |
|
|
""" |
|
|
lr_lambda = partial( |
|
|
_get_cosine_with_restarts_lr_lambda, |
|
|
num_warmup_steps=num_warmup_steps, |
|
|
num_training_steps=num_training_steps, |
|
|
num_cycles=num_cycles, |
|
|
min_lr_ratio=min_lr_ratio, |
|
|
) |
|
|
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) |
|
|
|
|
|
|
|
|
|
|
|
SCHEDULER_REGISTRY = { |
|
|
SchedulerType.WARMUP_STABLE_DECAY: get_warmup_stable_decay_schedule, |
|
|
SchedulerType.COSINE_WITH_WARMUP: get_enhanced_cosine_schedule_with_warmup, |
|
|
SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_restarts_schedule, |
|
|
} |
|
|
|
|
|
|
|
|
def get_scheduler( |
|
|
scheduler_type: str | SchedulerType, |
|
|
optimizer: Optimizer, |
|
|
num_warmup_steps: int, |
|
|
num_training_steps: int, |
|
|
scheduler_kwargs: dict | None = None, |
|
|
): |
|
|
""" |
|
|
Unified interface to create learning rate schedulers. |
|
|
|
|
|
Args: |
|
|
scheduler_type: Type of scheduler to create |
|
|
optimizer: The optimizer to schedule |
|
|
num_warmup_steps: Number of warmup steps |
|
|
num_training_steps: Total training steps |
|
|
scheduler_kwargs: Additional scheduler-specific parameters |
|
|
|
|
|
Returns: |
|
|
Configured learning rate scheduler |
|
|
""" |
|
|
if isinstance(scheduler_type, str): |
|
|
scheduler_type = SchedulerType(scheduler_type) |
|
|
|
|
|
if scheduler_kwargs is None: |
|
|
scheduler_kwargs = {} |
|
|
|
|
|
if scheduler_type not in SCHEDULER_REGISTRY: |
|
|
raise ValueError(f"Unsupported scheduler type: {scheduler_type}") |
|
|
|
|
|
scheduler_func = SCHEDULER_REGISTRY[scheduler_type] |
|
|
return scheduler_func( |
|
|
optimizer=optimizer, |
|
|
num_warmup_steps=num_warmup_steps, |
|
|
num_training_steps=num_training_steps, |
|
|
**scheduler_kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
class WarmupStableDecayScheduler: |
|
|
""" |
|
|
Alternative implementation as a standalone scheduler class. |
|
|
|
|
|
This provides more flexibility and better state management for |
|
|
complex training scenarios with checkpointing. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
optimizer: Optimizer, |
|
|
num_warmup_steps: int, |
|
|
num_stable_steps: int, |
|
|
total_steps: int, |
|
|
min_lr_ratio: float = 0.01, |
|
|
decay_type: str = "cosine", |
|
|
verbose: bool = False, |
|
|
): |
|
|
self.optimizer = optimizer |
|
|
self.num_warmup_steps = num_warmup_steps |
|
|
self.num_stable_steps = num_stable_steps |
|
|
self.total_steps = total_steps |
|
|
self.min_lr_ratio = min_lr_ratio |
|
|
self.decay_type = decay_type |
|
|
self.verbose = verbose |
|
|
|
|
|
|
|
|
self.base_lrs = [group["lr"] for group in optimizer.param_groups] |
|
|
self.current_step = 0 |
|
|
|
|
|
def get_lr_factor(self, step: int) -> float: |
|
|
"""Calculate the learning rate multiplication factor for given step.""" |
|
|
if step < self.num_warmup_steps: |
|
|
|
|
|
return step / max(1, self.num_warmup_steps) |
|
|
elif step < self.num_warmup_steps + self.num_stable_steps: |
|
|
|
|
|
return 1.0 |
|
|
else: |
|
|
|
|
|
decay_steps = self.total_steps - self.num_warmup_steps - self.num_stable_steps |
|
|
if decay_steps <= 0: |
|
|
return max(self.min_lr_ratio, 1.0) |
|
|
|
|
|
progress = (step - self.num_warmup_steps - self.num_stable_steps) / decay_steps |
|
|
progress = min(progress, 1.0) |
|
|
|
|
|
if self.decay_type == "cosine": |
|
|
decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress)) |
|
|
elif self.decay_type == "linear": |
|
|
decay_factor = 1.0 - progress |
|
|
else: |
|
|
raise ValueError(f"Unknown decay_type: {self.decay_type}") |
|
|
|
|
|
return max(self.min_lr_ratio, decay_factor) |
|
|
|
|
|
def step(self): |
|
|
"""Update learning rates for all parameter groups.""" |
|
|
lr_factor = self.get_lr_factor(self.current_step) |
|
|
|
|
|
for param_group, base_lr in zip(self.optimizer.param_groups, self.base_lrs, strict=True): |
|
|
param_group["lr"] = base_lr * lr_factor |
|
|
|
|
|
if self.verbose and self.current_step % 1000 == 0: |
|
|
phase = self.get_phase() |
|
|
print(f"Step {self.current_step}: LR factor = {lr_factor:.6f}, Phase = {phase}") |
|
|
|
|
|
self.current_step += 1 |
|
|
|
|
|
def get_phase(self) -> str: |
|
|
"""Get current training phase.""" |
|
|
if self.current_step < self.num_warmup_steps: |
|
|
return "warmup" |
|
|
elif self.current_step < self.num_warmup_steps + self.num_stable_steps: |
|
|
return "stable" |
|
|
else: |
|
|
return "decay" |
|
|
|
|
|
def state_dict(self) -> dict: |
|
|
"""Return scheduler state for checkpointing.""" |
|
|
return { |
|
|
"current_step": self.current_step, |
|
|
"base_lrs": self.base_lrs, |
|
|
} |
|
|
|
|
|
def load_state_dict(self, state_dict: dict): |
|
|
"""Load scheduler state from checkpoint.""" |
|
|
self.current_step = state_dict["current_step"] |
|
|
self.base_lrs = state_dict["base_lrs"] |
|
|
|