Load_state() with custom objects and scheduler

Hi!

I am having some issues understanding why I can’t resume training corectly with this code:

optimizer = optim.AdamW(params, lr=args.learning_rate)

scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(args.num_iterations * args.lr_warmup_fraction), num_training_steps=args.num_iterations)

# custom class which only saves 2 ints
# and implements load_state_dict() and state_dict()
training_state = TrainingStateTracker(iteration=1, global_step=0)
accelerator.register_for_checkpointing(training_state)

agent, optimizer, scheduler = accelerator.prepare(agent, optimizer, scheduler)

if args.checkpoint_dir != "":
    print(f"Loading checkpoint from {args.checkpoint_dir}")
    accelerator.load_state(args.checkpoint_dir)
    print(f"Checkpoint loaded from {args.checkpoint_dir}")

accelerator.wait_for_everyone()


# these will be either the initialized values or the loaded values from the checkpoint
start_iteration = training_state["iteration"]
global_step = training_state["global_step"]

print(f"Starting training from iteration {start_iteration} and global step {global_step}")
  1. I see scheduler.bin being stores when I call .save_state() so I don’t have to explicitly register the shceduler for checkpointing , right? Still, when loading state, the first LR that I get when calling .step() is not the continuation of the last one before it was stored….

  2. the last print statement: print(f"Starting training from iteration {start_iteration} and global step {global_step}") always prints starting from iteration 1 and global_step 0, so even though I see the custom_checkpoint_0.pkl being saved (since I registered it for checkpoint), the 2 fields are not loaded correctly…

1 Like

This is complicated…:sweat_smile:


Short answer:

  • You don’t need to register the scheduler if you passed it to accelerator.prepare(...). Accelerate will checkpoint it as scheduler (file scheduler*), together with model, optimizer, RNG, scaler. (huggingface.co)

  • Your LR jump and your two integers not restoring come from two distinct causes:

    1. scheduler stepping or construction doesn’t match the saved run;
    2. your custom object wasn’t actually reloaded because the number/order of registered custom objects didn’t match the files in the checkpoint folder. (huggingface.co)

Below is the exact, minimal fix.

Fix the LR scheduler resume

  1. Choose one stepping method and keep it the same across save and load:

    • If you keep the default Accelerator(step_scheduler_with_optimizer=True), do not call scheduler.step() yourself. Accelerate steps it on each optimizer step. (huggingface.co)
    • If you want to step it manually, construct Accelerator(step_scheduler_with_optimizer=False) and call scheduler.step() exactly once per optimizer update after optimizer.step(). This order avoids the classic off-by-one. (PyTorch Forums)
  2. Rebuild the scheduler with the same hyperparameters as when you saved it. get_linear_schedule_with_warmup bakes num_warmup_steps and num_training_steps into the schedule. If those differ, the resumed LR curve will diverge even when the state loads. (huggingface.co)

  3. Load after prepare. The documented sequence is: build objects → prepare(...) → load_state(...). (huggingface.co)

  4. Sanity check right after load:

# do NOT step here; just inspect
print("resumed_lr", scheduler.get_last_lr())
# ... after a real optimizer step ...
optimizer.step()
# if and only if you set step_scheduler_with_optimizer=False:
# scheduler.step()
print("post_step_lr", scheduler.get_last_lr())

This verifies the first LR is continuous and that you’re not double-stepping. The Accelerate scheduler wrapper is designed to step with the optimizer only when a real step happened. (huggingface.co)

Fix the custom TrainingStateTracker not restoring

The file custom_checkpoint_0.pkl can exist and still be skipped if the number of registered custom objects at load time doesn’t match the number of custom_checkpoint_*.pkl* files in the folder. Accelerate logs a warning and skips loading the customs, which leaves your fields at constructor defaults. Clean any extras (e.g., .orig) and register the same objects in the same order before load_state. (GitHub)

Concrete checklist:

  • Register every run, before load_state:

    training_state = TrainingStateTracker(iteration=1, global_step=0)
    accelerator.register_for_checkpointing(training_state)
    

    The contract requires state_dict() and in-place load_state_dict(...). (huggingface.co)

  • Ensure the checkpoint dir has exactly one matching file (for your single custom object): custom_checkpoint_0.pkl and nothing like custom_checkpoint_0.pkl.orig. If counts differ, Accelerate prints: “Found checkpoints: N / Registered objects: M — Skipping.” (GitHub)

Robust resume template (drop-in)

from accelerate import Accelerator
from transformers import get_linear_schedule_with_warmup

# 1) keep this consistent across runs
accelerator = Accelerator(step_scheduler_with_optimizer=True)

# 2) rebuild exactly as before
optimizer = torch.optim.AdamW(params, lr=args.learning_rate)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(args.num_iterations * args.lr_warmup_fraction),
    num_training_steps=args.num_iterations,
)

# 3) register customs BEFORE load_state
training_state = TrainingStateTracker(iteration=1, global_step=0)
accelerator.register_for_checkpointing(training_state)

# 4) prepare, then load
agent, optimizer, scheduler = accelerator.prepare(agent, optimizer, scheduler)
if args.checkpoint_dir:
    accelerator.load_state(args.checkpoint_dir)

# 5) read restored counters
start_iteration = training_state["iteration"]
global_step = training_state["global_step"]

Why this works:

  • save_state/load_state handle model, optimizer, RNG, scaler, and the scheduler automatically when it went through prepare(...). The constants include SCHEDULER_NAME="scheduler" (your scheduler.bin). (huggingface.co)
  • Custom objects load only if the registered count/order equals the found files. (GitHub)

Edge conditions that also break LR continuity

  • Manually calling scheduler.step() and leaving step_scheduler_with_optimizer=True → LR steps twice per update on multi-GPU. Users reported this exact drift; the fix is to pick one stepping path. (GitHub)
  • Changing DeepSpeed config to include a scheduler while also passing a real scheduler in code is rejected; you must not specify both. Accelerate enforces this. (gemfury.com)

Similar cases for reference

  • Accelerate skips custom object restore if checkpoint counts don’t match; log shows “Found checkpoints … Registered objects … Skipping.” Resolution: clean extra custom_checkpoint_*.pkl* and re-register before load. (GitHub)
  • Re current stepping behavior and defaults: step_scheduler_with_optimizer is True by default; set it to False if you want manual stepping. (huggingface.co)
  • PyTorch call order rule: call optimizer.step() before scheduler.step() to avoid skipping the first schedule value. (PyTorch Forums)

Bottom line:

  • Keep step_scheduler_with_optimizer and your call pattern aligned.
  • Recreate the scheduler with the same num_warmup_steps and num_training_steps.
  • Register the custom tracker every run before load_state and ensure the checkpoint dir has a single matching custom_checkpoint_0.pkl.
  • Always prepare(...) then load_state(...). (huggingface.co)