megablocks-hip / build.py
leonardlin's picture
Add ROCm build artifacts and HIP backend
1e407f0
raw
history blame
2.02 kB
import os
import pathlib
import shutil
from torch.utils.cpp_extension import load
try:
from kernels.utils import build_variant
except ImportError: # fallback when kernels is unavailable
build_variant = None
repo = pathlib.Path(__file__).resolve().parent
os.environ.setdefault("TORCH_EXTENSIONS_DIR", str(repo / ".torch_extensions"))
sources = [
repo / "torch-ext" / "torch_binding.cpp",
repo / "csrc" / "new_cumsum.cu",
repo / "csrc" / "new_histogram.cu",
repo / "csrc" / "new_indices.cu",
repo / "csrc" / "new_replicate.cu",
repo / "csrc" / "new_sort.cu",
repo / "csrc" / "grouped_gemm" / "grouped_gemm.cu",
]
mod = load(
name="_megablocks_rocm",
sources=[str(s) for s in sources],
extra_include_paths=[str(repo / "csrc")],
extra_cflags=["-O3", "-std=c++17"],
extra_cuda_cflags=["-O3"], # torch switches this to hipcc flags on ROCm builds
extra_ldflags=["-lhipblaslt"],
verbose=True,
is_python_module=False,
)
module_path = pathlib.Path(mod if isinstance(mod, str) else mod.__file__)
print("built:", module_path)
if build_variant is None:
print("kernels not available; skipping package staging")
else:
variant = build_variant()
package_root = repo / "build" / variant / "megablocks"
if package_root.exists():
shutil.rmtree(package_root)
shutil.copytree(
repo / "torch-ext" / "megablocks",
package_root,
ignore=shutil.ignore_patterns("__pycache__"),
)
ops_py = package_root / "_ops.py"
ops_py.write_text('''
import torch
from pathlib import Path
_LIB_NAME = "_megablocks_rocm.so"
def _load_ops():
lib_path = Path(__file__).with_name(_LIB_NAME)
torch.ops.load_library(str(lib_path))
return torch.ops._megablocks_rocm
ops = _load_ops()
def add_op_namespace_prefix(op_name: str) -> str:
return f"_megablocks_rocm::{op_name}"
''')
shutil.copy2(module_path, package_root / module_path.name)
print(f"staged local kernel under build/{variant}/megablocks")