TempoPFN / src /optim /lr_scheduler.py
Vladyslav Moroshan
Apply ruff formatting
0a58567
raw
history blame
12 kB
# src/utils/lr_scheduler.py
import math
from enum import Enum
from functools import partial
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
class SchedulerType(Enum):
"""Enumeration of available learning rate schedulers."""
COSINE = "cosine"
COSINE_WITH_WARMUP = "cosine_with_warmup"
COSINE_WITH_RESTARTS = "cosine_with_restarts"
WARMUP_STABLE_DECAY = "warmup_stable_decay"
POLYNOMIAL_WITH_WARMUP = "polynomial_with_warmup"
LINEAR_WITH_WARMUP = "linear_with_warmup"
CONSTANT_WITH_WARMUP = "constant_with_warmup"
INVERSE_SQRT = "inverse_sqrt"
def _get_warmup_stable_decay_lr_lambda(
current_step: int,
*,
num_warmup_steps: int,
num_stable_steps: int,
num_training_steps: int,
min_lr_ratio: float = 0.001,
decay_type: str = "cosine",
):
"""
Learning rate lambda function for Warmup-Stable-Decay (WSD) schedule.
This scheduler implements three phases:
1. Warmup: Linear increase from 0 to peak learning rate
2. Stable: Constant learning rate for majority of training
3. Decay: Gradual decrease using cosine or linear decay
Args:
current_step: Current training step
num_warmup_steps: Number of warmup steps
num_stable_steps: Number of stable learning rate steps
num_training_steps: Total number of training steps
min_lr_ratio: Minimum learning rate as ratio of peak learning rate
decay_type: Type of decay schedule ("cosine" or "linear")
"""
if current_step < num_warmup_steps:
# Warmup phase: linear increase
return float(current_step) / float(max(1, num_warmup_steps))
elif current_step < num_warmup_steps + num_stable_steps:
# Stable phase: constant learning rate
return 1.0
else:
# Decay phase
decay_steps = num_training_steps - num_warmup_steps - num_stable_steps
if decay_steps <= 0:
return max(min_lr_ratio, 1.0)
progress = (current_step - num_warmup_steps - num_stable_steps) / decay_steps
progress = min(progress, 1.0) # Clamp to [0, 1]
if decay_type == "cosine":
# Cosine decay
decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress))
return max(min_lr_ratio, decay_factor)
elif decay_type == "linear":
# Linear decay
decay_factor = 1.0 - progress
return max(min_lr_ratio, decay_factor)
else:
raise ValueError(f"Unknown decay_type: {decay_type}")
def get_warmup_stable_decay_schedule(
optimizer: Optimizer,
num_warmup_steps: int,
num_stable_steps: int,
num_training_steps: int,
min_lr_ratio: float = 0.01,
decay_type: str = "cosine",
last_epoch: int = -1,
):
"""
Create a Warmup-Stable-Decay learning rate schedule.
This scheduler is particularly well-suited for foundation model training as it:
- Provides stable learning during the majority of training
- Doesn't require pre-committing to exact training duration
- Allows for extended training without aggressive decay
Args:
optimizer: The optimizer for which to schedule the learning rate
num_warmup_steps: Number of steps for warmup phase
num_stable_steps: Number of steps for stable learning rate phase
num_training_steps: Total number of training steps
min_lr_ratio: Minimum learning rate as fraction of peak learning rate
decay_type: Type of decay ("cosine" or "linear")
last_epoch: The index of the last epoch when resuming training
Returns:
torch.optim.lr_scheduler.LambdaLR with the WSD schedule
"""
lr_lambda = partial(
_get_warmup_stable_decay_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_stable_steps=num_stable_steps,
num_training_steps=num_training_steps,
min_lr_ratio=min_lr_ratio,
decay_type=decay_type,
)
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
def _get_cosine_schedule_with_warmup_lr_lambda(
current_step: int,
*,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float = 0.5,
min_lr_ratio: float = 0.0,
):
"""Enhanced cosine schedule with configurable minimum learning rate."""
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
cosine_factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
return max(min_lr_ratio, cosine_factor)
def get_enhanced_cosine_schedule_with_warmup(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float = 0.5,
min_lr_ratio: float = 0.01,
last_epoch: int = -1,
):
"""
Enhanced cosine schedule with warmup and configurable minimum learning rate.
Args:
optimizer: The optimizer for which to schedule the learning rate
num_warmup_steps: Number of steps for warmup phase
num_training_steps: Total number of training steps
num_cycles: Number of cosine cycles (0.5 = half cosine)
min_lr_ratio: Minimum learning rate as fraction of peak learning rate
last_epoch: The index of the last epoch when resuming training
"""
lr_lambda = partial(
_get_cosine_schedule_with_warmup_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_cycles=num_cycles,
min_lr_ratio=min_lr_ratio,
)
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
def _get_cosine_with_restarts_lr_lambda(
current_step: int,
*,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: int = 1,
min_lr_ratio: float = 0.0,
):
"""Cosine schedule with hard restarts and configurable minimum learning rate."""
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
if progress >= 1.0:
return min_lr_ratio
cosine_factor = 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))
return max(min_lr_ratio, cosine_factor)
def get_cosine_with_restarts_schedule(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: int = 4,
min_lr_ratio: float = 0.01,
last_epoch: int = -1,
):
"""
Cosine schedule with hard restarts.
Args:
optimizer: The optimizer for which to schedule the learning rate
num_warmup_steps: Number of steps for warmup phase
num_training_steps: Total number of training steps
num_cycles: Number of restart cycles
min_lr_ratio: Minimum learning rate as fraction of peak learning rate
last_epoch: The index of the last epoch when resuming training
"""
lr_lambda = partial(
_get_cosine_with_restarts_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_cycles=num_cycles,
min_lr_ratio=min_lr_ratio,
)
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
# Scheduler registry for easy lookup
SCHEDULER_REGISTRY = {
SchedulerType.WARMUP_STABLE_DECAY: get_warmup_stable_decay_schedule,
SchedulerType.COSINE_WITH_WARMUP: get_enhanced_cosine_schedule_with_warmup,
SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_restarts_schedule,
}
def get_scheduler(
scheduler_type: str | SchedulerType,
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
scheduler_kwargs: dict | None = None,
):
"""
Unified interface to create learning rate schedulers.
Args:
scheduler_type: Type of scheduler to create
optimizer: The optimizer to schedule
num_warmup_steps: Number of warmup steps
num_training_steps: Total training steps
scheduler_kwargs: Additional scheduler-specific parameters
Returns:
Configured learning rate scheduler
"""
if isinstance(scheduler_type, str):
scheduler_type = SchedulerType(scheduler_type)
if scheduler_kwargs is None:
scheduler_kwargs = {}
if scheduler_type not in SCHEDULER_REGISTRY:
raise ValueError(f"Unsupported scheduler type: {scheduler_type}")
scheduler_func = SCHEDULER_REGISTRY[scheduler_type]
return scheduler_func(
optimizer=optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
**scheduler_kwargs,
)
class WarmupStableDecayScheduler:
"""
Alternative implementation as a standalone scheduler class.
This provides more flexibility and better state management for
complex training scenarios with checkpointing.
"""
def __init__(
self,
optimizer: Optimizer,
num_warmup_steps: int,
num_stable_steps: int,
total_steps: int,
min_lr_ratio: float = 0.01,
decay_type: str = "cosine",
verbose: bool = False,
):
self.optimizer = optimizer
self.num_warmup_steps = num_warmup_steps
self.num_stable_steps = num_stable_steps
self.total_steps = total_steps
self.min_lr_ratio = min_lr_ratio
self.decay_type = decay_type
self.verbose = verbose
# Store initial learning rates
self.base_lrs = [group["lr"] for group in optimizer.param_groups]
self.current_step = 0
def get_lr_factor(self, step: int) -> float:
"""Calculate the learning rate multiplication factor for given step."""
if step < self.num_warmup_steps:
# Warmup phase
return step / max(1, self.num_warmup_steps)
elif step < self.num_warmup_steps + self.num_stable_steps:
# Stable phase
return 1.0
else:
# Decay phase
decay_steps = self.total_steps - self.num_warmup_steps - self.num_stable_steps
if decay_steps <= 0:
return max(self.min_lr_ratio, 1.0)
progress = (step - self.num_warmup_steps - self.num_stable_steps) / decay_steps
progress = min(progress, 1.0)
if self.decay_type == "cosine":
decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress))
elif self.decay_type == "linear":
decay_factor = 1.0 - progress
else:
raise ValueError(f"Unknown decay_type: {self.decay_type}")
return max(self.min_lr_ratio, decay_factor)
def step(self):
"""Update learning rates for all parameter groups."""
lr_factor = self.get_lr_factor(self.current_step)
for param_group, base_lr in zip(self.optimizer.param_groups, self.base_lrs, strict=True):
param_group["lr"] = base_lr * lr_factor
if self.verbose and self.current_step % 1000 == 0:
phase = self.get_phase()
print(f"Step {self.current_step}: LR factor = {lr_factor:.6f}, Phase = {phase}")
self.current_step += 1
def get_phase(self) -> str:
"""Get current training phase."""
if self.current_step < self.num_warmup_steps:
return "warmup"
elif self.current_step < self.num_warmup_steps + self.num_stable_steps:
return "stable"
else:
return "decay"
def state_dict(self) -> dict:
"""Return scheduler state for checkpointing."""
return {
"current_step": self.current_step,
"base_lrs": self.base_lrs,
}
def load_state_dict(self, state_dict: dict):
"""Load scheduler state from checkpoint."""
self.current_step = state_dict["current_step"]
self.base_lrs = state_dict["base_lrs"]