|
|
|
|
|
import os |
|
|
import pathlib |
|
|
import shutil |
|
|
|
|
|
from torch.utils.cpp_extension import load |
|
|
|
|
|
try: |
|
|
from kernels.utils import build_variant |
|
|
except ImportError: |
|
|
build_variant = None |
|
|
|
|
|
repo = pathlib.Path(__file__).resolve().parent |
|
|
os.environ.setdefault("TORCH_EXTENSIONS_DIR", str(repo / ".torch_extensions")) |
|
|
|
|
|
sources = [ |
|
|
repo / "torch-ext" / "torch_binding.cpp", |
|
|
repo / "csrc" / "new_cumsum.cu", |
|
|
repo / "csrc" / "new_histogram.cu", |
|
|
repo / "csrc" / "new_indices.cu", |
|
|
repo / "csrc" / "new_replicate.cu", |
|
|
repo / "csrc" / "new_sort.cu", |
|
|
repo / "csrc" / "grouped_gemm" / "grouped_gemm.cu", |
|
|
] |
|
|
|
|
|
mod = load( |
|
|
name="_megablocks_rocm", |
|
|
sources=[str(s) for s in sources], |
|
|
extra_include_paths=[str(repo / "csrc")], |
|
|
extra_cflags=["-O3", "-std=c++17"], |
|
|
extra_cuda_cflags=["-O3"], |
|
|
extra_ldflags=["-lhipblaslt"], |
|
|
verbose=True, |
|
|
is_python_module=False, |
|
|
) |
|
|
|
|
|
module_path = pathlib.Path(mod if isinstance(mod, str) else mod.__file__) |
|
|
print("built:", module_path) |
|
|
|
|
|
if build_variant is None: |
|
|
print("kernels not available; skipping package staging") |
|
|
else: |
|
|
variant = build_variant() |
|
|
package_root = repo / "build" / variant / "megablocks" |
|
|
if package_root.exists(): |
|
|
shutil.rmtree(package_root) |
|
|
shutil.copytree( |
|
|
repo / "torch-ext" / "megablocks", |
|
|
package_root, |
|
|
ignore=shutil.ignore_patterns("__pycache__"), |
|
|
) |
|
|
ops_py = package_root / "_ops.py" |
|
|
ops_py.write_text(''' |
|
|
import torch |
|
|
from pathlib import Path |
|
|
|
|
|
_LIB_NAME = "_megablocks_rocm.so" |
|
|
|
|
|
|
|
|
def _load_ops(): |
|
|
lib_path = Path(__file__).with_name(_LIB_NAME) |
|
|
torch.ops.load_library(str(lib_path)) |
|
|
return torch.ops._megablocks_rocm |
|
|
|
|
|
|
|
|
ops = _load_ops() |
|
|
|
|
|
|
|
|
def add_op_namespace_prefix(op_name: str) -> str: |
|
|
return f"_megablocks_rocm::{op_name}" |
|
|
''') |
|
|
shutil.copy2(module_path, package_root / module_path.name) |
|
|
print(f"staged local kernel under build/{variant}/megablocks") |
|
|
|