Kernels
wyldecat github-actions[bot] commited on
Commit
ae32572
Β·
unverified Β·
1 Parent(s): fa059da

Support mHC (#15)

Browse files

* log target shape

* update toml

* use distributed muon for small param

* Add built binary [skip-build]

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. build.toml +24 -14
  2. build/torch210-cxx11-cu126-x86_64-linux/__init__.py +5 -0
  3. build/{torch28-cxx11-cu126-x86_64-linux/optimizer β†’ torch210-cxx11-cu126-x86_64-linux}/_ops.py +3 -3
  4. build/{torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so β†’ torch210-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so} +2 -2
  5. build/{torch28-cxx11-cu128-x86_64-linux/optimizer β†’ torch210-cxx11-cu126-x86_64-linux}/distributed/utils.py +3 -2
  6. build/{torch28-cxx11-cu126-x86_64-linux/optimizer β†’ torch210-cxx11-cu126-x86_64-linux}/matmul_transpose_triton.py +0 -0
  7. build/torch210-cxx11-cu126-x86_64-linux/metadata.json +1 -0
  8. build/{torch28-cxx11-cu128-x86_64-linux/optimizer β†’ torch210-cxx11-cu126-x86_64-linux}/muon.py +88 -60
  9. build/torch210-cxx11-cu126-x86_64-linux/optimizer/__init__.py +26 -0
  10. build/torch210-cxx11-cu128-x86_64-linux/__init__.py +5 -0
  11. build/{torch28-cxx11-cu128-x86_64-linux/optimizer β†’ torch210-cxx11-cu128-x86_64-linux}/_ops.py +3 -3
  12. build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so β†’ torch210-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so} +2 -2
  13. build/{torch28-cxx11-cu129-x86_64-linux/optimizer β†’ torch210-cxx11-cu128-x86_64-linux}/distributed/utils.py +3 -2
  14. build/{torch28-cxx11-cu128-x86_64-linux/optimizer β†’ torch210-cxx11-cu128-x86_64-linux}/matmul_transpose_triton.py +0 -0
  15. build/torch210-cxx11-cu128-x86_64-linux/metadata.json +1 -0
  16. build/{torch28-cxx11-cu129-x86_64-linux/optimizer β†’ torch210-cxx11-cu128-x86_64-linux}/muon.py +88 -60
  17. build/torch210-cxx11-cu128-x86_64-linux/optimizer/__init__.py +26 -0
  18. build/torch210-cxx11-cu130-x86_64-linux/__init__.py +5 -0
  19. build/{torch28-cxx11-cu129-x86_64-linux/optimizer β†’ torch210-cxx11-cu130-x86_64-linux}/_ops.py +3 -3
  20. build/{torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so β†’ torch210-cxx11-cu130-x86_64-linux/_optimizer_06a260a_dirty.abi3.so} +2 -2
  21. build/{torch28-cxx11-cu126-x86_64-linux/optimizer β†’ torch210-cxx11-cu130-x86_64-linux}/distributed/utils.py +3 -2
  22. build/{torch28-cxx11-cu129-x86_64-linux/optimizer β†’ torch210-cxx11-cu130-x86_64-linux}/matmul_transpose_triton.py +0 -0
  23. build/torch210-cxx11-cu130-x86_64-linux/metadata.json +1 -0
  24. build/{torch28-cxx11-rocm63-x86_64-linux/optimizer β†’ torch210-cxx11-cu130-x86_64-linux}/muon.py +88 -60
  25. build/torch210-cxx11-cu130-x86_64-linux/optimizer/__init__.py +26 -0
  26. build/torch210-cxx11-rocm70-x86_64-linux/__init__.py +5 -0
  27. build/{torch28-cxx11-rocm63-x86_64-linux/optimizer β†’ torch210-cxx11-rocm70-x86_64-linux}/_ops.py +3 -3
  28. build/{torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so β†’ torch210-cxx11-rocm70-x86_64-linux/_optimizer_06a260a_dirty.abi3.so} +2 -2
  29. build/{torch28-cxx11-rocm63-x86_64-linux/optimizer β†’ torch210-cxx11-rocm70-x86_64-linux}/distributed/utils.py +3 -2
  30. build/{torch28-cxx11-rocm63-x86_64-linux/optimizer β†’ torch210-cxx11-rocm70-x86_64-linux}/matmul_transpose_triton.py +0 -0
  31. build/torch210-cxx11-rocm70-x86_64-linux/metadata.json +1 -0
  32. build/{torch28-cxx11-cu126-x86_64-linux/optimizer β†’ torch210-cxx11-rocm70-x86_64-linux}/muon.py +88 -60
  33. build/torch210-cxx11-rocm70-x86_64-linux/optimizer/__init__.py +26 -0
  34. build/torch210-cxx11-rocm71-x86_64-linux/__init__.py +5 -0
  35. build/torch210-cxx11-rocm71-x86_64-linux/_ops.py +9 -0
  36. build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +3 -0
  37. build/torch210-cxx11-rocm71-x86_64-linux/distributed/utils.py +175 -0
  38. build/{torch28-cxx11-rocm64-x86_64-linux/optimizer β†’ torch210-cxx11-rocm71-x86_64-linux}/matmul_transpose_triton.py +0 -0
  39. build/torch210-cxx11-rocm71-x86_64-linux/metadata.json +1 -0
  40. build/torch210-cxx11-rocm71-x86_64-linux/muon.py +1268 -0
  41. build/torch210-cxx11-rocm71-x86_64-linux/optimizer/__init__.py +26 -0
  42. build/torch28-cxx11-cu126-x86_64-linux/__init__.py +5 -0
  43. build/torch28-cxx11-cu126-x86_64-linux/_ops.py +9 -0
  44. build/torch28-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so +3 -0
  45. build/torch28-cxx11-cu126-x86_64-linux/distributed/utils.py +175 -0
  46. build/{torch29-cxx11-cu126-x86_64-linux/optimizer β†’ torch28-cxx11-cu126-x86_64-linux}/matmul_transpose_triton.py +0 -0
  47. build/torch28-cxx11-cu126-x86_64-linux/metadata.json +1 -0
  48. build/torch28-cxx11-cu126-x86_64-linux/muon.py +1268 -0
  49. build/torch28-cxx11-cu126-x86_64-linux/optimizer/__init__.py +25 -4
  50. build/torch28-cxx11-cu128-x86_64-linux/__init__.py +5 -0
build.toml CHANGED
@@ -1,23 +1,33 @@
1
  [general]
2
  name = "optimizer"
3
- universal = false
4
-
5
- [torch]
6
- src = [
7
- "torch-ext/torch_binding.cpp",
8
- "torch-ext/torch_binding.h",
9
  ]
10
 
11
- [kernel.activation]
12
- backend = "rocm"
13
  src = [
14
- "optimizer/dummy.cu",
 
15
  ]
16
- depends = [ "torch" ]
17
 
18
- [kernel.activation_cuda]
19
  backend = "cuda"
20
- src = [
21
- "optimizer/dummy.cu",
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  ]
23
- depends = [ "torch" ]
 
 
1
  [general]
2
  name = "optimizer"
3
+ backends = [
4
+ "cuda",
5
+ "rocm",
 
 
 
6
  ]
7
 
8
+ [torch]
 
9
  src = [
10
+ "torch-ext/torch_binding.cpp",
11
+ "torch-ext/torch_binding.h",
12
  ]
 
13
 
14
+ [kernel.optimizer]
15
  backend = "cuda"
16
+ depends = ["torch"]
17
+ src = ["optimizer/dummy.cu"]
18
+
19
+ [kernel.optimizer_rocm]
20
+ backend = "rocm"
21
+ rocm-archs = [
22
+ "gfx906",
23
+ "gfx908",
24
+ "gfx90a",
25
+ "gfx940",
26
+ "gfx941",
27
+ "gfx942",
28
+ "gfx1030",
29
+ "gfx1100",
30
+ "gfx1101",
31
  ]
32
+ depends = ["torch"]
33
+ src = ["optimizer/dummy.cu"]
build/torch210-cxx11-cu126-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .muon import Muon
2
+
3
+ __all__ = [
4
+ "Muon",
5
+ ]
build/{torch28-cxx11-cu126-x86_64-linux/optimizer β†’ torch210-cxx11-cu126-x86_64-linux}/_ops.py RENAMED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_23d68bb_dirty
3
- ops = torch.ops._optimizer_23d68bb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_23d68bb_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_06a260a_dirty
3
+ ops = torch.ops._optimizer_06a260a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_06a260a_dirty::{op_name}"
build/{torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so β†’ torch210-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:03c3bbbbc5c4ceb5cebfe3a2e411f155bebb390f1921c14d59fcf791dd556da1
3
- size 1983488
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5384da54f22f488e0646e09915b821b3235cb404b163a570aa377967f853e3cf
3
+ size 1940944
build/{torch28-cxx11-cu128-x86_64-linux/optimizer β†’ torch210-cxx11-cu126-x86_64-linux}/distributed/utils.py RENAMED
@@ -50,7 +50,7 @@ def get_slices_of_dtensor(
50
  raise NotImplementedError(
51
  f"Dimension size {dim_size} is not divisible "
52
  f"by number of ranks {num_ranks} for shard "
53
- f"placement on dim {dim}.")
54
 
55
  shard_size = dim_size // num_ranks
56
 
@@ -64,7 +64,8 @@ def get_slices_of_dtensor(
64
  return tuple(slices)
65
 
66
 
67
- _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict()
 
68
 
69
 
70
  def construct_shard_mesh(
 
50
  raise NotImplementedError(
51
  f"Dimension size {dim_size} is not divisible "
52
  f"by number of ranks {num_ranks} for shard "
53
+ f"placement on dim {dim}. (shape: {target.shape})")
54
 
55
  shard_size = dim_size // num_ranks
56
 
 
64
  return tuple(slices)
65
 
66
 
67
+ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
68
+ ProcessGroup]] = dict()
69
 
70
 
71
  def construct_shard_mesh(
build/{torch28-cxx11-cu126-x86_64-linux/optimizer β†’ torch210-cxx11-cu126-x86_64-linux}/matmul_transpose_triton.py RENAMED
File without changes
build/torch210-cxx11-cu126-x86_64-linux/metadata.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"python-depends":[]}
build/{torch28-cxx11-cu128-x86_64-linux/optimizer β†’ torch210-cxx11-cu126-x86_64-linux}/muon.py RENAMED
@@ -583,6 +583,7 @@ class Muon(torch.optim.Optimizer):
583
  Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
  For testing purpose only.
 
586
  """
587
 
588
  def __init__(self,
@@ -604,7 +605,8 @@ class Muon(torch.optim.Optimizer):
604
  },
605
  warmup_step=5,
606
  chunk_size=-1,
607
- use_distributed_muon=False):
 
608
  defaults = dict(
609
  lr=lr,
610
  weight_decay=weight_decay,
@@ -637,6 +639,7 @@ class Muon(torch.optim.Optimizer):
637
  self.warmup_step = warmup_step
638
  self.chunk_size = chunk_size
639
  self.use_distributed_muon = use_distributed_muon
 
640
 
641
  def _calc_flops(self, G, steps):
642
  assert len(G.shape) == 2
@@ -745,16 +748,7 @@ class Muon(torch.optim.Optimizer):
745
  g = g.view(g.size(0), -1)
746
  assert g is not None
747
 
748
- # calc update
749
- state = self.state[p]
750
- if "momentum_buffer" not in state:
751
- state["momentum_buffer"] = torch.zeros_like(g)
752
- buf = state["momentum_buffer"]
753
- buf.mul_(momentum).add_(g)
754
- if group["nesterov"]:
755
- g = g.add(buf, alpha=momentum)
756
- else:
757
- g = buf
758
 
759
  u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
760
  steps=group["ns_steps"])
@@ -780,14 +774,6 @@ class Muon(torch.optim.Optimizer):
780
  qk_logits: list[torch.Tensor | DTensor] | None,
781
  ):
782
  """ Implementation of Distributed Muon by Liu et al. """
783
- if qk_logits is not None:
784
- raise NotImplementedError("QK clipping is not supported yet")
785
-
786
- if isinstance(params[0], DTensor):
787
- shard_mesh, _, shard_placements = construct_shard_mesh(
788
- placements=params[0].placements,
789
- mesh=params[0].device_mesh,
790
- )
791
 
792
  for n, p in zip(names, params):
793
  g = p.grad
@@ -797,39 +783,44 @@ class Muon(torch.optim.Optimizer):
797
  g = g.view(g.size(0), -1)
798
  assert g is not None
799
 
800
- # calc update
801
- state = self.state[p]
802
- if "momentum_buffer" not in state:
803
- state["momentum_buffer"] = torch.zeros_like(g)
804
- buf = state["momentum_buffer"]
805
- buf.mul_(momentum).add_(g)
806
- if group["nesterov"]:
807
- g = g.add(buf, alpha=momentum)
808
- else:
809
- g = buf
810
 
811
  # Gather G
812
  if isinstance(p.data, DTensor):
813
- g = g.full_tensor()
814
- u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
815
- steps=group["ns_steps"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
816
 
817
  if isinstance(p.data, DTensor):
818
- slices = get_slices_of_dtensor(
819
- target=p,
820
- local_rank=dist.get_rank(),
821
- shard_mesh=shard_mesh,
822
- shard_placements=shard_placements,
823
  )
824
- u_shard = u[slices]
825
- u = DTensor.from_local(
826
- u_shard,
827
  device_mesh=p.device_mesh,
828
  placements=p.placements,
829
  )
830
 
831
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
832
- Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
833
 
834
  def _update_g(self, p, g, group, momentum):
835
  # calc update
@@ -843,10 +834,14 @@ class Muon(torch.optim.Optimizer):
843
 
844
  @staticmethod
845
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
846
- # apply weight decay
847
- p.data.mul_(1 - lr * weight_decay)
848
- # apply update
849
- p.data.add_(u, alpha=-adjusted_lr)
 
 
 
 
850
 
851
  def get_qk_clip_info(self, n, qk_logits):
852
  if self.clip_config is None:
@@ -903,8 +898,12 @@ class Muon(torch.optim.Optimizer):
903
 
904
  @staticmethod
905
  def _qk_clip(p, scales, head_dim):
906
- W = p.data.view(-1, head_dim, p.data.shape[1])
907
- W.mul_(scales.view(-1, 1, 1))
 
 
 
 
908
 
909
  def parallel(self, names, params, group, lr, weight_decay, momentum,
910
  qk_logits):
@@ -1070,10 +1069,14 @@ class Muon(torch.optim.Optimizer):
1070
  names = group["names"]
1071
 
1072
  param_dtensors = []
1073
- param_tensors = []
1074
  name_dtensors = []
 
 
1075
  name_tensors = []
1076
 
 
 
 
1077
  if self.use_distributed_muon:
1078
  self.distributed_muon(names=names,
1079
  params=params,
@@ -1084,6 +1087,8 @@ class Muon(torch.optim.Optimizer):
1084
  qk_logits=qk_logits)
1085
  return
1086
 
 
 
1087
  for n, p in zip(names, params):
1088
  if p is None or p.grad is None:
1089
  continue
@@ -1093,6 +1098,9 @@ class Muon(torch.optim.Optimizer):
1093
  for placement in p.placements):
1094
  param_tensors.append(p)
1095
  name_tensors.append(n)
 
 
 
1096
  else:
1097
  param_dtensors.append(p)
1098
  name_dtensors.append(n)
@@ -1103,29 +1111,48 @@ class Muon(torch.optim.Optimizer):
1103
  raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1104
 
1105
  logger.debug(
1106
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors"
1107
- )
1108
-
1109
- if len(param_dtensors) > 0:
1110
- if not dist.is_initialized():
1111
- raise RuntimeError(
1112
- "Parallel Muon requires torch.distributed to be initialized."
1113
- )
1114
 
 
1115
  # To support different placements, we group parameters by placements
1116
  # and run parallel Muon on each group.
1117
 
1118
  placement_to_params = defaultdict(lambda: ([], []))
1119
  # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1120
 
1121
- assert len(name_dtensors) == len(param_dtensors)
1122
- for n, p in zip(name_dtensors, param_dtensors):
1123
  placement_to_params[tuple([p.placements,
1124
  p.device_mesh])][0].append(n)
1125
  placement_to_params[tuple([p.placements,
1126
  p.device_mesh])][1].append(p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1127
 
1128
- for _, (names, params) in placement_to_params.items():
 
1129
  self.parallel(
1130
  names,
1131
  params,
@@ -1215,6 +1242,7 @@ class Muon(torch.optim.Optimizer):
1215
  for params in placement_to_params.values():
1216
  self._step_adamw_params(params, group)
1217
 
 
1218
  def step(self, closure=None, qk_logits=None):
1219
  """Perform a single optimization step.
1220
 
 
583
  Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
  For testing purpose only.
586
+ small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
587
  """
588
 
589
  def __init__(self,
 
605
  },
606
  warmup_step=5,
607
  chunk_size=-1,
608
+ use_distributed_muon=False,
609
+ small_param_numel_threshold=65536):
610
  defaults = dict(
611
  lr=lr,
612
  weight_decay=weight_decay,
 
639
  self.warmup_step = warmup_step
640
  self.chunk_size = chunk_size
641
  self.use_distributed_muon = use_distributed_muon
642
+ self.small_param_numel_threshold = small_param_numel_threshold
643
 
644
  def _calc_flops(self, G, steps):
645
  assert len(G.shape) == 2
 
748
  g = g.view(g.size(0), -1)
749
  assert g is not None
750
 
751
+ g = self._update_g(p, g, group, momentum)
 
 
 
 
 
 
 
 
 
752
 
753
  u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
754
  steps=group["ns_steps"])
 
774
  qk_logits: list[torch.Tensor | DTensor] | None,
775
  ):
776
  """ Implementation of Distributed Muon by Liu et al. """
 
 
 
 
 
 
 
 
777
 
778
  for n, p in zip(names, params):
779
  g = p.grad
 
783
  g = g.view(g.size(0), -1)
784
  assert g is not None
785
 
786
+ g = self._update_g(p, g, group, momentum)
 
 
 
 
 
 
 
 
 
787
 
788
  # Gather G
789
  if isinstance(p.data, DTensor):
790
+ g_full = g.full_tensor()
791
+ p_full = p.data.full_tensor()
792
+ else:
793
+ g_full = g
794
+ p_full = p
795
+
796
+ u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
797
+ steps=group["ns_steps"])
798
+
799
+ adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape)
800
+ Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
801
+
802
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
803
+
804
+ scales_full = self._compute_scales(
805
+ p_full, qk_clip_state) if qk_clip_state is not None else None
806
+
807
+ if scales_full is not None:
808
+ Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim)
809
 
810
  if isinstance(p.data, DTensor):
811
+ ndims = len(p.device_mesh.mesh.shape)
812
+ p_replicate = DTensor.from_local(
813
+ p_full,
814
+ device_mesh=p.device_mesh,
815
+ placements=[Replicate() for _ in range(ndims)],
816
  )
817
+
818
+ p_sharded = p_replicate.redistribute(
 
819
  device_mesh=p.device_mesh,
820
  placements=p.placements,
821
  )
822
 
823
+ p.copy_(p_sharded)
 
824
 
825
  def _update_g(self, p, g, group, momentum):
826
  # calc update
 
834
 
835
  @staticmethod
836
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
837
+ if isinstance(p, torch.nn.Parameter):
838
+ # apply weight decay
839
+ p.data.mul_(1 - lr * weight_decay)
840
+ # apply update
841
+ p.data.add_(u, alpha=-adjusted_lr)
842
+ else:
843
+ p.mul_(1 - lr * weight_decay)
844
+ p.add_(u, alpha=-adjusted_lr)
845
 
846
  def get_qk_clip_info(self, n, qk_logits):
847
  if self.clip_config is None:
 
898
 
899
  @staticmethod
900
  def _qk_clip(p, scales, head_dim):
901
+ if isinstance(p, torch.nn.Parameter):
902
+ W = p.data.view(-1, head_dim, p.data.shape[1])
903
+ W.mul_(scales.view(-1, 1, 1))
904
+ else:
905
+ W = p.view(-1, head_dim, p.shape[1])
906
+ W.mul_(scales.view(-1, 1, 1))
907
 
908
  def parallel(self, names, params, group, lr, weight_decay, momentum,
909
  qk_logits):
 
1069
  names = group["names"]
1070
 
1071
  param_dtensors = []
 
1072
  name_dtensors = []
1073
+
1074
+ param_tensors = []
1075
  name_tensors = []
1076
 
1077
+ param_dtensors_small = []
1078
+ name_dtensors_small = []
1079
+
1080
  if self.use_distributed_muon:
1081
  self.distributed_muon(names=names,
1082
  params=params,
 
1087
  qk_logits=qk_logits)
1088
  return
1089
 
1090
+ # For simplicity, we use distributed Muon for small parameters
1091
+ # whose number of elements is below a threshold.
1092
  for n, p in zip(names, params):
1093
  if p is None or p.grad is None:
1094
  continue
 
1098
  for placement in p.placements):
1099
  param_tensors.append(p)
1100
  name_tensors.append(n)
1101
+ elif p.data.numel() <= self.small_param_numel_threshold:
1102
+ param_dtensors_small.append(p)
1103
+ name_dtensors_small.append(n)
1104
  else:
1105
  param_dtensors.append(p)
1106
  name_dtensors.append(n)
 
1111
  raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1112
 
1113
  logger.debug(
1114
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, "
1115
+ f"{len(param_dtensors_small)} Small DTensors")
 
 
 
 
 
 
1116
 
1117
+ def group_dtensors(dtensors, names):
1118
  # To support different placements, we group parameters by placements
1119
  # and run parallel Muon on each group.
1120
 
1121
  placement_to_params = defaultdict(lambda: ([], []))
1122
  # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1123
 
1124
+ assert len(dtensors) == len(names)
1125
+ for p, n in zip(dtensors, names):
1126
  placement_to_params[tuple([p.placements,
1127
  p.device_mesh])][0].append(n)
1128
  placement_to_params[tuple([p.placements,
1129
  p.device_mesh])][1].append(p)
1130
+ return placement_to_params
1131
+
1132
+ if len(param_dtensors_small) > 0:
1133
+ if not dist.is_initialized():
1134
+ raise RuntimeError(
1135
+ "Parallel Muon requires torch.distributed to be initialized."
1136
+ )
1137
+
1138
+ self.distributed_muon(
1139
+ params=param_dtensors_small,
1140
+ names=name_dtensors_small,
1141
+ group=group,
1142
+ lr=lr,
1143
+ weight_decay=weight_decay,
1144
+ momentum=momentum,
1145
+ qk_logits=qk_logits,
1146
+ )
1147
+
1148
+ if len(param_dtensors) > 0:
1149
+ if not dist.is_initialized():
1150
+ raise RuntimeError(
1151
+ "Parallel Muon requires torch.distributed to be initialized."
1152
+ )
1153
 
1154
+ dtensor_group = group_dtensors(param_dtensors, name_dtensors)
1155
+ for _, (names, params) in dtensor_group.items():
1156
  self.parallel(
1157
  names,
1158
  params,
 
1242
  for params in placement_to_params.values():
1243
  self._step_adamw_params(params, group)
1244
 
1245
+ @torch.no_grad
1246
  def step(self, closure=None, qk_logits=None):
1247
  """Perform a single optimization step.
1248
 
build/torch210-cxx11-cu126-x86_64-linux/optimizer/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import sys
3
+
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch210-cxx11-cu128-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .muon import Muon
2
+
3
+ __all__ = [
4
+ "Muon",
5
+ ]
build/{torch28-cxx11-cu128-x86_64-linux/optimizer β†’ torch210-cxx11-cu128-x86_64-linux}/_ops.py RENAMED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_23d68bb_dirty
3
- ops = torch.ops._optimizer_23d68bb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_23d68bb_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_06a260a_dirty
3
+ ops = torch.ops._optimizer_06a260a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_06a260a_dirty::{op_name}"
build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so β†’ torch210-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:35708a107d9ac807fa3e63bbacfc6234fd7622a689a79eae3e43fce11f85d3da
3
- size 1924376
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:976df6a1ec3ec4c462dea18477b56dfb75bcff76f504d55b592ce417931597c0
3
+ size 2004144
build/{torch28-cxx11-cu129-x86_64-linux/optimizer β†’ torch210-cxx11-cu128-x86_64-linux}/distributed/utils.py RENAMED
@@ -50,7 +50,7 @@ def get_slices_of_dtensor(
50
  raise NotImplementedError(
51
  f"Dimension size {dim_size} is not divisible "
52
  f"by number of ranks {num_ranks} for shard "
53
- f"placement on dim {dim}.")
54
 
55
  shard_size = dim_size // num_ranks
56
 
@@ -64,7 +64,8 @@ def get_slices_of_dtensor(
64
  return tuple(slices)
65
 
66
 
67
- _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict()
 
68
 
69
 
70
  def construct_shard_mesh(
 
50
  raise NotImplementedError(
51
  f"Dimension size {dim_size} is not divisible "
52
  f"by number of ranks {num_ranks} for shard "
53
+ f"placement on dim {dim}. (shape: {target.shape})")
54
 
55
  shard_size = dim_size // num_ranks
56
 
 
64
  return tuple(slices)
65
 
66
 
67
+ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
68
+ ProcessGroup]] = dict()
69
 
70
 
71
  def construct_shard_mesh(
build/{torch28-cxx11-cu128-x86_64-linux/optimizer β†’ torch210-cxx11-cu128-x86_64-linux}/matmul_transpose_triton.py RENAMED
File without changes
build/torch210-cxx11-cu128-x86_64-linux/metadata.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"python-depends":[]}
build/{torch28-cxx11-cu129-x86_64-linux/optimizer β†’ torch210-cxx11-cu128-x86_64-linux}/muon.py RENAMED
@@ -583,6 +583,7 @@ class Muon(torch.optim.Optimizer):
583
  Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
  For testing purpose only.
 
586
  """
587
 
588
  def __init__(self,
@@ -604,7 +605,8 @@ class Muon(torch.optim.Optimizer):
604
  },
605
  warmup_step=5,
606
  chunk_size=-1,
607
- use_distributed_muon=False):
 
608
  defaults = dict(
609
  lr=lr,
610
  weight_decay=weight_decay,
@@ -637,6 +639,7 @@ class Muon(torch.optim.Optimizer):
637
  self.warmup_step = warmup_step
638
  self.chunk_size = chunk_size
639
  self.use_distributed_muon = use_distributed_muon
 
640
 
641
  def _calc_flops(self, G, steps):
642
  assert len(G.shape) == 2
@@ -745,16 +748,7 @@ class Muon(torch.optim.Optimizer):
745
  g = g.view(g.size(0), -1)
746
  assert g is not None
747
 
748
- # calc update
749
- state = self.state[p]
750
- if "momentum_buffer" not in state:
751
- state["momentum_buffer"] = torch.zeros_like(g)
752
- buf = state["momentum_buffer"]
753
- buf.mul_(momentum).add_(g)
754
- if group["nesterov"]:
755
- g = g.add(buf, alpha=momentum)
756
- else:
757
- g = buf
758
 
759
  u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
760
  steps=group["ns_steps"])
@@ -780,14 +774,6 @@ class Muon(torch.optim.Optimizer):
780
  qk_logits: list[torch.Tensor | DTensor] | None,
781
  ):
782
  """ Implementation of Distributed Muon by Liu et al. """
783
- if qk_logits is not None:
784
- raise NotImplementedError("QK clipping is not supported yet")
785
-
786
- if isinstance(params[0], DTensor):
787
- shard_mesh, _, shard_placements = construct_shard_mesh(
788
- placements=params[0].placements,
789
- mesh=params[0].device_mesh,
790
- )
791
 
792
  for n, p in zip(names, params):
793
  g = p.grad
@@ -797,39 +783,44 @@ class Muon(torch.optim.Optimizer):
797
  g = g.view(g.size(0), -1)
798
  assert g is not None
799
 
800
- # calc update
801
- state = self.state[p]
802
- if "momentum_buffer" not in state:
803
- state["momentum_buffer"] = torch.zeros_like(g)
804
- buf = state["momentum_buffer"]
805
- buf.mul_(momentum).add_(g)
806
- if group["nesterov"]:
807
- g = g.add(buf, alpha=momentum)
808
- else:
809
- g = buf
810
 
811
  # Gather G
812
  if isinstance(p.data, DTensor):
813
- g = g.full_tensor()
814
- u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
815
- steps=group["ns_steps"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
816
 
817
  if isinstance(p.data, DTensor):
818
- slices = get_slices_of_dtensor(
819
- target=p,
820
- local_rank=dist.get_rank(),
821
- shard_mesh=shard_mesh,
822
- shard_placements=shard_placements,
823
  )
824
- u_shard = u[slices]
825
- u = DTensor.from_local(
826
- u_shard,
827
  device_mesh=p.device_mesh,
828
  placements=p.placements,
829
  )
830
 
831
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
832
- Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
833
 
834
  def _update_g(self, p, g, group, momentum):
835
  # calc update
@@ -843,10 +834,14 @@ class Muon(torch.optim.Optimizer):
843
 
844
  @staticmethod
845
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
846
- # apply weight decay
847
- p.data.mul_(1 - lr * weight_decay)
848
- # apply update
849
- p.data.add_(u, alpha=-adjusted_lr)
 
 
 
 
850
 
851
  def get_qk_clip_info(self, n, qk_logits):
852
  if self.clip_config is None:
@@ -903,8 +898,12 @@ class Muon(torch.optim.Optimizer):
903
 
904
  @staticmethod
905
  def _qk_clip(p, scales, head_dim):
906
- W = p.data.view(-1, head_dim, p.data.shape[1])
907
- W.mul_(scales.view(-1, 1, 1))
 
 
 
 
908
 
909
  def parallel(self, names, params, group, lr, weight_decay, momentum,
910
  qk_logits):
@@ -1070,10 +1069,14 @@ class Muon(torch.optim.Optimizer):
1070
  names = group["names"]
1071
 
1072
  param_dtensors = []
1073
- param_tensors = []
1074
  name_dtensors = []
 
 
1075
  name_tensors = []
1076
 
 
 
 
1077
  if self.use_distributed_muon:
1078
  self.distributed_muon(names=names,
1079
  params=params,
@@ -1084,6 +1087,8 @@ class Muon(torch.optim.Optimizer):
1084
  qk_logits=qk_logits)
1085
  return
1086
 
 
 
1087
  for n, p in zip(names, params):
1088
  if p is None or p.grad is None:
1089
  continue
@@ -1093,6 +1098,9 @@ class Muon(torch.optim.Optimizer):
1093
  for placement in p.placements):
1094
  param_tensors.append(p)
1095
  name_tensors.append(n)
 
 
 
1096
  else:
1097
  param_dtensors.append(p)
1098
  name_dtensors.append(n)
@@ -1103,29 +1111,48 @@ class Muon(torch.optim.Optimizer):
1103
  raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1104
 
1105
  logger.debug(
1106
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors"
1107
- )
1108
-
1109
- if len(param_dtensors) > 0:
1110
- if not dist.is_initialized():
1111
- raise RuntimeError(
1112
- "Parallel Muon requires torch.distributed to be initialized."
1113
- )
1114
 
 
1115
  # To support different placements, we group parameters by placements
1116
  # and run parallel Muon on each group.
1117
 
1118
  placement_to_params = defaultdict(lambda: ([], []))
1119
  # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1120
 
1121
- assert len(name_dtensors) == len(param_dtensors)
1122
- for n, p in zip(name_dtensors, param_dtensors):
1123
  placement_to_params[tuple([p.placements,
1124
  p.device_mesh])][0].append(n)
1125
  placement_to_params[tuple([p.placements,
1126
  p.device_mesh])][1].append(p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1127
 
1128
- for _, (names, params) in placement_to_params.items():
 
1129
  self.parallel(
1130
  names,
1131
  params,
@@ -1215,6 +1242,7 @@ class Muon(torch.optim.Optimizer):
1215
  for params in placement_to_params.values():
1216
  self._step_adamw_params(params, group)
1217
 
 
1218
  def step(self, closure=None, qk_logits=None):
1219
  """Perform a single optimization step.
1220
 
 
583
  Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
  For testing purpose only.
586
+ small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
587
  """
588
 
589
  def __init__(self,
 
605
  },
606
  warmup_step=5,
607
  chunk_size=-1,
608
+ use_distributed_muon=False,
609
+ small_param_numel_threshold=65536):
610
  defaults = dict(
611
  lr=lr,
612
  weight_decay=weight_decay,
 
639
  self.warmup_step = warmup_step
640
  self.chunk_size = chunk_size
641
  self.use_distributed_muon = use_distributed_muon
642
+ self.small_param_numel_threshold = small_param_numel_threshold
643
 
644
  def _calc_flops(self, G, steps):
645
  assert len(G.shape) == 2
 
748
  g = g.view(g.size(0), -1)
749
  assert g is not None
750
 
751
+ g = self._update_g(p, g, group, momentum)
 
 
 
 
 
 
 
 
 
752
 
753
  u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
754
  steps=group["ns_steps"])
 
774
  qk_logits: list[torch.Tensor | DTensor] | None,
775
  ):
776
  """ Implementation of Distributed Muon by Liu et al. """
 
 
 
 
 
 
 
 
777
 
778
  for n, p in zip(names, params):
779
  g = p.grad
 
783
  g = g.view(g.size(0), -1)
784
  assert g is not None
785
 
786
+ g = self._update_g(p, g, group, momentum)
 
 
 
 
 
 
 
 
 
787
 
788
  # Gather G
789
  if isinstance(p.data, DTensor):
790
+ g_full = g.full_tensor()
791
+ p_full = p.data.full_tensor()
792
+ else:
793
+ g_full = g
794
+ p_full = p
795
+
796
+ u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
797
+ steps=group["ns_steps"])
798
+
799
+ adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape)
800
+ Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
801
+
802
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
803
+
804
+ scales_full = self._compute_scales(
805
+ p_full, qk_clip_state) if qk_clip_state is not None else None
806
+
807
+ if scales_full is not None:
808
+ Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim)
809
 
810
  if isinstance(p.data, DTensor):
811
+ ndims = len(p.device_mesh.mesh.shape)
812
+ p_replicate = DTensor.from_local(
813
+ p_full,
814
+ device_mesh=p.device_mesh,
815
+ placements=[Replicate() for _ in range(ndims)],
816
  )
817
+
818
+ p_sharded = p_replicate.redistribute(
 
819
  device_mesh=p.device_mesh,
820
  placements=p.placements,
821
  )
822
 
823
+ p.copy_(p_sharded)
 
824
 
825
  def _update_g(self, p, g, group, momentum):
826
  # calc update
 
834
 
835
  @staticmethod
836
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
837
+ if isinstance(p, torch.nn.Parameter):
838
+ # apply weight decay
839
+ p.data.mul_(1 - lr * weight_decay)
840
+ # apply update
841
+ p.data.add_(u, alpha=-adjusted_lr)
842
+ else:
843
+ p.mul_(1 - lr * weight_decay)
844
+ p.add_(u, alpha=-adjusted_lr)
845
 
846
  def get_qk_clip_info(self, n, qk_logits):
847
  if self.clip_config is None:
 
898
 
899
  @staticmethod
900
  def _qk_clip(p, scales, head_dim):
901
+ if isinstance(p, torch.nn.Parameter):
902
+ W = p.data.view(-1, head_dim, p.data.shape[1])
903
+ W.mul_(scales.view(-1, 1, 1))
904
+ else:
905
+ W = p.view(-1, head_dim, p.shape[1])
906
+ W.mul_(scales.view(-1, 1, 1))
907
 
908
  def parallel(self, names, params, group, lr, weight_decay, momentum,
909
  qk_logits):
 
1069
  names = group["names"]
1070
 
1071
  param_dtensors = []
 
1072
  name_dtensors = []
1073
+
1074
+ param_tensors = []
1075
  name_tensors = []
1076
 
1077
+ param_dtensors_small = []
1078
+ name_dtensors_small = []
1079
+
1080
  if self.use_distributed_muon:
1081
  self.distributed_muon(names=names,
1082
  params=params,
 
1087
  qk_logits=qk_logits)
1088
  return
1089
 
1090
+ # For simplicity, we use distributed Muon for small parameters
1091
+ # whose number of elements is below a threshold.
1092
  for n, p in zip(names, params):
1093
  if p is None or p.grad is None:
1094
  continue
 
1098
  for placement in p.placements):
1099
  param_tensors.append(p)
1100
  name_tensors.append(n)
1101
+ elif p.data.numel() <= self.small_param_numel_threshold:
1102
+ param_dtensors_small.append(p)
1103
+ name_dtensors_small.append(n)
1104
  else:
1105
  param_dtensors.append(p)
1106
  name_dtensors.append(n)
 
1111
  raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1112
 
1113
  logger.debug(
1114
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, "
1115
+ f"{len(param_dtensors_small)} Small DTensors")
 
 
 
 
 
 
1116
 
1117
+ def group_dtensors(dtensors, names):
1118
  # To support different placements, we group parameters by placements
1119
  # and run parallel Muon on each group.
1120
 
1121
  placement_to_params = defaultdict(lambda: ([], []))
1122
  # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1123
 
1124
+ assert len(dtensors) == len(names)
1125
+ for p, n in zip(dtensors, names):
1126
  placement_to_params[tuple([p.placements,
1127
  p.device_mesh])][0].append(n)
1128
  placement_to_params[tuple([p.placements,
1129
  p.device_mesh])][1].append(p)
1130
+ return placement_to_params
1131
+
1132
+ if len(param_dtensors_small) > 0:
1133
+ if not dist.is_initialized():
1134
+ raise RuntimeError(
1135
+ "Parallel Muon requires torch.distributed to be initialized."
1136
+ )
1137
+
1138
+ self.distributed_muon(
1139
+ params=param_dtensors_small,
1140
+ names=name_dtensors_small,
1141
+ group=group,
1142
+ lr=lr,
1143
+ weight_decay=weight_decay,
1144
+ momentum=momentum,
1145
+ qk_logits=qk_logits,
1146
+ )
1147
+
1148
+ if len(param_dtensors) > 0:
1149
+ if not dist.is_initialized():
1150
+ raise RuntimeError(
1151
+ "Parallel Muon requires torch.distributed to be initialized."
1152
+ )
1153
 
1154
+ dtensor_group = group_dtensors(param_dtensors, name_dtensors)
1155
+ for _, (names, params) in dtensor_group.items():
1156
  self.parallel(
1157
  names,
1158
  params,
 
1242
  for params in placement_to_params.values():
1243
  self._step_adamw_params(params, group)
1244
 
1245
+ @torch.no_grad
1246
  def step(self, closure=None, qk_logits=None):
1247
  """Perform a single optimization step.
1248
 
build/torch210-cxx11-cu128-x86_64-linux/optimizer/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import sys
3
+
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch210-cxx11-cu130-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .muon import Muon
2
+
3
+ __all__ = [
4
+ "Muon",
5
+ ]
build/{torch28-cxx11-cu129-x86_64-linux/optimizer β†’ torch210-cxx11-cu130-x86_64-linux}/_ops.py RENAMED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_23d68bb_dirty
3
- ops = torch.ops._optimizer_23d68bb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_23d68bb_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_06a260a_dirty
3
+ ops = torch.ops._optimizer_06a260a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_06a260a_dirty::{op_name}"
build/{torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so β†’ torch210-cxx11-cu130-x86_64-linux/_optimizer_06a260a_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1cbcd3df518412314d547a86b947998802e488e8aec0f22bf8b59fbc2d1c91e8
3
- size 1983488
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:330aaa6cb247ba3b5df7a13ced6ef7eff3e5d7a72a0b88f674f948aeaed66ee2
3
+ size 2004728
build/{torch28-cxx11-cu126-x86_64-linux/optimizer β†’ torch210-cxx11-cu130-x86_64-linux}/distributed/utils.py RENAMED
@@ -50,7 +50,7 @@ def get_slices_of_dtensor(
50
  raise NotImplementedError(
51
  f"Dimension size {dim_size} is not divisible "
52
  f"by number of ranks {num_ranks} for shard "
53
- f"placement on dim {dim}.")
54
 
55
  shard_size = dim_size // num_ranks
56
 
@@ -64,7 +64,8 @@ def get_slices_of_dtensor(
64
  return tuple(slices)
65
 
66
 
67
- _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict()
 
68
 
69
 
70
  def construct_shard_mesh(
 
50
  raise NotImplementedError(
51
  f"Dimension size {dim_size} is not divisible "
52
  f"by number of ranks {num_ranks} for shard "
53
+ f"placement on dim {dim}. (shape: {target.shape})")
54
 
55
  shard_size = dim_size // num_ranks
56
 
 
64
  return tuple(slices)
65
 
66
 
67
+ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
68
+ ProcessGroup]] = dict()
69
 
70
 
71
  def construct_shard_mesh(
build/{torch28-cxx11-cu129-x86_64-linux/optimizer β†’ torch210-cxx11-cu130-x86_64-linux}/matmul_transpose_triton.py RENAMED
File without changes
build/torch210-cxx11-cu130-x86_64-linux/metadata.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"python-depends":[]}
build/{torch28-cxx11-rocm63-x86_64-linux/optimizer β†’ torch210-cxx11-cu130-x86_64-linux}/muon.py RENAMED
@@ -583,6 +583,7 @@ class Muon(torch.optim.Optimizer):
583
  Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
  For testing purpose only.
 
586
  """
587
 
588
  def __init__(self,
@@ -604,7 +605,8 @@ class Muon(torch.optim.Optimizer):
604
  },
605
  warmup_step=5,
606
  chunk_size=-1,
607
- use_distributed_muon=False):
 
608
  defaults = dict(
609
  lr=lr,
610
  weight_decay=weight_decay,
@@ -637,6 +639,7 @@ class Muon(torch.optim.Optimizer):
637
  self.warmup_step = warmup_step
638
  self.chunk_size = chunk_size
639
  self.use_distributed_muon = use_distributed_muon
 
640
 
641
  def _calc_flops(self, G, steps):
642
  assert len(G.shape) == 2
@@ -745,16 +748,7 @@ class Muon(torch.optim.Optimizer):
745
  g = g.view(g.size(0), -1)
746
  assert g is not None
747
 
748
- # calc update
749
- state = self.state[p]
750
- if "momentum_buffer" not in state:
751
- state["momentum_buffer"] = torch.zeros_like(g)
752
- buf = state["momentum_buffer"]
753
- buf.mul_(momentum).add_(g)
754
- if group["nesterov"]:
755
- g = g.add(buf, alpha=momentum)
756
- else:
757
- g = buf
758
 
759
  u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
760
  steps=group["ns_steps"])
@@ -780,14 +774,6 @@ class Muon(torch.optim.Optimizer):
780
  qk_logits: list[torch.Tensor | DTensor] | None,
781
  ):
782
  """ Implementation of Distributed Muon by Liu et al. """
783
- if qk_logits is not None:
784
- raise NotImplementedError("QK clipping is not supported yet")
785
-
786
- if isinstance(params[0], DTensor):
787
- shard_mesh, _, shard_placements = construct_shard_mesh(
788
- placements=params[0].placements,
789
- mesh=params[0].device_mesh,
790
- )
791
 
792
  for n, p in zip(names, params):
793
  g = p.grad
@@ -797,39 +783,44 @@ class Muon(torch.optim.Optimizer):
797
  g = g.view(g.size(0), -1)
798
  assert g is not None
799
 
800
- # calc update
801
- state = self.state[p]
802
- if "momentum_buffer" not in state:
803
- state["momentum_buffer"] = torch.zeros_like(g)
804
- buf = state["momentum_buffer"]
805
- buf.mul_(momentum).add_(g)
806
- if group["nesterov"]:
807
- g = g.add(buf, alpha=momentum)
808
- else:
809
- g = buf
810
 
811
  # Gather G
812
  if isinstance(p.data, DTensor):
813
- g = g.full_tensor()
814
- u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
815
- steps=group["ns_steps"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
816
 
817
  if isinstance(p.data, DTensor):
818
- slices = get_slices_of_dtensor(
819
- target=p,
820
- local_rank=dist.get_rank(),
821
- shard_mesh=shard_mesh,
822
- shard_placements=shard_placements,
823
  )
824
- u_shard = u[slices]
825
- u = DTensor.from_local(
826
- u_shard,
827
  device_mesh=p.device_mesh,
828
  placements=p.placements,
829
  )
830
 
831
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
832
- Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
833
 
834
  def _update_g(self, p, g, group, momentum):
835
  # calc update
@@ -843,10 +834,14 @@ class Muon(torch.optim.Optimizer):
843
 
844
  @staticmethod
845
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
846
- # apply weight decay
847
- p.data.mul_(1 - lr * weight_decay)
848
- # apply update
849
- p.data.add_(u, alpha=-adjusted_lr)
 
 
 
 
850
 
851
  def get_qk_clip_info(self, n, qk_logits):
852
  if self.clip_config is None:
@@ -903,8 +898,12 @@ class Muon(torch.optim.Optimizer):
903
 
904
  @staticmethod
905
  def _qk_clip(p, scales, head_dim):
906
- W = p.data.view(-1, head_dim, p.data.shape[1])
907
- W.mul_(scales.view(-1, 1, 1))
 
 
 
 
908
 
909
  def parallel(self, names, params, group, lr, weight_decay, momentum,
910
  qk_logits):
@@ -1070,10 +1069,14 @@ class Muon(torch.optim.Optimizer):
1070
  names = group["names"]
1071
 
1072
  param_dtensors = []
1073
- param_tensors = []
1074
  name_dtensors = []
 
 
1075
  name_tensors = []
1076
 
 
 
 
1077
  if self.use_distributed_muon:
1078
  self.distributed_muon(names=names,
1079
  params=params,
@@ -1084,6 +1087,8 @@ class Muon(torch.optim.Optimizer):
1084
  qk_logits=qk_logits)
1085
  return
1086
 
 
 
1087
  for n, p in zip(names, params):
1088
  if p is None or p.grad is None:
1089
  continue
@@ -1093,6 +1098,9 @@ class Muon(torch.optim.Optimizer):
1093
  for placement in p.placements):
1094
  param_tensors.append(p)
1095
  name_tensors.append(n)
 
 
 
1096
  else:
1097
  param_dtensors.append(p)
1098
  name_dtensors.append(n)
@@ -1103,29 +1111,48 @@ class Muon(torch.optim.Optimizer):
1103
  raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1104
 
1105
  logger.debug(
1106
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors"
1107
- )
1108
-
1109
- if len(param_dtensors) > 0:
1110
- if not dist.is_initialized():
1111
- raise RuntimeError(
1112
- "Parallel Muon requires torch.distributed to be initialized."
1113
- )
1114
 
 
1115
  # To support different placements, we group parameters by placements
1116
  # and run parallel Muon on each group.
1117
 
1118
  placement_to_params = defaultdict(lambda: ([], []))
1119
  # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1120
 
1121
- assert len(name_dtensors) == len(param_dtensors)
1122
- for n, p in zip(name_dtensors, param_dtensors):
1123
  placement_to_params[tuple([p.placements,
1124
  p.device_mesh])][0].append(n)
1125
  placement_to_params[tuple([p.placements,
1126
  p.device_mesh])][1].append(p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1127
 
1128
- for _, (names, params) in placement_to_params.items():
 
1129
  self.parallel(
1130
  names,
1131
  params,
@@ -1215,6 +1242,7 @@ class Muon(torch.optim.Optimizer):
1215
  for params in placement_to_params.values():
1216
  self._step_adamw_params(params, group)
1217
 
 
1218
  def step(self, closure=None, qk_logits=None):
1219
  """Perform a single optimization step.
1220
 
 
583
  Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
  For testing purpose only.
586
+ small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
587
  """
588
 
589
  def __init__(self,
 
605
  },
606
  warmup_step=5,
607
  chunk_size=-1,
608
+ use_distributed_muon=False,
609
+ small_param_numel_threshold=65536):
610
  defaults = dict(
611
  lr=lr,
612
  weight_decay=weight_decay,
 
639
  self.warmup_step = warmup_step
640
  self.chunk_size = chunk_size
641
  self.use_distributed_muon = use_distributed_muon
642
+ self.small_param_numel_threshold = small_param_numel_threshold
643
 
644
  def _calc_flops(self, G, steps):
645
  assert len(G.shape) == 2
 
748
  g = g.view(g.size(0), -1)
749
  assert g is not None
750
 
751
+ g = self._update_g(p, g, group, momentum)
 
 
 
 
 
 
 
 
 
752
 
753
  u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
754
  steps=group["ns_steps"])
 
774
  qk_logits: list[torch.Tensor | DTensor] | None,
775
  ):
776
  """ Implementation of Distributed Muon by Liu et al. """
 
 
 
 
 
 
 
 
777
 
778
  for n, p in zip(names, params):
779
  g = p.grad
 
783
  g = g.view(g.size(0), -1)
784
  assert g is not None
785
 
786
+ g = self._update_g(p, g, group, momentum)
 
 
 
 
 
 
 
 
 
787
 
788
  # Gather G
789
  if isinstance(p.data, DTensor):
790
+ g_full = g.full_tensor()
791
+ p_full = p.data.full_tensor()
792
+ else:
793
+ g_full = g
794
+ p_full = p
795
+
796
+ u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
797
+ steps=group["ns_steps"])
798
+
799
+ adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape)
800
+ Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
801
+
802
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
803
+
804
+ scales_full = self._compute_scales(
805
+ p_full, qk_clip_state) if qk_clip_state is not None else None
806
+
807
+ if scales_full is not None:
808
+ Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim)
809
 
810
  if isinstance(p.data, DTensor):
811
+ ndims = len(p.device_mesh.mesh.shape)
812
+ p_replicate = DTensor.from_local(
813
+ p_full,
814
+ device_mesh=p.device_mesh,
815
+ placements=[Replicate() for _ in range(ndims)],
816
  )
817
+
818
+ p_sharded = p_replicate.redistribute(
 
819
  device_mesh=p.device_mesh,
820
  placements=p.placements,
821
  )
822
 
823
+ p.copy_(p_sharded)
 
824
 
825
  def _update_g(self, p, g, group, momentum):
826
  # calc update
 
834
 
835
  @staticmethod
836
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
837
+ if isinstance(p, torch.nn.Parameter):
838
+ # apply weight decay
839
+ p.data.mul_(1 - lr * weight_decay)
840
+ # apply update
841
+ p.data.add_(u, alpha=-adjusted_lr)
842
+ else:
843
+ p.mul_(1 - lr * weight_decay)
844
+ p.add_(u, alpha=-adjusted_lr)
845
 
846
  def get_qk_clip_info(self, n, qk_logits):
847
  if self.clip_config is None:
 
898
 
899
  @staticmethod
900
  def _qk_clip(p, scales, head_dim):
901
+ if isinstance(p, torch.nn.Parameter):
902
+ W = p.data.view(-1, head_dim, p.data.shape[1])
903
+ W.mul_(scales.view(-1, 1, 1))
904
+ else:
905
+ W = p.view(-1, head_dim, p.shape[1])
906
+ W.mul_(scales.view(-1, 1, 1))
907
 
908
  def parallel(self, names, params, group, lr, weight_decay, momentum,
909
  qk_logits):
 
1069
  names = group["names"]
1070
 
1071
  param_dtensors = []
 
1072
  name_dtensors = []
1073
+
1074
+ param_tensors = []
1075
  name_tensors = []
1076
 
1077
+ param_dtensors_small = []
1078
+ name_dtensors_small = []
1079
+
1080
  if self.use_distributed_muon:
1081
  self.distributed_muon(names=names,
1082
  params=params,
 
1087
  qk_logits=qk_logits)
1088
  return
1089
 
1090
+ # For simplicity, we use distributed Muon for small parameters
1091
+ # whose number of elements is below a threshold.
1092
  for n, p in zip(names, params):
1093
  if p is None or p.grad is None:
1094
  continue
 
1098
  for placement in p.placements):
1099
  param_tensors.append(p)
1100
  name_tensors.append(n)
1101
+ elif p.data.numel() <= self.small_param_numel_threshold:
1102
+ param_dtensors_small.append(p)
1103
+ name_dtensors_small.append(n)
1104
  else:
1105
  param_dtensors.append(p)
1106
  name_dtensors.append(n)
 
1111
  raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1112
 
1113
  logger.debug(
1114
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, "
1115
+ f"{len(param_dtensors_small)} Small DTensors")
 
 
 
 
 
 
1116
 
1117
+ def group_dtensors(dtensors, names):
1118
  # To support different placements, we group parameters by placements
1119
  # and run parallel Muon on each group.
1120
 
1121
  placement_to_params = defaultdict(lambda: ([], []))
1122
  # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1123
 
1124
+ assert len(dtensors) == len(names)
1125
+ for p, n in zip(dtensors, names):
1126
  placement_to_params[tuple([p.placements,
1127
  p.device_mesh])][0].append(n)
1128
  placement_to_params[tuple([p.placements,
1129
  p.device_mesh])][1].append(p)
1130
+ return placement_to_params
1131
+
1132
+ if len(param_dtensors_small) > 0:
1133
+ if not dist.is_initialized():
1134
+ raise RuntimeError(
1135
+ "Parallel Muon requires torch.distributed to be initialized."
1136
+ )
1137
+
1138
+ self.distributed_muon(
1139
+ params=param_dtensors_small,
1140
+ names=name_dtensors_small,
1141
+ group=group,
1142
+ lr=lr,
1143
+ weight_decay=weight_decay,
1144
+ momentum=momentum,
1145
+ qk_logits=qk_logits,
1146
+ )
1147
+
1148
+ if len(param_dtensors) > 0:
1149
+ if not dist.is_initialized():
1150
+ raise RuntimeError(
1151
+ "Parallel Muon requires torch.distributed to be initialized."
1152
+ )
1153
 
1154
+ dtensor_group = group_dtensors(param_dtensors, name_dtensors)
1155
+ for _, (names, params) in dtensor_group.items():
1156
  self.parallel(
1157
  names,
1158
  params,
 
1242
  for params in placement_to_params.values():
1243
  self._step_adamw_params(params, group)
1244
 
1245
+ @torch.no_grad
1246
  def step(self, closure=None, qk_logits=None):
1247
  """Perform a single optimization step.
1248
 
build/torch210-cxx11-cu130-x86_64-linux/optimizer/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import sys
3
+
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch210-cxx11-rocm70-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .muon import Muon
2
+
3
+ __all__ = [
4
+ "Muon",
5
+ ]
build/{torch28-cxx11-rocm63-x86_64-linux/optimizer β†’ torch210-cxx11-rocm70-x86_64-linux}/_ops.py RENAMED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_23d68bb_dirty
3
- ops = torch.ops._optimizer_23d68bb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_23d68bb_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_06a260a_dirty
3
+ ops = torch.ops._optimizer_06a260a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_06a260a_dirty::{op_name}"
build/{torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so β†’ torch210-cxx11-rocm70-x86_64-linux/_optimizer_06a260a_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8a2999010ee158e13e3ef247e877dfab073b5bde7babefe2b2b5273b760c7ddf
3
- size 1852152
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3562c68e8ee85fc5b268e079150ffff69d52860092d59e44fb9b3c4526c5d497
3
+ size 1866400
build/{torch28-cxx11-rocm63-x86_64-linux/optimizer β†’ torch210-cxx11-rocm70-x86_64-linux}/distributed/utils.py RENAMED
@@ -50,7 +50,7 @@ def get_slices_of_dtensor(
50
  raise NotImplementedError(
51
  f"Dimension size {dim_size} is not divisible "
52
  f"by number of ranks {num_ranks} for shard "
53
- f"placement on dim {dim}.")
54
 
55
  shard_size = dim_size // num_ranks
56
 
@@ -64,7 +64,8 @@ def get_slices_of_dtensor(
64
  return tuple(slices)
65
 
66
 
67
- _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict()
 
68
 
69
 
70
  def construct_shard_mesh(
 
50
  raise NotImplementedError(
51
  f"Dimension size {dim_size} is not divisible "
52
  f"by number of ranks {num_ranks} for shard "
53
+ f"placement on dim {dim}. (shape: {target.shape})")
54
 
55
  shard_size = dim_size // num_ranks
56
 
 
64
  return tuple(slices)
65
 
66
 
67
+ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
68
+ ProcessGroup]] = dict()
69
 
70
 
71
  def construct_shard_mesh(
build/{torch28-cxx11-rocm63-x86_64-linux/optimizer β†’ torch210-cxx11-rocm70-x86_64-linux}/matmul_transpose_triton.py RENAMED
File without changes
build/torch210-cxx11-rocm70-x86_64-linux/metadata.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"python-depends":[]}
build/{torch28-cxx11-cu126-x86_64-linux/optimizer β†’ torch210-cxx11-rocm70-x86_64-linux}/muon.py RENAMED
@@ -583,6 +583,7 @@ class Muon(torch.optim.Optimizer):
583
  Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
  For testing purpose only.
 
586
  """
587
 
588
  def __init__(self,
@@ -604,7 +605,8 @@ class Muon(torch.optim.Optimizer):
604
  },
605
  warmup_step=5,
606
  chunk_size=-1,
607
- use_distributed_muon=False):
 
608
  defaults = dict(
609
  lr=lr,
610
  weight_decay=weight_decay,
@@ -637,6 +639,7 @@ class Muon(torch.optim.Optimizer):
637
  self.warmup_step = warmup_step
638
  self.chunk_size = chunk_size
639
  self.use_distributed_muon = use_distributed_muon
 
640
 
641
  def _calc_flops(self, G, steps):
642
  assert len(G.shape) == 2
@@ -745,16 +748,7 @@ class Muon(torch.optim.Optimizer):
745
  g = g.view(g.size(0), -1)
746
  assert g is not None
747
 
748
- # calc update
749
- state = self.state[p]
750
- if "momentum_buffer" not in state:
751
- state["momentum_buffer"] = torch.zeros_like(g)
752
- buf = state["momentum_buffer"]
753
- buf.mul_(momentum).add_(g)
754
- if group["nesterov"]:
755
- g = g.add(buf, alpha=momentum)
756
- else:
757
- g = buf
758
 
759
  u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
760
  steps=group["ns_steps"])
@@ -780,14 +774,6 @@ class Muon(torch.optim.Optimizer):
780
  qk_logits: list[torch.Tensor | DTensor] | None,
781
  ):
782
  """ Implementation of Distributed Muon by Liu et al. """
783
- if qk_logits is not None:
784
- raise NotImplementedError("QK clipping is not supported yet")
785
-
786
- if isinstance(params[0], DTensor):
787
- shard_mesh, _, shard_placements = construct_shard_mesh(
788
- placements=params[0].placements,
789
- mesh=params[0].device_mesh,
790
- )
791
 
792
  for n, p in zip(names, params):
793
  g = p.grad
@@ -797,39 +783,44 @@ class Muon(torch.optim.Optimizer):
797
  g = g.view(g.size(0), -1)
798
  assert g is not None
799
 
800
- # calc update
801
- state = self.state[p]
802
- if "momentum_buffer" not in state:
803
- state["momentum_buffer"] = torch.zeros_like(g)
804
- buf = state["momentum_buffer"]
805
- buf.mul_(momentum).add_(g)
806
- if group["nesterov"]:
807
- g = g.add(buf, alpha=momentum)
808
- else:
809
- g = buf
810
 
811
  # Gather G
812
  if isinstance(p.data, DTensor):
813
- g = g.full_tensor()
814
- u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
815
- steps=group["ns_steps"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
816
 
817
  if isinstance(p.data, DTensor):
818
- slices = get_slices_of_dtensor(
819
- target=p,
820
- local_rank=dist.get_rank(),
821
- shard_mesh=shard_mesh,
822
- shard_placements=shard_placements,
823
  )
824
- u_shard = u[slices]
825
- u = DTensor.from_local(
826
- u_shard,
827
  device_mesh=p.device_mesh,
828
  placements=p.placements,
829
  )
830
 
831
- adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
832
- Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
833
 
834
  def _update_g(self, p, g, group, momentum):
835
  # calc update
@@ -843,10 +834,14 @@ class Muon(torch.optim.Optimizer):
843
 
844
  @staticmethod
845
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
846
- # apply weight decay
847
- p.data.mul_(1 - lr * weight_decay)
848
- # apply update
849
- p.data.add_(u, alpha=-adjusted_lr)
 
 
 
 
850
 
851
  def get_qk_clip_info(self, n, qk_logits):
852
  if self.clip_config is None:
@@ -903,8 +898,12 @@ class Muon(torch.optim.Optimizer):
903
 
904
  @staticmethod
905
  def _qk_clip(p, scales, head_dim):
906
- W = p.data.view(-1, head_dim, p.data.shape[1])
907
- W.mul_(scales.view(-1, 1, 1))
 
 
 
 
908
 
909
  def parallel(self, names, params, group, lr, weight_decay, momentum,
910
  qk_logits):
@@ -1070,10 +1069,14 @@ class Muon(torch.optim.Optimizer):
1070
  names = group["names"]
1071
 
1072
  param_dtensors = []
1073
- param_tensors = []
1074
  name_dtensors = []
 
 
1075
  name_tensors = []
1076
 
 
 
 
1077
  if self.use_distributed_muon:
1078
  self.distributed_muon(names=names,
1079
  params=params,
@@ -1084,6 +1087,8 @@ class Muon(torch.optim.Optimizer):
1084
  qk_logits=qk_logits)
1085
  return
1086
 
 
 
1087
  for n, p in zip(names, params):
1088
  if p is None or p.grad is None:
1089
  continue
@@ -1093,6 +1098,9 @@ class Muon(torch.optim.Optimizer):
1093
  for placement in p.placements):
1094
  param_tensors.append(p)
1095
  name_tensors.append(n)
 
 
 
1096
  else:
1097
  param_dtensors.append(p)
1098
  name_dtensors.append(n)
@@ -1103,29 +1111,48 @@ class Muon(torch.optim.Optimizer):
1103
  raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1104
 
1105
  logger.debug(
1106
- f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors"
1107
- )
1108
-
1109
- if len(param_dtensors) > 0:
1110
- if not dist.is_initialized():
1111
- raise RuntimeError(
1112
- "Parallel Muon requires torch.distributed to be initialized."
1113
- )
1114
 
 
1115
  # To support different placements, we group parameters by placements
1116
  # and run parallel Muon on each group.
1117
 
1118
  placement_to_params = defaultdict(lambda: ([], []))
1119
  # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1120
 
1121
- assert len(name_dtensors) == len(param_dtensors)
1122
- for n, p in zip(name_dtensors, param_dtensors):
1123
  placement_to_params[tuple([p.placements,
1124
  p.device_mesh])][0].append(n)
1125
  placement_to_params[tuple([p.placements,
1126
  p.device_mesh])][1].append(p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1127
 
1128
- for _, (names, params) in placement_to_params.items():
 
1129
  self.parallel(
1130
  names,
1131
  params,
@@ -1215,6 +1242,7 @@ class Muon(torch.optim.Optimizer):
1215
  for params in placement_to_params.values():
1216
  self._step_adamw_params(params, group)
1217
 
 
1218
  def step(self, closure=None, qk_logits=None):
1219
  """Perform a single optimization step.
1220
 
 
583
  Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
  use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
  For testing purpose only.
586
+ small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
587
  """
588
 
589
  def __init__(self,
 
605
  },
606
  warmup_step=5,
607
  chunk_size=-1,
608
+ use_distributed_muon=False,
609
+ small_param_numel_threshold=65536):
610
  defaults = dict(
611
  lr=lr,
612
  weight_decay=weight_decay,
 
639
  self.warmup_step = warmup_step
640
  self.chunk_size = chunk_size
641
  self.use_distributed_muon = use_distributed_muon
642
+ self.small_param_numel_threshold = small_param_numel_threshold
643
 
644
  def _calc_flops(self, G, steps):
645
  assert len(G.shape) == 2
 
748
  g = g.view(g.size(0), -1)
749
  assert g is not None
750
 
751
+ g = self._update_g(p, g, group, momentum)
 
 
 
 
 
 
 
 
 
752
 
753
  u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
754
  steps=group["ns_steps"])
 
774
  qk_logits: list[torch.Tensor | DTensor] | None,
775
  ):
776
  """ Implementation of Distributed Muon by Liu et al. """
 
 
 
 
 
 
 
 
777
 
778
  for n, p in zip(names, params):
779
  g = p.grad
 
783
  g = g.view(g.size(0), -1)
784
  assert g is not None
785
 
786
+ g = self._update_g(p, g, group, momentum)
 
 
 
 
 
 
 
 
 
787
 
788
  # Gather G
789
  if isinstance(p.data, DTensor):
790
+ g_full = g.full_tensor()
791
+ p_full = p.data.full_tensor()
792
+ else:
793
+ g_full = g
794
+ p_full = p
795
+
796
+ u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
797
+ steps=group["ns_steps"])
798
+
799
+ adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape)
800
+ Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
801
+
802
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
803
+
804
+ scales_full = self._compute_scales(
805
+ p_full, qk_clip_state) if qk_clip_state is not None else None
806
+
807
+ if scales_full is not None:
808
+ Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim)
809
 
810
  if isinstance(p.data, DTensor):
811
+ ndims = len(p.device_mesh.mesh.shape)
812
+ p_replicate = DTensor.from_local(
813
+ p_full,
814
+ device_mesh=p.device_mesh,
815
+ placements=[Replicate() for _ in range(ndims)],
816
  )
817
+
818
+ p_sharded = p_replicate.redistribute(
 
819
  device_mesh=p.device_mesh,
820
  placements=p.placements,
821
  )
822
 
823
+ p.copy_(p_sharded)
 
824
 
825
  def _update_g(self, p, g, group, momentum):
826
  # calc update
 
834
 
835
  @staticmethod
836
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
837
+ if isinstance(p, torch.nn.Parameter):
838
+ # apply weight decay
839
+ p.data.mul_(1 - lr * weight_decay)
840
+ # apply update
841
+ p.data.add_(u, alpha=-adjusted_lr)
842
+ else:
843
+ p.mul_(1 - lr * weight_decay)
844
+ p.add_(u, alpha=-adjusted_lr)
845
 
846
  def get_qk_clip_info(self, n, qk_logits):
847
  if self.clip_config is None:
 
898
 
899
  @staticmethod
900
  def _qk_clip(p, scales, head_dim):
901
+ if isinstance(p, torch.nn.Parameter):
902
+ W = p.data.view(-1, head_dim, p.data.shape[1])
903
+ W.mul_(scales.view(-1, 1, 1))
904
+ else:
905
+ W = p.view(-1, head_dim, p.shape[1])
906
+ W.mul_(scales.view(-1, 1, 1))
907
 
908
  def parallel(self, names, params, group, lr, weight_decay, momentum,
909
  qk_logits):
 
1069
  names = group["names"]
1070
 
1071
  param_dtensors = []
 
1072
  name_dtensors = []
1073
+
1074
+ param_tensors = []
1075
  name_tensors = []
1076
 
1077
+ param_dtensors_small = []
1078
+ name_dtensors_small = []
1079
+
1080
  if self.use_distributed_muon:
1081
  self.distributed_muon(names=names,
1082
  params=params,
 
1087
  qk_logits=qk_logits)
1088
  return
1089
 
1090
+ # For simplicity, we use distributed Muon for small parameters
1091
+ # whose number of elements is below a threshold.
1092
  for n, p in zip(names, params):
1093
  if p is None or p.grad is None:
1094
  continue
 
1098
  for placement in p.placements):
1099
  param_tensors.append(p)
1100
  name_tensors.append(n)
1101
+ elif p.data.numel() <= self.small_param_numel_threshold:
1102
+ param_dtensors_small.append(p)
1103
+ name_dtensors_small.append(n)
1104
  else:
1105
  param_dtensors.append(p)
1106
  name_dtensors.append(n)
 
1111
  raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1112
 
1113
  logger.debug(
1114
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, "
1115
+ f"{len(param_dtensors_small)} Small DTensors")
 
 
 
 
 
 
1116
 
1117
+ def group_dtensors(dtensors, names):
1118
  # To support different placements, we group parameters by placements
1119
  # and run parallel Muon on each group.
1120
 
1121
  placement_to_params = defaultdict(lambda: ([], []))
1122
  # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1123
 
1124
+ assert len(dtensors) == len(names)
1125
+ for p, n in zip(dtensors, names):
1126
  placement_to_params[tuple([p.placements,
1127
  p.device_mesh])][0].append(n)
1128
  placement_to_params[tuple([p.placements,
1129
  p.device_mesh])][1].append(p)
1130
+ return placement_to_params
1131
+
1132
+ if len(param_dtensors_small) > 0:
1133
+ if not dist.is_initialized():
1134
+ raise RuntimeError(
1135
+ "Parallel Muon requires torch.distributed to be initialized."
1136
+ )
1137
+
1138
+ self.distributed_muon(
1139
+ params=param_dtensors_small,
1140
+ names=name_dtensors_small,
1141
+ group=group,
1142
+ lr=lr,
1143
+ weight_decay=weight_decay,
1144
+ momentum=momentum,
1145
+ qk_logits=qk_logits,
1146
+ )
1147
+
1148
+ if len(param_dtensors) > 0:
1149
+ if not dist.is_initialized():
1150
+ raise RuntimeError(
1151
+ "Parallel Muon requires torch.distributed to be initialized."
1152
+ )
1153
 
1154
+ dtensor_group = group_dtensors(param_dtensors, name_dtensors)
1155
+ for _, (names, params) in dtensor_group.items():
1156
  self.parallel(
1157
  names,
1158
  params,
 
1242
  for params in placement_to_params.values():
1243
  self._step_adamw_params(params, group)
1244
 
1245
+ @torch.no_grad
1246
  def step(self, closure=None, qk_logits=None):
1247
  """Perform a single optimization step.
1248
 
build/torch210-cxx11-rocm70-x86_64-linux/optimizer/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import sys
3
+
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch210-cxx11-rocm71-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .muon import Muon
2
+
3
+ __all__ = [
4
+ "Muon",
5
+ ]
build/torch210-cxx11-rocm71-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _optimizer_06a260a_dirty
3
+ ops = torch.ops._optimizer_06a260a_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_optimizer_06a260a_dirty::{op_name}"
build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_06a260a_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d804ba4d3ed9716c80e9819ba16a2bef300fb23fa4c456c550f4a96167a2eb00
3
+ size 1866112
build/torch210-cxx11-rocm71-x86_64-linux/distributed/utils.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from torch.distributed import ProcessGroup
4
+ from torch.distributed.device_mesh import DeviceMesh
5
+ from torch.distributed.tensor import DTensor
6
+ from torch.distributed.tensor.placement_types import (Placement, Shard,
7
+ _StridedShard)
8
+
9
+
10
+ def get_slices_of_dtensor(
11
+ target: DTensor | torch.Tensor,
12
+ local_rank: int,
13
+ shard_mesh: DeviceMesh,
14
+ shard_placements: tuple[Placement],
15
+ ) -> tuple[slice]:
16
+ """
17
+ Get the slice of local tensor for a given rank from a tensor.
18
+ Args:
19
+ target (DTensor | torch.Tensor): The target tensor.
20
+ rank (int): The local rank of the shard group.
21
+ shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks.
22
+ shard_placements (tuple[Placement]): The shard placements.
23
+ """
24
+
25
+ slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()]
26
+
27
+ # find the global rank of the local rank in the shard mesh
28
+ rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
29
+
30
+ rank_coords = (shard_mesh.mesh == rank).nonzero()
31
+
32
+ assert len(rank_coords) == 1
33
+ rank_coords = tuple(rank_coords[0].tolist())
34
+
35
+ assert len(rank_coords) == len(shard_placements)
36
+
37
+ # Caution: Assuming replicate-to-shard of the shard mesh goes with
38
+ # left-to-right sharding. This is ensured by the sorting logic of
39
+ # construct_shard_mesh function.
40
+ for i, (rank_coord,
41
+ placement) in enumerate(zip(rank_coords, shard_placements)):
42
+ assert isinstance(placement, Shard)
43
+
44
+ num_ranks = shard_mesh.mesh.shape[i]
45
+
46
+ dim = placement.dim
47
+ dim_size = (slices[dim].stop - slices[dim].start)
48
+
49
+ if dim_size % num_ranks != 0:
50
+ raise NotImplementedError(
51
+ f"Dimension size {dim_size} is not divisible "
52
+ f"by number of ranks {num_ranks} for shard "
53
+ f"placement on dim {dim}. (shape: {target.shape})")
54
+
55
+ shard_size = dim_size // num_ranks
56
+
57
+ start = slices[dim].start + rank_coord * shard_size
58
+ end = start + shard_size
59
+
60
+ assert start < end <= slices[dim].stop
61
+
62
+ slices[dim] = slice(start, end)
63
+
64
+ return tuple(slices)
65
+
66
+
67
+ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
68
+ ProcessGroup]] = dict()
69
+
70
+
71
+ def construct_shard_mesh(
72
+ placements: tuple[Placement],
73
+ mesh: DeviceMesh,
74
+ ) -> (DeviceMesh, ProcessGroup, tuple[Placement]):
75
+ """
76
+ Construct Shard Mesh and Placements for unsharding.
77
+ It removes Replicate placements and constructs a new Mesh and ProcessGroup.
78
+ """
79
+ my_rank = dist.get_rank()
80
+
81
+ assert mesh.mesh.device.type == 'cpu'
82
+
83
+ # Copy mesh to avoid modifying the original mesh
84
+ mesh = mesh.mesh.clone()
85
+
86
+ # 1. Sort placements. Replicate first, then Shard by dim ascending.
87
+
88
+ # For Shard, strided shard comes after regular shard on the same dim
89
+ # to preserve left-to-right order of replicate-to-shard.
90
+ # This is because that strided shard is using stride to represent
91
+ # more fine-grained sharding on the same dim.
92
+ # Please check the URL below for _StridedShard.
93
+ # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366
94
+
95
+ def placement_sort_key(
96
+ placement_with_index: tuple[float, Placement]
97
+ ) -> tuple[int, float, int]: # (dim, split factor, original index)
98
+ index, placement = placement_with_index
99
+ is_replicate = placement.is_replicate()
100
+ is_shard = placement.is_shard()
101
+ is_partial = placement.is_partial()
102
+
103
+ assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}"
104
+ assert not is_partial, "Partial placement is not supported."
105
+
106
+ if is_replicate:
107
+ return (-1.0, 0, index)
108
+ elif is_shard:
109
+ if isinstance(placement, _StridedShard):
110
+ return (placement.dim, 1 / placement.split_factor, index)
111
+ return (placement.dim, 0, index)
112
+ else:
113
+ raise TypeError(f"Unknown placement type: {type(placement)}")
114
+
115
+ placements_with_index: list[tuple[int,
116
+ Placement]] = list(enumerate(placements))
117
+ placements_with_index = sorted(placements_with_index,
118
+ key=placement_sort_key)
119
+
120
+ sorted_indices, sorted_placements = zip(*placements_with_index)
121
+
122
+ # 2. Permute mesh according to sorted placements.
123
+ sorted_mesh = mesh.permute(sorted_indices)
124
+
125
+ # 3. Collect list of shard meshes by removing replicate dims
126
+ # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)]
127
+ # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4)
128
+ num_replicates = sum(1 for p in sorted_placements if p.is_replicate())
129
+
130
+ # merge replicate dims
131
+ # shard_meshes became a list of shard meshes with a length of replicate degree
132
+ if num_replicates > 0:
133
+ sorted_mesh = sorted_mesh.flatten(
134
+ 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh
135
+ shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
136
+ else:
137
+ shard_meshes = [sorted_mesh]
138
+ shard_placements = sorted_placements[num_replicates:]
139
+
140
+ # assume all shard placements are different
141
+ assert len(shard_placements) == len(set(shard_placements))
142
+
143
+ # 4. Construct ProcessGroups
144
+ # Caution: all groups should be created in the same order in all processes,
145
+ # even though each process only needs its own group.
146
+
147
+ # To use tensor as dict key, convert it to tuple
148
+ def tensor_to_tuple(t):
149
+ if isinstance(t, torch.Tensor):
150
+ t = t.tolist()
151
+ if isinstance(t, list):
152
+ return tuple(tensor_to_tuple(x) for x in t)
153
+ return t
154
+
155
+ my_shard_mesh_as_tuple = None
156
+ for shard_mesh in shard_meshes:
157
+ assert isinstance(shard_mesh, torch.Tensor)
158
+ shard_mesh_as_tuple = tensor_to_tuple(shard_mesh)
159
+
160
+ if (my_rank == shard_mesh).any().item():
161
+ assert my_shard_mesh_as_tuple is None
162
+ my_shard_mesh_as_tuple = shard_mesh_as_tuple
163
+
164
+ # update global cache
165
+ if shard_mesh_as_tuple not in _ranks_to_dist_cache:
166
+ shard_process_group = dist.new_group(shard_mesh.flatten().tolist())
167
+ _ranks_to_dist_cache[shard_mesh_as_tuple] = (
168
+ DeviceMesh(device_type="cuda", mesh=shard_mesh),
169
+ shard_process_group,
170
+ )
171
+
172
+ my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[
173
+ my_shard_mesh_as_tuple]
174
+
175
+ return my_shard_mesh, my_shard_process_group, shard_placements
build/{torch28-cxx11-rocm64-x86_64-linux/optimizer β†’ torch210-cxx11-rocm71-x86_64-linux}/matmul_transpose_triton.py RENAMED
File without changes
build/torch210-cxx11-rocm71-x86_64-linux/metadata.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"python-depends":[]}
build/torch210-cxx11-rocm71-x86_64-linux/muon.py ADDED
@@ -0,0 +1,1268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import types
4
+ from collections import defaultdict
5
+ from dataclasses import dataclass
6
+ from typing import Any, cast
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ from torch.distributed import ProcessGroup
11
+ from torch.distributed.device_mesh import DeviceMesh
12
+ from torch.distributed.tensor import DTensor, Replicate
13
+ from torch.distributed.tensor.placement_types import Placement
14
+
15
+ from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor
16
+ from .matmul_transpose_triton import matmul_transpose_assign
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ COMM_DTYPE = torch.bfloat16
21
+ DEFAULT_CHUNK_SIZE_RATIO = 4
22
+
23
+
24
+ # This code snippet is a modified version adapted from the following GitHub repositories:
25
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
26
+ # Muon's Newton–Schulz iteration causes high variance in singular values
27
+ # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
28
+ @torch.no_grad()
29
+ # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
30
+ def _zeropower_via_newtonschulz5(G, steps):
31
+ """
32
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
33
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
34
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
35
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
36
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
37
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
38
+ performance at all relative to UV^T, where USV^T = G is the SVD.
39
+ """
40
+ assert len(G.shape) == 2
41
+ assert G.dtype == COMM_DTYPE
42
+ X = G # no manual typecast
43
+
44
+ if G.size(0) > G.size(1):
45
+ X = X.T
46
+ # Ensure spectral norm is at most 1
47
+ X = X / (X.norm() + 1e-7)
48
+ buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
49
+ buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
50
+ # Perform the NS iterations
51
+ for a, b, c in [
52
+ (4.0848, -6.8946, 2.9270),
53
+ (3.9505, -6.3029, 2.6377),
54
+ (3.7418, -5.5913, 2.3037),
55
+ (2.8769, -3.1427, 1.2046),
56
+ (2.8366, -3.0525, 1.2012),
57
+ ]:
58
+ matmul_transpose_assign(X, buf1)
59
+ matmul_transpose_assign(buf1, buf2)
60
+ buf1.mul_(b).add_(buf2, alpha=c)
61
+ X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
62
+
63
+ if G.size(0) > G.size(1):
64
+ X = X.T
65
+ return X
66
+
67
+
68
+ @dataclass
69
+ class _muon_state:
70
+ # TODO: use Optional
71
+ worker_rank: int
72
+ process_group: ProcessGroup
73
+ shard_mesh: DeviceMesh
74
+ shard_placements: tuple[Placement, ...]
75
+ name: str
76
+ qk_clip_state: torch.Tensor | None = None
77
+ gathered_grad: torch.Tensor | None = None
78
+ scattered_u: DTensor | None = None
79
+ computed_u: torch.Tensor | None = None
80
+ gather_event: torch.cuda.Event | None = None
81
+ compute_event: torch.cuda.Event | None = None
82
+ scatter_event: torch.cuda.Event | None = None
83
+
84
+
85
+ def numel_for_rank(
86
+ param: DTensor,
87
+ local_rank: int,
88
+ state: _muon_state,
89
+ ) -> int:
90
+ slices = get_slices_of_dtensor(
91
+ param,
92
+ local_rank,
93
+ state.shard_mesh,
94
+ state.shard_placements,
95
+ )
96
+
97
+ numel = 1
98
+ for s, dim in zip(slices, param.shape):
99
+ start, stop, step = s.indices(dim)
100
+ length = max(0, (stop - start + (step - 1)) // step)
101
+ numel *= length
102
+
103
+ return numel
104
+
105
+
106
+ @torch.no_grad()
107
+ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
108
+ """
109
+ Pre-allocate gathered_grad buffer on compute_stream
110
+ before launching all2all gather
111
+ """
112
+ with torch.cuda.stream(compute_stream):
113
+ for p in params:
114
+ state = param_to_state[id(p)]
115
+ if rank == state.worker_rank:
116
+ state.gathered_grad = torch.empty(p.shape,
117
+ dtype=COMM_DTYPE,
118
+ device="cuda")
119
+ else:
120
+ state.gathered_grad = None
121
+
122
+ alloc_event = torch.cuda.Event()
123
+ alloc_event.record(compute_stream)
124
+ return alloc_event
125
+
126
+
127
+ @torch.no_grad()
128
+ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
129
+ alloc_event):
130
+ """
131
+ All2all gathers shards so each owner rank reconstructs its full gradient
132
+ """
133
+ with torch.cuda.stream(comm_stream):
134
+ process_group = param_to_state[id(params[0])].process_group
135
+ num_ranks = dist.get_world_size(group=process_group)
136
+
137
+ # Construct sending buffers
138
+ per_dst = [[] for _ in range(num_ranks)]
139
+ send_counts = [0] * num_ranks
140
+
141
+ for p in params:
142
+ state = param_to_state[id(p)]
143
+ dst = state.worker_rank
144
+ assert dst < num_ranks
145
+ shard_elems = numel_for_rank(p, rank, state)
146
+ g = p.grad
147
+ g = g.to_local().to(COMM_DTYPE).contiguous()
148
+ assert g.numel() == shard_elems
149
+ per_dst[dst].append(g.view(-1))
150
+ send_counts[dst] += shard_elems
151
+
152
+ assert any(
153
+ len(v) > 0 for v in per_dst
154
+ ), "At least one destination rank must receive a sharded tensor"
155
+ # list[list[Tensor]] -> list[Tensor]
156
+ per_dst = [t for dst in per_dst for t in dst]
157
+
158
+ send_buf = torch.cat(per_dst, dim=0)
159
+
160
+ owned_params = [
161
+ p for p in params if param_to_state[id(p)].worker_rank == rank
162
+ ]
163
+
164
+ # Compute receive sizes and allocate receiving buffers
165
+ recv_counts = [0] * num_ranks
166
+
167
+ for src in range(num_ranks):
168
+ total = 0
169
+ for p in owned_params:
170
+ state = param_to_state[id(p)]
171
+ assert state.worker_rank == rank
172
+ total += numel_for_rank(p, src, state)
173
+ recv_counts[src] = total
174
+
175
+ recv_total = sum(recv_counts)
176
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
177
+
178
+ #All2All
179
+ logger.debug(f"send_buf size: {send_buf.numel()}, "
180
+ f"recv_buf size: {recv_buf.numel()}, "
181
+ f"recv_counts: {recv_counts}, "
182
+ f"send_counts: {send_counts}, "
183
+ f"process_group: {str(process_group)}")
184
+ dist.all_to_all_single(
185
+ recv_buf,
186
+ send_buf,
187
+ output_split_sizes=recv_counts,
188
+ input_split_sizes=send_counts,
189
+ group=process_group,
190
+ )
191
+
192
+ # Reconstructs gathered grad from the received buffer
193
+ #
194
+ # recv_buf (num ranks = 3)
195
+ #
196
+ # From rank 0 From rank 1 From rank 2
197
+ # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
198
+ #
199
+ # Outer loop:
200
+ # rank 0 -> rank 1 -> rank2
201
+ #
202
+ # Inner loop:
203
+ # p1_n -> p2_n -> p3_n
204
+
205
+ comm_stream.wait_event(alloc_event)
206
+
207
+ off = 0
208
+ for src in range(num_ranks):
209
+ if recv_counts[src] == 0:
210
+ continue
211
+
212
+ block = recv_counts[src]
213
+ inner_off = 0
214
+ for p in owned_params:
215
+ state = param_to_state[id(p)]
216
+ assert state.worker_rank == rank
217
+
218
+ # get the slice of the full dtensor corresponding to rank src.
219
+ slices = get_slices_of_dtensor(state.gathered_grad, src,
220
+ state.shard_mesh,
221
+ state.shard_placements)
222
+
223
+ dst = state.gathered_grad[slices]
224
+ assert dst._base is state.gathered_grad
225
+
226
+ n = dst.numel()
227
+ assert n > 0
228
+
229
+ sg = recv_buf.narrow(0, off + inner_off, n)
230
+ sg = sg.reshape_as(dst)
231
+ dst.copy_(sg)
232
+
233
+ inner_off += n
234
+ off += block
235
+
236
+ for p in params:
237
+ state = param_to_state[id(p)]
238
+ if state.worker_rank == rank:
239
+ state.gather_event = torch.cuda.Event()
240
+ state.gather_event.record(comm_stream)
241
+ else:
242
+ state.gathered_grad = None
243
+ state.gather_event = None
244
+ if none_grad:
245
+ p.grad = None
246
+
247
+
248
+ @torch.no_grad()
249
+ def _compute_u(p, state, steps, rank, compute_stream):
250
+ """
251
+ On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
252
+ """
253
+ with torch.cuda.stream(compute_stream):
254
+ if rank == state.worker_rank:
255
+ if state.gather_event is None:
256
+ raise RuntimeError("Gather event must be set before compute.")
257
+ compute_stream.wait_event(state.gather_event)
258
+ u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
259
+ state.gathered_grad = None
260
+ state.computed_u = u
261
+ state.compute_event = torch.cuda.Event()
262
+ state.compute_event.record()
263
+ else:
264
+ state.computed_u = None
265
+ state.compute_event = None
266
+
267
+
268
+ @torch.no_grad()
269
+ def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
270
+ """
271
+ Pre-allocate scattered_u buffer on compute_stream
272
+ before launching all2all gather
273
+ """
274
+ with torch.cuda.stream(compute_stream):
275
+ for p in params:
276
+ state = param_to_state[id(p)]
277
+ state.scattered_u = torch.empty_like(p.to_local(),
278
+ dtype=COMM_DTYPE)
279
+
280
+ alloc_event = torch.cuda.Event()
281
+ alloc_event.record(compute_stream)
282
+ return alloc_event
283
+
284
+
285
+ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
286
+ """
287
+ All2all scatters full gradients to all ranks
288
+ """
289
+ with torch.cuda.stream(comm_stream):
290
+ process_group = param_to_state[id(params[0])].process_group
291
+ num_ranks = dist.get_world_size(group=process_group)
292
+ owned_params = [
293
+ p for p in params if param_to_state[id(p)].worker_rank == rank
294
+ ]
295
+
296
+ # Construct sending buffer
297
+ per_dst = [[] for _ in range(num_ranks)]
298
+ send_counts = [0] * num_ranks
299
+
300
+ if owned_params:
301
+ for p in owned_params:
302
+ state = param_to_state[id(p)]
303
+ if state.compute_event is None:
304
+ raise RuntimeError(
305
+ "Compute event must be set before scatter.")
306
+ comm_stream.wait_event(state.compute_event)
307
+ state.gathered_grad = None
308
+
309
+ assert state.computed_u is not None
310
+
311
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous()
312
+
313
+ offset = 0
314
+ for dst in range(num_ranks):
315
+ # get the slice of the full tensor corresponding to rank dst.
316
+ slices = get_slices_of_dtensor(u_full, dst,
317
+ state.shard_mesh,
318
+ state.shard_placements)
319
+ su = u_full[slices].flatten()
320
+
321
+ n = su.numel()
322
+ assert n > 0
323
+
324
+ per_dst[dst].append(su)
325
+ send_counts[dst] += n
326
+ offset += n
327
+
328
+ assert offset == u_full.numel()
329
+
330
+ lengths = [len(v) for v in per_dst]
331
+ if all(l > 0 for l in lengths):
332
+ assert all(
333
+ l == lengths[0] for l in lengths
334
+ ), "All destination ranks must have the same number of sharded tensor"
335
+ # list[list[Tensor]] -> list[Tensor]
336
+ per_dst = [t for dst in per_dst for t in dst]
337
+ send_buf = torch.cat(per_dst, dim=0)
338
+ else:
339
+ # all_to_all requires participation from all ranks
340
+ # Even non-owner ranks must join the collective call
341
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
342
+
343
+ # Compute receive sizes and allocate receiving buffers
344
+ recv_counts = [0] * num_ranks
345
+
346
+ for src in range(num_ranks):
347
+ total = 0
348
+ for p in params:
349
+ state = param_to_state[id(p)]
350
+ if state.worker_rank != src:
351
+ continue
352
+ total += numel_for_rank(p, rank, state)
353
+ recv_counts[src] = total
354
+
355
+ recv_total = sum(recv_counts)
356
+ assert recv_total > 0
357
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
358
+
359
+ #All2All
360
+ dist.all_to_all_single(
361
+ recv_buf,
362
+ send_buf,
363
+ output_split_sizes=recv_counts,
364
+ input_split_sizes=send_counts,
365
+ group=process_group,
366
+ )
367
+
368
+ # Copy to pre-allocated scattered_u buffer from the received buffer
369
+ #
370
+ # recv_buf (num ranks = 3, local_rank = 0)
371
+ #
372
+ # From rank 0 From rank 1 From rank 2
373
+ # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
374
+ #
375
+ # Outer loop:
376
+ # rank 0 -> rank 1 -> rank2
377
+ #
378
+ # Inner loop:
379
+ # src(0) : p1_0 -> p2_0 -> p3_0
380
+ # src(1) : p4_0
381
+ # src(2) : p5_0 -> p6_0
382
+
383
+ comm_stream.wait_event(alloc_event)
384
+
385
+ off = 0
386
+ for src in range(num_ranks):
387
+ block = recv_counts[src]
388
+ if block == 0:
389
+ continue
390
+
391
+ inner_off = 0
392
+ for p in params:
393
+ state = param_to_state[id(p)]
394
+ if state.worker_rank != src:
395
+ continue
396
+ n = numel_for_rank(p, rank, state)
397
+ assert n > 0
398
+
399
+ flat_local = recv_buf.narrow(0, off + inner_off,
400
+ n).view_as(p.to_local())
401
+ state.scattered_u.copy_(flat_local)
402
+
403
+ state.scatter_event = torch.cuda.Event()
404
+ state.scatter_event.record(comm_stream)
405
+ inner_off += n
406
+
407
+ assert inner_off == block
408
+ off += block
409
+
410
+
411
+ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
412
+ compute_stream):
413
+ """
414
+ Update sharded parameter p with the scattered_u.
415
+ Only worker_rank frees computed_u.
416
+ """
417
+ with torch.cuda.stream(compute_stream):
418
+ if state.scatter_event is None:
419
+ raise RuntimeError("Scatter event must be set before update")
420
+ compute_stream.wait_event(state.scatter_event)
421
+ u_dtensor = DTensor.from_local(
422
+ state.scattered_u,
423
+ placements=p.placements,
424
+ device_mesh=p.device_mesh,
425
+ )
426
+
427
+ state.scattered_u = u_dtensor
428
+
429
+ if rank == state.worker_rank:
430
+ # Free computed_u
431
+ state.computed_u = None
432
+
433
+ Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
434
+ state.scattered_u = None
435
+ u_dtensor = None
436
+
437
+ scales_full = Muon._compute_scales(
438
+ p,
439
+ state.qk_clip_state) if state.qk_clip_state is not None else None
440
+ if scales_full is not None:
441
+ # Have to slice scales_full among dim 0
442
+ weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh,
443
+ state.shard_placements)
444
+ ratio = p.shape[0] // scales_full.shape[0]
445
+ scales_slice = slice(
446
+ None if weight_slices[0].start is None else
447
+ weight_slices[0].start // ratio,
448
+ None if weight_slices[0].stop is None else
449
+ weight_slices[0].stop // ratio,
450
+ None,
451
+ )
452
+
453
+ scales_local = scales_full[scales_slice]
454
+ scales_local = DTensor.from_local(
455
+ scales_local,
456
+ placements=p.placements,
457
+ device_mesh=p.device_mesh,
458
+ )
459
+ Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
460
+
461
+
462
+ def default_is_muon(name, x):
463
+ skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
464
+ return x.ndim >= 2 and not any(key in name for key in skip_keys)
465
+
466
+
467
+ def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
468
+ muon_params, muon_names = [], []
469
+ non_muon_params = []
470
+
471
+ for n, p in model.named_parameters():
472
+ if not p.requires_grad:
473
+ continue
474
+ if is_muon_func(n, p):
475
+ muon_params.append(p)
476
+ muon_names.append(n)
477
+ else:
478
+ non_muon_params.append(p)
479
+
480
+ return [
481
+ {
482
+ "params": muon_params,
483
+ "names": muon_names,
484
+ "use_muon": True,
485
+ },
486
+ {
487
+ "params": non_muon_params,
488
+ "use_muon": False,
489
+ },
490
+ ]
491
+
492
+
493
+ def parse_qk_layer(name: str) -> tuple[str | None, int]:
494
+ """
495
+ Parse a parameter name to check if it is a query/key projection layer
496
+ ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
497
+
498
+ Returns:
499
+ (kind, layer_idx) or (None, -1) if not matched.
500
+
501
+ Example:
502
+ 'model.3.attn.wq.weight' -> ('wq', 3)
503
+ 'model.5.attn.wk.weight' -> ('wk', 5)
504
+ 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
505
+ 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
506
+ 'model.4.attn.v_proj.weight' -> (None, -1)
507
+ """
508
+ parts = name.split('.')
509
+ if len(parts) < 3:
510
+ return None, -1
511
+
512
+ kind = parts[-2]
513
+
514
+ layer_idx = -1
515
+ for part in reversed(parts):
516
+ if part.isdigit():
517
+ layer_idx = int(part)
518
+ break
519
+
520
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
521
+ return kind, layer_idx
522
+
523
+ return None, -1
524
+
525
+
526
+ @dataclass
527
+ class QKClipInfo:
528
+ """Per-parameter dynamic info computed from config + runtime logits."""
529
+ kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
530
+ indices: list[int] # which heads to consider for clipping
531
+ head_dim: int # from config
532
+ threshold: float # from config
533
+ logit: torch.Tensor | None
534
+
535
+
536
+ class Muon(torch.optim.Optimizer):
537
+ """
538
+ Muon - MomentUm Orthogonalized by Newton-schulz
539
+
540
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
541
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
542
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
543
+ the advantage that it can be stably run in bfloat16 on the GPU.
544
+
545
+ Some warnings:
546
+ - We believe this optimizer is unlikely to work well for training with small batch size.
547
+ - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
548
+
549
+ Arguments:
550
+ model: The model to be optimized by Muon.
551
+ is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
552
+ lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
553
+ momentum: The momentum used by the internal SGD. (0.95 is a good default)
554
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
555
+ ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
556
+ weight_decay: The weight decay for Muon and AdamW.
557
+ {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
558
+ adamw_lr: The learning rate for the internal AdamW.
559
+ adamw_betas: The betas for the internal AdamW.
560
+ adamw_eps: The epsilon for the internal AdamW.
561
+ none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
562
+ debug: Whether to print debug information.
563
+ clip_info : Configuration for QK clipping. Expected keys:
564
+ - "q_indices" (list[int]): Indices of query heads to consider.
565
+ - "k_indices" (list[int]): Indices of key heads to consider.
566
+ - "head_dim" (int): Dimensionality of each attention head.
567
+ - "threshold" (float): Threshold value; heads whose QK logits exceed
568
+ this value will be scaled down.
569
+ Default is:
570
+ {
571
+ "q_indices": [],
572
+ "k_indices": [],
573
+ "head_dim": 128,
574
+ "threshold": 100
575
+ }
576
+ warmup_step : How many all2all gather, compute operations are launched in advance
577
+ before the corresponding all2all scatter steps begin.
578
+ A higher warmup_step increases memory usage but can improve
579
+ performance by overlapping communication.
580
+ Parallel muon only.
581
+ chunk_size : Batch size of parameters to process in each
582
+ all2all gather/compute/scatter step.
583
+ Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
+ use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
+ For testing purpose only.
586
+ small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
587
+ """
588
+
589
+ def __init__(self,
590
+ params,
591
+ lr=1e-3,
592
+ momentum=0.95,
593
+ nesterov=True,
594
+ ns_steps=5,
595
+ weight_decay=0.1,
596
+ adamw_betas=(0.9, 0.95),
597
+ adamw_eps=1e-8,
598
+ none_grad=True,
599
+ debug=False,
600
+ clip_config={
601
+ "q_indices": [],
602
+ "k_indices": [],
603
+ "head_dim": 128,
604
+ "threshold": 100
605
+ },
606
+ warmup_step=5,
607
+ chunk_size=-1,
608
+ use_distributed_muon=False,
609
+ small_param_numel_threshold=65536):
610
+ defaults = dict(
611
+ lr=lr,
612
+ weight_decay=weight_decay,
613
+ momentum=momentum,
614
+ nesterov=nesterov,
615
+ ns_steps=ns_steps,
616
+ adamw_betas=adamw_betas,
617
+ adamw_eps=adamw_eps,
618
+ none_grad=none_grad,
619
+ use_muon=True,
620
+ )
621
+ error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
622
+ instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
623
+
624
+ if isinstance(params, types.GeneratorType):
625
+ raise ValueError(error_message.format(idx=0) + instruction_code)
626
+ for _idx, param_group in enumerate(params):
627
+ if param_group.get("use_muon", None) is None:
628
+ raise ValueError(
629
+ error_message.format(idx=_idx) + instruction_code)
630
+
631
+ super().__init__(params, defaults)
632
+
633
+ self.rank = None
634
+
635
+ self.comm_stream = torch.cuda.Stream()
636
+ self.compute_stream = torch.cuda.Stream()
637
+ self.debug = debug
638
+ self.clip_config = clip_config
639
+ self.warmup_step = warmup_step
640
+ self.chunk_size = chunk_size
641
+ self.use_distributed_muon = use_distributed_muon
642
+ self.small_param_numel_threshold = small_param_numel_threshold
643
+
644
+ def _calc_flops(self, G, steps):
645
+ assert len(G.shape) == 2
646
+ M, N = G.shape
647
+ if M > N:
648
+ M, N = N, M
649
+
650
+ return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
651
+
652
+ def adjust_lr_for_muon(self, lr, param_shape):
653
+ A, B = param_shape[:2]
654
+ # We adjust the learning rate and weight decay based on the size of the parameter matrix
655
+ # as describted in the paper
656
+ adjusted_ratio = 0.2 * math.sqrt(max(A, B))
657
+ adjusted_lr = lr * adjusted_ratio
658
+ return adjusted_lr
659
+
660
+ def set_rank_once(self, rank):
661
+ if self.rank is None:
662
+ self.rank = rank
663
+ else:
664
+ assert self.rank == rank
665
+
666
+ def get_shard_mesh(self, p):
667
+ """
668
+ Get the shard mesh for a parameter p on the given rank.
669
+ """
670
+ assert isinstance(
671
+ p, DTensor), "Parallel Muon only supports DTensor parameters."
672
+
673
+ shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
674
+ p.placements, p.device_mesh)
675
+
676
+ # set rank with the local rank in the shard process group
677
+ self.set_rank_once(dist.get_rank(group=shard_pg))
678
+
679
+ return shard_mesh, shard_pg, shard_placements
680
+
681
+ def init_state_and_assign_params(self, names, params, group, qk_logits):
682
+ param_to_state = {}
683
+ param_to_flops = {}
684
+
685
+ total_flops = 0
686
+ for p in params:
687
+ g = p.grad
688
+ if g is None:
689
+ continue
690
+ assert g.ndim == 2, "Muon only supports 2D parameters."
691
+
692
+ flops = self._calc_flops(g, group["ns_steps"])
693
+ param_to_flops[id(p)] = flops
694
+ total_flops += flops
695
+
696
+ if self.debug:
697
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
698
+ flush=True)
699
+
700
+ paired = list(zip(names, params))
701
+
702
+ paired_sorted = sorted(paired,
703
+ key=lambda x: param_to_flops[id(x[1])],
704
+ reverse=True)
705
+
706
+ names_sorted, params_sorted = zip(*paired_sorted)
707
+ ordered_names = list(names_sorted)
708
+ ordered_params = list(params_sorted)
709
+
710
+ round_robin = 0
711
+ mesh = ordered_params[0].device_mesh
712
+ placements = ordered_params[0].placements
713
+
714
+ shard_mesh, shard_pg, shard_placements = self.get_shard_mesh(
715
+ ordered_params[0])
716
+ shard_mesh_flattened = shard_mesh.mesh.flatten()
717
+ num_ranks = dist.get_world_size(group=shard_pg)
718
+
719
+ for n, p in zip(ordered_names, ordered_params):
720
+ if mesh != p.device_mesh:
721
+ raise ValueError("All parameters must be on the same mesh.")
722
+ if placements != p.placements:
723
+ raise ValueError("All parameters must have same placements.")
724
+
725
+ worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
726
+ round_robin = (round_robin + 1) % len(shard_mesh_flattened)
727
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
728
+
729
+ param_to_state[id(p)] = _muon_state(
730
+ worker_rank=worker_rank,
731
+ process_group=shard_pg,
732
+ shard_mesh=shard_mesh,
733
+ shard_placements=shard_placements,
734
+ name=n,
735
+ qk_clip_state=qk_clip_state,
736
+ )
737
+
738
+ return param_to_state, ordered_params
739
+
740
+ def base(self, names, params, group, lr, weight_decay, momentum,
741
+ qk_logits):
742
+ # generate weight updates in distributed fashion
743
+ for n, p in zip(names, params):
744
+ g = p.grad
745
+ if g is None:
746
+ continue
747
+ if g.ndim > 2:
748
+ g = g.view(g.size(0), -1)
749
+ assert g is not None
750
+
751
+ g = self._update_g(p, g, group, momentum)
752
+
753
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
754
+ steps=group["ns_steps"])
755
+
756
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
757
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
758
+
759
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
760
+
761
+ scales_full = self._compute_scales(
762
+ p, qk_clip_state) if qk_clip_state is not None else None
763
+ if scales_full is not None:
764
+ Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
765
+
766
+ def distributed_muon(
767
+ self,
768
+ names: list[str],
769
+ params: list[torch.nn.Parameter],
770
+ group: dict[str, Any],
771
+ lr: float,
772
+ weight_decay: float,
773
+ momentum: float,
774
+ qk_logits: list[torch.Tensor | DTensor] | None,
775
+ ):
776
+ """ Implementation of Distributed Muon by Liu et al. """
777
+
778
+ for n, p in zip(names, params):
779
+ g = p.grad
780
+ if g is None:
781
+ continue
782
+ if g.ndim > 2:
783
+ g = g.view(g.size(0), -1)
784
+ assert g is not None
785
+
786
+ g = self._update_g(p, g, group, momentum)
787
+
788
+ # Gather G
789
+ if isinstance(p.data, DTensor):
790
+ g_full = g.full_tensor()
791
+ p_full = p.data.full_tensor()
792
+ else:
793
+ g_full = g
794
+ p_full = p
795
+
796
+ u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
797
+ steps=group["ns_steps"])
798
+
799
+ adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape)
800
+ Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
801
+
802
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
803
+
804
+ scales_full = self._compute_scales(
805
+ p_full, qk_clip_state) if qk_clip_state is not None else None
806
+
807
+ if scales_full is not None:
808
+ Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim)
809
+
810
+ if isinstance(p.data, DTensor):
811
+ ndims = len(p.device_mesh.mesh.shape)
812
+ p_replicate = DTensor.from_local(
813
+ p_full,
814
+ device_mesh=p.device_mesh,
815
+ placements=[Replicate() for _ in range(ndims)],
816
+ )
817
+
818
+ p_sharded = p_replicate.redistribute(
819
+ device_mesh=p.device_mesh,
820
+ placements=p.placements,
821
+ )
822
+
823
+ p.copy_(p_sharded)
824
+
825
+ def _update_g(self, p, g, group, momentum):
826
+ # calc update
827
+ state = self.state[p]
828
+ buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
829
+ torch.add(g, buf, alpha=momentum, out=buf)
830
+ if group["nesterov"]:
831
+ g.add_(buf, alpha=momentum)
832
+ return g
833
+ return buf
834
+
835
+ @staticmethod
836
+ def _update_p(p, u, lr, adjusted_lr, weight_decay):
837
+ if isinstance(p, torch.nn.Parameter):
838
+ # apply weight decay
839
+ p.data.mul_(1 - lr * weight_decay)
840
+ # apply update
841
+ p.data.add_(u, alpha=-adjusted_lr)
842
+ else:
843
+ p.mul_(1 - lr * weight_decay)
844
+ p.add_(u, alpha=-adjusted_lr)
845
+
846
+ def get_qk_clip_info(self, n, qk_logits):
847
+ if self.clip_config is None:
848
+ return None
849
+
850
+ head_dim = self.clip_config.get('head_dim')
851
+ threshold = self.clip_config.get('threshold')
852
+ kind, layer_idx = parse_qk_layer(n)
853
+
854
+ logit, indices = None, []
855
+ if qk_logits is not None and kind is not None:
856
+ logit = qk_logits[layer_idx]
857
+ indices_key = 'q_indices' if 'q' in kind else 'k_indices'
858
+ indices = self.clip_config.get(indices_key, []) or []
859
+
860
+ if isinstance(logit, DTensor):
861
+ # In TP settings, qk_logits may be DTensor
862
+ # We convert it to full tensor here for simplicity
863
+ logit = logit.full_tensor()
864
+
865
+ return QKClipInfo(
866
+ kind=kind,
867
+ indices=indices,
868
+ head_dim=head_dim,
869
+ threshold=threshold,
870
+ logit=logit,
871
+ )
872
+
873
+ @staticmethod
874
+ def _compute_scales(p, qk_clip_state):
875
+ kind = qk_clip_state.kind
876
+ indices = qk_clip_state.indices
877
+ head_dim = qk_clip_state.head_dim
878
+ threshold = qk_clip_state.threshold
879
+ logit = qk_clip_state.logit
880
+
881
+ H_global = p.shape[0] // head_dim
882
+ scales_full = torch.ones(H_global, device=p.data.device)
883
+ scaling = 0
884
+
885
+ for logit_idx, head_idx in enumerate(indices):
886
+ v_ele = float(logit[logit_idx])
887
+ if v_ele > threshold:
888
+ new_scale = math.sqrt(threshold / v_ele)
889
+ if new_scale < scales_full[head_idx]:
890
+ scales_full[head_idx] = new_scale
891
+ logger.info(
892
+ f"[{kind}] Head {head_idx} exceeded threshold "
893
+ f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
894
+ )
895
+ scaling += 1
896
+
897
+ return scales_full if scaling > 0 else None
898
+
899
+ @staticmethod
900
+ def _qk_clip(p, scales, head_dim):
901
+ if isinstance(p, torch.nn.Parameter):
902
+ W = p.data.view(-1, head_dim, p.data.shape[1])
903
+ W.mul_(scales.view(-1, 1, 1))
904
+ else:
905
+ W = p.view(-1, head_dim, p.shape[1])
906
+ W.mul_(scales.view(-1, 1, 1))
907
+
908
+ def parallel(self, names, params, group, lr, weight_decay, momentum,
909
+ qk_logits):
910
+ """
911
+ Perform a parallel optimization step using Muon.
912
+ """
913
+
914
+ for p in params:
915
+ g = p.grad
916
+ if g is None:
917
+ continue
918
+ if g.ndim > 2:
919
+ g = g.view(g.size(0), -1)
920
+
921
+ # Update g in the local rank
922
+ g = self._update_g(
923
+ p,
924
+ g,
925
+ group,
926
+ momentum=momentum,
927
+ )
928
+ p.grad = g
929
+
930
+ param_to_state, ordered_params = self.init_state_and_assign_params(
931
+ names, params, group, qk_logits)
932
+
933
+ assert self.rank is not None
934
+
935
+ def enqueue_all2all_gather(start_idx, chunk_size):
936
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
937
+ if target_params:
938
+ alloc_event = _alloc_gathered_grad(target_params,
939
+ param_to_state, self.rank,
940
+ self.compute_stream)
941
+ _all2all_gather(target_params, param_to_state, self.rank,
942
+ self.comm_stream, group["none_grad"],
943
+ alloc_event)
944
+
945
+ def enqueue_computes(start_idx, chunk_size):
946
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
947
+ state = param_to_state[id(p)]
948
+ _compute_u(p, state, group["ns_steps"], self.rank,
949
+ self.compute_stream)
950
+
951
+ def enqueue_all2all_scatter(start_idx, chunk_size):
952
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
953
+ if target_params:
954
+ alloc_event = _alloc_scattered_u(target_params, param_to_state,
955
+ self.rank,
956
+ self.compute_stream)
957
+ _all2all_scatter(target_params, param_to_state, self.rank,
958
+ self.comm_stream, alloc_event)
959
+
960
+ def enqueue_update_param(start_idx, chunk_size):
961
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
962
+ state = param_to_state[id(p)]
963
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
964
+ _update_param(p, state, lr, adjusted_lr, weight_decay,
965
+ self.rank, self.compute_stream)
966
+
967
+ if self.chunk_size == -1:
968
+ shard_ranks = dist.get_world_size(param_to_state[id(
969
+ params[0])].process_group)
970
+ chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
971
+ elif self.chunk_size > 0:
972
+ chunk_size = self.chunk_size
973
+ else:
974
+ raise ValueError("chunk_size must be -1 or a positive integer.")
975
+
976
+ # Wait grad update
977
+ self.comm_stream.wait_stream(torch.cuda.current_stream())
978
+
979
+ warmup_step = self.warmup_step
980
+ for i in range(0, warmup_step):
981
+ enqueue_all2all_gather(i * chunk_size, chunk_size)
982
+ enqueue_computes(i * chunk_size, chunk_size)
983
+
984
+ for i in range(0, len(params) + chunk_size - 1, chunk_size):
985
+ enqueue_all2all_scatter(i, chunk_size)
986
+ enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size)
987
+ enqueue_update_param(i, chunk_size)
988
+ enqueue_computes(i + warmup_step * chunk_size, chunk_size)
989
+
990
+ # Wait the last update_param to finish
991
+ torch.cuda.current_stream().wait_stream(self.compute_stream)
992
+
993
+ @staticmethod
994
+ def _fused_adamw(
995
+ params: list[torch.Tensor],
996
+ grads: list[torch.Tensor],
997
+ exp_avgs: list[torch.Tensor],
998
+ exp_avg_sqs: list[torch.Tensor],
999
+ max_exp_avg_sqs: list[torch.Tensor],
1000
+ state_steps: list[torch.Tensor],
1001
+ amsgrad: bool,
1002
+ beta1: float,
1003
+ beta2: float,
1004
+ lr: float | torch.Tensor,
1005
+ weight_decay: float,
1006
+ eps: float,
1007
+ maximize: bool,
1008
+ ) -> None:
1009
+ if not params:
1010
+ return
1011
+
1012
+ # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
1013
+ # treating it as a scalar.
1014
+ lr_dict: DeviceDict | None = ({
1015
+ lr.device: lr
1016
+ } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
1017
+ None)
1018
+ grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
1019
+ [
1020
+ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
1021
+ state_steps
1022
+ ] # type: ignore[list-item]
1023
+ )
1024
+ for (device, _), (
1025
+ (
1026
+ device_params_,
1027
+ device_grads_,
1028
+ device_exp_avgs_,
1029
+ device_exp_avg_sqs_,
1030
+ device_max_exp_avg_sqs,
1031
+ device_state_steps_,
1032
+ ),
1033
+ _,
1034
+ ) in grouped_tensors.items():
1035
+ device_params = cast(list[torch.Tensor], device_params_)
1036
+ device_grads = cast(list[torch.Tensor], device_grads_)
1037
+ device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
1038
+ device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
1039
+ device_state_steps = cast(list[torch.Tensor], device_state_steps_)
1040
+
1041
+ if lr_dict is not None and device not in lr_dict:
1042
+ lr_dict[device] = lr.to(
1043
+ device=device,
1044
+ non_blocking=True) # type: ignore[union-attr]
1045
+ lr = lr_dict[device]
1046
+ torch._foreach_add_(device_state_steps, 1)
1047
+ func = torch._fused_adamw_
1048
+ func(
1049
+ device_params,
1050
+ device_grads,
1051
+ device_exp_avgs,
1052
+ device_exp_avg_sqs,
1053
+ device_max_exp_avg_sqs, # type: ignore[arg-type]
1054
+ device_state_steps,
1055
+ amsgrad=amsgrad,
1056
+ lr=lr, # type: ignore[arg-type]
1057
+ beta1=beta1,
1058
+ beta2=beta2,
1059
+ weight_decay=weight_decay,
1060
+ eps=eps,
1061
+ maximize=maximize,
1062
+ )
1063
+
1064
+ def _step_muon(self, group, qk_logits=None):
1065
+ params = group["params"]
1066
+ lr = group["lr"]
1067
+ weight_decay = group["weight_decay"]
1068
+ momentum = group["momentum"]
1069
+ names = group["names"]
1070
+
1071
+ param_dtensors = []
1072
+ name_dtensors = []
1073
+
1074
+ param_tensors = []
1075
+ name_tensors = []
1076
+
1077
+ param_dtensors_small = []
1078
+ name_dtensors_small = []
1079
+
1080
+ if self.use_distributed_muon:
1081
+ self.distributed_muon(names=names,
1082
+ params=params,
1083
+ group=group,
1084
+ lr=lr,
1085
+ weight_decay=weight_decay,
1086
+ momentum=momentum,
1087
+ qk_logits=qk_logits)
1088
+ return
1089
+
1090
+ # For simplicity, we use distributed Muon for small parameters
1091
+ # whose number of elements is below a threshold.
1092
+ for n, p in zip(names, params):
1093
+ if p is None or p.grad is None:
1094
+ continue
1095
+ if isinstance(p.data, DTensor):
1096
+ if all(
1097
+ isinstance(placement, Replicate)
1098
+ for placement in p.placements):
1099
+ param_tensors.append(p)
1100
+ name_tensors.append(n)
1101
+ elif p.data.numel() <= self.small_param_numel_threshold:
1102
+ param_dtensors_small.append(p)
1103
+ name_dtensors_small.append(n)
1104
+ else:
1105
+ param_dtensors.append(p)
1106
+ name_dtensors.append(n)
1107
+ elif isinstance(p.data, torch.Tensor):
1108
+ param_tensors.append(p)
1109
+ name_tensors.append(n)
1110
+ else:
1111
+ raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1112
+
1113
+ logger.debug(
1114
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, "
1115
+ f"{len(param_dtensors_small)} Small DTensors")
1116
+
1117
+ def group_dtensors(dtensors, names):
1118
+ # To support different placements, we group parameters by placements
1119
+ # and run parallel Muon on each group.
1120
+
1121
+ placement_to_params = defaultdict(lambda: ([], []))
1122
+ # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1123
+
1124
+ assert len(dtensors) == len(names)
1125
+ for p, n in zip(dtensors, names):
1126
+ placement_to_params[tuple([p.placements,
1127
+ p.device_mesh])][0].append(n)
1128
+ placement_to_params[tuple([p.placements,
1129
+ p.device_mesh])][1].append(p)
1130
+ return placement_to_params
1131
+
1132
+ if len(param_dtensors_small) > 0:
1133
+ if not dist.is_initialized():
1134
+ raise RuntimeError(
1135
+ "Parallel Muon requires torch.distributed to be initialized."
1136
+ )
1137
+
1138
+ self.distributed_muon(
1139
+ params=param_dtensors_small,
1140
+ names=name_dtensors_small,
1141
+ group=group,
1142
+ lr=lr,
1143
+ weight_decay=weight_decay,
1144
+ momentum=momentum,
1145
+ qk_logits=qk_logits,
1146
+ )
1147
+
1148
+ if len(param_dtensors) > 0:
1149
+ if not dist.is_initialized():
1150
+ raise RuntimeError(
1151
+ "Parallel Muon requires torch.distributed to be initialized."
1152
+ )
1153
+
1154
+ dtensor_group = group_dtensors(param_dtensors, name_dtensors)
1155
+ for _, (names, params) in dtensor_group.items():
1156
+ self.parallel(
1157
+ names,
1158
+ params,
1159
+ group,
1160
+ lr=lr,
1161
+ weight_decay=weight_decay,
1162
+ momentum=momentum,
1163
+ qk_logits=qk_logits,
1164
+ )
1165
+
1166
+ if len(param_tensors) > 0:
1167
+ self.base(
1168
+ name_tensors,
1169
+ param_tensors,
1170
+ group,
1171
+ lr=lr,
1172
+ weight_decay=weight_decay,
1173
+ momentum=momentum,
1174
+ qk_logits=qk_logits,
1175
+ )
1176
+
1177
+ def _step_adamw_params(self, params, group):
1178
+ params_with_grads = []
1179
+ grads = []
1180
+ moment1 = []
1181
+ moment2 = []
1182
+ max_exp_avg_sqs = []
1183
+ state_steps = []
1184
+ lr = group["lr"]
1185
+ beta1, beta2 = group["adamw_betas"]
1186
+ eps = group["adamw_eps"]
1187
+ weight_decay = group["weight_decay"]
1188
+
1189
+ for p in params:
1190
+ g = p.grad
1191
+ if g is None:
1192
+ continue
1193
+ state = self.state[p]
1194
+ params_with_grads.append(p)
1195
+ grads.append(g)
1196
+ if "step" not in state:
1197
+ state["step"] = (torch.zeros((),
1198
+ dtype=torch.float32,
1199
+ device=p.device))
1200
+ state["moment1"] = torch.zeros_like(g)
1201
+ state["moment2"] = torch.zeros_like(g)
1202
+ moment1.append(state["moment1"])
1203
+ moment2.append(state["moment2"])
1204
+ if not isinstance(state["step"], torch.Tensor):
1205
+ step_tensor = torch.tensor(state["step"],
1206
+ dtype=torch.float32,
1207
+ device=p.device)
1208
+ else:
1209
+ step_tensor = state["step"]
1210
+ state_steps.append(step_tensor)
1211
+
1212
+ self._fused_adamw(
1213
+ params_with_grads,
1214
+ grads,
1215
+ moment1,
1216
+ moment2,
1217
+ max_exp_avg_sqs,
1218
+ state_steps,
1219
+ amsgrad=False,
1220
+ beta1=beta1,
1221
+ beta2=beta2,
1222
+ lr=lr,
1223
+ weight_decay=weight_decay,
1224
+ eps=eps,
1225
+ maximize=False,
1226
+ )
1227
+
1228
+ def _step_adamw(self, group):
1229
+ params = group["params"]
1230
+
1231
+ # group params with it's type and placement
1232
+ placement_to_params: dict[tuple[Placement | type,
1233
+ DeviceMesh | None]] = defaultdict(list)
1234
+ for p in params:
1235
+ match p:
1236
+ case DTensor():
1237
+ placement_to_params[tuple([p.placements,
1238
+ p.device_mesh])].append(p)
1239
+ case torch.Tensor():
1240
+ placement_to_params[tuple([torch.Tensor, None])].append(p)
1241
+
1242
+ for params in placement_to_params.values():
1243
+ self._step_adamw_params(params, group)
1244
+
1245
+ @torch.no_grad
1246
+ def step(self, closure=None, qk_logits=None):
1247
+ """Perform a single optimization step.
1248
+
1249
+ Args:
1250
+ closure (Callable, optional): A closure that reevaluates the model
1251
+ and returns the loss.
1252
+ qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
1253
+ to 1D tensors of shape (num_heads,), representing the maximum
1254
+ QK logits across all tokens, computed as
1255
+ (1 / sqrt(head_dim)) * (Q @ K^T).
1256
+ """
1257
+ loss = None
1258
+ if closure is not None:
1259
+ with torch.enable_grad():
1260
+ loss = closure()
1261
+
1262
+ for group in self.param_groups:
1263
+ if group["use_muon"]:
1264
+ self._step_muon(group, qk_logits=qk_logits)
1265
+ else:
1266
+ self._step_adamw(group)
1267
+
1268
+ return loss
build/torch210-cxx11-rocm71-x86_64-linux/optimizer/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import sys
3
+
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch28-cxx11-cu126-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .muon import Muon
2
+
3
+ __all__ = [
4
+ "Muon",
5
+ ]
build/torch28-cxx11-cu126-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _optimizer_06a260a_dirty
3
+ ops = torch.ops._optimizer_06a260a_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_optimizer_06a260a_dirty::{op_name}"
build/torch28-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:222315672693e6d4544b1eee4772dc7be744b3794cfd6ff370a6f46d782386a1
3
+ size 1936664
build/torch28-cxx11-cu126-x86_64-linux/distributed/utils.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from torch.distributed import ProcessGroup
4
+ from torch.distributed.device_mesh import DeviceMesh
5
+ from torch.distributed.tensor import DTensor
6
+ from torch.distributed.tensor.placement_types import (Placement, Shard,
7
+ _StridedShard)
8
+
9
+
10
+ def get_slices_of_dtensor(
11
+ target: DTensor | torch.Tensor,
12
+ local_rank: int,
13
+ shard_mesh: DeviceMesh,
14
+ shard_placements: tuple[Placement],
15
+ ) -> tuple[slice]:
16
+ """
17
+ Get the slice of local tensor for a given rank from a tensor.
18
+ Args:
19
+ target (DTensor | torch.Tensor): The target tensor.
20
+ rank (int): The local rank of the shard group.
21
+ shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks.
22
+ shard_placements (tuple[Placement]): The shard placements.
23
+ """
24
+
25
+ slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()]
26
+
27
+ # find the global rank of the local rank in the shard mesh
28
+ rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank]
29
+
30
+ rank_coords = (shard_mesh.mesh == rank).nonzero()
31
+
32
+ assert len(rank_coords) == 1
33
+ rank_coords = tuple(rank_coords[0].tolist())
34
+
35
+ assert len(rank_coords) == len(shard_placements)
36
+
37
+ # Caution: Assuming replicate-to-shard of the shard mesh goes with
38
+ # left-to-right sharding. This is ensured by the sorting logic of
39
+ # construct_shard_mesh function.
40
+ for i, (rank_coord,
41
+ placement) in enumerate(zip(rank_coords, shard_placements)):
42
+ assert isinstance(placement, Shard)
43
+
44
+ num_ranks = shard_mesh.mesh.shape[i]
45
+
46
+ dim = placement.dim
47
+ dim_size = (slices[dim].stop - slices[dim].start)
48
+
49
+ if dim_size % num_ranks != 0:
50
+ raise NotImplementedError(
51
+ f"Dimension size {dim_size} is not divisible "
52
+ f"by number of ranks {num_ranks} for shard "
53
+ f"placement on dim {dim}. (shape: {target.shape})")
54
+
55
+ shard_size = dim_size // num_ranks
56
+
57
+ start = slices[dim].start + rank_coord * shard_size
58
+ end = start + shard_size
59
+
60
+ assert start < end <= slices[dim].stop
61
+
62
+ slices[dim] = slice(start, end)
63
+
64
+ return tuple(slices)
65
+
66
+
67
+ _ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh,
68
+ ProcessGroup]] = dict()
69
+
70
+
71
+ def construct_shard_mesh(
72
+ placements: tuple[Placement],
73
+ mesh: DeviceMesh,
74
+ ) -> (DeviceMesh, ProcessGroup, tuple[Placement]):
75
+ """
76
+ Construct Shard Mesh and Placements for unsharding.
77
+ It removes Replicate placements and constructs a new Mesh and ProcessGroup.
78
+ """
79
+ my_rank = dist.get_rank()
80
+
81
+ assert mesh.mesh.device.type == 'cpu'
82
+
83
+ # Copy mesh to avoid modifying the original mesh
84
+ mesh = mesh.mesh.clone()
85
+
86
+ # 1. Sort placements. Replicate first, then Shard by dim ascending.
87
+
88
+ # For Shard, strided shard comes after regular shard on the same dim
89
+ # to preserve left-to-right order of replicate-to-shard.
90
+ # This is because that strided shard is using stride to represent
91
+ # more fine-grained sharding on the same dim.
92
+ # Please check the URL below for _StridedShard.
93
+ # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366
94
+
95
+ def placement_sort_key(
96
+ placement_with_index: tuple[float, Placement]
97
+ ) -> tuple[int, float, int]: # (dim, split factor, original index)
98
+ index, placement = placement_with_index
99
+ is_replicate = placement.is_replicate()
100
+ is_shard = placement.is_shard()
101
+ is_partial = placement.is_partial()
102
+
103
+ assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}"
104
+ assert not is_partial, "Partial placement is not supported."
105
+
106
+ if is_replicate:
107
+ return (-1.0, 0, index)
108
+ elif is_shard:
109
+ if isinstance(placement, _StridedShard):
110
+ return (placement.dim, 1 / placement.split_factor, index)
111
+ return (placement.dim, 0, index)
112
+ else:
113
+ raise TypeError(f"Unknown placement type: {type(placement)}")
114
+
115
+ placements_with_index: list[tuple[int,
116
+ Placement]] = list(enumerate(placements))
117
+ placements_with_index = sorted(placements_with_index,
118
+ key=placement_sort_key)
119
+
120
+ sorted_indices, sorted_placements = zip(*placements_with_index)
121
+
122
+ # 2. Permute mesh according to sorted placements.
123
+ sorted_mesh = mesh.permute(sorted_indices)
124
+
125
+ # 3. Collect list of shard meshes by removing replicate dims
126
+ # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)]
127
+ # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4)
128
+ num_replicates = sum(1 for p in sorted_placements if p.is_replicate())
129
+
130
+ # merge replicate dims
131
+ # shard_meshes became a list of shard meshes with a length of replicate degree
132
+ if num_replicates > 0:
133
+ sorted_mesh = sorted_mesh.flatten(
134
+ 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh
135
+ shard_meshes = list(torch.unbind(sorted_mesh, dim=0))
136
+ else:
137
+ shard_meshes = [sorted_mesh]
138
+ shard_placements = sorted_placements[num_replicates:]
139
+
140
+ # assume all shard placements are different
141
+ assert len(shard_placements) == len(set(shard_placements))
142
+
143
+ # 4. Construct ProcessGroups
144
+ # Caution: all groups should be created in the same order in all processes,
145
+ # even though each process only needs its own group.
146
+
147
+ # To use tensor as dict key, convert it to tuple
148
+ def tensor_to_tuple(t):
149
+ if isinstance(t, torch.Tensor):
150
+ t = t.tolist()
151
+ if isinstance(t, list):
152
+ return tuple(tensor_to_tuple(x) for x in t)
153
+ return t
154
+
155
+ my_shard_mesh_as_tuple = None
156
+ for shard_mesh in shard_meshes:
157
+ assert isinstance(shard_mesh, torch.Tensor)
158
+ shard_mesh_as_tuple = tensor_to_tuple(shard_mesh)
159
+
160
+ if (my_rank == shard_mesh).any().item():
161
+ assert my_shard_mesh_as_tuple is None
162
+ my_shard_mesh_as_tuple = shard_mesh_as_tuple
163
+
164
+ # update global cache
165
+ if shard_mesh_as_tuple not in _ranks_to_dist_cache:
166
+ shard_process_group = dist.new_group(shard_mesh.flatten().tolist())
167
+ _ranks_to_dist_cache[shard_mesh_as_tuple] = (
168
+ DeviceMesh(device_type="cuda", mesh=shard_mesh),
169
+ shard_process_group,
170
+ )
171
+
172
+ my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[
173
+ my_shard_mesh_as_tuple]
174
+
175
+ return my_shard_mesh, my_shard_process_group, shard_placements
build/{torch29-cxx11-cu126-x86_64-linux/optimizer β†’ torch28-cxx11-cu126-x86_64-linux}/matmul_transpose_triton.py RENAMED
File without changes
build/torch28-cxx11-cu126-x86_64-linux/metadata.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"python-depends":[]}
build/torch28-cxx11-cu126-x86_64-linux/muon.py ADDED
@@ -0,0 +1,1268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import types
4
+ from collections import defaultdict
5
+ from dataclasses import dataclass
6
+ from typing import Any, cast
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ from torch.distributed import ProcessGroup
11
+ from torch.distributed.device_mesh import DeviceMesh
12
+ from torch.distributed.tensor import DTensor, Replicate
13
+ from torch.distributed.tensor.placement_types import Placement
14
+
15
+ from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor
16
+ from .matmul_transpose_triton import matmul_transpose_assign
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ COMM_DTYPE = torch.bfloat16
21
+ DEFAULT_CHUNK_SIZE_RATIO = 4
22
+
23
+
24
+ # This code snippet is a modified version adapted from the following GitHub repositories:
25
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
26
+ # Muon's Newton–Schulz iteration causes high variance in singular values
27
+ # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
28
+ @torch.no_grad()
29
+ # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
30
+ def _zeropower_via_newtonschulz5(G, steps):
31
+ """
32
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
33
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
34
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
35
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
36
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
37
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
38
+ performance at all relative to UV^T, where USV^T = G is the SVD.
39
+ """
40
+ assert len(G.shape) == 2
41
+ assert G.dtype == COMM_DTYPE
42
+ X = G # no manual typecast
43
+
44
+ if G.size(0) > G.size(1):
45
+ X = X.T
46
+ # Ensure spectral norm is at most 1
47
+ X = X / (X.norm() + 1e-7)
48
+ buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
49
+ buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
50
+ # Perform the NS iterations
51
+ for a, b, c in [
52
+ (4.0848, -6.8946, 2.9270),
53
+ (3.9505, -6.3029, 2.6377),
54
+ (3.7418, -5.5913, 2.3037),
55
+ (2.8769, -3.1427, 1.2046),
56
+ (2.8366, -3.0525, 1.2012),
57
+ ]:
58
+ matmul_transpose_assign(X, buf1)
59
+ matmul_transpose_assign(buf1, buf2)
60
+ buf1.mul_(b).add_(buf2, alpha=c)
61
+ X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
62
+
63
+ if G.size(0) > G.size(1):
64
+ X = X.T
65
+ return X
66
+
67
+
68
+ @dataclass
69
+ class _muon_state:
70
+ # TODO: use Optional
71
+ worker_rank: int
72
+ process_group: ProcessGroup
73
+ shard_mesh: DeviceMesh
74
+ shard_placements: tuple[Placement, ...]
75
+ name: str
76
+ qk_clip_state: torch.Tensor | None = None
77
+ gathered_grad: torch.Tensor | None = None
78
+ scattered_u: DTensor | None = None
79
+ computed_u: torch.Tensor | None = None
80
+ gather_event: torch.cuda.Event | None = None
81
+ compute_event: torch.cuda.Event | None = None
82
+ scatter_event: torch.cuda.Event | None = None
83
+
84
+
85
+ def numel_for_rank(
86
+ param: DTensor,
87
+ local_rank: int,
88
+ state: _muon_state,
89
+ ) -> int:
90
+ slices = get_slices_of_dtensor(
91
+ param,
92
+ local_rank,
93
+ state.shard_mesh,
94
+ state.shard_placements,
95
+ )
96
+
97
+ numel = 1
98
+ for s, dim in zip(slices, param.shape):
99
+ start, stop, step = s.indices(dim)
100
+ length = max(0, (stop - start + (step - 1)) // step)
101
+ numel *= length
102
+
103
+ return numel
104
+
105
+
106
+ @torch.no_grad()
107
+ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
108
+ """
109
+ Pre-allocate gathered_grad buffer on compute_stream
110
+ before launching all2all gather
111
+ """
112
+ with torch.cuda.stream(compute_stream):
113
+ for p in params:
114
+ state = param_to_state[id(p)]
115
+ if rank == state.worker_rank:
116
+ state.gathered_grad = torch.empty(p.shape,
117
+ dtype=COMM_DTYPE,
118
+ device="cuda")
119
+ else:
120
+ state.gathered_grad = None
121
+
122
+ alloc_event = torch.cuda.Event()
123
+ alloc_event.record(compute_stream)
124
+ return alloc_event
125
+
126
+
127
+ @torch.no_grad()
128
+ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
129
+ alloc_event):
130
+ """
131
+ All2all gathers shards so each owner rank reconstructs its full gradient
132
+ """
133
+ with torch.cuda.stream(comm_stream):
134
+ process_group = param_to_state[id(params[0])].process_group
135
+ num_ranks = dist.get_world_size(group=process_group)
136
+
137
+ # Construct sending buffers
138
+ per_dst = [[] for _ in range(num_ranks)]
139
+ send_counts = [0] * num_ranks
140
+
141
+ for p in params:
142
+ state = param_to_state[id(p)]
143
+ dst = state.worker_rank
144
+ assert dst < num_ranks
145
+ shard_elems = numel_for_rank(p, rank, state)
146
+ g = p.grad
147
+ g = g.to_local().to(COMM_DTYPE).contiguous()
148
+ assert g.numel() == shard_elems
149
+ per_dst[dst].append(g.view(-1))
150
+ send_counts[dst] += shard_elems
151
+
152
+ assert any(
153
+ len(v) > 0 for v in per_dst
154
+ ), "At least one destination rank must receive a sharded tensor"
155
+ # list[list[Tensor]] -> list[Tensor]
156
+ per_dst = [t for dst in per_dst for t in dst]
157
+
158
+ send_buf = torch.cat(per_dst, dim=0)
159
+
160
+ owned_params = [
161
+ p for p in params if param_to_state[id(p)].worker_rank == rank
162
+ ]
163
+
164
+ # Compute receive sizes and allocate receiving buffers
165
+ recv_counts = [0] * num_ranks
166
+
167
+ for src in range(num_ranks):
168
+ total = 0
169
+ for p in owned_params:
170
+ state = param_to_state[id(p)]
171
+ assert state.worker_rank == rank
172
+ total += numel_for_rank(p, src, state)
173
+ recv_counts[src] = total
174
+
175
+ recv_total = sum(recv_counts)
176
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
177
+
178
+ #All2All
179
+ logger.debug(f"send_buf size: {send_buf.numel()}, "
180
+ f"recv_buf size: {recv_buf.numel()}, "
181
+ f"recv_counts: {recv_counts}, "
182
+ f"send_counts: {send_counts}, "
183
+ f"process_group: {str(process_group)}")
184
+ dist.all_to_all_single(
185
+ recv_buf,
186
+ send_buf,
187
+ output_split_sizes=recv_counts,
188
+ input_split_sizes=send_counts,
189
+ group=process_group,
190
+ )
191
+
192
+ # Reconstructs gathered grad from the received buffer
193
+ #
194
+ # recv_buf (num ranks = 3)
195
+ #
196
+ # From rank 0 From rank 1 From rank 2
197
+ # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
198
+ #
199
+ # Outer loop:
200
+ # rank 0 -> rank 1 -> rank2
201
+ #
202
+ # Inner loop:
203
+ # p1_n -> p2_n -> p3_n
204
+
205
+ comm_stream.wait_event(alloc_event)
206
+
207
+ off = 0
208
+ for src in range(num_ranks):
209
+ if recv_counts[src] == 0:
210
+ continue
211
+
212
+ block = recv_counts[src]
213
+ inner_off = 0
214
+ for p in owned_params:
215
+ state = param_to_state[id(p)]
216
+ assert state.worker_rank == rank
217
+
218
+ # get the slice of the full dtensor corresponding to rank src.
219
+ slices = get_slices_of_dtensor(state.gathered_grad, src,
220
+ state.shard_mesh,
221
+ state.shard_placements)
222
+
223
+ dst = state.gathered_grad[slices]
224
+ assert dst._base is state.gathered_grad
225
+
226
+ n = dst.numel()
227
+ assert n > 0
228
+
229
+ sg = recv_buf.narrow(0, off + inner_off, n)
230
+ sg = sg.reshape_as(dst)
231
+ dst.copy_(sg)
232
+
233
+ inner_off += n
234
+ off += block
235
+
236
+ for p in params:
237
+ state = param_to_state[id(p)]
238
+ if state.worker_rank == rank:
239
+ state.gather_event = torch.cuda.Event()
240
+ state.gather_event.record(comm_stream)
241
+ else:
242
+ state.gathered_grad = None
243
+ state.gather_event = None
244
+ if none_grad:
245
+ p.grad = None
246
+
247
+
248
+ @torch.no_grad()
249
+ def _compute_u(p, state, steps, rank, compute_stream):
250
+ """
251
+ On worker_rank, compute the orthogonalized update using Newton-Schulz iteration.
252
+ """
253
+ with torch.cuda.stream(compute_stream):
254
+ if rank == state.worker_rank:
255
+ if state.gather_event is None:
256
+ raise RuntimeError("Gather event must be set before compute.")
257
+ compute_stream.wait_event(state.gather_event)
258
+ u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
259
+ state.gathered_grad = None
260
+ state.computed_u = u
261
+ state.compute_event = torch.cuda.Event()
262
+ state.compute_event.record()
263
+ else:
264
+ state.computed_u = None
265
+ state.compute_event = None
266
+
267
+
268
+ @torch.no_grad()
269
+ def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
270
+ """
271
+ Pre-allocate scattered_u buffer on compute_stream
272
+ before launching all2all gather
273
+ """
274
+ with torch.cuda.stream(compute_stream):
275
+ for p in params:
276
+ state = param_to_state[id(p)]
277
+ state.scattered_u = torch.empty_like(p.to_local(),
278
+ dtype=COMM_DTYPE)
279
+
280
+ alloc_event = torch.cuda.Event()
281
+ alloc_event.record(compute_stream)
282
+ return alloc_event
283
+
284
+
285
+ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
286
+ """
287
+ All2all scatters full gradients to all ranks
288
+ """
289
+ with torch.cuda.stream(comm_stream):
290
+ process_group = param_to_state[id(params[0])].process_group
291
+ num_ranks = dist.get_world_size(group=process_group)
292
+ owned_params = [
293
+ p for p in params if param_to_state[id(p)].worker_rank == rank
294
+ ]
295
+
296
+ # Construct sending buffer
297
+ per_dst = [[] for _ in range(num_ranks)]
298
+ send_counts = [0] * num_ranks
299
+
300
+ if owned_params:
301
+ for p in owned_params:
302
+ state = param_to_state[id(p)]
303
+ if state.compute_event is None:
304
+ raise RuntimeError(
305
+ "Compute event must be set before scatter.")
306
+ comm_stream.wait_event(state.compute_event)
307
+ state.gathered_grad = None
308
+
309
+ assert state.computed_u is not None
310
+
311
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous()
312
+
313
+ offset = 0
314
+ for dst in range(num_ranks):
315
+ # get the slice of the full tensor corresponding to rank dst.
316
+ slices = get_slices_of_dtensor(u_full, dst,
317
+ state.shard_mesh,
318
+ state.shard_placements)
319
+ su = u_full[slices].flatten()
320
+
321
+ n = su.numel()
322
+ assert n > 0
323
+
324
+ per_dst[dst].append(su)
325
+ send_counts[dst] += n
326
+ offset += n
327
+
328
+ assert offset == u_full.numel()
329
+
330
+ lengths = [len(v) for v in per_dst]
331
+ if all(l > 0 for l in lengths):
332
+ assert all(
333
+ l == lengths[0] for l in lengths
334
+ ), "All destination ranks must have the same number of sharded tensor"
335
+ # list[list[Tensor]] -> list[Tensor]
336
+ per_dst = [t for dst in per_dst for t in dst]
337
+ send_buf = torch.cat(per_dst, dim=0)
338
+ else:
339
+ # all_to_all requires participation from all ranks
340
+ # Even non-owner ranks must join the collective call
341
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
342
+
343
+ # Compute receive sizes and allocate receiving buffers
344
+ recv_counts = [0] * num_ranks
345
+
346
+ for src in range(num_ranks):
347
+ total = 0
348
+ for p in params:
349
+ state = param_to_state[id(p)]
350
+ if state.worker_rank != src:
351
+ continue
352
+ total += numel_for_rank(p, rank, state)
353
+ recv_counts[src] = total
354
+
355
+ recv_total = sum(recv_counts)
356
+ assert recv_total > 0
357
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
358
+
359
+ #All2All
360
+ dist.all_to_all_single(
361
+ recv_buf,
362
+ send_buf,
363
+ output_split_sizes=recv_counts,
364
+ input_split_sizes=send_counts,
365
+ group=process_group,
366
+ )
367
+
368
+ # Copy to pre-allocated scattered_u buffer from the received buffer
369
+ #
370
+ # recv_buf (num ranks = 3, local_rank = 0)
371
+ #
372
+ # From rank 0 From rank 1 From rank 2
373
+ # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
374
+ #
375
+ # Outer loop:
376
+ # rank 0 -> rank 1 -> rank2
377
+ #
378
+ # Inner loop:
379
+ # src(0) : p1_0 -> p2_0 -> p3_0
380
+ # src(1) : p4_0
381
+ # src(2) : p5_0 -> p6_0
382
+
383
+ comm_stream.wait_event(alloc_event)
384
+
385
+ off = 0
386
+ for src in range(num_ranks):
387
+ block = recv_counts[src]
388
+ if block == 0:
389
+ continue
390
+
391
+ inner_off = 0
392
+ for p in params:
393
+ state = param_to_state[id(p)]
394
+ if state.worker_rank != src:
395
+ continue
396
+ n = numel_for_rank(p, rank, state)
397
+ assert n > 0
398
+
399
+ flat_local = recv_buf.narrow(0, off + inner_off,
400
+ n).view_as(p.to_local())
401
+ state.scattered_u.copy_(flat_local)
402
+
403
+ state.scatter_event = torch.cuda.Event()
404
+ state.scatter_event.record(comm_stream)
405
+ inner_off += n
406
+
407
+ assert inner_off == block
408
+ off += block
409
+
410
+
411
+ def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
412
+ compute_stream):
413
+ """
414
+ Update sharded parameter p with the scattered_u.
415
+ Only worker_rank frees computed_u.
416
+ """
417
+ with torch.cuda.stream(compute_stream):
418
+ if state.scatter_event is None:
419
+ raise RuntimeError("Scatter event must be set before update")
420
+ compute_stream.wait_event(state.scatter_event)
421
+ u_dtensor = DTensor.from_local(
422
+ state.scattered_u,
423
+ placements=p.placements,
424
+ device_mesh=p.device_mesh,
425
+ )
426
+
427
+ state.scattered_u = u_dtensor
428
+
429
+ if rank == state.worker_rank:
430
+ # Free computed_u
431
+ state.computed_u = None
432
+
433
+ Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay)
434
+ state.scattered_u = None
435
+ u_dtensor = None
436
+
437
+ scales_full = Muon._compute_scales(
438
+ p,
439
+ state.qk_clip_state) if state.qk_clip_state is not None else None
440
+ if scales_full is not None:
441
+ # Have to slice scales_full among dim 0
442
+ weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh,
443
+ state.shard_placements)
444
+ ratio = p.shape[0] // scales_full.shape[0]
445
+ scales_slice = slice(
446
+ None if weight_slices[0].start is None else
447
+ weight_slices[0].start // ratio,
448
+ None if weight_slices[0].stop is None else
449
+ weight_slices[0].stop // ratio,
450
+ None,
451
+ )
452
+
453
+ scales_local = scales_full[scales_slice]
454
+ scales_local = DTensor.from_local(
455
+ scales_local,
456
+ placements=p.placements,
457
+ device_mesh=p.device_mesh,
458
+ )
459
+ Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim)
460
+
461
+
462
+ def default_is_muon(name, x):
463
+ skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
464
+ return x.ndim >= 2 and not any(key in name for key in skip_keys)
465
+
466
+
467
+ def get_default_muon_param_groups(model, is_muon_func=default_is_muon):
468
+ muon_params, muon_names = [], []
469
+ non_muon_params = []
470
+
471
+ for n, p in model.named_parameters():
472
+ if not p.requires_grad:
473
+ continue
474
+ if is_muon_func(n, p):
475
+ muon_params.append(p)
476
+ muon_names.append(n)
477
+ else:
478
+ non_muon_params.append(p)
479
+
480
+ return [
481
+ {
482
+ "params": muon_params,
483
+ "names": muon_names,
484
+ "use_muon": True,
485
+ },
486
+ {
487
+ "params": non_muon_params,
488
+ "use_muon": False,
489
+ },
490
+ ]
491
+
492
+
493
+ def parse_qk_layer(name: str) -> tuple[str | None, int]:
494
+ """
495
+ Parse a parameter name to check if it is a query/key projection layer
496
+ ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index).
497
+
498
+ Returns:
499
+ (kind, layer_idx) or (None, -1) if not matched.
500
+
501
+ Example:
502
+ 'model.3.attn.wq.weight' -> ('wq', 3)
503
+ 'model.5.attn.wk.weight' -> ('wk', 5)
504
+ 'model.2.attn.q_proj.weight' -> ('q_proj', 2)
505
+ 'model.7.attn.k_proj.weight' -> ('k_proj', 7)
506
+ 'model.4.attn.v_proj.weight' -> (None, -1)
507
+ """
508
+ parts = name.split('.')
509
+ if len(parts) < 3:
510
+ return None, -1
511
+
512
+ kind = parts[-2]
513
+
514
+ layer_idx = -1
515
+ for part in reversed(parts):
516
+ if part.isdigit():
517
+ layer_idx = int(part)
518
+ break
519
+
520
+ if kind in ('wq', 'wk', 'q_proj', 'k_proj'):
521
+ return kind, layer_idx
522
+
523
+ return None, -1
524
+
525
+
526
+ @dataclass
527
+ class QKClipInfo:
528
+ """Per-parameter dynamic info computed from config + runtime logits."""
529
+ kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None
530
+ indices: list[int] # which heads to consider for clipping
531
+ head_dim: int # from config
532
+ threshold: float # from config
533
+ logit: torch.Tensor | None
534
+
535
+
536
+ class Muon(torch.optim.Optimizer):
537
+ """
538
+ Muon - MomentUm Orthogonalized by Newton-schulz
539
+
540
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
541
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
542
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
543
+ the advantage that it can be stably run in bfloat16 on the GPU.
544
+
545
+ Some warnings:
546
+ - We believe this optimizer is unlikely to work well for training with small batch size.
547
+ - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
548
+
549
+ Arguments:
550
+ model: The model to be optimized by Muon.
551
+ is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon.
552
+ lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
553
+ momentum: The momentum used by the internal SGD. (0.95 is a good default)
554
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
555
+ ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
556
+ weight_decay: The weight decay for Muon and AdamW.
557
+ {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
558
+ adamw_lr: The learning rate for the internal AdamW.
559
+ adamw_betas: The betas for the internal AdamW.
560
+ adamw_eps: The epsilon for the internal AdamW.
561
+ none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory.
562
+ debug: Whether to print debug information.
563
+ clip_info : Configuration for QK clipping. Expected keys:
564
+ - "q_indices" (list[int]): Indices of query heads to consider.
565
+ - "k_indices" (list[int]): Indices of key heads to consider.
566
+ - "head_dim" (int): Dimensionality of each attention head.
567
+ - "threshold" (float): Threshold value; heads whose QK logits exceed
568
+ this value will be scaled down.
569
+ Default is:
570
+ {
571
+ "q_indices": [],
572
+ "k_indices": [],
573
+ "head_dim": 128,
574
+ "threshold": 100
575
+ }
576
+ warmup_step : How many all2all gather, compute operations are launched in advance
577
+ before the corresponding all2all scatter steps begin.
578
+ A higher warmup_step increases memory usage but can improve
579
+ performance by overlapping communication.
580
+ Parallel muon only.
581
+ chunk_size : Batch size of parameters to process in each
582
+ all2all gather/compute/scatter step.
583
+ Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified.
584
+ use_distributed_muon: Use distributed muon by Liu et al. (2024).
585
+ For testing purpose only.
586
+ small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon
587
+ """
588
+
589
+ def __init__(self,
590
+ params,
591
+ lr=1e-3,
592
+ momentum=0.95,
593
+ nesterov=True,
594
+ ns_steps=5,
595
+ weight_decay=0.1,
596
+ adamw_betas=(0.9, 0.95),
597
+ adamw_eps=1e-8,
598
+ none_grad=True,
599
+ debug=False,
600
+ clip_config={
601
+ "q_indices": [],
602
+ "k_indices": [],
603
+ "head_dim": 128,
604
+ "threshold": 100
605
+ },
606
+ warmup_step=5,
607
+ chunk_size=-1,
608
+ use_distributed_muon=False,
609
+ small_param_numel_threshold=65536):
610
+ defaults = dict(
611
+ lr=lr,
612
+ weight_decay=weight_decay,
613
+ momentum=momentum,
614
+ nesterov=nesterov,
615
+ ns_steps=ns_steps,
616
+ adamw_betas=adamw_betas,
617
+ adamw_eps=adamw_eps,
618
+ none_grad=none_grad,
619
+ use_muon=True,
620
+ )
621
+ error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior."
622
+ instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```"
623
+
624
+ if isinstance(params, types.GeneratorType):
625
+ raise ValueError(error_message.format(idx=0) + instruction_code)
626
+ for _idx, param_group in enumerate(params):
627
+ if param_group.get("use_muon", None) is None:
628
+ raise ValueError(
629
+ error_message.format(idx=_idx) + instruction_code)
630
+
631
+ super().__init__(params, defaults)
632
+
633
+ self.rank = None
634
+
635
+ self.comm_stream = torch.cuda.Stream()
636
+ self.compute_stream = torch.cuda.Stream()
637
+ self.debug = debug
638
+ self.clip_config = clip_config
639
+ self.warmup_step = warmup_step
640
+ self.chunk_size = chunk_size
641
+ self.use_distributed_muon = use_distributed_muon
642
+ self.small_param_numel_threshold = small_param_numel_threshold
643
+
644
+ def _calc_flops(self, G, steps):
645
+ assert len(G.shape) == 2
646
+ M, N = G.shape
647
+ if M > N:
648
+ M, N = N, M
649
+
650
+ return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
651
+
652
+ def adjust_lr_for_muon(self, lr, param_shape):
653
+ A, B = param_shape[:2]
654
+ # We adjust the learning rate and weight decay based on the size of the parameter matrix
655
+ # as describted in the paper
656
+ adjusted_ratio = 0.2 * math.sqrt(max(A, B))
657
+ adjusted_lr = lr * adjusted_ratio
658
+ return adjusted_lr
659
+
660
+ def set_rank_once(self, rank):
661
+ if self.rank is None:
662
+ self.rank = rank
663
+ else:
664
+ assert self.rank == rank
665
+
666
+ def get_shard_mesh(self, p):
667
+ """
668
+ Get the shard mesh for a parameter p on the given rank.
669
+ """
670
+ assert isinstance(
671
+ p, DTensor), "Parallel Muon only supports DTensor parameters."
672
+
673
+ shard_mesh, shard_pg, shard_placements = construct_shard_mesh(
674
+ p.placements, p.device_mesh)
675
+
676
+ # set rank with the local rank in the shard process group
677
+ self.set_rank_once(dist.get_rank(group=shard_pg))
678
+
679
+ return shard_mesh, shard_pg, shard_placements
680
+
681
+ def init_state_and_assign_params(self, names, params, group, qk_logits):
682
+ param_to_state = {}
683
+ param_to_flops = {}
684
+
685
+ total_flops = 0
686
+ for p in params:
687
+ g = p.grad
688
+ if g is None:
689
+ continue
690
+ assert g.ndim == 2, "Muon only supports 2D parameters."
691
+
692
+ flops = self._calc_flops(g, group["ns_steps"])
693
+ param_to_flops[id(p)] = flops
694
+ total_flops += flops
695
+
696
+ if self.debug:
697
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs",
698
+ flush=True)
699
+
700
+ paired = list(zip(names, params))
701
+
702
+ paired_sorted = sorted(paired,
703
+ key=lambda x: param_to_flops[id(x[1])],
704
+ reverse=True)
705
+
706
+ names_sorted, params_sorted = zip(*paired_sorted)
707
+ ordered_names = list(names_sorted)
708
+ ordered_params = list(params_sorted)
709
+
710
+ round_robin = 0
711
+ mesh = ordered_params[0].device_mesh
712
+ placements = ordered_params[0].placements
713
+
714
+ shard_mesh, shard_pg, shard_placements = self.get_shard_mesh(
715
+ ordered_params[0])
716
+ shard_mesh_flattened = shard_mesh.mesh.flatten()
717
+ num_ranks = dist.get_world_size(group=shard_pg)
718
+
719
+ for n, p in zip(ordered_names, ordered_params):
720
+ if mesh != p.device_mesh:
721
+ raise ValueError("All parameters must be on the same mesh.")
722
+ if placements != p.placements:
723
+ raise ValueError("All parameters must have same placements.")
724
+
725
+ worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks
726
+ round_robin = (round_robin + 1) % len(shard_mesh_flattened)
727
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
728
+
729
+ param_to_state[id(p)] = _muon_state(
730
+ worker_rank=worker_rank,
731
+ process_group=shard_pg,
732
+ shard_mesh=shard_mesh,
733
+ shard_placements=shard_placements,
734
+ name=n,
735
+ qk_clip_state=qk_clip_state,
736
+ )
737
+
738
+ return param_to_state, ordered_params
739
+
740
+ def base(self, names, params, group, lr, weight_decay, momentum,
741
+ qk_logits):
742
+ # generate weight updates in distributed fashion
743
+ for n, p in zip(names, params):
744
+ g = p.grad
745
+ if g is None:
746
+ continue
747
+ if g.ndim > 2:
748
+ g = g.view(g.size(0), -1)
749
+ assert g is not None
750
+
751
+ g = self._update_g(p, g, group, momentum)
752
+
753
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
754
+ steps=group["ns_steps"])
755
+
756
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
757
+ Muon._update_p(p, u, lr, adjusted_lr, weight_decay)
758
+
759
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
760
+
761
+ scales_full = self._compute_scales(
762
+ p, qk_clip_state) if qk_clip_state is not None else None
763
+ if scales_full is not None:
764
+ Muon._qk_clip(p, scales_full, qk_clip_state.head_dim)
765
+
766
+ def distributed_muon(
767
+ self,
768
+ names: list[str],
769
+ params: list[torch.nn.Parameter],
770
+ group: dict[str, Any],
771
+ lr: float,
772
+ weight_decay: float,
773
+ momentum: float,
774
+ qk_logits: list[torch.Tensor | DTensor] | None,
775
+ ):
776
+ """ Implementation of Distributed Muon by Liu et al. """
777
+
778
+ for n, p in zip(names, params):
779
+ g = p.grad
780
+ if g is None:
781
+ continue
782
+ if g.ndim > 2:
783
+ g = g.view(g.size(0), -1)
784
+ assert g is not None
785
+
786
+ g = self._update_g(p, g, group, momentum)
787
+
788
+ # Gather G
789
+ if isinstance(p.data, DTensor):
790
+ g_full = g.full_tensor()
791
+ p_full = p.data.full_tensor()
792
+ else:
793
+ g_full = g
794
+ p_full = p
795
+
796
+ u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE),
797
+ steps=group["ns_steps"])
798
+
799
+ adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape)
800
+ Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay)
801
+
802
+ qk_clip_state = self.get_qk_clip_info(n, qk_logits)
803
+
804
+ scales_full = self._compute_scales(
805
+ p_full, qk_clip_state) if qk_clip_state is not None else None
806
+
807
+ if scales_full is not None:
808
+ Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim)
809
+
810
+ if isinstance(p.data, DTensor):
811
+ ndims = len(p.device_mesh.mesh.shape)
812
+ p_replicate = DTensor.from_local(
813
+ p_full,
814
+ device_mesh=p.device_mesh,
815
+ placements=[Replicate() for _ in range(ndims)],
816
+ )
817
+
818
+ p_sharded = p_replicate.redistribute(
819
+ device_mesh=p.device_mesh,
820
+ placements=p.placements,
821
+ )
822
+
823
+ p.copy_(p_sharded)
824
+
825
+ def _update_g(self, p, g, group, momentum):
826
+ # calc update
827
+ state = self.state[p]
828
+ buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
829
+ torch.add(g, buf, alpha=momentum, out=buf)
830
+ if group["nesterov"]:
831
+ g.add_(buf, alpha=momentum)
832
+ return g
833
+ return buf
834
+
835
+ @staticmethod
836
+ def _update_p(p, u, lr, adjusted_lr, weight_decay):
837
+ if isinstance(p, torch.nn.Parameter):
838
+ # apply weight decay
839
+ p.data.mul_(1 - lr * weight_decay)
840
+ # apply update
841
+ p.data.add_(u, alpha=-adjusted_lr)
842
+ else:
843
+ p.mul_(1 - lr * weight_decay)
844
+ p.add_(u, alpha=-adjusted_lr)
845
+
846
+ def get_qk_clip_info(self, n, qk_logits):
847
+ if self.clip_config is None:
848
+ return None
849
+
850
+ head_dim = self.clip_config.get('head_dim')
851
+ threshold = self.clip_config.get('threshold')
852
+ kind, layer_idx = parse_qk_layer(n)
853
+
854
+ logit, indices = None, []
855
+ if qk_logits is not None and kind is not None:
856
+ logit = qk_logits[layer_idx]
857
+ indices_key = 'q_indices' if 'q' in kind else 'k_indices'
858
+ indices = self.clip_config.get(indices_key, []) or []
859
+
860
+ if isinstance(logit, DTensor):
861
+ # In TP settings, qk_logits may be DTensor
862
+ # We convert it to full tensor here for simplicity
863
+ logit = logit.full_tensor()
864
+
865
+ return QKClipInfo(
866
+ kind=kind,
867
+ indices=indices,
868
+ head_dim=head_dim,
869
+ threshold=threshold,
870
+ logit=logit,
871
+ )
872
+
873
+ @staticmethod
874
+ def _compute_scales(p, qk_clip_state):
875
+ kind = qk_clip_state.kind
876
+ indices = qk_clip_state.indices
877
+ head_dim = qk_clip_state.head_dim
878
+ threshold = qk_clip_state.threshold
879
+ logit = qk_clip_state.logit
880
+
881
+ H_global = p.shape[0] // head_dim
882
+ scales_full = torch.ones(H_global, device=p.data.device)
883
+ scaling = 0
884
+
885
+ for logit_idx, head_idx in enumerate(indices):
886
+ v_ele = float(logit[logit_idx])
887
+ if v_ele > threshold:
888
+ new_scale = math.sqrt(threshold / v_ele)
889
+ if new_scale < scales_full[head_idx]:
890
+ scales_full[head_idx] = new_scale
891
+ logger.info(
892
+ f"[{kind}] Head {head_idx} exceeded threshold "
893
+ f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
894
+ )
895
+ scaling += 1
896
+
897
+ return scales_full if scaling > 0 else None
898
+
899
+ @staticmethod
900
+ def _qk_clip(p, scales, head_dim):
901
+ if isinstance(p, torch.nn.Parameter):
902
+ W = p.data.view(-1, head_dim, p.data.shape[1])
903
+ W.mul_(scales.view(-1, 1, 1))
904
+ else:
905
+ W = p.view(-1, head_dim, p.shape[1])
906
+ W.mul_(scales.view(-1, 1, 1))
907
+
908
+ def parallel(self, names, params, group, lr, weight_decay, momentum,
909
+ qk_logits):
910
+ """
911
+ Perform a parallel optimization step using Muon.
912
+ """
913
+
914
+ for p in params:
915
+ g = p.grad
916
+ if g is None:
917
+ continue
918
+ if g.ndim > 2:
919
+ g = g.view(g.size(0), -1)
920
+
921
+ # Update g in the local rank
922
+ g = self._update_g(
923
+ p,
924
+ g,
925
+ group,
926
+ momentum=momentum,
927
+ )
928
+ p.grad = g
929
+
930
+ param_to_state, ordered_params = self.init_state_and_assign_params(
931
+ names, params, group, qk_logits)
932
+
933
+ assert self.rank is not None
934
+
935
+ def enqueue_all2all_gather(start_idx, chunk_size):
936
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
937
+ if target_params:
938
+ alloc_event = _alloc_gathered_grad(target_params,
939
+ param_to_state, self.rank,
940
+ self.compute_stream)
941
+ _all2all_gather(target_params, param_to_state, self.rank,
942
+ self.comm_stream, group["none_grad"],
943
+ alloc_event)
944
+
945
+ def enqueue_computes(start_idx, chunk_size):
946
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
947
+ state = param_to_state[id(p)]
948
+ _compute_u(p, state, group["ns_steps"], self.rank,
949
+ self.compute_stream)
950
+
951
+ def enqueue_all2all_scatter(start_idx, chunk_size):
952
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
953
+ if target_params:
954
+ alloc_event = _alloc_scattered_u(target_params, param_to_state,
955
+ self.rank,
956
+ self.compute_stream)
957
+ _all2all_scatter(target_params, param_to_state, self.rank,
958
+ self.comm_stream, alloc_event)
959
+
960
+ def enqueue_update_param(start_idx, chunk_size):
961
+ for p in ordered_params[start_idx:start_idx + chunk_size]:
962
+ state = param_to_state[id(p)]
963
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
964
+ _update_param(p, state, lr, adjusted_lr, weight_decay,
965
+ self.rank, self.compute_stream)
966
+
967
+ if self.chunk_size == -1:
968
+ shard_ranks = dist.get_world_size(param_to_state[id(
969
+ params[0])].process_group)
970
+ chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO
971
+ elif self.chunk_size > 0:
972
+ chunk_size = self.chunk_size
973
+ else:
974
+ raise ValueError("chunk_size must be -1 or a positive integer.")
975
+
976
+ # Wait grad update
977
+ self.comm_stream.wait_stream(torch.cuda.current_stream())
978
+
979
+ warmup_step = self.warmup_step
980
+ for i in range(0, warmup_step):
981
+ enqueue_all2all_gather(i * chunk_size, chunk_size)
982
+ enqueue_computes(i * chunk_size, chunk_size)
983
+
984
+ for i in range(0, len(params) + chunk_size - 1, chunk_size):
985
+ enqueue_all2all_scatter(i, chunk_size)
986
+ enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size)
987
+ enqueue_update_param(i, chunk_size)
988
+ enqueue_computes(i + warmup_step * chunk_size, chunk_size)
989
+
990
+ # Wait the last update_param to finish
991
+ torch.cuda.current_stream().wait_stream(self.compute_stream)
992
+
993
+ @staticmethod
994
+ def _fused_adamw(
995
+ params: list[torch.Tensor],
996
+ grads: list[torch.Tensor],
997
+ exp_avgs: list[torch.Tensor],
998
+ exp_avg_sqs: list[torch.Tensor],
999
+ max_exp_avg_sqs: list[torch.Tensor],
1000
+ state_steps: list[torch.Tensor],
1001
+ amsgrad: bool,
1002
+ beta1: float,
1003
+ beta2: float,
1004
+ lr: float | torch.Tensor,
1005
+ weight_decay: float,
1006
+ eps: float,
1007
+ maximize: bool,
1008
+ ) -> None:
1009
+ if not params:
1010
+ return
1011
+
1012
+ # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
1013
+ # treating it as a scalar.
1014
+ lr_dict: DeviceDict | None = ({
1015
+ lr.device: lr
1016
+ } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else
1017
+ None)
1018
+ grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype(
1019
+ [
1020
+ params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs,
1021
+ state_steps
1022
+ ] # type: ignore[list-item]
1023
+ )
1024
+ for (device, _), (
1025
+ (
1026
+ device_params_,
1027
+ device_grads_,
1028
+ device_exp_avgs_,
1029
+ device_exp_avg_sqs_,
1030
+ device_max_exp_avg_sqs,
1031
+ device_state_steps_,
1032
+ ),
1033
+ _,
1034
+ ) in grouped_tensors.items():
1035
+ device_params = cast(list[torch.Tensor], device_params_)
1036
+ device_grads = cast(list[torch.Tensor], device_grads_)
1037
+ device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_)
1038
+ device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_)
1039
+ device_state_steps = cast(list[torch.Tensor], device_state_steps_)
1040
+
1041
+ if lr_dict is not None and device not in lr_dict:
1042
+ lr_dict[device] = lr.to(
1043
+ device=device,
1044
+ non_blocking=True) # type: ignore[union-attr]
1045
+ lr = lr_dict[device]
1046
+ torch._foreach_add_(device_state_steps, 1)
1047
+ func = torch._fused_adamw_
1048
+ func(
1049
+ device_params,
1050
+ device_grads,
1051
+ device_exp_avgs,
1052
+ device_exp_avg_sqs,
1053
+ device_max_exp_avg_sqs, # type: ignore[arg-type]
1054
+ device_state_steps,
1055
+ amsgrad=amsgrad,
1056
+ lr=lr, # type: ignore[arg-type]
1057
+ beta1=beta1,
1058
+ beta2=beta2,
1059
+ weight_decay=weight_decay,
1060
+ eps=eps,
1061
+ maximize=maximize,
1062
+ )
1063
+
1064
+ def _step_muon(self, group, qk_logits=None):
1065
+ params = group["params"]
1066
+ lr = group["lr"]
1067
+ weight_decay = group["weight_decay"]
1068
+ momentum = group["momentum"]
1069
+ names = group["names"]
1070
+
1071
+ param_dtensors = []
1072
+ name_dtensors = []
1073
+
1074
+ param_tensors = []
1075
+ name_tensors = []
1076
+
1077
+ param_dtensors_small = []
1078
+ name_dtensors_small = []
1079
+
1080
+ if self.use_distributed_muon:
1081
+ self.distributed_muon(names=names,
1082
+ params=params,
1083
+ group=group,
1084
+ lr=lr,
1085
+ weight_decay=weight_decay,
1086
+ momentum=momentum,
1087
+ qk_logits=qk_logits)
1088
+ return
1089
+
1090
+ # For simplicity, we use distributed Muon for small parameters
1091
+ # whose number of elements is below a threshold.
1092
+ for n, p in zip(names, params):
1093
+ if p is None or p.grad is None:
1094
+ continue
1095
+ if isinstance(p.data, DTensor):
1096
+ if all(
1097
+ isinstance(placement, Replicate)
1098
+ for placement in p.placements):
1099
+ param_tensors.append(p)
1100
+ name_tensors.append(n)
1101
+ elif p.data.numel() <= self.small_param_numel_threshold:
1102
+ param_dtensors_small.append(p)
1103
+ name_dtensors_small.append(n)
1104
+ else:
1105
+ param_dtensors.append(p)
1106
+ name_dtensors.append(n)
1107
+ elif isinstance(p.data, torch.Tensor):
1108
+ param_tensors.append(p)
1109
+ name_tensors.append(n)
1110
+ else:
1111
+ raise TypeError(f"Unsupported parameter type: {type(p.data)}")
1112
+
1113
+ logger.debug(
1114
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, "
1115
+ f"{len(param_dtensors_small)} Small DTensors")
1116
+
1117
+ def group_dtensors(dtensors, names):
1118
+ # To support different placements, we group parameters by placements
1119
+ # and run parallel Muon on each group.
1120
+
1121
+ placement_to_params = defaultdict(lambda: ([], []))
1122
+ # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]]
1123
+
1124
+ assert len(dtensors) == len(names)
1125
+ for p, n in zip(dtensors, names):
1126
+ placement_to_params[tuple([p.placements,
1127
+ p.device_mesh])][0].append(n)
1128
+ placement_to_params[tuple([p.placements,
1129
+ p.device_mesh])][1].append(p)
1130
+ return placement_to_params
1131
+
1132
+ if len(param_dtensors_small) > 0:
1133
+ if not dist.is_initialized():
1134
+ raise RuntimeError(
1135
+ "Parallel Muon requires torch.distributed to be initialized."
1136
+ )
1137
+
1138
+ self.distributed_muon(
1139
+ params=param_dtensors_small,
1140
+ names=name_dtensors_small,
1141
+ group=group,
1142
+ lr=lr,
1143
+ weight_decay=weight_decay,
1144
+ momentum=momentum,
1145
+ qk_logits=qk_logits,
1146
+ )
1147
+
1148
+ if len(param_dtensors) > 0:
1149
+ if not dist.is_initialized():
1150
+ raise RuntimeError(
1151
+ "Parallel Muon requires torch.distributed to be initialized."
1152
+ )
1153
+
1154
+ dtensor_group = group_dtensors(param_dtensors, name_dtensors)
1155
+ for _, (names, params) in dtensor_group.items():
1156
+ self.parallel(
1157
+ names,
1158
+ params,
1159
+ group,
1160
+ lr=lr,
1161
+ weight_decay=weight_decay,
1162
+ momentum=momentum,
1163
+ qk_logits=qk_logits,
1164
+ )
1165
+
1166
+ if len(param_tensors) > 0:
1167
+ self.base(
1168
+ name_tensors,
1169
+ param_tensors,
1170
+ group,
1171
+ lr=lr,
1172
+ weight_decay=weight_decay,
1173
+ momentum=momentum,
1174
+ qk_logits=qk_logits,
1175
+ )
1176
+
1177
+ def _step_adamw_params(self, params, group):
1178
+ params_with_grads = []
1179
+ grads = []
1180
+ moment1 = []
1181
+ moment2 = []
1182
+ max_exp_avg_sqs = []
1183
+ state_steps = []
1184
+ lr = group["lr"]
1185
+ beta1, beta2 = group["adamw_betas"]
1186
+ eps = group["adamw_eps"]
1187
+ weight_decay = group["weight_decay"]
1188
+
1189
+ for p in params:
1190
+ g = p.grad
1191
+ if g is None:
1192
+ continue
1193
+ state = self.state[p]
1194
+ params_with_grads.append(p)
1195
+ grads.append(g)
1196
+ if "step" not in state:
1197
+ state["step"] = (torch.zeros((),
1198
+ dtype=torch.float32,
1199
+ device=p.device))
1200
+ state["moment1"] = torch.zeros_like(g)
1201
+ state["moment2"] = torch.zeros_like(g)
1202
+ moment1.append(state["moment1"])
1203
+ moment2.append(state["moment2"])
1204
+ if not isinstance(state["step"], torch.Tensor):
1205
+ step_tensor = torch.tensor(state["step"],
1206
+ dtype=torch.float32,
1207
+ device=p.device)
1208
+ else:
1209
+ step_tensor = state["step"]
1210
+ state_steps.append(step_tensor)
1211
+
1212
+ self._fused_adamw(
1213
+ params_with_grads,
1214
+ grads,
1215
+ moment1,
1216
+ moment2,
1217
+ max_exp_avg_sqs,
1218
+ state_steps,
1219
+ amsgrad=False,
1220
+ beta1=beta1,
1221
+ beta2=beta2,
1222
+ lr=lr,
1223
+ weight_decay=weight_decay,
1224
+ eps=eps,
1225
+ maximize=False,
1226
+ )
1227
+
1228
+ def _step_adamw(self, group):
1229
+ params = group["params"]
1230
+
1231
+ # group params with it's type and placement
1232
+ placement_to_params: dict[tuple[Placement | type,
1233
+ DeviceMesh | None]] = defaultdict(list)
1234
+ for p in params:
1235
+ match p:
1236
+ case DTensor():
1237
+ placement_to_params[tuple([p.placements,
1238
+ p.device_mesh])].append(p)
1239
+ case torch.Tensor():
1240
+ placement_to_params[tuple([torch.Tensor, None])].append(p)
1241
+
1242
+ for params in placement_to_params.values():
1243
+ self._step_adamw_params(params, group)
1244
+
1245
+ @torch.no_grad
1246
+ def step(self, closure=None, qk_logits=None):
1247
+ """Perform a single optimization step.
1248
+
1249
+ Args:
1250
+ closure (Callable, optional): A closure that reevaluates the model
1251
+ and returns the loss.
1252
+ qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices
1253
+ to 1D tensors of shape (num_heads,), representing the maximum
1254
+ QK logits across all tokens, computed as
1255
+ (1 / sqrt(head_dim)) * (Q @ K^T).
1256
+ """
1257
+ loss = None
1258
+ if closure is not None:
1259
+ with torch.enable_grad():
1260
+ loss = closure()
1261
+
1262
+ for group in self.param_groups:
1263
+ if group["use_muon"]:
1264
+ self._step_muon(group, qk_logits=qk_logits)
1265
+ else:
1266
+ self._step_adamw(group)
1267
+
1268
+ return loss
build/torch28-cxx11-cu126-x86_64-linux/optimizer/__init__.py CHANGED
@@ -1,5 +1,26 @@
1
- from .muon import Muon
 
2
 
3
- __all__ = [
4
- "Muon",
5
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import sys
3
 
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch28-cxx11-cu128-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .muon import Muon
2
+
3
+ __all__ = [
4
+ "Muon",
5
+ ]