File size: 2,022 Bytes
1e407f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import os
import pathlib
import shutil
from torch.utils.cpp_extension import load
try:
from kernels.utils import build_variant
except ImportError: # fallback when kernels is unavailable
build_variant = None
repo = pathlib.Path(__file__).resolve().parent
os.environ.setdefault("TORCH_EXTENSIONS_DIR", str(repo / ".torch_extensions"))
sources = [
repo / "torch-ext" / "torch_binding.cpp",
repo / "csrc" / "new_cumsum.cu",
repo / "csrc" / "new_histogram.cu",
repo / "csrc" / "new_indices.cu",
repo / "csrc" / "new_replicate.cu",
repo / "csrc" / "new_sort.cu",
repo / "csrc" / "grouped_gemm" / "grouped_gemm.cu",
]
mod = load(
name="_megablocks_rocm",
sources=[str(s) for s in sources],
extra_include_paths=[str(repo / "csrc")],
extra_cflags=["-O3", "-std=c++17"],
extra_cuda_cflags=["-O3"], # torch switches this to hipcc flags on ROCm builds
extra_ldflags=["-lhipblaslt"],
verbose=True,
is_python_module=False,
)
module_path = pathlib.Path(mod if isinstance(mod, str) else mod.__file__)
print("built:", module_path)
if build_variant is None:
print("kernels not available; skipping package staging")
else:
variant = build_variant()
package_root = repo / "build" / variant / "megablocks"
if package_root.exists():
shutil.rmtree(package_root)
shutil.copytree(
repo / "torch-ext" / "megablocks",
package_root,
ignore=shutil.ignore_patterns("__pycache__"),
)
ops_py = package_root / "_ops.py"
ops_py.write_text('''
import torch
from pathlib import Path
_LIB_NAME = "_megablocks_rocm.so"
def _load_ops():
lib_path = Path(__file__).with_name(_LIB_NAME)
torch.ops.load_library(str(lib_path))
return torch.ops._megablocks_rocm
ops = _load_ops()
def add_op_namespace_prefix(op_name: str) -> str:
return f"_megablocks_rocm::{op_name}"
''')
shutil.copy2(module_path, package_root / module_path.name)
print(f"staged local kernel under build/{variant}/megablocks")
|