Hi!
I am having some issues understanding why I can’t resume training corectly with this code:
optimizer = optim.AdamW(params, lr=args.learning_rate)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(args.num_iterations * args.lr_warmup_fraction), num_training_steps=args.num_iterations)
# custom class which only saves 2 ints
# and implements load_state_dict() and state_dict()
training_state = TrainingStateTracker(iteration=1, global_step=0)
accelerator.register_for_checkpointing(training_state)
agent, optimizer, scheduler = accelerator.prepare(agent, optimizer, scheduler)
if args.checkpoint_dir != "":
print(f"Loading checkpoint from {args.checkpoint_dir}")
accelerator.load_state(args.checkpoint_dir)
print(f"Checkpoint loaded from {args.checkpoint_dir}")
accelerator.wait_for_everyone()
# these will be either the initialized values or the loaded values from the checkpoint
start_iteration = training_state["iteration"]
global_step = training_state["global_step"]
print(f"Starting training from iteration {start_iteration} and global step {global_step}")
-
I see scheduler.bin being stores when I call .save_state() so I don’t have to explicitly register the shceduler for checkpointing , right? Still, when loading state, the first LR that I get when calling .step() is not the continuation of the last one before it was stored….
-
the last print statement: print(f"Starting training from iteration {start_iteration} and global step {global_step}") always prints starting from iteration 1 and global_step 0, so even though I see the custom_checkpoint_0.pkl being saved (since I registered it for checkpoint), the 2 fields are not loaded correctly…
1 Like
This is complicated…
Short answer:
-
You don’t need to register the scheduler if you passed it to accelerator.prepare(...). Accelerate will checkpoint it as scheduler (file scheduler*), together with model, optimizer, RNG, scaler. (huggingface.co)
-
Your LR jump and your two integers not restoring come from two distinct causes:
- scheduler stepping or construction doesn’t match the saved run;
- your custom object wasn’t actually reloaded because the number/order of registered custom objects didn’t match the files in the checkpoint folder. (huggingface.co)
Below is the exact, minimal fix.
Fix the LR scheduler resume
-
Choose one stepping method and keep it the same across save and load:
- If you keep the default
Accelerator(step_scheduler_with_optimizer=True), do not call scheduler.step() yourself. Accelerate steps it on each optimizer step. (huggingface.co)
- If you want to step it manually, construct
Accelerator(step_scheduler_with_optimizer=False) and call scheduler.step() exactly once per optimizer update after optimizer.step(). This order avoids the classic off-by-one. (PyTorch Forums)
-
Rebuild the scheduler with the same hyperparameters as when you saved it. get_linear_schedule_with_warmup bakes num_warmup_steps and num_training_steps into the schedule. If those differ, the resumed LR curve will diverge even when the state loads. (huggingface.co)
-
Load after prepare. The documented sequence is: build objects → prepare(...) → load_state(...). (huggingface.co)
-
Sanity check right after load:
# do NOT step here; just inspect
print("resumed_lr", scheduler.get_last_lr())
# ... after a real optimizer step ...
optimizer.step()
# if and only if you set step_scheduler_with_optimizer=False:
# scheduler.step()
print("post_step_lr", scheduler.get_last_lr())
This verifies the first LR is continuous and that you’re not double-stepping. The Accelerate scheduler wrapper is designed to step with the optimizer only when a real step happened. (huggingface.co)
Fix the custom TrainingStateTracker not restoring
The file custom_checkpoint_0.pkl can exist and still be skipped if the number of registered custom objects at load time doesn’t match the number of custom_checkpoint_*.pkl* files in the folder. Accelerate logs a warning and skips loading the customs, which leaves your fields at constructor defaults. Clean any extras (e.g., .orig) and register the same objects in the same order before load_state. (GitHub)
Concrete checklist:
-
Register every run, before load_state:
training_state = TrainingStateTracker(iteration=1, global_step=0)
accelerator.register_for_checkpointing(training_state)
The contract requires state_dict() and in-place load_state_dict(...). (huggingface.co)
-
Ensure the checkpoint dir has exactly one matching file (for your single custom object): custom_checkpoint_0.pkl and nothing like custom_checkpoint_0.pkl.orig. If counts differ, Accelerate prints: “Found checkpoints: N / Registered objects: M — Skipping.” (GitHub)
Robust resume template (drop-in)
from accelerate import Accelerator
from transformers import get_linear_schedule_with_warmup
# 1) keep this consistent across runs
accelerator = Accelerator(step_scheduler_with_optimizer=True)
# 2) rebuild exactly as before
optimizer = torch.optim.AdamW(params, lr=args.learning_rate)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=int(args.num_iterations * args.lr_warmup_fraction),
num_training_steps=args.num_iterations,
)
# 3) register customs BEFORE load_state
training_state = TrainingStateTracker(iteration=1, global_step=0)
accelerator.register_for_checkpointing(training_state)
# 4) prepare, then load
agent, optimizer, scheduler = accelerator.prepare(agent, optimizer, scheduler)
if args.checkpoint_dir:
accelerator.load_state(args.checkpoint_dir)
# 5) read restored counters
start_iteration = training_state["iteration"]
global_step = training_state["global_step"]
Why this works:
save_state/load_state handle model, optimizer, RNG, scaler, and the scheduler automatically when it went through prepare(...). The constants include SCHEDULER_NAME="scheduler" (your scheduler.bin). (huggingface.co)
- Custom objects load only if the registered count/order equals the found files. (GitHub)
Edge conditions that also break LR continuity
- Manually calling
scheduler.step() and leaving step_scheduler_with_optimizer=True → LR steps twice per update on multi-GPU. Users reported this exact drift; the fix is to pick one stepping path. (GitHub)
- Changing DeepSpeed config to include a scheduler while also passing a real scheduler in code is rejected; you must not specify both. Accelerate enforces this. (gemfury.com)
Similar cases for reference
- Accelerate skips custom object restore if checkpoint counts don’t match; log shows “Found checkpoints … Registered objects … Skipping.” Resolution: clean extra
custom_checkpoint_*.pkl* and re-register before load. (GitHub)
- Re current stepping behavior and defaults:
step_scheduler_with_optimizer is True by default; set it to False if you want manual stepping. (huggingface.co)
- PyTorch call order rule: call
optimizer.step() before scheduler.step() to avoid skipping the first schedule value. (PyTorch Forums)
Bottom line:
- Keep
step_scheduler_with_optimizer and your call pattern aligned.
- Recreate the scheduler with the same
num_warmup_steps and num_training_steps.
- Register the custom tracker every run before
load_state and ensure the checkpoint dir has a single matching custom_checkpoint_0.pkl.
- Always
prepare(...) then load_state(...). (huggingface.co)