| import torch | |
| input_path = "best_checkpoint.pth" | |
| output_path = "best_checkpoint_ema.pth" | |
| state = torch.load(input_path, map_location="cpu", weights_only=False) | |
| ema_state = state["model_ema"] | |
| if hasattr(ema_state, "state_dict"): | |
| ema_state = ema_state.state_dict() | |
| torch.save(ema_state, output_path) | |
| print(f"saved EMA weights to {output_path}") |