Kaloscope-onnx / save_ema.py
DraconicDragon's picture
Create save_ema.py
486419f
raw
history blame
351 Bytes
import torch
input_path = "best_checkpoint.pth"
output_path = "best_checkpoint_ema.pth"
state = torch.load(input_path, map_location="cpu", weights_only=False)
ema_state = state["model_ema"]
if hasattr(ema_state, "state_dict"):
ema_state = ema_state.state_dict()
torch.save(ema_state, output_path)
print(f"saved EMA weights to {output_path}")