| TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | |
| ops.def("rmsnorm_metal_forward(Tensor! input, Tensor! weight, Tensor! output) -> ()"); | |
| ops.impl("rmsnorm_metal_forward", torch::kMPS, rmsnorm_metal_forward); | |
| } | |
| REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |