diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..5472131e8a1dee26ee1859bafcdf26e0c7a83665 100755 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.glb filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000000000000000000000000000000000000..dd471ef2535ee6a5e331965dfdd3039797529151 --- /dev/null +++ b/.gitignore @@ -0,0 +1,215 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[codz] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py.cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock +#poetry.toml + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. +# https://pdm-project.org/en/latest/usage/project/#working-with-version-control +#pdm.lock +#pdm.toml +.pdm-python +.pdm-build/ + +# pixi +# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. +#pixi.lock +# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one +# in the .venv directory. It is recommended not to include this directory in version control. +.pixi + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.envrc +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Abstra +# Abstra is an AI-powered process automation framework. +# Ignore directories containing user credentials, local state, and settings. +# Learn more at https://abstra.io/docs +.abstra/ + +# Visual Studio Code +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore +# and can be added to the global gitignore or merged into this file. However, if you prefer, +# you could uncomment the following to ignore the entire vscode folder +# .vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Cursor +# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to +# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data +# refer to https://docs.cursor.com/context/ignore-files +.cursorignore +.cursorindexingignore + +# Marimo +marimo/_static/ +marimo/_lsp/ +__marimo__/ + +# Streamlit +.streamlit/secrets.toml + +results/ +weights/ +.gradio/ +demo/segment_result.glb \ No newline at end of file diff --git a/P3-SAM/demo/assets/1.glb b/P3-SAM/demo/assets/1.glb new file mode 100644 index 0000000000000000000000000000000000000000..c38b6c9f790418e6c0aacbb00a01b1be732adffd --- /dev/null +++ b/P3-SAM/demo/assets/1.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b57626cf269cb1e5b949586b6b6b87efa552c66d0a45371fed0c9f47db4c3314 +size 29140044 diff --git a/P3-SAM/demo/assets/2.glb b/P3-SAM/demo/assets/2.glb new file mode 100644 index 0000000000000000000000000000000000000000..9e9e3a14a7ddb839118e4054d022cad55b3b31dd --- /dev/null +++ b/P3-SAM/demo/assets/2.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6cedbb7365e8506d42809cc547fc0d69d9118af1d419c8b8748bae01b545634c +size 8529600 diff --git a/P3-SAM/demo/assets/3.glb b/P3-SAM/demo/assets/3.glb new file mode 100644 index 0000000000000000000000000000000000000000..e68996a70edd6b0e83ec13ea2bfb64fef3d904ac --- /dev/null +++ b/P3-SAM/demo/assets/3.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b179b1141dc0dd847e0be289c3f9b0dfd9446995b38b5050d41df7cbabbc516d +size 30475016 diff --git a/P3-SAM/demo/assets/4.glb b/P3-SAM/demo/assets/4.glb new file mode 100644 index 0000000000000000000000000000000000000000..0b014cfe11dfbc076749668d0664a3de3933a296 --- /dev/null +++ b/P3-SAM/demo/assets/4.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be9fea2e4b63233ab30f929ca9f735ccbd154b1f9652e9bbaacd87839829f02b +size 29621968 diff --git a/P3-SAM/demo/auto_mask.py b/P3-SAM/demo/auto_mask.py new file mode 100755 index 0000000000000000000000000000000000000000..57f18494e11a136391a4437f5ecaf357df00ff92 --- /dev/null +++ b/P3-SAM/demo/auto_mask.py @@ -0,0 +1,1405 @@ +import os +import sys +import torch +import torch.nn as nn +import numpy as np +import argparse +import trimesh +from sklearn.decomposition import PCA +import fpsample +from tqdm import tqdm +import threading +import random + +# from tqdm.notebook import tqdm +import time +import copy +import shutil +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed +from collections import defaultdict + +import numba +from numba import njit + +sys.path.append('..') +from model import build_P3SAM, load_state_dict + +class P3SAM(nn.Module): + def __init__(self): + super().__init__() + build_P3SAM(self) + + def load_state_dict(self, + ckpt_path=None, + state_dict=None, + strict=True, + assign=False, + ignore_seg_mlp=False, + ignore_seg_s2_mlp=False, + ignore_iou_mlp=False): + load_state_dict(self, + ckpt_path=ckpt_path, + state_dict=state_dict, + strict=strict, + assign=assign, + ignore_seg_mlp=ignore_seg_mlp, + ignore_seg_s2_mlp=ignore_seg_s2_mlp, + ignore_iou_mlp=ignore_iou_mlp) + + def forward(self, feats, points, point_prompt, iter=1): + ''' + feats: [K, N, 512] + points: [K, N, 3] + point_prompt: [K, N, 3] + ''' + # print(feats.shape, points.shape, point_prompt.shape) + point_num = points.shape[1] + feats = feats.transpose(0, 1) # [N, K, 512] + points = points.transpose(0, 1) # [N, K, 3] + point_prompt = point_prompt.transpose(0, 1) # [N, K, 3] + feats_seg = torch.cat([feats, points, point_prompt], dim=-1) # [N, K, 512+3+3] + + # 预测mask stage-1 + pred_mask_1 = self.seg_mlp_1(feats_seg).squeeze(-1) # [N, K] + pred_mask_2 = self.seg_mlp_2(feats_seg).squeeze(-1) # [N, K] + pred_mask_3 = self.seg_mlp_3(feats_seg).squeeze(-1) # [N, K] + pred_mask = torch.stack( + [pred_mask_1, pred_mask_2, pred_mask_3], dim=-1 + ) # [N, K, 3] + + for _ in range(iter): + # 预测mask stage-2 + feats_seg_2 = torch.cat([feats_seg, pred_mask], dim=-1) # [N, K, 512+3+3+3] + feats_seg_global = self.seg_s2_mlp_g(feats_seg_2) # [N, K, 512] + feats_seg_global = torch.max(feats_seg_global, dim=0).values # [K, 512] + feats_seg_global = feats_seg_global.unsqueeze(0).repeat( + point_num, 1, 1 + ) # [N, K, 512] + feats_seg_3 = torch.cat( + [feats_seg_global, feats_seg_2], dim=-1 + ) # [N, K, 512+3+3+3+512] + pred_mask_s2_1 = self.seg_s2_mlp_1(feats_seg_3).squeeze(-1) # [N, K] + pred_mask_s2_2 = self.seg_s2_mlp_2(feats_seg_3).squeeze(-1) # [N, K] + pred_mask_s2_3 = self.seg_s2_mlp_3(feats_seg_3).squeeze(-1) # [N, K] + pred_mask_s2 = torch.stack( + [pred_mask_s2_1, pred_mask_s2_2, pred_mask_s2_3], dim=-1 + ) # [N,, K 3] + pred_mask = pred_mask_s2 + + mask_1 = torch.sigmoid(pred_mask_s2_1).to(dtype=torch.float32) # [N, K] + mask_2 = torch.sigmoid(pred_mask_s2_2).to(dtype=torch.float32) # [N, K] + mask_3 = torch.sigmoid(pred_mask_s2_3).to(dtype=torch.float32) # [N, K] + + feats_iou = torch.cat( + [feats_seg_global, feats_seg, pred_mask_s2], dim=-1 + ) # [N, K, 512+3+3+3+512] + feats_iou = self.iou_mlp(feats_iou) # [N, K, 512] + feats_iou = torch.max(feats_iou, dim=0).values # [K, 512] + pred_iou = self.iou_mlp_out(feats_iou) # [K, 3] + pred_iou = torch.sigmoid(pred_iou).to(dtype=torch.float32) # [K, 3] + + mask_1 = mask_1.transpose(0, 1) # [K, N] + mask_2 = mask_2.transpose(0, 1) # [K, N] + mask_3 = mask_3.transpose(0, 1) # [K, N] + + return mask_1, mask_2, mask_3, pred_iou + + +def normalize_pc(pc): + """ + pc: (N, 3) + """ + max_, min_ = np.max(pc, axis=0), np.min(pc, axis=0) + center = (max_ + min_) / 2 + scale = (max_ - min_) / 2 + scale = np.max(np.abs(scale)) + pc = (pc - center) / (scale + 1e-10) + return pc + + +@torch.no_grad() +def get_feat(model, points, normals): + data_dict = { + "coord": points, + "normal": normals, + "color": np.ones_like(points), + "batch": np.zeros(points.shape[0], dtype=np.int64), + } + data_dict = model.transform(data_dict) + for k in data_dict: + if isinstance(data_dict[k], torch.Tensor): + data_dict[k] = data_dict[k].cuda() + point = model.sonata(data_dict) + while "pooling_parent" in point.keys(): + assert "pooling_inverse" in point.keys() + parent = point.pop("pooling_parent") + inverse = point.pop("pooling_inverse") + parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1) + point = parent + feat = point.feat # [M, 1232] + feat = model.mlp(feat) # [M, 512] + feat = feat[point.inverse] # [N, 512] + feats = feat + return feats + + +@torch.no_grad() +def get_mask(model, feats, points, point_prompt, iter=1): + """ + feats: [N, 512] + points: [N, 3] + point_prompt: [K, 3] + """ + point_num = points.shape[0] + prompt_num = point_prompt.shape[0] + feats = feats.unsqueeze(1) # [N, 1, 512] + feats = feats.repeat(1, prompt_num, 1).cuda() # [N, K, 512] + points = torch.from_numpy(points).float().cuda().unsqueeze(1) # [N, 1, 3] + points = points.repeat(1, prompt_num, 1) # [N, K, 3] + prompt_coord = ( + torch.from_numpy(point_prompt).float().cuda().unsqueeze(0) + ) # [1, K, 3] + prompt_coord = prompt_coord.repeat(point_num, 1, 1) # [N, K, 3] + + feats = feats.transpose(0, 1) # [K, N, 512] + points = points.transpose(0, 1) # [K, N, 3] + prompt_coord = prompt_coord.transpose(0, 1) # [K, N, 3] + + mask_1, mask_2, mask_3, pred_iou = model(feats, points, prompt_coord, iter) + + mask_1 = mask_1.transpose(0, 1) # [N, K] + mask_2 = mask_2.transpose(0, 1) # [N, K] + mask_3 = mask_3.transpose(0, 1) # [N, K] + + mask_1 = mask_1.detach().cpu().numpy() > 0.5 + mask_2 = mask_2.detach().cpu().numpy() > 0.5 + mask_3 = mask_3.detach().cpu().numpy() > 0.5 + + org_iou = pred_iou.detach().cpu().numpy() # [K, 3] + + return mask_1, mask_2, mask_3, org_iou + + +def cal_iou(m1, m2): + return np.sum(np.logical_and(m1, m2)) / np.sum(np.logical_or(m1, m2)) + + +def cal_single_iou(m1, m2): + return np.sum(np.logical_and(m1, m2)) / np.sum(m1) + + +def iou_3d(box1, box2, signle=None): + """ + 计算两个三维边界框的交并比 (IoU) + + 参数: + box1 (list): 第一个边界框的坐标 [x1_min, y1_min, z1_min, x1_max, y1_max, z1_max] + box2 (list): 第二个边界框的坐标 [x2_min, y2_min, z2_min, x2_max, y2_max, z2_max] + + 返回: + float: 交并比 (IoU) 值 + """ + # 计算交集的坐标 + intersection_xmin = max(box1[0], box2[0]) + intersection_ymin = max(box1[1], box2[1]) + intersection_zmin = max(box1[2], box2[2]) + intersection_xmax = min(box1[3], box2[3]) + intersection_ymax = min(box1[4], box2[4]) + intersection_zmax = min(box1[5], box2[5]) + + # 判断是否有交集 + if ( + intersection_xmin >= intersection_xmax + or intersection_ymin >= intersection_ymax + or intersection_zmin >= intersection_zmax + ): + return 0.0 # 无交集 + + # 计算交集的体积 + intersection_volume = ( + (intersection_xmax - intersection_xmin) + * (intersection_ymax - intersection_ymin) + * (intersection_zmax - intersection_zmin) + ) + + # 计算两个盒子的体积 + box1_volume = (box1[3] - box1[0]) * (box1[4] - box1[1]) * (box1[5] - box1[2]) + box2_volume = (box2[3] - box2[0]) * (box2[4] - box2[1]) * (box2[5] - box2[2]) + + if signle is None: + # 计算并集的体积 + union_volume = box1_volume + box2_volume - intersection_volume + elif signle == "1": + union_volume = box1_volume + elif signle == "2": + union_volume = box2_volume + else: + raise ValueError("signle must be None or 1 or 2") + + # 计算 IoU + iou = intersection_volume / union_volume if union_volume > 0 else 0.0 + return iou + + +def cal_point_bbox_iou(p1, p2, signle=None): + min_p1 = np.min(p1, axis=0) + max_p1 = np.max(p1, axis=0) + min_p2 = np.min(p2, axis=0) + max_p2 = np.max(p2, axis=0) + box1 = [min_p1[0], min_p1[1], min_p1[2], max_p1[0], max_p1[1], max_p1[2]] + box2 = [min_p2[0], min_p2[1], min_p2[2], max_p2[0], max_p2[1], max_p2[2]] + return iou_3d(box1, box2, signle) + + +def cal_bbox_iou(points, m1, m2): + p1 = points[m1] + p2 = points[m2] + return cal_point_bbox_iou(p1, p2) + + +def clean_mesh(mesh): + """ + mesh: trimesh.Trimesh + """ + # 1. 合并接近的顶点 + mesh.merge_vertices() + + # 2. 删除重复的顶点 + # 3. 删除重复的面片 + mesh.process(True) + return mesh + + +def remove_outliers_iqr(data, factor=1.5): + """ + 基于 IQR 去除离群值 + :param data: 输入的列表或 NumPy 数组 + :param factor: IQR 的倍数(默认 1.5) + :return: 去除离群值后的列表 + """ + data = np.array(data, dtype=np.float32) + q1 = np.percentile(data, 25) # 第一四分位数 + q3 = np.percentile(data, 75) # 第三四分位数 + iqr = q3 - q1 # 四分位距 + lower_bound = q1 - factor * iqr + upper_bound = q3 + factor * iqr + return data[(data >= lower_bound) & (data <= upper_bound)].tolist() + +def better_aabb(points): + x = points[:, 0] + y = points[:, 1] + z = points[:, 2] + x = remove_outliers_iqr(x) + y = remove_outliers_iqr(y) + z = remove_outliers_iqr(z) + min_xyz = np.array([np.min(x), np.min(y), np.min(z)]) + max_xyz = np.array([np.max(x), np.max(y), np.max(z)]) + return [min_xyz, max_xyz] + +def fix_label(face_ids, adjacent_faces, use_aabb=False, mesh=None, show_info=False): + if use_aabb: + def _cal_aabb(face_ids, i, _points_org): + _part_mask = face_ids == i + _faces = mesh.faces[_part_mask] + _faces = np.reshape(_faces, (-1)) + _points = mesh.vertices[_faces] + min_xyz, max_xyz = better_aabb(_points) + _part_mask = ( + (_points_org[:, 0] >= min_xyz[0]) + & (_points_org[:, 0] <= max_xyz[0]) + & (_points_org[:, 1] >= min_xyz[1]) + & (_points_org[:, 1] <= max_xyz[1]) + & (_points_org[:, 2] >= min_xyz[2]) + & (_points_org[:, 2] <= max_xyz[2]) + ) + _part_mask = np.reshape(_part_mask, (-1, 3)) + _part_mask = np.all(_part_mask, axis=1) + return i, [min_xyz, max_xyz], _part_mask + with Timer("计算aabb"): + aabb = {} + unique_ids = np.unique(face_ids) + # print(max(unique_ids)) + aabb_face_mask = {} + _faces = mesh.faces + _vertices = mesh.vertices + _faces = np.reshape(_faces, (-1)) + _points = _vertices[_faces] + with ThreadPoolExecutor(max_workers=20) as executor: + futures = [] + for i in unique_ids: + if i < 0: + continue + futures.append(executor.submit(_cal_aabb, face_ids, i, _points)) + for future in futures: + res = future.result() + aabb[res[0]] = res[1] + aabb_face_mask[res[0]] = res[2] + + # _faces = mesh.faces + # _vertices = mesh.vertices + # _faces = np.reshape(_faces, (-1)) + # _points = _vertices[_faces] + # aabb_face_mask = cal_aabb_mask(_points, face_ids) + + with Timer("合并mesh"): + loop_cnt = 1 + changed = True + progress = tqdm(disable=not show_info) + no_mask_ids = np.where(face_ids < 0)[0].tolist() + faces_max = adjacent_faces.shape[0] + while changed and loop_cnt <= 50: + changed = False + # 获取无色面片 + new_no_mask_ids = [] + for i in no_mask_ids: + # if face_ids[i] < 0: + # 找邻居 + if not (0 <= i < faces_max): + continue + _adj_faces = adjacent_faces[i] + _adj_ids = [] + for j in _adj_faces: + if j == -1: + break + if face_ids[j] >= 0: + _tar_id = face_ids[j] + if use_aabb: + _mask = aabb_face_mask[_tar_id] + if _mask[i]: + _adj_ids.append(_tar_id) + else: + _adj_ids.append(_tar_id) + if len(_adj_ids) == 0: + new_no_mask_ids.append(i) + continue + _max_id = np.argmax(np.bincount(_adj_ids)) + face_ids[i] = _max_id + changed = True + no_mask_ids = new_no_mask_ids + # print(loop_cnt) + progress.update(1) + # progress.set_description(f"合并mesh循环:{loop_cnt} {np.sum(face_ids < 0)}") + loop_cnt += 1 + return face_ids + + +def save_mesh(save_path, mesh, face_ids, color_map): + face_colors = np.zeros((len(mesh.faces), 3), dtype=np.uint8) + for i in tqdm(range(len(mesh.faces)), disable=True): + _max_id = face_ids[i] + if _max_id == -2: + continue + face_colors[i, :3] = color_map[_max_id] + + mesh_save = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces) + mesh_save.visual.face_colors = face_colors + mesh_save.export(save_path) + mesh_save.export(save_path.replace(".glb", ".ply")) + # print('保存mesh完成') + + scene_mesh = trimesh.Scene() + scene_mesh.add_geometry(mesh_save) + unique_ids = np.unique(face_ids) + aabb = [] + for i in unique_ids: + if i == -1 or i == -2: + continue + _part_mask = face_ids == i + _faces = mesh.faces[_part_mask] + _faces = np.reshape(_faces, (-1)) + _points = mesh.vertices[_faces] + min_xyz, max_xyz = better_aabb(_points) + center = (min_xyz + max_xyz) / 2 + size = max_xyz - min_xyz + box = trimesh.path.creation.box_outline() + box.vertices *= size + box.vertices += center + box_color = np.array([[color_map[i][0], color_map[i][1], color_map[i][2], 255]]) + box_color = np.repeat(box_color, len(box.entities), axis=0).astype(np.uint8) + box.colors = box_color + scene_mesh.add_geometry(box) + min_xyz = np.min(_points, axis=0) + max_xyz = np.max(_points, axis=0) + aabb.append([min_xyz, max_xyz]) + scene_mesh.export(save_path.replace(".glb", "_aabb.glb")) + aabb = np.array(aabb) + np.save(save_path.replace(".glb", "_aabb.npy"), aabb) + np.save(save_path.replace(".glb", "_face_ids.npy"), face_ids) + + +def get_aabb_from_face_ids(mesh, face_ids): + unique_ids = np.unique(face_ids) + aabb = [] + for i in unique_ids: + if i == -1 or i == -2: + continue + _part_mask = face_ids == i + _faces = mesh.faces[_part_mask] + _faces = np.reshape(_faces, (-1)) + _points = mesh.vertices[_faces] + min_xyz = np.min(_points, axis=0) + max_xyz = np.max(_points, axis=0) + aabb.append([min_xyz, max_xyz]) + return np.array(aabb) + + +def calculate_face_areas(mesh): + """ + 计算每个三角形面片的面积 + :param mesh: trimesh.Trimesh 对象 + :return: 面片面积数组 (n_faces,) + """ + return mesh.area_faces + # # 提取顶点和面片索引 + # vertices = mesh.vertices + # faces = mesh.faces + + # # 获取所有三个顶点的坐标 + # v0 = vertices[faces[:, 0]] + # v1 = vertices[faces[:, 1]] + # v2 = vertices[faces[:, 2]] + + # # 计算两个边向量 + # edge1 = v1 - v0 + # edge2 = v2 - v0 + + # # 计算叉积的模长(向量面积的两倍) + # cross_product = np.cross(edge1, edge2) + # areas = 0.5 * np.linalg.norm(cross_product, axis=1) + + # return areas + + +def get_connected_region(face_ids, adjacent_faces, return_face_part_ids=False): + vis = [False] * len(face_ids) + parts = [] + face_part_ids = np.ones_like(face_ids) * -1 + for i in range(len(face_ids)): + if vis[i]: + continue + _part = [] + _queue = [i] + while len(_queue) > 0: + _cur_face = _queue.pop(0) + if vis[_cur_face]: + continue + vis[_cur_face] = True + _part.append(_cur_face) + face_part_ids[_cur_face] = len(parts) + if not (0 <= _cur_face < adjacent_faces.shape[0]): + continue + _cur_face_id = face_ids[_cur_face] + _adj_faces = adjacent_faces[_cur_face] + for j in _adj_faces: + if j == -1: + break + if not vis[j] and face_ids[j] == _cur_face_id: + _queue.append(j) + parts.append(_part) + if return_face_part_ids: + return parts, face_part_ids + else: + return parts + +def aabb_distance(box1, box2): + """ + 计算两个轴对齐包围盒(AABB)之间的最近距离。 + :param box1: 元组 (min_x, min_y, min_z, max_x, max_y, max_z) + :param box2: 元组 (min_x, min_y, min_z, max_x, max_y, max_z) + :return: 最近距离(浮点数) + """ + # 解包坐标 + min1, max1 = box1 + min2, max2 = box2 + + # 计算各轴上的分离距离 + dx = max(0, max2[0] - min1[0], max1[0] - min2[0]) # x轴分离距离 + dy = max(0, max2[1] - min1[1], max1[1] - min2[1]) # y轴分离距离 + dz = max(0, max2[2] - min1[2], max1[2] - min2[2]) # z轴分离距离 + + # 如果所有轴都重叠,则距离为0 + if dx == 0 and dy == 0 and dz == 0: + return 0.0 + + # 计算欧几里得距离 + return np.sqrt(dx**2 + dy**2 + dz**2) + +def aabb_volume(aabb): + """ + 计算轴对齐包围盒(AABB)的体积。 + :param aabb: 元组 (min_x, min_y, min_z, max_x, max_y, max_z) + :return: 体积(浮点数) + """ + # 解包坐标 + min_xyz, max_xyz = aabb + + # 计算体积 + dx = max_xyz[0] - min_xyz[0] + dy = max_xyz[1] - min_xyz[1] + dz = max_xyz[2] - min_xyz[2] + return dx * dy * dz + +def find_neighbor_part(parts, adjacent_faces, parts_aabb=None, parts_ids=None): + face2part = {} + for i, part in enumerate(parts): + for face in part: + face2part[face] = i + neighbor_parts = [] + for i, part in enumerate(parts): + neighbor_part = set() + for face in part: + if not (0 <= face < adjacent_faces.shape[0]): + continue + for adj_face in adjacent_faces[face]: + if adj_face == -1: + break + if adj_face not in face2part: + continue + if face2part[adj_face] == i: + continue + if parts_ids is not None and parts_ids[face2part[adj_face]] in [-1, -2]: + continue + neighbor_part.add(face2part[adj_face]) + neighbor_part = list(neighbor_part) + if parts_aabb is not None and parts_ids is not None and (parts_ids[i] == -1 or parts_ids[i] == -2) and len(neighbor_part) == 0: + min_dis = np.inf + min_idx = -1 + for j, _part in tqdm(enumerate(parts)): + if j == i: + continue + if parts_ids[j] == -1 or parts_ids[j] == -2: + continue + aabb_1 = parts_aabb[i] + aabb_2 = parts_aabb[j] + dis = aabb_distance(aabb_1, aabb_2) + if dis < min_dis: + min_dis = dis + min_idx = j + elif dis == min_dis: + if aabb_volume(parts_aabb[j]) < aabb_volume(parts_aabb[min_idx]): + min_idx = j + neighbor_part = [min_idx] + neighbor_parts.append(neighbor_part) + return neighbor_parts + + +def do_post_process(face_areas, parts, adjacent_faces, face_ids, threshold=0.95, show_info=False): + # # 获取邻接面片 + # mesh_save = mesh.copy() + # face_adjacency = mesh.face_adjacency + # adjacent_faces = {} + # for face1, face2 in face_adjacency: + # if face1 not in adjacent_faces: + # adjacent_faces[face1] = [] + # if face2 not in adjacent_faces: + # adjacent_faces[face2] = [] + # adjacent_faces[face1].append(face2) + # adjacent_faces[face2].append(face1) + + # parts = get_connected_region(face_ids, adjacent_faces) + + + unique_ids = np.unique(face_ids) + if show_info: + print(f"连通区域数量:{len(parts)}") + print(f"ID数量:{len(unique_ids)}") + + # face_areas = calculate_face_areas(mesh) + total_area = np.sum(face_areas) + if show_info: + print(f"总面积:{total_area}") + part_areas = [] + for i, part in enumerate(parts): + part_area = np.sum(face_areas[part]) + part_areas.append(float(part_area / total_area)) + + sorted_parts = sorted(zip(part_areas, parts), key=lambda x: x[0], reverse=True) + parts = [x[1] for x in sorted_parts] + part_areas = [x[0] for x in sorted_parts] + integral_part_areas = np.cumsum(part_areas) + + neighbor_parts = find_neighbor_part(parts, adjacent_faces) + + new_face_ids = face_ids.copy() + + for i, part in enumerate(parts): + if integral_part_areas[i] > threshold and part_areas[i] < 0.01: + if len(neighbor_parts[i]) > 0: + max_area = 0 + max_part = -1 + for j in neighbor_parts[i]: + if integral_part_areas[j] > threshold: + continue + if part_areas[j] > max_area: + max_area = part_areas[j] + max_part = j + if max_part != -1: + if show_info: + print(f"合并mesh:{i} {max_part}") + parts[max_part].extend(part) + parts[i] = [] + target_face_id = face_ids[parts[max_part][0]] + for face in part: + new_face_ids[face] = target_face_id + + return new_face_ids + + +def do_no_mask_process(parts, face_ids): + # # 获取邻接面片 + # mesh_save = mesh.copy() + # face_adjacency = mesh.face_adjacency + # adjacent_faces = {} + # for face1, face2 in face_adjacency: + # if face1 not in adjacent_faces: + # adjacent_faces[face1] = [] + # if face2 not in adjacent_faces: + # adjacent_faces[face2] = [] + # adjacent_faces[face1].append(face2) + # adjacent_faces[face2].append(face1) + # parts = get_connected_region(face_ids, adjacent_faces) + + unique_ids = np.unique(face_ids) + max_id = np.max(unique_ids) + if -1 or -2 in unique_ids: + new_face_ids = face_ids.copy() + for i, part in enumerate(parts): + if face_ids[part[0]] == -1 or face_ids[part[0]] == -2: + for face in part: + new_face_ids[face] = max_id + 1 + max_id += 1 + return new_face_ids + else: + return face_ids + + +def union_aabb(aabb1, aabb2): + min_xyz1 = aabb1[0] + max_xyz1 = aabb1[1] + min_xyz2 = aabb2[0] + max_xyz2 = aabb2[1] + min_xyz = np.minimum(min_xyz1, min_xyz2) + max_xyz = np.maximum(max_xyz1, max_xyz2) + return [min_xyz, max_xyz] + + +def aabb_increase(aabb1, aabb2): + min_xyz_before = aabb1[0] + max_xyz_before = aabb1[1] + min_xyz_after, max_xyz_after = union_aabb(aabb1, aabb2) + min_xyz_increase = np.abs(min_xyz_after - min_xyz_before) / np.abs(min_xyz_before) + max_xyz_increase = np.abs(max_xyz_after - max_xyz_before) / np.abs(max_xyz_before) + return min_xyz_increase, max_xyz_increase + +def sort_multi_list(multi_list, key=lambda x: x[0], reverse=False): + ''' + multi_list: [list1, list2, list3, list4, ...], len(list1)=N, len(list2)=N, len(list3)=N, ... + key: 排序函数,默认按第一个元素排序 + reverse: 排序顺序,默认降序 + return: + [list1, list2, list3, list4, ...]: 按同一个顺序排序后的多个list + ''' + sorted_list = sorted(zip(*multi_list), key=key, reverse=reverse) + return zip(*sorted_list) + + +class Timer: + STATE = True + def __init__(self, name): + self.name = name + + def __enter__(self): + if not Timer.STATE: + return + self.start_time = time.time() + return self # 可以返回 self 以便在 with 块内访问 + + def __exit__(self, exc_type, exc_val, exc_tb): + if not Timer.STATE: + return + self.end_time = time.time() + self.elapsed_time = self.end_time - self.start_time + print(f">>>>>>代码{self.name} 运行时间: {self.elapsed_time:.4f} 秒") + +###################### NUMBA 加速 ###################### +@njit +def build_adjacent_faces_numba(face_adjacency): + """ + 使用 Numba 加速构建邻接面片数组。 + :param face_adjacency: (N, 2) numpy 数组,包含邻接面片对。 + :return: + - adj_list: 一维数组,存储所有邻接面片。 + - offsets: 一维数组,记录每个面片的邻接起始位置。 + """ + n_faces = np.max(face_adjacency) + 1 # 总面片数 + n_edges = face_adjacency.shape[0] # 总邻接边数 + + # 第一步:统计每个面片的邻接数量(度数) + degrees = np.zeros(n_faces, dtype=np.int32) + for i in range(n_edges): + f1, f2 = face_adjacency[i] + degrees[f1] += 1 + degrees[f2] += 1 + max_degree = np.max(degrees) # 最大度数 + + adjacent_faces = np.ones((n_faces, max_degree), dtype=np.int32) * -1 # 邻接面片数组 + adjacent_faces_count = np.zeros(n_faces, dtype=np.int32) # 邻接面片计数器 + for i in range(n_edges): + f1, f2 = face_adjacency[i] + adjacent_faces[f1, adjacent_faces_count[f1]] = f2 + adjacent_faces_count[f1] += 1 + adjacent_faces[f2, adjacent_faces_count[f2]] = f1 + adjacent_faces_count[f2] += 1 + return adjacent_faces +###################### NUMBA 加速 ###################### + +def mesh_sam( + model, + mesh, + save_path, + point_num=100000, + prompt_num=400, + save_mid_res=False, + show_info=False, + post_process=False, + threshold=0.95, + clean_mesh_flag=True, + seed=42, + prompt_bs=32, +): + with Timer("加载mesh"): + model, model_parallel = model + if clean_mesh_flag: + mesh = clean_mesh(mesh) + mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces) + if show_info: + print(f"点数:{mesh.vertices.shape[0]} 面片数:{mesh.faces.shape[0]}") + + point_num = 100000 + prompt_num = 400 + with Timer("获取邻接面片"): + face_adjacency = mesh.face_adjacency + with Timer("处理邻接面片"): + adjacent_faces = build_adjacent_faces_numba(face_adjacency) + + + with Timer("采样点云"): + _points, face_idx = trimesh.sample.sample_surface(mesh, point_num, seed=seed) + _points_org = _points.copy() + _points = normalize_pc(_points) + normals = mesh.face_normals[face_idx] + if show_info: + print(f"点数:{point_num} 面片数:{mesh.faces.shape[0]}") + + with Timer("获取特征"): + _feats = get_feat(model, _points, normals) + if show_info: + print("预处理特征") + + if save_mid_res: + feat_save = _feats.float().detach().cpu().numpy() + data_scaled = feat_save / np.linalg.norm(feat_save, axis=-1, keepdims=True) + pca = PCA(n_components=3) + data_reduced = pca.fit_transform(data_scaled) + data_reduced = (data_reduced - data_reduced.min()) / ( + data_reduced.max() - data_reduced.min() + ) + _colors_pca = (data_reduced * 255).astype(np.uint8) + pc_save = trimesh.points.PointCloud(_points, colors=_colors_pca) + pc_save.export(os.path.join(save_path, "point_pca.glb")) + pc_save.export(os.path.join(save_path, "point_pca.ply")) + if show_info: + print("PCA获取特征颜色") + + with Timer("FPS采样提示点"): + fps_idx = fpsample.fps_sampling(_points, prompt_num) + _point_prompts = _points[fps_idx] + if save_mid_res: + trimesh.points.PointCloud(_point_prompts, colors=_colors_pca[fps_idx]).export( + os.path.join(save_path, "point_prompts_pca.glb") + ) + trimesh.points.PointCloud(_point_prompts, colors=_colors_pca[fps_idx]).export( + os.path.join(save_path, "point_prompts_pca.ply") + ) + if show_info: + print("采样完成") + + with Timer("推理"): + bs = prompt_bs + step_num = prompt_num // bs + 1 + mask_res = [] + iou_res = [] + for i in tqdm(range(step_num), disable=not show_info): + cur_propmt = _point_prompts[bs * i : bs * (i + 1)] + pred_mask_1, pred_mask_2, pred_mask_3, pred_iou = get_mask( + model_parallel, _feats, _points, cur_propmt + ) + pred_mask = np.stack( + [pred_mask_1, pred_mask_2, pred_mask_3], axis=-1 + ) # [N, K, 3] + max_idx = np.argmax(pred_iou, axis=-1) # [K] + for j in range(max_idx.shape[0]): + mask_res.append(pred_mask[:, j, max_idx[j]]) + iou_res.append(pred_iou[j, max_idx[j]]) + mask_res = np.stack(mask_res, axis=-1) # [N, K] + if show_info: + print("prmopt 推理完成") + + with Timer("根据IOU排序"): + iou_res = np.array(iou_res).tolist() + mask_iou = [[mask_res[:, i], iou_res[i]] for i in range(prompt_num)] + mask_iou_sorted = sorted(mask_iou, key=lambda x: x[1], reverse=True) + mask_sorted = [mask_iou_sorted[i][0] for i in range(prompt_num)] + iou_sorted = [mask_iou_sorted[i][1] for i in range(prompt_num)] + + with Timer("NMS"): + clusters = defaultdict(list) + with ThreadPoolExecutor(max_workers=20) as executor: + for i in tqdm(range(prompt_num), desc="NMS", disable=not show_info): + _mask = mask_sorted[i] + futures = [] + for j in clusters.keys(): + futures.append(executor.submit(cal_iou, _mask, mask_sorted[j])) + + for j, future in zip(clusters.keys(), futures): + if future.result() > 0.9: + clusters[j].append(i) + break + else: + clusters[i].append(i) + + if show_info: + print(f"NMS完成,mask数量:{len(clusters)}") + + if save_mid_res: + part_mask_save_path = os.path.join(save_path, "part_mask") + if os.path.exists(part_mask_save_path): + shutil.rmtree(part_mask_save_path) + os.makedirs(part_mask_save_path, exist_ok=True) + for i in tqdm(clusters.keys(), desc="保存mask", disable=not show_info): + cluster_num = len(clusters[i]) + cluster_iou = iou_sorted[i] + cluster_area = np.sum(mask_sorted[i]) + if cluster_num <= 2: + continue + mask_save = mask_sorted[i] + mask_save = np.expand_dims(mask_save, axis=-1) + mask_save = np.repeat(mask_save, 3, axis=-1) + mask_save = (mask_save * 255).astype(np.uint8) + point_save = trimesh.points.PointCloud(_points, colors=mask_save) + point_save.export( + os.path.join( + part_mask_save_path, + f"mask_{i}_iou_{cluster_iou:.5f}_area_{cluster_area:.5f}_num_{cluster_num}.glb", + ) + ) + + # 过滤只有一个mask的cluster + with Timer("过滤只有一个mask的cluster"): + filtered_clusters = [] + other_clusters = [] + for i in clusters.keys(): + if len(clusters[i]) > 2: + filtered_clusters.append(i) + else: + other_clusters.append(i) + if show_info: + print( + f"过滤前:{len(clusters)} 个cluster," + f"过滤后:{len(filtered_clusters)} 个cluster" + ) + + # 再次合并 + with Timer("再次合并"): + filtered_clusters_num = len(filtered_clusters) + cluster2 = {} + is_union = [False] * filtered_clusters_num + for i in range(filtered_clusters_num): + if is_union[i]: + continue + cur_cluster = filtered_clusters[i] + cluster2[cur_cluster] = [cur_cluster] + for j in range(i + 1, filtered_clusters_num): + if is_union[j]: + continue + tar_cluster = filtered_clusters[j] + if ( + cal_bbox_iou( + _points, mask_sorted[tar_cluster], mask_sorted[cur_cluster] + ) + > 0.5 + ): + cluster2[cur_cluster].append(tar_cluster) + is_union[j] = True + if show_info: + print(f"再次合并,合并数量:{len(cluster2.keys())}") + + with Timer("计算没有mask的点"): + no_mask = np.ones(point_num) + for i in cluster2: + part_mask = mask_sorted[i] + no_mask[part_mask] = 0 + if show_info: + print( + f"{np.sum(no_mask == 1)} 个点没有mask," + f" 占比:{np.sum(no_mask == 1) / point_num:.4f}" + ) + + with Timer("修补遗漏mask"): + # 查询漏掉的mask + for i in tqdm(range(len(mask_sorted)), desc="漏掉mask", disable=not show_info): + if i in cluster2: + continue + part_mask = mask_sorted[i] + _iou = cal_single_iou(part_mask, no_mask) + if _iou > 0.7: + cluster2[i] = [i] + no_mask[part_mask] = 0 + if save_mid_res: + mask_save = mask_sorted[i] + mask_save = np.expand_dims(mask_save, axis=-1) + mask_save = np.repeat(mask_save, 3, axis=-1) + mask_save = (mask_save * 255).astype(np.uint8) + point_save = trimesh.points.PointCloud(_points, colors=mask_save) + cluster_iou = iou_sorted[i] + cluster_area = int(np.sum(mask_sorted[i])) + cluster_num = 1 + point_save.export( + os.path.join( + part_mask_save_path, + f"mask_{i}_iou_{cluster_iou:.5f}_area_{cluster_area:.5f}_num_{cluster_num}.glb", + ) + ) + if show_info: + print(f"修补遗漏mask:{len(cluster2.keys())}") + + with Timer("计算点云最终mask"): + final_mask = list(cluster2.keys()) + final_mask_area = [int(np.sum(mask_sorted[i])) for i in final_mask] + final_mask_area = [ + [final_mask[i], final_mask_area[i]] for i in range(len(final_mask)) + ] + final_mask_area_sorted = sorted(final_mask_area, key=lambda x: x[1], reverse=True) + final_mask_sorted = [ + final_mask_area_sorted[i][0] for i in range(len(final_mask_area)) + ] + final_mask_area_sorted = [ + final_mask_area_sorted[i][1] for i in range(len(final_mask_area)) + ] + if show_info: + print(f"最终mask数量:{len(final_mask_sorted)}") + + with Timer("点云上色"): + # 生成color map + color_map = {} + for i in final_mask_sorted: + part_color = np.random.rand(3) * 255 + color_map[i] = part_color + # print(color_map) + + result_mask = -np.ones(point_num, dtype=np.int64) + for i in final_mask_sorted: + part_mask = mask_sorted[i] + result_mask[part_mask] = i + if save_mid_res: + # 保存点云结果 + result_colors = np.zeros_like(_colors_pca) + for i in final_mask_sorted: + part_color = color_map[i] + part_mask = mask_sorted[i] + result_colors[part_mask, :3] = part_color + trimesh.points.PointCloud(_points, colors=result_colors).export( + os.path.join(save_path, "auto_mask_cluster.glb") + ) + trimesh.points.PointCloud(_points, colors=result_colors).export( + os.path.join(save_path, "auto_mask_cluster.ply") + ) + if show_info: + print("保存点云完成") + + + with Timer("投影Mesh并统计label"): + # 保存mesh结果 + face_seg_res = {} + for i in final_mask_sorted: + _part_mask = result_mask == i + _face_idx = face_idx[_part_mask] + for k in _face_idx: + if k not in face_seg_res: + face_seg_res[k] = [] + face_seg_res[k].append(i) + _part_mask = result_mask == -1 + _face_idx = face_idx[_part_mask] + for k in _face_idx: + if k not in face_seg_res: + face_seg_res[k] = [] + face_seg_res[k].append(-1) + + face_ids = -np.ones(len(mesh.faces), dtype=np.int64) * 2 + for i in tqdm(face_seg_res, leave=False, disable=True): + _seg_ids = np.array(face_seg_res[i]) + # 获取最多的seg_id + _max_id = np.argmax(np.bincount(_seg_ids + 2)) - 2 + face_ids[i] = _max_id + face_ids_org = face_ids.copy() + if show_info: + print("生成face_ids完成") + + + with Timer("第一次修复face_ids"): + face_ids += 1 + face_ids = fix_label(face_ids, adjacent_faces, mesh=mesh, show_info=show_info) + face_ids -= 1 + if show_info: + print("修复face_ids完成") + + color_map[-1] = np.array([255, 0, 0], dtype=np.uint8) + + if save_mid_res: + save_mesh( + os.path.join(save_path, "auto_mask_mesh.glb"), mesh, face_ids, color_map + ) + save_mesh( + os.path.join(save_path, "auto_mask_mesh_org.glb"), + mesh, + face_ids_org, + color_map, + ) + if show_info: + print("保存mesh结果完成") + + with Timer("计算连通区域"): + face_areas = calculate_face_areas(mesh) + mesh_total_area = np.sum(face_areas) + parts = get_connected_region(face_ids, adjacent_faces) + connected_parts, _face_connected_parts_ids = get_connected_region(np.ones_like(face_ids), adjacent_faces, return_face_part_ids=True) + if show_info: + print(f"共{len(parts)}个mesh") + with Timer("排序连通区域"): + parts_cp_idx = [] + for x in parts: + _face_idx = x[0] + parts_cp_idx.append(_face_connected_parts_ids[_face_idx]) + parts_cp_idx = np.array(parts_cp_idx) + parts_areas = [float(np.sum(face_areas[x])) for x in parts] + connected_parts_areas = [float(np.sum(face_areas[x])) for x in connected_parts] + parts_cp_areas = [connected_parts_areas[x] for x in parts_cp_idx] + parts_sorted, parts_areas_sorted, parts_cp_areas_sorted = sort_multi_list([parts, parts_areas, parts_cp_areas], key=lambda x: x[1], reverse=True) + + with Timer("去除面积过小的区域"): + filtered_parts = [] + other_parts = [] + for i in range(len(parts_sorted)): + parts = parts_sorted[i] + area = parts_areas_sorted[i] + cp_area = parts_cp_areas_sorted[i] + if area / (cp_area+1e-7) > 0.001: + filtered_parts.append(i) + else: + other_parts.append(i) + if show_info: + print(f"保留{len(filtered_parts)}个mesh, 其他{len(other_parts)}个mesh") + + with Timer("去除面积过小区域的label"): + face_ids_2 = face_ids.copy() + part_num = len(cluster2.keys()) + for j in other_parts: + parts = parts_sorted[j] + for i in parts: + face_ids_2[i] = -1 + + with Timer("第二次修复face_ids"): + face_ids_3 = face_ids_2.copy() + face_ids_3 = fix_label(face_ids_3, adjacent_faces, mesh=mesh, show_info=show_info) + + if save_mid_res: + save_mesh( + os.path.join(save_path, "auto_mask_mesh_filtered_2.glb"), + mesh, + face_ids_3, + color_map, + ) + if show_info: + print("保存mesh结果完成") + + with Timer("第二次计算连通区域"): + parts_2 = get_connected_region(face_ids_3, adjacent_faces) + parts_areas_2 = [float(np.sum(face_areas[x])) for x in parts_2] + parts_ids_2 = [face_ids_3[x[0]] for x in parts_2] + + with Timer("添加过大的缺失part"): + color_map_2 = copy.deepcopy(color_map) + max_id = np.max(parts_ids_2) + for i in range(len(parts_2)): + _parts = parts_2[i] + _area = parts_areas_2[i] + _parts_id = face_ids_3[_parts[0]] + if _area / mesh_total_area > 0.001: + if _parts_id == -1 or _parts_id == -2: + parts_ids_2[i] = max_id + 1 + max_id += 1 + color_map_2[max_id] = np.random.rand(3) * 255 + if show_info: + print(f"新增part {max_id}") + # else: + # parts_ids_2[i] = -1 + + with Timer("赋值新的face_ids"): + face_ids_4 = face_ids_3.copy() + for i in range(len(parts_2)): + _parts = parts_2[i] + _parts_id = parts_ids_2[i] + for j in _parts: + face_ids_4[j] = _parts_id + with Timer("计算part和label的aabb"): + ids_aabb = {} + unique_ids = np.unique(face_ids_4) + for i in unique_ids: + if i < 0: + continue + _part_mask = face_ids_4 == i + _faces = mesh.faces[_part_mask] + _faces = np.reshape(_faces, (-1)) + _points = mesh.vertices[_faces] + min_xyz = np.min(_points, axis=0) + max_xyz = np.max(_points, axis=0) + ids_aabb[i] = [min_xyz, max_xyz] + + parts_2_aabb = [] + for i in range(len(parts_2)): + _parts = parts_2[i] + _faces = mesh.faces[_parts] + _faces = np.reshape(_faces, (-1)) + _points = mesh.vertices[_faces] + min_xyz = np.min(_points, axis=0) + max_xyz = np.max(_points, axis=0) + parts_2_aabb.append([min_xyz, max_xyz]) + + with Timer("计算part的邻居"): + parts_2_neighbor = find_neighbor_part(parts_2, adjacent_faces, parts_2_aabb, parts_ids_2) + with Timer("合并无mask区域"): + for i in range(len(parts_2)): + _parts = parts_2[i] + _ids = parts_ids_2[i] + if _ids == -1 or _ids == -2: + _cur_aabb = parts_2_aabb[i] + _min_aabb_increase = 1e10 + _min_id = -1 + for j in parts_2_neighbor[i]: + if parts_ids_2[j] == -1 or parts_ids_2[j] == -2: + continue + _tar_id = parts_ids_2[j] + _tar_aabb = ids_aabb[_tar_id] + _min_increase, _max_increase = aabb_increase(_tar_aabb, _cur_aabb) + _increase = max(np.max(_min_increase), np.max(_max_increase)) + if _min_aabb_increase > _increase: + _min_aabb_increase = _increase + _min_id = _tar_id + if _min_id >= 0: + parts_ids_2[i] = _min_id + + + with Timer("再次赋值新的face_ids"): + face_ids_4 = face_ids_3.copy() + for i in range(len(parts_2)): + _parts = parts_2[i] + _parts_id = parts_ids_2[i] + for j in _parts: + face_ids_4[j] = _parts_id + + final_face_ids = face_ids_4 + if save_mid_res: + save_mesh( + os.path.join(save_path, "auto_mask_mesh_final.glb"), + mesh, + face_ids_4, + color_map_2, + ) + + if post_process: + parts = get_connected_region(final_face_ids, adjacent_faces) + final_face_ids = do_no_mask_process(parts, final_face_ids) + face_ids_5 = do_post_process(face_areas, parts, adjacent_faces, face_ids_4, threshold, show_info=show_info) + if save_mid_res: + save_mesh( + os.path.join(save_path, "auto_mask_mesh_final_post.glb"), + mesh, + face_ids_5, + color_map_2, + ) + final_face_ids = face_ids_5 + with Timer("计算最后的aabb"): + aabb = get_aabb_from_face_ids(mesh, final_face_ids) + return aabb, final_face_ids, mesh + + +class AutoMask: + def __init__( + self, + ckpt_path=None, + point_num=100000, + prompt_num=400, + threshold=0.95, + post_process=True, + ): + """ + ckpt_path: str, 模型路径 + point_num: int, 采样点数量 + prompt_num: int, 提示数量 + threshold: float, 阈值 + post_process: bool, 是否后处理 + """ + self.model = P3SAM() + self.model.load_state_dict(ckpt_path) + self.model.eval() + self.model_parallel = torch.nn.DataParallel(self.model) + self.model.cuda() + self.model_parallel.cuda() + self.point_num = point_num + self.prompt_num = prompt_num + self.threshold = threshold + self.post_process = post_process + + def predict_aabb( + self, mesh, point_num=None, prompt_num=None, threshold=None, post_process=None, save_path=None, save_mid_res=False, show_info=True, clean_mesh_flag=True, seed=42, is_parallel=True, prompt_bs=32 + ): + """ + Parameters: + mesh: trimesh.Trimesh, 输入网格 + point_num: int, 采样点数量 + prompt_num: int, 提示数量 + threshold: float, 阈值 + post_process: bool, 是否后处理 + Returns: + aabb: np.ndarray, 包围盒 + face_ids: np.ndarray, 面id + """ + point_num = point_num if point_num is not None else self.point_num + prompt_num = prompt_num if prompt_num is not None else self.prompt_num + threshold = threshold if threshold is not None else self.threshold + post_process = post_process if post_process is not None else self.post_process + return mesh_sam( + [self.model, self.model_parallel if is_parallel else self.model], + mesh, + save_path=save_path, + point_num=point_num, + prompt_num=prompt_num, + threshold=threshold, + post_process=post_process, + show_info=show_info, + save_mid_res=save_mid_res, + clean_mesh_flag=clean_mesh_flag, + seed=seed, + prompt_bs=prompt_bs, + ) + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +if __name__ == '__main__': + argparser = argparse.ArgumentParser() + argparser.add_argument('--ckpt_path', type=str, default=None, help='模型路径') + argparser.add_argument('--mesh_path', type=str, default='assets/1.glb', help='输入网格路径') + argparser.add_argument('--output_path', type=str, default='results/1', help='保存路径') + argparser.add_argument('--point_num', type=int, default=100000, help='采样点数量') + argparser.add_argument('--prompt_num', type=int, default=400, help='提示数量') + argparser.add_argument('--threshold', type=float, default=0.95, help='阈值') + argparser.add_argument('--post_process', type=int, default=0, help='是否后处理') + argparser.add_argument('--save_mid_res', type=int, default=1, help='是否保存中间结果') + argparser.add_argument('--show_info', type=int, default=1, help='是否显示信息') + argparser.add_argument('--show_time_info', type=int, default=1, help='是否显示时间信息') + argparser.add_argument('--seed', type=int, default=42, help='随机种子') + argparser.add_argument('--parallel', type=int, default=1, help='是否使用多卡') + argparser.add_argument('--prompt_bs', type=int, default=32, help='提示点推理时的batch size大小') + argparser.add_argument('--clean_mesh', type=int, default=1, help='是否清洗网格') + args = argparser.parse_args() + Timer.STATE = args.show_time_info + + + output_path = args.output_path + os.makedirs(output_path, exist_ok=True) + ckpt_path = args.ckpt_path + auto_mask = AutoMask(ckpt_path) + mesh_path = args.mesh_path + if os.path.isdir(mesh_path): + for file in os.listdir(mesh_path): + if not (file.endswith('.glb') or file.endswith('.obj') or file.endswith('.ply')): + continue + _mesh_path = os.path.join(mesh_path, file) + _output_path = os.path.join(output_path, file[:-4]) + os.makedirs(_output_path, exist_ok=True) + mesh = trimesh.load(_mesh_path, force='mesh') + set_seed(args.seed) + aabb, face_ids, mesh = auto_mask.predict_aabb(mesh, + save_path=_output_path, + point_num=args.point_num, + prompt_num=args.prompt_num, + threshold=args.threshold, + post_process=args.post_process, + save_mid_res=args.save_mid_res, + show_info=args.show_info, + seed=args.seed, + is_parallel=args.parallel, + clean_mesh_flag=args.clean_mesh,) + else: + mesh = trimesh.load(mesh_path, force='mesh') + set_seed(args.seed) + aabb, face_ids, mesh = auto_mask.predict_aabb(mesh, + save_path=output_path, + point_num=args.point_num, + prompt_num=args.prompt_num, + threshold=args.threshold, + post_process=args.post_process, + save_mid_res=args.save_mid_res, + show_info=args.show_info, + seed=args.seed, + is_parallel=args.parallel, + clean_mesh_flag=args.clean_mesh,) + + ############################################### + ## 可以通过以下代码保存返回的结果 + ## You can save the returned result by the following code + ################# save result ################# + # color_map = {} + # unique_ids = np.unique(face_ids) + # for i in unique_ids: + # if i == -1: + # continue + # part_color = np.random.rand(3) * 255 + # color_map[i] = part_color + # face_colors = [] + # for i in face_ids: + # if i == -1: + # face_colors.append([0, 0, 0]) + # else: + # face_colors.append(color_map[i]) + # face_colors = np.array(face_colors).astype(np.uint8) + # mesh_save = mesh.copy() + # mesh_save.visual.face_colors = face_colors + # mesh_save.export(os.path.join(output_path, 'auto_mask_mesh.glb')) + # scene_mesh = trimesh.Scene() + # scene_mesh.add_geometry(mesh_save) + # for i in range(len(aabb)): + # min_xyz, max_xyz = aabb[i] + # center = (min_xyz + max_xyz) / 2 + # size = max_xyz - min_xyz + # box = trimesh.path.creation.box_outline() + # box.vertices *= size + # box.vertices += center + # scene_mesh.add_geometry(box) + # scene_mesh.export(os.path.join(output_path, 'auto_mask_aabb.glb')) + ################# save result ################# + +''' +python auto_mask.py --parallel 0 +python auto_mask.py --ckpt_path ../weights/last.ckpt --mesh_path assets/1.glb --output_path results/1 --parallel 0 +python auto_mask.py --ckpt_path ../weights/last.ckpt --mesh_path assets --output_path results/all +''' \ No newline at end of file diff --git a/P3-SAM/demo/auto_mask_no_postprocess.py b/P3-SAM/demo/auto_mask_no_postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..83517b5c4a49d8d4e67e828cdf1061b57a862441 --- /dev/null +++ b/P3-SAM/demo/auto_mask_no_postprocess.py @@ -0,0 +1,943 @@ +import os +import sys +import torch +import torch.nn as nn +import numpy as np +import argparse +import trimesh +from sklearn.decomposition import PCA +import fpsample +from tqdm import tqdm +import threading +import random + +# from tqdm.notebook import tqdm +import time +import copy +import shutil +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed +from collections import defaultdict + +import numba +from numba import njit + +sys.path.append("..") +from model import build_P3SAM, load_state_dict + +from utils.chamfer3D.dist_chamfer_3D import chamfer_3DDist + +cmd_loss = chamfer_3DDist() + + +class P3SAM(nn.Module): + def __init__(self): + super().__init__() + build_P3SAM(self) + + def load_state_dict(self, + ckpt_path=None, + state_dict=None, + strict=True, + assign=False, + ignore_seg_mlp=False, + ignore_seg_s2_mlp=False, + ignore_iou_mlp=False): + load_state_dict(self, + ckpt_path=ckpt_path, + state_dict=state_dict, + strict=strict, + assign=assign, + ignore_seg_mlp=ignore_seg_mlp, + ignore_seg_s2_mlp=ignore_seg_s2_mlp, + ignore_iou_mlp=ignore_iou_mlp) + + def forward(self, feats, points, point_prompt, iter=1): + """ + feats: [K, N, 512] + points: [K, N, 3] + point_prompt: [K, N, 3] + """ + # print(feats.shape, points.shape, point_prompt.shape) + point_num = points.shape[1] + feats = feats.transpose(0, 1) # [N, K, 512] + points = points.transpose(0, 1) # [N, K, 3] + point_prompt = point_prompt.transpose(0, 1) # [N, K, 3] + feats_seg = torch.cat([feats, points, point_prompt], dim=-1) # [N, K, 512+3+3] + + # 预测mask stage-1 + pred_mask_1 = self.seg_mlp_1(feats_seg).squeeze(-1) # [N, K] + pred_mask_2 = self.seg_mlp_2(feats_seg).squeeze(-1) # [N, K] + pred_mask_3 = self.seg_mlp_3(feats_seg).squeeze(-1) # [N, K] + pred_mask = torch.stack( + [pred_mask_1, pred_mask_2, pred_mask_3], dim=-1 + ) # [N, K, 3] + + for _ in range(iter): + # 预测mask stage-2 + feats_seg_2 = torch.cat([feats_seg, pred_mask], dim=-1) # [N, K, 512+3+3+3] + feats_seg_global = self.seg_s2_mlp_g(feats_seg_2) # [N, K, 512] + feats_seg_global = torch.max(feats_seg_global, dim=0).values # [K, 512] + feats_seg_global = feats_seg_global.unsqueeze(0).repeat( + point_num, 1, 1 + ) # [N, K, 512] + feats_seg_3 = torch.cat( + [feats_seg_global, feats_seg_2], dim=-1 + ) # [N, K, 512+3+3+3+512] + pred_mask_s2_1 = self.seg_s2_mlp_1(feats_seg_3).squeeze(-1) # [N, K] + pred_mask_s2_2 = self.seg_s2_mlp_2(feats_seg_3).squeeze(-1) # [N, K] + pred_mask_s2_3 = self.seg_s2_mlp_3(feats_seg_3).squeeze(-1) # [N, K] + pred_mask_s2 = torch.stack( + [pred_mask_s2_1, pred_mask_s2_2, pred_mask_s2_3], dim=-1 + ) # [N,, K 3] + pred_mask = pred_mask_s2 + + mask_1 = torch.sigmoid(pred_mask_s2_1).to(dtype=torch.float32) # [N, K] + mask_2 = torch.sigmoid(pred_mask_s2_2).to(dtype=torch.float32) # [N, K] + mask_3 = torch.sigmoid(pred_mask_s2_3).to(dtype=torch.float32) # [N, K] + + feats_iou = torch.cat( + [feats_seg_global, feats_seg, pred_mask_s2], dim=-1 + ) # [N, K, 512+3+3+3+512] + feats_iou = self.iou_mlp(feats_iou) # [N, K, 512] + feats_iou = torch.max(feats_iou, dim=0).values # [K, 512] + pred_iou = self.iou_mlp_out(feats_iou) # [K, 3] + pred_iou = torch.sigmoid(pred_iou).to(dtype=torch.float32) # [K, 3] + + mask_1 = mask_1.transpose(0, 1) # [K, N] + mask_2 = mask_2.transpose(0, 1) # [K, N] + mask_3 = mask_3.transpose(0, 1) # [K, N] + + return mask_1, mask_2, mask_3, pred_iou + + +def normalize_pc(pc): + """ + pc: (N, 3) + """ + max_, min_ = np.max(pc, axis=0), np.min(pc, axis=0) + center = (max_ + min_) / 2 + scale = (max_ - min_) / 2 + scale = np.max(np.abs(scale)) + pc = (pc - center) / (scale + 1e-10) + return pc + + +@torch.no_grad() +def get_feat(model, points, normals): + data_dict = { + "coord": points, + "normal": normals, + "color": np.ones_like(points), + "batch": np.zeros(points.shape[0], dtype=np.int64), + } + data_dict = model.transform(data_dict) + for k in data_dict: + if isinstance(data_dict[k], torch.Tensor): + data_dict[k] = data_dict[k].cuda() + point = model.sonata(data_dict) + while "pooling_parent" in point.keys(): + assert "pooling_inverse" in point.keys() + parent = point.pop("pooling_parent") + inverse = point.pop("pooling_inverse") + parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1) + point = parent + feat = point.feat # [M, 1232] + feat = model.mlp(feat) # [M, 512] + feat = feat[point.inverse] # [N, 512] + feats = feat + return feats + + +@torch.no_grad() +def get_mask(model, feats, points, point_prompt, iter=1): + """ + feats: [N, 512] + points: [N, 3] + point_prompt: [K, 3] + """ + point_num = points.shape[0] + prompt_num = point_prompt.shape[0] + feats = feats.unsqueeze(1) # [N, 1, 512] + feats = feats.repeat(1, prompt_num, 1).cuda() # [N, K, 512] + points = torch.from_numpy(points).float().cuda().unsqueeze(1) # [N, 1, 3] + points = points.repeat(1, prompt_num, 1) # [N, K, 3] + prompt_coord = ( + torch.from_numpy(point_prompt).float().cuda().unsqueeze(0) + ) # [1, K, 3] + prompt_coord = prompt_coord.repeat(point_num, 1, 1) # [N, K, 3] + + feats = feats.transpose(0, 1) # [K, N, 512] + points = points.transpose(0, 1) # [K, N, 3] + prompt_coord = prompt_coord.transpose(0, 1) # [K, N, 3] + + mask_1, mask_2, mask_3, pred_iou = model(feats, points, prompt_coord, iter) + + mask_1 = mask_1.transpose(0, 1) # [N, K] + mask_2 = mask_2.transpose(0, 1) # [N, K] + mask_3 = mask_3.transpose(0, 1) # [N, K] + + mask_1 = mask_1.detach().cpu().numpy() > 0.5 + mask_2 = mask_2.detach().cpu().numpy() > 0.5 + mask_3 = mask_3.detach().cpu().numpy() > 0.5 + + org_iou = pred_iou.detach().cpu().numpy() # [K, 3] + + return mask_1, mask_2, mask_3, org_iou + + +def cal_iou(m1, m2): + return np.sum(np.logical_and(m1, m2)) / np.sum(np.logical_or(m1, m2)) + + +def cal_single_iou(m1, m2): + return np.sum(np.logical_and(m1, m2)) / np.sum(m1) + + +def iou_3d(box1, box2, signle=None): + """ + 计算两个三维边界框的交并比 (IoU) + + 参数: + box1 (list): 第一个边界框的坐标 [x1_min, y1_min, z1_min, x1_max, y1_max, z1_max] + box2 (list): 第二个边界框的坐标 [x2_min, y2_min, z2_min, x2_max, y2_max, z2_max] + + 返回: + float: 交并比 (IoU) 值 + """ + # 计算交集的坐标 + intersection_xmin = max(box1[0], box2[0]) + intersection_ymin = max(box1[1], box2[1]) + intersection_zmin = max(box1[2], box2[2]) + intersection_xmax = min(box1[3], box2[3]) + intersection_ymax = min(box1[4], box2[4]) + intersection_zmax = min(box1[5], box2[5]) + + # 判断是否有交集 + if ( + intersection_xmin >= intersection_xmax + or intersection_ymin >= intersection_ymax + or intersection_zmin >= intersection_zmax + ): + return 0.0 # 无交集 + + # 计算交集的体积 + intersection_volume = ( + (intersection_xmax - intersection_xmin) + * (intersection_ymax - intersection_ymin) + * (intersection_zmax - intersection_zmin) + ) + + # 计算两个盒子的体积 + box1_volume = (box1[3] - box1[0]) * (box1[4] - box1[1]) * (box1[5] - box1[2]) + box2_volume = (box2[3] - box2[0]) * (box2[4] - box2[1]) * (box2[5] - box2[2]) + + if signle is None: + # 计算并集的体积 + union_volume = box1_volume + box2_volume - intersection_volume + elif signle == "1": + union_volume = box1_volume + elif signle == "2": + union_volume = box2_volume + else: + raise ValueError("signle must be None or 1 or 2") + + # 计算 IoU + iou = intersection_volume / union_volume if union_volume > 0 else 0.0 + return iou + + +def cal_point_bbox_iou(p1, p2, signle=None): + min_p1 = np.min(p1, axis=0) + max_p1 = np.max(p1, axis=0) + min_p2 = np.min(p2, axis=0) + max_p2 = np.max(p2, axis=0) + box1 = [min_p1[0], min_p1[1], min_p1[2], max_p1[0], max_p1[1], max_p1[2]] + box2 = [min_p2[0], min_p2[1], min_p2[2], max_p2[0], max_p2[1], max_p2[2]] + return iou_3d(box1, box2, signle) + + +def cal_bbox_iou(points, m1, m2): + p1 = points[m1] + p2 = points[m2] + return cal_point_bbox_iou(p1, p2) + + +def clean_mesh(mesh): + """ + mesh: trimesh.Trimesh + """ + # 1. 合并接近的顶点 + mesh.merge_vertices() + + # 2. 删除重复的顶点 + # 3. 删除重复的面片 + mesh.process(True) + return mesh + + +def get_aabb_from_face_ids(mesh, face_ids): + unique_ids = np.unique(face_ids) + aabb = [] + for i in unique_ids: + if i == -1 or i == -2: + continue + _part_mask = face_ids == i + _faces = mesh.faces[_part_mask] + _faces = np.reshape(_faces, (-1)) + _points = mesh.vertices[_faces] + min_xyz = np.min(_points, axis=0) + max_xyz = np.max(_points, axis=0) + aabb.append([min_xyz, max_xyz]) + return np.array(aabb) + + +class Timer: + def __init__(self, name): + self.name = name + + def __enter__(self): + self.start_time = time.time() + return self # 可以返回 self 以便在 with 块内访问 + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end_time = time.time() + self.elapsed_time = self.end_time - self.start_time + print(f">>>>>>代码{self.name} 运行时间: {self.elapsed_time:.4f} 秒") + + +def sample_points_pre_face(vertices, faces, n_point_per_face=2000): + n_f = faces.shape[0] # 面片数量 + + # 生成随机数 u, v + u = np.sqrt(np.random.rand(n_f, n_point_per_face, 1)) # (n_f, n_point_per_face, 1) + v = np.random.rand(n_f, n_point_per_face, 1) # (n_f, n_point_per_face, 1) + + # 计算 barycentric 坐标 + w0 = 1 - u + w1 = u * (1 - v) + w2 = u * v # (n_f, n_point_per_face, 1) + + # 从顶点中提取每个面的三个顶点 + face_v_0 = vertices[faces[:, 0].reshape(-1)] # (n_f, 3) + face_v_1 = vertices[faces[:, 1].reshape(-1)] # (n_f, 3) + face_v_2 = vertices[faces[:, 2].reshape(-1)] # (n_f, 3) + + # 扩展维度以匹配 w0, w1, w2 的形状 + face_v_0 = face_v_0.reshape(n_f, 1, 3) # (n_f, 1, 3) + face_v_1 = face_v_1.reshape(n_f, 1, 3) # (n_f, 1, 3) + face_v_2 = face_v_2.reshape(n_f, 1, 3) # (n_f, 1, 3) + + # 计算每个点的坐标 + points = w0 * face_v_0 + w1 * face_v_1 + w2 * face_v_2 # (n_f, n_point_per_face, 3) + + return points + + +def cal_cd_batch(p1, p2, pn=100000): + p1_n = p1.shape[0] + batch_num = (p1_n + pn - 1) // pn + p2_cuda = torch.from_numpy(p2).cuda().float().unsqueeze(0) + p1_cuda = torch.from_numpy(p1).cuda().float().unsqueeze(0) + cd_res = [] + for i in tqdm(range(batch_num)): + start_idx = i * pn + end_idx = min((i + 1) * pn, p1_n) + _p1_cuda = p1_cuda[:, start_idx:end_idx, :] + _, _, idx, _ = cmd_loss(_p1_cuda, p2_cuda) + idx = idx[0].detach().cpu().numpy() + cd_res.append(idx) + cd_res = np.concatenate(cd_res, axis=0) + return cd_res + + +def remove_outliers_iqr(data, factor=1.5): + """ + 基于 IQR 去除离群值 + :param data: 输入的列表或 NumPy 数组 + :param factor: IQR 的倍数(默认 1.5) + :return: 去除离群值后的列表 + """ + data = np.array(data, dtype=np.float32) + q1 = np.percentile(data, 25) # 第一四分位数 + q3 = np.percentile(data, 75) # 第三四分位数 + iqr = q3 - q1 # 四分位距 + lower_bound = q1 - factor * iqr + upper_bound = q3 + factor * iqr + return data[(data >= lower_bound) & (data <= upper_bound)].tolist() + + +def better_aabb(points): + x = points[:, 0] + y = points[:, 1] + z = points[:, 2] + x = remove_outliers_iqr(x) + y = remove_outliers_iqr(y) + z = remove_outliers_iqr(z) + min_xyz = np.array([np.min(x), np.min(y), np.min(z)]) + max_xyz = np.array([np.max(x), np.max(y), np.max(z)]) + return [min_xyz, max_xyz] + + +def save_mesh(save_path, mesh, face_ids, color_map): + face_colors = np.zeros((len(mesh.faces), 3), dtype=np.uint8) + for i in tqdm(range(len(mesh.faces)), disable=True): + _max_id = face_ids[i] + if _max_id == -2: + continue + face_colors[i, :3] = color_map[_max_id] + + mesh_save = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces) + mesh_save.visual.face_colors = face_colors + mesh_save.export(save_path) + mesh_save.export(save_path.replace(".glb", ".ply")) + # print('保存mesh完成') + + scene_mesh = trimesh.Scene() + scene_mesh.add_geometry(mesh_save) + unique_ids = np.unique(face_ids) + aabb = [] + for i in unique_ids: + if i == -1 or i == -2: + continue + _part_mask = face_ids == i + _faces = mesh.faces[_part_mask] + _faces = np.reshape(_faces, (-1)) + _points = mesh.vertices[_faces] + min_xyz, max_xyz = better_aabb(_points) + center = (min_xyz + max_xyz) / 2 + size = max_xyz - min_xyz + box = trimesh.path.creation.box_outline() + box.vertices *= size + box.vertices += center + box_color = np.array([[color_map[i][0], color_map[i][1], color_map[i][2], 255]]) + box_color = np.repeat(box_color, len(box.entities), axis=0).astype(np.uint8) + box.colors = box_color + scene_mesh.add_geometry(box) + min_xyz = np.min(_points, axis=0) + max_xyz = np.max(_points, axis=0) + aabb.append([min_xyz, max_xyz]) + scene_mesh.export(save_path.replace(".glb", "_aabb.glb")) + aabb = np.array(aabb) + np.save(save_path.replace(".glb", "_aabb.npy"), aabb) + np.save(save_path.replace(".glb", "_face_ids.npy"), face_ids) + + +def mesh_sam( + model, + mesh, + save_path, + point_num=100000, + prompt_num=400, + save_mid_res=False, + show_info=False, + post_process=False, + threshold=0.95, + clean_mesh_flag=True, + seed=42, + prompt_bs=32, +): + with Timer("加载mesh"): + model, model_parallel = model + if clean_mesh_flag: + mesh = clean_mesh(mesh) + mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, process=False) + if show_info: + print(f"点数:{mesh.vertices.shape[0]} 面片数:{mesh.faces.shape[0]}") + + point_num = 100000 + prompt_num = 400 + + with Timer("采样点云"): + _points, face_idx = trimesh.sample.sample_surface(mesh, point_num, seed=seed) + _points_org = _points.copy() + _points = normalize_pc(_points) + normals = mesh.face_normals[face_idx] + # _points = _points + np.random.normal(0, 1, size=_points.shape) * 0.01 + # normals = normals * 0. # debug no normal + if show_info: + print(f"点数:{point_num} 面片数:{mesh.faces.shape[0]}") + + with Timer("获取特征"): + _feats = get_feat(model, _points, normals) + if show_info: + print("预处理特征") + + if save_mid_res: + feat_save = _feats.float().detach().cpu().numpy() + data_scaled = feat_save / np.linalg.norm(feat_save, axis=-1, keepdims=True) + pca = PCA(n_components=3) + data_reduced = pca.fit_transform(data_scaled) + data_reduced = (data_reduced - data_reduced.min()) / ( + data_reduced.max() - data_reduced.min() + ) + _colors_pca = (data_reduced * 255).astype(np.uint8) + pc_save = trimesh.points.PointCloud(_points, colors=_colors_pca) + pc_save.export(os.path.join(save_path, "point_pca.glb")) + pc_save.export(os.path.join(save_path, "point_pca.ply")) + if show_info: + print("PCA获取特征颜色") + + with Timer("FPS采样提示点"): + fps_idx = fpsample.fps_sampling(_points, prompt_num) + _point_prompts = _points[fps_idx] + if save_mid_res: + trimesh.points.PointCloud(_point_prompts, colors=_colors_pca[fps_idx]).export( + os.path.join(save_path, "point_prompts_pca.glb") + ) + trimesh.points.PointCloud(_point_prompts, colors=_colors_pca[fps_idx]).export( + os.path.join(save_path, "point_prompts_pca.ply") + ) + if show_info: + print("采样完成") + + with Timer("推理"): + bs = prompt_bs + step_num = prompt_num // bs + 1 + mask_res = [] + iou_res = [] + for i in tqdm(range(step_num), disable=not show_info): + cur_propmt = _point_prompts[bs * i : bs * (i + 1)] + pred_mask_1, pred_mask_2, pred_mask_3, pred_iou = get_mask( + model_parallel, _feats, _points, cur_propmt + ) + pred_mask = np.stack( + [pred_mask_1, pred_mask_2, pred_mask_3], axis=-1 + ) # [N, K, 3] + max_idx = np.argmax(pred_iou, axis=-1) # [K] + for j in range(max_idx.shape[0]): + mask_res.append(pred_mask[:, j, max_idx[j]]) + iou_res.append(pred_iou[j, max_idx[j]]) + mask_res = np.stack(mask_res, axis=-1) # [N, K] + if show_info: + print("prmopt 推理完成") + + with Timer("根据IOU排序"): + iou_res = np.array(iou_res).tolist() + mask_iou = [[mask_res[:, i], iou_res[i]] for i in range(prompt_num)] + mask_iou_sorted = sorted(mask_iou, key=lambda x: x[1], reverse=True) + mask_sorted = [mask_iou_sorted[i][0] for i in range(prompt_num)] + iou_sorted = [mask_iou_sorted[i][1] for i in range(prompt_num)] + + # clusters = {} + # for i in tqdm(range(prompt_num), desc="NMS", disable=not show_info): + # _mask = mask_sorted[i] + # union_flag = False + # for j in clusters.keys(): + # if cal_iou(_mask, mask_sorted[j]) > 0.9: + # clusters[j].append(i) + # union_flag = True + # break + # if not union_flag: + # clusters[i] = [i] + with Timer("NMS"): + clusters = defaultdict(list) + with ThreadPoolExecutor(max_workers=20) as executor: + for i in tqdm(range(prompt_num), desc="NMS", disable=not show_info): + _mask = mask_sorted[i] + futures = [] + for j in clusters.keys(): + futures.append(executor.submit(cal_iou, _mask, mask_sorted[j])) + + for j, future in zip(clusters.keys(), futures): + if future.result() > 0.9: + clusters[j].append(i) + break + else: + clusters[i].append(i) + + # print(clusters) + if show_info: + print(f"NMS完成,mask数量:{len(clusters)}") + + if save_mid_res: + part_mask_save_path = os.path.join(save_path, "part_mask") + if os.path.exists(part_mask_save_path): + shutil.rmtree(part_mask_save_path) + os.makedirs(part_mask_save_path, exist_ok=True) + for i in tqdm(clusters.keys(), desc="保存mask", disable=not show_info): + cluster_num = len(clusters[i]) + cluster_iou = iou_sorted[i] + cluster_area = np.sum(mask_sorted[i]) + if cluster_num <= 2: + continue + mask_save = mask_sorted[i] + mask_save = np.expand_dims(mask_save, axis=-1) + mask_save = np.repeat(mask_save, 3, axis=-1) + mask_save = (mask_save * 255).astype(np.uint8) + point_save = trimesh.points.PointCloud(_points, colors=mask_save) + point_save.export( + os.path.join( + part_mask_save_path, + f"mask_{i}_iou_{cluster_iou:.5f}_area_{cluster_area:.5f}_num_{cluster_num}.glb", + ) + ) + + # 过滤只有一个mask的cluster + with Timer("过滤只有一个mask的cluster"): + filtered_clusters = [] + other_clusters = [] + for i in clusters.keys(): + if len(clusters[i]) > 2: + filtered_clusters.append(i) + else: + other_clusters.append(i) + if show_info: + print( + f"过滤前:{len(clusters)} 个cluster," + f"过滤后:{len(filtered_clusters)} 个cluster" + ) + + # 再次合并 + with Timer("再次合并"): + filtered_clusters_num = len(filtered_clusters) + cluster2 = {} + is_union = [False] * filtered_clusters_num + for i in range(filtered_clusters_num): + if is_union[i]: + continue + cur_cluster = filtered_clusters[i] + cluster2[cur_cluster] = [cur_cluster] + for j in range(i + 1, filtered_clusters_num): + if is_union[j]: + continue + tar_cluster = filtered_clusters[j] + # if cal_single_iou(mask_sorted[tar_cluster], mask_sorted[cur_cluster]) > 0.9: + # if cal_iou(mask_sorted[tar_cluster], mask_sorted[cur_cluster]) > 0.5: + if ( + cal_bbox_iou( + _points, mask_sorted[tar_cluster], mask_sorted[cur_cluster] + ) + > 0.5 + ): + cluster2[cur_cluster].append(tar_cluster) + is_union[j] = True + if show_info: + print(f"再次合并,合并数量:{len(cluster2.keys())}") + + with Timer("计算没有mask的点"): + no_mask = np.ones(point_num) + for i in cluster2: + part_mask = mask_sorted[i] + no_mask[part_mask] = 0 + if show_info: + print( + f"{np.sum(no_mask == 1)} 个点没有mask," + f" 占比:{np.sum(no_mask == 1) / point_num:.4f}" + ) + + with Timer("修补遗漏mask"): + # 查询漏掉的mask + for i in tqdm(range(len(mask_sorted)), desc="漏掉mask", disable=not show_info): + if i in cluster2: + continue + part_mask = mask_sorted[i] + _iou = cal_single_iou(part_mask, no_mask) + if _iou > 0.7: + cluster2[i] = [i] + no_mask[part_mask] = 0 + if save_mid_res: + mask_save = mask_sorted[i] + mask_save = np.expand_dims(mask_save, axis=-1) + mask_save = np.repeat(mask_save, 3, axis=-1) + mask_save = (mask_save * 255).astype(np.uint8) + point_save = trimesh.points.PointCloud(_points, colors=mask_save) + cluster_iou = iou_sorted[i] + cluster_area = int(np.sum(mask_sorted[i])) + cluster_num = 1 + point_save.export( + os.path.join( + part_mask_save_path, + f"mask_{i}_iou_{cluster_iou:.5f}_area_{cluster_area:.5f}_num_{cluster_num}.glb", + ) + ) + # print(cluster2) + # print(len(cluster2.keys())) + if show_info: + print(f"修补遗漏mask:{len(cluster2.keys())}") + + with Timer("计算点云最终mask"): + final_mask = list(cluster2.keys()) + final_mask_area = [int(np.sum(mask_sorted[i])) for i in final_mask] + final_mask_area = [ + [final_mask[i], final_mask_area[i]] for i in range(len(final_mask)) + ] + final_mask_area_sorted = sorted( + final_mask_area, key=lambda x: x[1], reverse=True + ) + final_mask_sorted = [ + final_mask_area_sorted[i][0] for i in range(len(final_mask_area)) + ] + final_mask_area_sorted = [ + final_mask_area_sorted[i][1] for i in range(len(final_mask_area)) + ] + # print(final_mask_sorted) + # print(final_mask_area_sorted) + if show_info: + print(f"最终mask数量:{len(final_mask_sorted)}") + + with Timer("点云上色"): + # 生成color map + color_map = {} + for i in final_mask_sorted: + part_color = np.random.rand(3) * 255 + color_map[i] = part_color + # print(color_map) + + result_mask = -np.ones(point_num, dtype=np.int64) + for i in final_mask_sorted: + part_mask = mask_sorted[i] + result_mask[part_mask] = i + if save_mid_res: + # 保存点云结果 + result_colors = np.zeros_like(_colors_pca) + for i in final_mask_sorted: + part_color = color_map[i] + part_mask = mask_sorted[i] + result_colors[part_mask, :3] = part_color + trimesh.points.PointCloud(_points, colors=result_colors).export( + os.path.join(save_path, "auto_mask_cluster.glb") + ) + trimesh.points.PointCloud(_points, colors=result_colors).export( + os.path.join(save_path, "auto_mask_cluster.ply") + ) + if show_info: + print("保存点云完成") + + with Timer("后处理"): + valid_mask = result_mask >= 0 + _org = _points_org[valid_mask] + _results = result_mask[valid_mask] + pre_face = 10 + _face_points = sample_points_pre_face( + mesh.vertices, mesh.faces, n_point_per_face=pre_face + ) + _face_points = np.reshape(_face_points, (len(mesh.faces) * pre_face, 3)) + _idx = cal_cd_batch(_face_points, _org) + _idx_res = _results[_idx] + _idx_res = np.reshape(_idx_res, (-1, pre_face)) + + face_ids = [] + for i in range(len(mesh.faces)): + _label = np.argmax(np.bincount(_idx_res[i] + 2)) - 2 + face_ids.append(_label) + final_face_ids = np.array(face_ids) + + if save_mid_res: + save_mesh( + os.path.join(save_path, "auto_mask_mesh_final.glb"), + mesh, + final_face_ids, + color_map, + ) + + with Timer("计算最后的aabb"): + aabb = get_aabb_from_face_ids(mesh, final_face_ids) + return aabb, final_face_ids, mesh + + +class AutoMask: + def __init__( + self, + ckpt_path=None, + point_num=100000, + prompt_num=400, + threshold=0.95, + post_process=True, + automask_instance=None, + ): + """ + ckpt_path: str, 模型路径 + point_num: int, 采样点数量 + prompt_num: int, 提示数量 + threshold: float, 阈值 + post_process: bool, 是否后处理 + """ + if automask_instance is not None: + self.model = automask_instance.model + self.model_parallel = automask_instance.model_parallel + else: + self.model = P3SAM() + self.model.load_state_dict(ckpt_path) + self.model.eval() + self.model_parallel = torch.nn.DataParallel(self.model) + self.model.cuda() + self.model_parallel.cuda() + self.point_num = point_num + self.prompt_num = prompt_num + self.threshold = threshold + self.post_process = post_process + + def predict_aabb( + self, + mesh, + point_num=None, + prompt_num=None, + threshold=None, + post_process=None, + save_path=None, + save_mid_res=False, + show_info=True, + clean_mesh_flag=True, + seed=42, + is_parallel=True, + prompt_bs=32, + ): + """ + Parameters: + mesh: trimesh.Trimesh, 输入网格 + point_num: int, 采样点数量 + prompt_num: int, 提示数量 + threshold: float, 阈值 + post_process: bool, 是否后处理 + Returns: + aabb: np.ndarray, 包围盒 + face_ids: np.ndarray, 面id + """ + point_num = point_num if point_num is not None else self.point_num + prompt_num = prompt_num if prompt_num is not None else self.prompt_num + threshold = threshold if threshold is not None else self.threshold + post_process = post_process if post_process is not None else self.post_process + return mesh_sam( + [self.model, self.model_parallel if is_parallel else self.model], + mesh, + save_path=save_path, + point_num=point_num, + prompt_num=prompt_num, + threshold=threshold, + post_process=post_process, + show_info=show_info, + save_mid_res=save_mid_res, + clean_mesh_flag=clean_mesh_flag, + seed=seed, + prompt_bs=prompt_bs, + ) + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +if __name__ == "__main__": + argparser = argparse.ArgumentParser() + argparser.add_argument( + "--ckpt_path", type=str, default=None, help="模型路径" + ) + argparser.add_argument( + "--mesh_path", type=str, default="assets/1.glb", help="输入网格路径" + ) + argparser.add_argument( + "--output_path", type=str, default="results/1", help="保存路径" + ) + argparser.add_argument("--point_num", type=int, default=100000, help="采样点数量") + argparser.add_argument("--prompt_num", type=int, default=400, help="提示数量") + argparser.add_argument("--threshold", type=float, default=0.95, help="阈值") + argparser.add_argument("--post_process", type=int, default=0, help="是否后处理") + argparser.add_argument( + "--save_mid_res", type=int, default=1, help="是否保存中间结果" + ) + argparser.add_argument("--show_info", type=int, default=1, help="是否显示信息") + argparser.add_argument( + "--show_time_info", type=int, default=1, help="是否显示时间信息" + ) + argparser.add_argument("--seed", type=int, default=42, help="随机种子") + argparser.add_argument("--parallel", type=int, default=1, help="是否使用多卡") + argparser.add_argument( + "--prompt_bs", type=int, default=32, help="提示点推理时的batch size大小" + ) + argparser.add_argument("--clean_mesh", type=int, default=1, help="是否清洗网格") + args = argparser.parse_args() + Timer.STATE = args.show_time_info + + output_path = args.output_path + os.makedirs(output_path, exist_ok=True) + ckpt_path = args.ckpt_path + auto_mask = AutoMask(ckpt_path) + mesh_path = args.mesh_path + if os.path.isdir(mesh_path): + for file in os.listdir(mesh_path): + if not ( + file.endswith(".glb") or file.endswith(".obj") or file.endswith(".ply") + ): + continue + _mesh_path = os.path.join(mesh_path, file) + _output_path = os.path.join(output_path, file[:-4]) + os.makedirs(_output_path, exist_ok=True) + mesh = trimesh.load(_mesh_path, force="mesh") + set_seed(args.seed) + aabb, face_ids, mesh = auto_mask.predict_aabb( + mesh, + save_path=_output_path, + point_num=args.point_num, + prompt_num=args.prompt_num, + threshold=args.threshold, + post_process=args.post_process, + save_mid_res=args.save_mid_res, + show_info=args.show_info, + seed=args.seed, + is_parallel=args.parallel, + clean_mesh_flag=args.clean_mesh, + ) + else: + mesh = trimesh.load(mesh_path, force="mesh") + set_seed(args.seed) + aabb, face_ids, mesh = auto_mask.predict_aabb( + mesh, + save_path=output_path, + point_num=args.point_num, + prompt_num=args.prompt_num, + threshold=args.threshold, + post_process=args.post_process, + save_mid_res=args.save_mid_res, + show_info=args.show_info, + seed=args.seed, + is_parallel=args.parallel, + clean_mesh_flag=args.clean_mesh, + ) + + ############################################### + ## 可以通过以下代码保存返回的结果 + ## You can save the returned result by the following code + ################# save result ################# + # color_map = {} + # unique_ids = np.unique(face_ids) + # for i in unique_ids: + # if i == -1: + # continue + # part_color = np.random.rand(3) * 255 + # color_map[i] = part_color + # face_colors = [] + # for i in face_ids: + # if i == -1: + # face_colors.append([0, 0, 0]) + # else: + # face_colors.append(color_map[i]) + # face_colors = np.array(face_colors).astype(np.uint8) + # mesh_save = mesh.copy() + # mesh_save.visual.face_colors = face_colors + # mesh_save.export(os.path.join(output_path, 'auto_mask_mesh.glb')) + # scene_mesh = trimesh.Scene() + # scene_mesh.add_geometry(mesh_save) + # for i in range(len(aabb)): + # min_xyz, max_xyz = aabb[i] + # center = (min_xyz + max_xyz) / 2 + # size = max_xyz - min_xyz + # box = trimesh.path.creation.box_outline() + # box.vertices *= size + # box.vertices += center + # scene_mesh.add_geometry(box) + # scene_mesh.export(os.path.join(output_path, 'auto_mask_aabb.glb')) + ################# save result ################# + +""" +python auto_mask_no_postprocess.py --parallel 0 +python auto_mask_no_postprocess.py --ckpt_path ../weights/p3sam.ckpt --mesh_path assets/1.glb --output_path results/1 --parallel 0 +python auto_mask_no_postprocess.py --ckpt_path ../weights/p3sam.ckpt --mesh_path assets --output_path results/all_no_postprocess +""" diff --git a/P3-SAM/model.py b/P3-SAM/model.py new file mode 100644 index 0000000000000000000000000000000000000000..9e5df8de1addd57b728b8bd8d5e934a5e20f94d7 --- /dev/null +++ b/P3-SAM/model.py @@ -0,0 +1,156 @@ +import os +import sys +import torch +import torch.nn as nn +sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'XPart/partgen')) +from models import sonata +from utils.misc import smart_load_model + +''' +This is the P3-SAM model. +The model is composed of three parts: +1. Sonata: a 3D-CNN model for point cloud feature extraction. +2. SEG1+SEG2: a two-stage multi-head segmentor +3. IoU prediction: an IoU predictor +''' +def build_P3SAM(self): + ######################## Sonata ######################## + self.sonata = sonata.load("sonata", repo_id="facebook/sonata", download_root='/root/sonata') + self.mlp = nn.Sequential( + nn.Linear(1232, 512), + nn.GELU(), + nn.Linear(512, 512), + nn.GELU(), + nn.Linear(512, 512), + ) + self.transform = sonata.transform.default() + ######################## Sonata ######################## + + ######################## SEG1 ######################## + self.seg_mlp_1 = nn.Sequential( + nn.Linear(512+3+3, 512), + nn.GELU(), + nn.Linear(512, 512), + nn.GELU(), + nn.Linear(512, 1), + ) + self.seg_mlp_2 = nn.Sequential( + nn.Linear(512+3+3, 512), + nn.GELU(), + nn.Linear(512, 512), + nn.GELU(), + nn.Linear(512, 1), + ) + self.seg_mlp_3 = nn.Sequential( + nn.Linear(512+3+3, 512), + nn.GELU(), + nn.Linear(512, 512), + nn.GELU(), + nn.Linear(512, 1), + ) + ######################## SEG1 ######################## + + ######################## SEG2 ######################## + self.seg_s2_mlp_g = nn.Sequential( + nn.Linear(512+3+3+3, 256), + nn.GELU(), + nn.Linear(256, 256), + nn.GELU(), + nn.Linear(256, 256), + ) + self.seg_s2_mlp_1 = nn.Sequential( + nn.Linear(512+3+3+3+256, 256), + nn.GELU(), + nn.Linear(256, 256), + nn.GELU(), + nn.Linear(256, 1), + ) + self.seg_s2_mlp_2 = nn.Sequential( + nn.Linear(512+3+3+3+256, 256), + nn.GELU(), + nn.Linear(256, 256), + nn.GELU(), + nn.Linear(256, 1), + ) + self.seg_s2_mlp_3 = nn.Sequential( + nn.Linear(512+3+3+3+256, 256), + nn.GELU(), + nn.Linear(256, 256), + nn.GELU(), + nn.Linear(256, 1), + ) + ######################## SEG2 ######################## + + + self.iou_mlp = nn.Sequential( + nn.Linear(512+3+3+3+256, 256), + nn.GELU(), + nn.Linear(256, 256), + nn.GELU(), + nn.Linear(256, 256), + ) + self.iou_mlp_out = nn.Sequential( + nn.Linear(256, 256), + nn.GELU(), + nn.Linear(256, 256), + nn.GELU(), + nn.Linear(256, 3), + ) + self.iou_criterion = torch.nn.MSELoss() + +''' +Load the P3-SAM model from a checkpoint. +If ckpt_path is not None, load the checkpoint from the given path. +If state_dict is not None, load the state_dict from the given state_dict. +If both ckpt_path and state_dict are None, download the model from huggingface and load the checkpoint. +''' +def load_state_dict(self, + ckpt_path=None, + state_dict=None, + strict=True, + assign=False, + ignore_seg_mlp=False, + ignore_seg_s2_mlp=False, + ignore_iou_mlp=False): + if ckpt_path is not None: + state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] + elif state_dict is None: + # download from huggingface + print(f'trying to download model from huggingface...') + from huggingface_hub import hf_hub_download + ckpt_path = hf_hub_download(repo_id="tencent/Hunyuan3D-Part", filename="p3sam.ckpt", local_dir='/cache/P3-SAM/') + print(f'download model from huggingface to: {ckpt_path}') + state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] + + local_state_dict = self.state_dict() + seen_keys = {k: False for k in local_state_dict.keys()} + for k, v in state_dict.items(): + if k.startswith("dit."): + k = k[4:] + if k in local_state_dict: + seen_keys[k] = True + if local_state_dict[k].shape == v.shape: + local_state_dict[k].copy_(v) + else: + print(f"mismatching shape for key {k}: loaded {local_state_dict[k].shape} but model has {v.shape}") + else: + print(f"unexpected key {k} in loaded state dict") + seg_mlp_flag = False + seg_s2_mlp_flag = False + iou_mlp_flag = False + for k in seen_keys: + if not seen_keys[k]: + if ignore_seg_mlp and 'seg_mlp' in k: + seg_mlp_flag = True + elif ignore_seg_s2_mlp and'seg_s2_mlp' in k: + seg_s2_mlp_flag = True + elif ignore_iou_mlp and 'iou_mlp' in k: + iou_mlp_flag = True + else: + print(f"missing key {k} in loaded state dict") + if ignore_seg_mlp and seg_mlp_flag: + print("seg_mlp is missing in loaded state dict, ignore seg_mlp in loaded state dict") + if ignore_seg_s2_mlp and seg_s2_mlp_flag: + print("seg_s2_mlp is missing in loaded state dict, ignore seg_s2_mlp in loaded state dict") + if ignore_iou_mlp and iou_mlp_flag: + print("iou_mlp is missing in loaded state dict, ignore iou_mlp in loaded state dict") diff --git a/P3-SAM/utils/chamfer3D/chamfer3D.cu b/P3-SAM/utils/chamfer3D/chamfer3D.cu new file mode 100644 index 0000000000000000000000000000000000000000..d5b886dff11733be30519247d1fdb784818bff4a --- /dev/null +++ b/P3-SAM/utils/chamfer3D/chamfer3D.cu @@ -0,0 +1,196 @@ + +#include +#include + +#include +#include + +#include + + + +__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ + const int batch=512; + __shared__ float buf[batch*3]; + for (int i=blockIdx.x;ibest){ + result[(i*n+j)]=best; + result_i[(i*n+j)]=best_i; + } + } + __syncthreads(); + } + } +} +// int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ +int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ + + const auto batch_size = xyz1.size(0); + const auto n = xyz1.size(1); //num_points point cloud A + const auto m = xyz2.size(1); //num_points point cloud B + + NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); + NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); + //THError("aborting"); + return 0; + } + return 1; + + +} +__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ + for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); + NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); + //THError("aborting"); + return 0; + } + return 1; + +} + diff --git a/P3-SAM/utils/chamfer3D/chamfer_cuda.cpp b/P3-SAM/utils/chamfer3D/chamfer_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..28098210dcd1e428077d0bcae11a22abe8116d09 --- /dev/null +++ b/P3-SAM/utils/chamfer3D/chamfer_cuda.cpp @@ -0,0 +1,29 @@ +#include +#include + +/// TMP +// #include "common.h" +/// NOT TMP + +int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, + at::Tensor idx1, at::Tensor idx2); + +int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, + at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, + at::Tensor idx1, at::Tensor idx2); + +int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, + at::Tensor idx1, at::Tensor idx2) { + return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); +} + +int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, + at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { + + return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); + m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); +} \ No newline at end of file diff --git a/P3-SAM/utils/chamfer3D/dist_chamfer_3D.py b/P3-SAM/utils/chamfer3D/dist_chamfer_3D.py new file mode 100644 index 0000000000000000000000000000000000000000..de26d2c29e24330af6fa0193bebe5fce5ffef527 --- /dev/null +++ b/P3-SAM/utils/chamfer3D/dist_chamfer_3D.py @@ -0,0 +1,81 @@ +from torch import nn +from torch.autograd import Function +import torch +import importlib +import os +chamfer_found = importlib.find_loader("chamfer_3D") is not None +if not chamfer_found: + ## Cool trick from https://github.com/chrdiller + print("Jitting Chamfer 3D") + cur_path = os.path.dirname(os.path.abspath(__file__)) + build_path = cur_path.replace('chamfer3D', 'tmp') + os.makedirs(build_path, exist_ok=True) + + from torch.utils.cpp_extension import load + chamfer_3D = load(name="chamfer_3D", + sources=[ + "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), + "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]), + ], build_directory=build_path) + print("Loaded JIT 3D CUDA chamfer distance") + +else: + import chamfer_3D + print("Loaded compiled 3D CUDA chamfer distance") + + +# Chamfer's distance module @thibaultgroueix +# GPU tensors only +class chamfer_3DFunction(Function): + @staticmethod + def forward(ctx, xyz1, xyz2): + batchsize, n, dim = xyz1.size() + assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()" + _, m, dim = xyz2.size() + assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()" + device = xyz1.device + + device = xyz1.device + + dist1 = torch.zeros(batchsize, n) + dist2 = torch.zeros(batchsize, m) + + idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) + idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) + + dist1 = dist1.to(device) + dist2 = dist2.to(device) + idx1 = idx1.to(device) + idx2 = idx2.to(device) + torch.cuda.set_device(device) + + chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) + ctx.save_for_backward(xyz1, xyz2, idx1, idx2) + return dist1, dist2, idx1, idx2 + + @staticmethod + def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): + xyz1, xyz2, idx1, idx2 = ctx.saved_tensors + graddist1 = graddist1.contiguous() + graddist2 = graddist2.contiguous() + device = graddist1.device + + gradxyz1 = torch.zeros(xyz1.size()) + gradxyz2 = torch.zeros(xyz2.size()) + + gradxyz1 = gradxyz1.to(device) + gradxyz2 = gradxyz2.to(device) + chamfer_3D.backward( + xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 + ) + return gradxyz1, gradxyz2 + + +class chamfer_3DDist(nn.Module): + def __init__(self): + super(chamfer_3DDist, self).__init__() + + def forward(self, input1, input2): + input1 = input1.contiguous() + input2 = input2.contiguous() + return chamfer_3DFunction.apply(input1, input2) diff --git a/P3-SAM/utils/chamfer3D/setup.py b/P3-SAM/utils/chamfer3D/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..9a23aadadde026eb8c3db68a43d63086f6be856a --- /dev/null +++ b/P3-SAM/utils/chamfer3D/setup.py @@ -0,0 +1,14 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name='chamfer_3D', + ext_modules=[ + CUDAExtension('chamfer_3D', [ + "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), + "/".join(__file__.split('/')[:-1] + ['chamfer3D.cu']), + ]), + ], + cmdclass={ + 'build_ext': BuildExtension + }) \ No newline at end of file diff --git a/XPart/data/000.glb b/XPart/data/000.glb new file mode 100755 index 0000000000000000000000000000000000000000..9f3605789560591d64163048a7d5d09d86ef9f7c --- /dev/null +++ b/XPart/data/000.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e1b728ba92d353d87c5acd2289d9f19a0dc9ab6ceacdde488e7c5d3b456b2ff8 +size 9000484 diff --git a/XPart/data/001.glb b/XPart/data/001.glb new file mode 100644 index 0000000000000000000000000000000000000000..e8122db2d31c9ba3143c15ba0d7e5fc2fbf7309c --- /dev/null +++ b/XPart/data/001.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98601f642c444d8466007b5b35e33a57d3f0bade873c9a7d7c57db039ea7a318 +size 9000676 diff --git a/XPart/data/002.glb b/XPart/data/002.glb new file mode 100644 index 0000000000000000000000000000000000000000..f8558a2f4ed79f0d9eba686582a219201bb500a2 --- /dev/null +++ b/XPart/data/002.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5dbd8408e37e41be78358bbca5bc1ef4d81e6413c92a19bc9c2f136760c0248a +size 8999812 diff --git a/XPart/data/003.glb b/XPart/data/003.glb new file mode 100644 index 0000000000000000000000000000000000000000..f050c9928be8faacef77cd493c339c56b7cb51b3 --- /dev/null +++ b/XPart/data/003.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ad8390bdafd83b8aea0f7e0159b24d9a8070c58140cd1706b80c6217fe8cf14d +size 9000880 diff --git a/XPart/data/004.glb b/XPart/data/004.glb new file mode 100644 index 0000000000000000000000000000000000000000..b70468cd1579bf29efbd378a39906c3de0d8c7be --- /dev/null +++ b/XPart/data/004.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7fb21265f371436d18437a9cb420e012fe07cf1e50fb20c1473223c921a101dc +size 9000796 diff --git a/XPart/partgen/bbox_estimator/auto_mask_api.py b/XPart/partgen/bbox_estimator/auto_mask_api.py new file mode 100755 index 0000000000000000000000000000000000000000..1bc3a70e540fd7aadf6b7e284e51ac1555060379 --- /dev/null +++ b/XPart/partgen/bbox_estimator/auto_mask_api.py @@ -0,0 +1,1417 @@ +import os +import sys +import torch +import torch.nn as nn +import numpy as np +import argparse +import trimesh +from sklearn.decomposition import PCA +import fpsample +from tqdm import tqdm +from collections import defaultdict + +# from tqdm.notebook import tqdm +import time +import copy +import shutil +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor + +from numba import njit + +################################# +# 修改sonata import路径 +from ..models import sonata + +################################# +sys.path.append("../P3-SAM") +from model import build_P3SAM, load_state_dict + + +class YSAM(nn.Module): + def __init__(self): + super().__init__() + build_P3SAM(self) + + def load_state_dict( + self, + state_dict=None, + strict=True, + assign=False, + ignore_seg_mlp=False, + ignore_seg_s2_mlp=False, + ignore_iou_mlp=False, + ): + load_state_dict( + self, + state_dict=state_dict, + strict=strict, + assign=assign, + ignore_seg_mlp=ignore_seg_mlp, + ignore_seg_s2_mlp=ignore_seg_s2_mlp, + ignore_iou_mlp=ignore_iou_mlp, + ) + + def forward(self, feats, points, point_prompt, iter=1): + """ + feats: [K, N, 512] + points: [K, N, 3] + point_prompt: [K, N, 3] + """ + # print(feats.shape, points.shape, point_prompt.shape) + point_num = points.shape[1] + feats = feats.transpose(0, 1) # [N, K, 512] + points = points.transpose(0, 1) # [N, K, 3] + point_prompt = point_prompt.transpose(0, 1) # [N, K, 3] + feats_seg = torch.cat([feats, points, point_prompt], dim=-1) # [N, K, 512+3+3] + + # 预测mask stage-1 + pred_mask_1 = self.seg_mlp_1(feats_seg).squeeze(-1) # [N, K] + pred_mask_2 = self.seg_mlp_2(feats_seg).squeeze(-1) # [N, K] + pred_mask_3 = self.seg_mlp_3(feats_seg).squeeze(-1) # [N, K] + pred_mask = torch.stack( + [pred_mask_1, pred_mask_2, pred_mask_3], dim=-1 + ) # [N, K, 3] + + for _ in range(iter): + # 预测mask stage-2 + feats_seg_2 = torch.cat([feats_seg, pred_mask], dim=-1) # [N, K, 512+3+3+3] + feats_seg_global = self.seg_s2_mlp_g(feats_seg_2) # [N, K, 512] + feats_seg_global = torch.max(feats_seg_global, dim=0).values # [K, 512] + feats_seg_global = feats_seg_global.unsqueeze(0).repeat( + point_num, 1, 1 + ) # [N, K, 512] + feats_seg_3 = torch.cat( + [feats_seg_global, feats_seg_2], dim=-1 + ) # [N, K, 512+3+3+3+512] + pred_mask_s2_1 = self.seg_s2_mlp_1(feats_seg_3).squeeze(-1) # [N, K] + pred_mask_s2_2 = self.seg_s2_mlp_2(feats_seg_3).squeeze(-1) # [N, K] + pred_mask_s2_3 = self.seg_s2_mlp_3(feats_seg_3).squeeze(-1) # [N, K] + pred_mask_s2 = torch.stack( + [pred_mask_s2_1, pred_mask_s2_2, pred_mask_s2_3], dim=-1 + ) # [N,, K 3] + pred_mask = pred_mask_s2 + + mask_1 = torch.sigmoid(pred_mask_s2_1).to(dtype=torch.float32) # [N, K] + mask_2 = torch.sigmoid(pred_mask_s2_2).to(dtype=torch.float32) # [N, K] + mask_3 = torch.sigmoid(pred_mask_s2_3).to(dtype=torch.float32) # [N, K] + + feats_iou = torch.cat( + [feats_seg_global, feats_seg, pred_mask_s2], dim=-1 + ) # [N, K, 512+3+3+3+512] + feats_iou = self.iou_mlp(feats_iou) # [N, K, 512] + feats_iou = torch.max(feats_iou, dim=0).values # [K, 512] + pred_iou = self.iou_mlp_out(feats_iou) # [K, 3] + pred_iou = torch.sigmoid(pred_iou).to(dtype=torch.float32) # [K, 3] + + mask_1 = mask_1.transpose(0, 1) # [K, N] + mask_2 = mask_2.transpose(0, 1) # [K, N] + mask_3 = mask_3.transpose(0, 1) # [K, N] + + return mask_1, mask_2, mask_3, pred_iou + + +def normalize_pc(pc): + """ + pc: (N, 3) + """ + max_, min_ = np.max(pc, axis=0), np.min(pc, axis=0) + center = (max_ + min_) / 2 + scale = (max_ - min_) / 2 + scale = np.max(np.abs(scale)) + pc = (pc - center) / (scale + 1e-10) + return pc + + +@torch.no_grad() +def get_feat(model, points, normals): + data_dict = { + "coord": points, + "normal": normals, + "color": np.ones_like(points), + "batch": np.zeros(points.shape[0], dtype=np.int64), + } + data_dict = model.transform(data_dict) + for k in data_dict: + if isinstance(data_dict[k], torch.Tensor): + data_dict[k] = data_dict[k].cuda() + point = model.sonata(data_dict) + while "pooling_parent" in point.keys(): + assert "pooling_inverse" in point.keys() + parent = point.pop("pooling_parent") + inverse = point.pop("pooling_inverse") + parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1) + point = parent + feat = point.feat # [M, 1232] + feat = model.mlp(feat) # [M, 512] + feat = feat[point.inverse] # [N, 512] + feats = feat + return feats + + +@torch.no_grad() +def get_mask(model, feats, points, point_prompt, iter=1): + """ + feats: [N, 512] + points: [N, 3] + point_prompt: [K, 3] + """ + point_num = points.shape[0] + prompt_num = point_prompt.shape[0] + feats = feats.unsqueeze(1) # [N, 1, 512] + feats = feats.repeat(1, prompt_num, 1).cuda() # [N, K, 512] + points = torch.from_numpy(points).float().cuda().unsqueeze(1) # [N, 1, 3] + points = points.repeat(1, prompt_num, 1) # [N, K, 3] + prompt_coord = ( + torch.from_numpy(point_prompt).float().cuda().unsqueeze(0) + ) # [1, K, 3] + prompt_coord = prompt_coord.repeat(point_num, 1, 1) # [N, K, 3] + + feats = feats.transpose(0, 1) # [K, N, 512] + points = points.transpose(0, 1) # [K, N, 3] + prompt_coord = prompt_coord.transpose(0, 1) # [K, N, 3] + + mask_1, mask_2, mask_3, pred_iou = model(feats, points, prompt_coord, iter) + + mask_1 = mask_1.transpose(0, 1) # [N, K] + mask_2 = mask_2.transpose(0, 1) # [N, K] + mask_3 = mask_3.transpose(0, 1) # [N, K] + + mask_1 = mask_1.detach().cpu().numpy() > 0.5 + mask_2 = mask_2.detach().cpu().numpy() > 0.5 + mask_3 = mask_3.detach().cpu().numpy() > 0.5 + + org_iou = pred_iou.detach().cpu().numpy() # [K, 3] + + return mask_1, mask_2, mask_3, org_iou + + +def cal_iou(m1, m2): + return np.sum(np.logical_and(m1, m2)) / np.sum(np.logical_or(m1, m2)) + + +def cal_single_iou(m1, m2): + return np.sum(np.logical_and(m1, m2)) / np.sum(m1) + + +def iou_3d(box1, box2, signle=None): + """ + 计算两个三维边界框的交并比 (IoU) + + 参数: + box1 (list): 第一个边界框的坐标 [x1_min, y1_min, z1_min, x1_max, y1_max, z1_max] + box2 (list): 第二个边界框的坐标 [x2_min, y2_min, z2_min, x2_max, y2_max, z2_max] + + 返回: + float: 交并比 (IoU) 值 + """ + # 计算交集的坐标 + intersection_xmin = max(box1[0], box2[0]) + intersection_ymin = max(box1[1], box2[1]) + intersection_zmin = max(box1[2], box2[2]) + intersection_xmax = min(box1[3], box2[3]) + intersection_ymax = min(box1[4], box2[4]) + intersection_zmax = min(box1[5], box2[5]) + + # 判断是否有交集 + if ( + intersection_xmin >= intersection_xmax + or intersection_ymin >= intersection_ymax + or intersection_zmin >= intersection_zmax + ): + return 0.0 # 无交集 + + # 计算交集的体积 + intersection_volume = ( + (intersection_xmax - intersection_xmin) + * (intersection_ymax - intersection_ymin) + * (intersection_zmax - intersection_zmin) + ) + + # 计算两个盒子的体积 + box1_volume = (box1[3] - box1[0]) * (box1[4] - box1[1]) * (box1[5] - box1[2]) + box2_volume = (box2[3] - box2[0]) * (box2[4] - box2[1]) * (box2[5] - box2[2]) + + if signle is None: + # 计算并集的体积 + union_volume = box1_volume + box2_volume - intersection_volume + elif signle == "1": + union_volume = box1_volume + elif signle == "2": + union_volume = box2_volume + else: + raise ValueError("signle must be None or 1 or 2") + + # 计算 IoU + iou = intersection_volume / union_volume if union_volume > 0 else 0.0 + return iou + + +def cal_point_bbox_iou(p1, p2, signle=None): + min_p1 = np.min(p1, axis=0) + max_p1 = np.max(p1, axis=0) + min_p2 = np.min(p2, axis=0) + max_p2 = np.max(p2, axis=0) + box1 = [min_p1[0], min_p1[1], min_p1[2], max_p1[0], max_p1[1], max_p1[2]] + box2 = [min_p2[0], min_p2[1], min_p2[2], max_p2[0], max_p2[1], max_p2[2]] + return iou_3d(box1, box2, signle) + + +def cal_bbox_iou(points, m1, m2): + p1 = points[m1] + p2 = points[m2] + return cal_point_bbox_iou(p1, p2) + + +def clean_mesh(mesh): + """ + mesh: trimesh.Trimesh + """ + # 1. 合并接近的顶点 + mesh.merge_vertices() + + # 2. 删除重复的顶点 + # 3. 删除重复的面片 + mesh.process(True) + return mesh + + +# @njit +def remove_outliers_iqr(data, factor=1.5): + """ + 基于 IQR 去除离群值 + :param data: 输入的列表或 NumPy 数组 + :param factor: IQR 的倍数(默认 1.5) + :return: 去除离群值后的列表 + """ + data = np.array(data, dtype=np.float32) + q1 = np.percentile(data, 25) # 第一四分位数 + q3 = np.percentile(data, 75) # 第三四分位数 + iqr = q3 - q1 # 四分位距 + lower_bound = q1 - factor * iqr + upper_bound = q3 + factor * iqr + return data[(data >= lower_bound) & (data <= upper_bound)].tolist() + + +# @njit +def better_aabb(points): + x = points[:, 0] + y = points[:, 1] + z = points[:, 2] + x = remove_outliers_iqr(x) + y = remove_outliers_iqr(y) + z = remove_outliers_iqr(z) + min_xyz = np.array([np.min(x), np.min(y), np.min(z)]) + max_xyz = np.array([np.max(x), np.max(y), np.max(z)]) + return [min_xyz, max_xyz] + + +def fix_label(face_ids, adjacent_faces, use_aabb=False, mesh=None, show_info=False): + if use_aabb: + + def _cal_aabb(face_ids, i, _points_org): + _part_mask = face_ids == i + _faces = mesh.faces[_part_mask] + _faces = np.reshape(_faces, (-1)) + _points = mesh.vertices[_faces] + min_xyz, max_xyz = better_aabb(_points) + _part_mask = ( + (_points_org[:, 0] >= min_xyz[0]) + & (_points_org[:, 0] <= max_xyz[0]) + & (_points_org[:, 1] >= min_xyz[1]) + & (_points_org[:, 1] <= max_xyz[1]) + & (_points_org[:, 2] >= min_xyz[2]) + & (_points_org[:, 2] <= max_xyz[2]) + ) + _part_mask = np.reshape(_part_mask, (-1, 3)) + _part_mask = np.all(_part_mask, axis=1) + return i, [min_xyz, max_xyz], _part_mask + + with Timer("计算aabb"): + aabb = {} + unique_ids = np.unique(face_ids) + # print(max(unique_ids)) + aabb_face_mask = {} + _faces = mesh.faces + _vertices = mesh.vertices + _faces = np.reshape(_faces, (-1)) + _points = _vertices[_faces] + with ThreadPoolExecutor(max_workers=20) as executor: + futures = [] + for i in unique_ids: + if i < 0: + continue + futures.append(executor.submit(_cal_aabb, face_ids, i, _points)) + for future in futures: + res = future.result() + aabb[res[0]] = res[1] + aabb_face_mask[res[0]] = res[2] + + # _faces = mesh.faces + # _vertices = mesh.vertices + # _faces = np.reshape(_faces, (-1)) + # _points = _vertices[_faces] + # aabb_face_mask = cal_aabb_mask(_points, face_ids) + + with Timer("合并mesh"): + loop_cnt = 1 + changed = True + progress = tqdm(disable=not show_info) + no_mask_ids = np.where(face_ids < 0)[0].tolist() + faces_max = adjacent_faces.shape[0] + while changed and loop_cnt <= 50: + changed = False + # 获取无色面片 + new_no_mask_ids = [] + for i in no_mask_ids: + # if face_ids[i] < 0: + # 找邻居 + if not (0 <= i < faces_max): + continue + _adj_faces = adjacent_faces[i] + _adj_ids = [] + for j in _adj_faces: + if j == -1: + break + if face_ids[j] >= 0: + _tar_id = face_ids[j] + if use_aabb: + _mask = aabb_face_mask[_tar_id] + if _mask[i]: + _adj_ids.append(_tar_id) + else: + _adj_ids.append(_tar_id) + if len(_adj_ids) == 0: + new_no_mask_ids.append(i) + continue + _max_id = np.argmax(np.bincount(_adj_ids)) + face_ids[i] = _max_id + changed = True + no_mask_ids = new_no_mask_ids + # print(loop_cnt) + progress.update(1) + # progress.set_description(f"合并mesh循环:{loop_cnt} {np.sum(face_ids < 0)}") + loop_cnt += 1 + return face_ids + + +def save_mesh(save_path, mesh, face_ids, color_map): + face_colors = np.zeros((len(mesh.faces), 3), dtype=np.uint8) + for i in tqdm(range(len(mesh.faces)), disable=True): + _max_id = face_ids[i] + if _max_id == -2: + continue + face_colors[i, :3] = color_map[_max_id] + + mesh_save = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces) + mesh_save.visual.face_colors = face_colors + mesh_save.export(save_path) + mesh_save.export(save_path.replace(".glb", ".ply")) + # print('保存mesh完成') + + scene_mesh = trimesh.Scene() + scene_mesh.add_geometry(mesh_save) + unique_ids = np.unique(face_ids) + aabb = [] + for i in unique_ids: + if i == -1 or i == -2: + continue + _part_mask = face_ids == i + _faces = mesh.faces[_part_mask] + _faces = np.reshape(_faces, (-1)) + _points = mesh.vertices[_faces] + min_xyz, max_xyz = better_aabb(_points) + center = (min_xyz + max_xyz) / 2 + size = max_xyz - min_xyz + box = trimesh.path.creation.box_outline() + box.vertices *= size + box.vertices += center + box_color = np.array([[color_map[i][0], color_map[i][1], color_map[i][2], 255]]) + box_color = np.repeat(box_color, len(box.entities), axis=0).astype(np.uint8) + box.colors = box_color + scene_mesh.add_geometry(box) + min_xyz = np.min(_points, axis=0) + max_xyz = np.max(_points, axis=0) + aabb.append([min_xyz, max_xyz]) + scene_mesh.export(save_path.replace(".glb", "_aabb.glb")) + aabb = np.array(aabb) + np.save(save_path.replace(".glb", "_aabb.npy"), aabb) + np.save(save_path.replace(".glb", "_face_ids.npy"), face_ids) + + +def get_aabb_from_face_ids(mesh, face_ids): + unique_ids = np.unique(face_ids) + aabb = [] + for i in unique_ids: + if i == -1 or i == -2: + continue + _part_mask = face_ids == i + _faces = mesh.faces[_part_mask] + _faces = np.reshape(_faces, (-1)) + _points = mesh.vertices[_faces] + min_xyz = np.min(_points, axis=0) + max_xyz = np.max(_points, axis=0) + aabb.append([min_xyz, max_xyz]) + return np.array(aabb) + + +def calculate_face_areas(mesh): + """ + 计算每个三角形面片的面积 + :param mesh: trimesh.Trimesh 对象 + :return: 面片面积数组 (n_faces,) + """ + return mesh.area_faces + # # 提取顶点和面片索引 + # vertices = mesh.vertices + # faces = mesh.faces + + # # 获取所有三个顶点的坐标 + # v0 = vertices[faces[:, 0]] + # v1 = vertices[faces[:, 1]] + # v2 = vertices[faces[:, 2]] + + # # 计算两个边向量 + # edge1 = v1 - v0 + # edge2 = v2 - v0 + + # # 计算叉积的模长(向量面积的两倍) + # cross_product = np.cross(edge1, edge2) + # areas = 0.5 * np.linalg.norm(cross_product, axis=1) + + # return areas + + +def get_connected_region(face_ids, adjacent_faces, return_face_part_ids=False): + vis = [False] * len(face_ids) + parts = [] + face_part_ids = np.ones_like(face_ids) * -1 + for i in range(len(face_ids)): + if vis[i]: + continue + _part = [] + _queue = [i] + while len(_queue) > 0: + _cur_face = _queue.pop(0) + if vis[_cur_face]: + continue + vis[_cur_face] = True + _part.append(_cur_face) + face_part_ids[_cur_face] = len(parts) + if not (0 <= _cur_face < adjacent_faces.shape[0]): + continue + _cur_face_id = face_ids[_cur_face] + _adj_faces = adjacent_faces[_cur_face] + for j in _adj_faces: + if j == -1: + break + if not vis[j] and face_ids[j] == _cur_face_id: + _queue.append(j) + parts.append(_part) + if return_face_part_ids: + return parts, face_part_ids + else: + return parts + + +def aabb_distance(box1, box2): + """ + 计算两个轴对齐包围盒(AABB)之间的最近距离。 + :param box1: 元组 (min_x, min_y, min_z, max_x, max_y, max_z) + :param box2: 元组 (min_x, min_y, min_z, max_x, max_y, max_z) + :return: 最近距离(浮点数) + """ + # 解包坐标 + min1, max1 = box1 + min2, max2 = box2 + + # 计算各轴上的分离距离 + dx = max(0, max2[0] - min1[0], max1[0] - min2[0]) # x轴分离距离 + dy = max(0, max2[1] - min1[1], max1[1] - min2[1]) # y轴分离距离 + dz = max(0, max2[2] - min1[2], max1[2] - min2[2]) # z轴分离距离 + + # 如果所有轴都重叠,则距离为0 + if dx == 0 and dy == 0 and dz == 0: + return 0.0 + + # 计算欧几里得距离 + return np.sqrt(dx**2 + dy**2 + dz**2) + + +def aabb_volume(aabb): + """ + 计算轴对齐包围盒(AABB)的体积。 + :param aabb: 元组 (min_x, min_y, min_z, max_x, max_y, max_z) + :return: 体积(浮点数) + """ + # 解包坐标 + min_xyz, max_xyz = aabb + + # 计算体积 + dx = max_xyz[0] - min_xyz[0] + dy = max_xyz[1] - min_xyz[1] + dz = max_xyz[2] - min_xyz[2] + return dx * dy * dz + + +def find_neighbor_part(parts, adjacent_faces, parts_aabb=None, parts_ids=None): + face2part = {} + for i, part in enumerate(parts): + for face in part: + face2part[face] = i + neighbor_parts = [] + for i, part in enumerate(parts): + neighbor_part = set() + for face in part: + if not (0 <= face < adjacent_faces.shape[0]): + continue + for adj_face in adjacent_faces[face]: + if adj_face == -1: + break + if adj_face not in face2part: + continue + if face2part[adj_face] == i: + continue + if parts_ids is not None and parts_ids[face2part[adj_face]] in [-1, -2]: + continue + neighbor_part.add(face2part[adj_face]) + neighbor_part = list(neighbor_part) + if ( + parts_aabb is not None + and parts_ids is not None + and (parts_ids[i] == -1 or parts_ids[i] == -2) + and len(neighbor_part) == 0 + ): + min_dis = np.inf + min_idx = -1 + for j, _part in enumerate(parts): + if j == i: + continue + if parts_ids[j] == -1 or parts_ids[j] == -2: + continue + aabb_1 = parts_aabb[i] + aabb_2 = parts_aabb[j] + dis = aabb_distance(aabb_1, aabb_2) + if dis < min_dis: + min_dis = dis + min_idx = j + elif dis == min_dis: + if aabb_volume(parts_aabb[j]) < aabb_volume(parts_aabb[min_idx]): + min_idx = j + neighbor_part = [min_idx] + neighbor_parts.append(neighbor_part) + return neighbor_parts + + +def do_post_process( + face_areas, parts, adjacent_faces, face_ids, threshold=0.95, show_info=False +): + # # 获取邻接面片 + # mesh_save = mesh.copy() + # face_adjacency = mesh.face_adjacency + # adjacent_faces = {} + # for face1, face2 in face_adjacency: + # if face1 not in adjacent_faces: + # adjacent_faces[face1] = [] + # if face2 not in adjacent_faces: + # adjacent_faces[face2] = [] + # adjacent_faces[face1].append(face2) + # adjacent_faces[face2].append(face1) + + # parts = get_connected_region(face_ids, adjacent_faces) + + unique_ids = np.unique(face_ids) + if show_info: + print(f"连通区域数量:{len(parts)}") + print(f"ID数量:{len(unique_ids)}") + + # face_areas = calculate_face_areas(mesh) + total_area = np.sum(face_areas) + if show_info: + print(f"总面积:{total_area}") + part_areas = [] + for i, part in enumerate(parts): + part_area = np.sum(face_areas[part]) + part_areas.append(float(part_area / total_area)) + + sorted_parts = sorted(zip(part_areas, parts), key=lambda x: x[0], reverse=True) + parts = [x[1] for x in sorted_parts] + part_areas = [x[0] for x in sorted_parts] + integral_part_areas = np.cumsum(part_areas) + + neighbor_parts = find_neighbor_part(parts, adjacent_faces) + + new_face_ids = face_ids.copy() + + for i, part in enumerate(parts): + if integral_part_areas[i] > threshold and part_areas[i] < 0.01: + if len(neighbor_parts[i]) > 0: + max_area = 0 + max_part = -1 + for j in neighbor_parts[i]: + if integral_part_areas[j] > threshold: + continue + if part_areas[j] > max_area: + max_area = part_areas[j] + max_part = j + if max_part != -1: + if show_info: + print(f"合并mesh:{i} {max_part}") + parts[max_part].extend(part) + parts[i] = [] + target_face_id = face_ids[parts[max_part][0]] + for face in part: + new_face_ids[face] = target_face_id + + return new_face_ids + + +def do_no_mask_process(parts, face_ids): + # # 获取邻接面片 + # mesh_save = mesh.copy() + # face_adjacency = mesh.face_adjacency + # adjacent_faces = {} + # for face1, face2 in face_adjacency: + # if face1 not in adjacent_faces: + # adjacent_faces[face1] = [] + # if face2 not in adjacent_faces: + # adjacent_faces[face2] = [] + # adjacent_faces[face1].append(face2) + # adjacent_faces[face2].append(face1) + # parts = get_connected_region(face_ids, adjacent_faces) + + unique_ids = np.unique(face_ids) + max_id = np.max(unique_ids) + if -1 or -2 in unique_ids: + new_face_ids = face_ids.copy() + for i, part in enumerate(parts): + if face_ids[part[0]] == -1 or face_ids[part[0]] == -2: + for face in part: + new_face_ids[face] = max_id + 1 + max_id += 1 + return new_face_ids + else: + return face_ids + + +def union_aabb(aabb1, aabb2): + min_xyz1 = aabb1[0] + max_xyz1 = aabb1[1] + min_xyz2 = aabb2[0] + max_xyz2 = aabb2[1] + min_xyz = np.minimum(min_xyz1, min_xyz2) + max_xyz = np.maximum(max_xyz1, max_xyz2) + return [min_xyz, max_xyz] + + +def aabb_increase(aabb1, aabb2): + min_xyz_before = aabb1[0] + max_xyz_before = aabb1[1] + min_xyz_after, max_xyz_after = union_aabb(aabb1, aabb2) + min_xyz_increase = np.abs(min_xyz_after - min_xyz_before) / np.abs(min_xyz_before) + max_xyz_increase = np.abs(max_xyz_after - max_xyz_before) / np.abs(max_xyz_before) + return min_xyz_increase, max_xyz_increase + + +def sort_multi_list(multi_list, key=lambda x: x[0], reverse=False): + """ + multi_list: [list1, list2, list3, list4, ...], len(list1)=N, len(list2)=N, len(list3)=N, ... + key: 排序函数,默认按第一个元素排序 + reverse: 排序顺序,默认降序 + return: + [list1, list2, list3, list4, ...]: 按同一个顺序排序后的多个list + """ + sorted_list = sorted(zip(*multi_list), key=key, reverse=reverse) + return zip(*sorted_list) + + +# def sample_mesh(mesh, adjacent_faces, point_num=100000): +# connected_parts = get_connected_region(np.ones(len(mesh.faces)), adjacent_faces) +# _points, face_idx = trimesh.sample.sample_surface(mesh, point_num) +# face_sampled = np.zeros(len(mesh.faces), dtype=np.bool) +# face_sampled[face_idx] = True +# for parts in connected_parts + +# def parallel_run(model_parallel, feats, points, prompts): +# bs = prompts.shape[0] +# prompts_1 = prompts[:bs//2] +# prompts_2 = prompts[bs//2:] +# device_1 = 'cuda:0' +# device_2 = 'cuda:1' +# pred_mask_1_1, pred_mask_2_1, pred_mask_3_1, pred_iou_1 = get_mask( +# model_parallel.module.to(device_1), feats, points, prompts_1, device=device_1 +# ) +# pred_mask_1_2, pred_mask_2_2, pred_mask_3_2, pred_iou_2 = get_mask( +# model_parallel.module.to(device_2), feats, points, prompts_2, device=device_2 +# ) +# pred_mask_1 = np.concatenate([pred_mask_1_1, pred_mask_1_2], axis=1) +# pred_mask_2 = np.concatenate([pred_mask_2_1, pred_mask_2_2], axis=1) +# pred_mask_3 = np.concatenate([pred_mask_3_1, pred_mask_3_2], axis=1) +# pred_iou = np.concatenate([pred_iou_1, pred_iou_2], axis=0) +# return pred_mask_1, pred_mask_2, pred_mask_3, pred_iou + +############################################################################################ + + +class Timer: + def __init__(self, name): + self.name = name + + def __enter__(self): + self.start_time = time.time() + return self # 可以返回 self 以便在 with 块内访问 + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end_time = time.time() + self.elapsed_time = self.end_time - self.start_time + print(f">>>>>>代码{self.name} 运行时间: {self.elapsed_time:.4f} 秒") + + +###################### NUMBA 加速 ###################### +@njit +def build_adjacent_faces_numba(face_adjacency): + """ + 使用 Numba 加速构建邻接面片数组。 + :param face_adjacency: (N, 2) numpy 数组,包含邻接面片对。 + :return: + - adj_list: 一维数组,存储所有邻接面片。 + - offsets: 一维数组,记录每个面片的邻接起始位置。 + """ + n_faces = np.max(face_adjacency) + 1 # 总面片数 + n_edges = face_adjacency.shape[0] # 总邻接边数 + + # 第一步:统计每个面片的邻接数量(度数) + degrees = np.zeros(n_faces, dtype=np.int32) + for i in range(n_edges): + f1, f2 = face_adjacency[i] + degrees[f1] += 1 + degrees[f2] += 1 + max_degree = np.max(degrees) # 最大度数 + + adjacent_faces = np.ones((n_faces, max_degree), dtype=np.int32) * -1 # 邻接面片数组 + adjacent_faces_count = np.zeros(n_faces, dtype=np.int32) # 邻接面片计数器 + for i in range(n_edges): + f1, f2 = face_adjacency[i] + adjacent_faces[f1, adjacent_faces_count[f1]] = f2 + adjacent_faces_count[f1] += 1 + adjacent_faces[f2, adjacent_faces_count[f2]] = f1 + adjacent_faces_count[f2] += 1 + return adjacent_faces + + +###################### NUMBA 加速 ###################### + + +def mesh_sam( + model, + mesh, + save_path, + point_num=100000, + prompt_num=400, + save_mid_res=False, + show_info=False, + post_process=False, + threshold=0.95, + clean_mesh_flag=True, + seed=42, +): + with Timer("加载mesh"): + model, model_parallel = model + if clean_mesh_flag: + mesh = clean_mesh(mesh) + mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, process=False) + if show_info: + print(f"点数:{mesh.vertices.shape[0]} 面片数:{mesh.faces.shape[0]}") + + point_num = 100000 + prompt_num = 400 + with Timer("获取邻接面片"): + # 获取邻接面片 + face_adjacency = mesh.face_adjacency + with Timer("处理邻接面片"): + # adjacent_faces = defaultdict(list) + # for face1, face2 in face_adjacency: + # adjacent_faces[face1].append(face2) + # adjacent_faces[face2].append(face1) + # adj_list, offsets = build_adjacent_faces_numba(face_adjacency) + adjacent_faces = build_adjacent_faces_numba(face_adjacency) + # with Timer("处理邻接面片2"): + # adjacent_faces = to_adj_dict(adj_list, offsets) + + with Timer("采样点云"): + _points, face_idx = trimesh.sample.sample_surface(mesh, point_num, seed=seed) + _points_org = _points.copy() + _points = normalize_pc(_points) + normals = mesh.face_normals[face_idx] + # _points = _points + np.random.normal(0, 1, size=_points.shape) * 0.01 + # normals = normals * 0. # debug no normal + if show_info: + print(f"点数:{point_num} 面片数:{mesh.faces.shape[0]}") + + with Timer("获取特征"): + _feats = get_feat(model, _points, normals) + if show_info: + print("预处理特征") + + if save_mid_res: + feat_save = _feats.float().detach().cpu().numpy() + data_scaled = feat_save / np.linalg.norm(feat_save, axis=-1, keepdims=True) + pca = PCA(n_components=3) + data_reduced = pca.fit_transform(data_scaled) + data_reduced = (data_reduced - data_reduced.min()) / ( + data_reduced.max() - data_reduced.min() + ) + _colors_pca = (data_reduced * 255).astype(np.uint8) + pc_save = trimesh.points.PointCloud(_points, colors=_colors_pca) + pc_save.export(os.path.join(save_path, "point_pca.glb")) + pc_save.export(os.path.join(save_path, "point_pca.ply")) + if show_info: + print("PCA获取特征颜色") + + with Timer("FPS采样提示点"): + fps_idx = fpsample.fps_sampling(_points, prompt_num) + _point_prompts = _points[fps_idx] + if save_mid_res: + trimesh.points.PointCloud(_point_prompts, colors=_colors_pca[fps_idx]).export( + os.path.join(save_path, "point_prompts_pca.glb") + ) + trimesh.points.PointCloud(_point_prompts, colors=_colors_pca[fps_idx]).export( + os.path.join(save_path, "point_prompts_pca.ply") + ) + if show_info: + print("采样完成") + + with Timer("推理"): + bs = 64 + step_num = prompt_num // bs + 1 + mask_res = [] + iou_res = [] + for i in tqdm(range(step_num), disable=not show_info): + cur_propmt = _point_prompts[bs * i : bs * (i + 1)] + # pred_mask_1, pred_mask_2, pred_mask_3, pred_iou = get_mask( + # model, _feats, _points, cur_propmt + # ) + # pred_mask_1, pred_mask_2, pred_mask_3, pred_iou = model_parallel( + # _feats, _points, cur_propmt + # ) + # pred_mask_1, pred_mask_2, pred_mask_3, pred_iou = parallel_run( + # model_parallel, _feats, _points, cur_propmt + # ) + pred_mask_1, pred_mask_2, pred_mask_3, pred_iou = get_mask( + model_parallel, _feats, _points, cur_propmt + ) + # print(pred_mask_1.shape, pred_mask_2.shape, pred_mask_3.shape, pred_iou.shape) + pred_mask = np.stack( + [pred_mask_1, pred_mask_2, pred_mask_3], axis=-1 + ) # [N, K, 3] + max_idx = np.argmax(pred_iou, axis=-1) # [K] + for j in range(max_idx.shape[0]): + mask_res.append(pred_mask[:, j, max_idx[j]]) + iou_res.append(pred_iou[j, max_idx[j]]) + mask_res = np.stack(mask_res, axis=-1) # [N, K] + if show_info: + print("prmopt 推理完成") + + with Timer("根据IOU排序"): + iou_res = np.array(iou_res).tolist() + mask_iou = [[mask_res[:, i], iou_res[i]] for i in range(prompt_num)] + mask_iou_sorted = sorted(mask_iou, key=lambda x: x[1], reverse=True) + mask_sorted = [mask_iou_sorted[i][0] for i in range(prompt_num)] + iou_sorted = [mask_iou_sorted[i][1] for i in range(prompt_num)] + + # clusters = {} + # for i in tqdm(range(prompt_num), desc="NMS", disable=not show_info): + # _mask = mask_sorted[i] + # union_flag = False + # for j in clusters.keys(): + # if cal_iou(_mask, mask_sorted[j]) > 0.9: + # clusters[j].append(i) + # union_flag = True + # break + # if not union_flag: + # clusters[i] = [i] + with Timer("NMS"): + clusters = defaultdict(list) + with ThreadPoolExecutor(max_workers=20) as executor: + for i in tqdm(range(prompt_num), desc="NMS", disable=not show_info): + _mask = mask_sorted[i] + futures = [] + for j in clusters.keys(): + futures.append(executor.submit(cal_iou, _mask, mask_sorted[j])) + + for j, future in zip(clusters.keys(), futures): + if future.result() > 0.9: + clusters[j].append(i) + break + else: + clusters[i].append(i) + + # print(clusters) + if show_info: + print(f"NMS完成,mask数量:{len(clusters)}") + + if save_mid_res: + part_mask_save_path = os.path.join(save_path, "part_mask") + if os.path.exists(part_mask_save_path): + shutil.rmtree(part_mask_save_path) + os.makedirs(part_mask_save_path, exist_ok=True) + for i in tqdm(clusters.keys(), desc="保存mask", disable=not show_info): + cluster_num = len(clusters[i]) + cluster_iou = iou_sorted[i] + cluster_area = np.sum(mask_sorted[i]) + if cluster_num <= 2: + continue + mask_save = mask_sorted[i] + mask_save = np.expand_dims(mask_save, axis=-1) + mask_save = np.repeat(mask_save, 3, axis=-1) + mask_save = (mask_save * 255).astype(np.uint8) + point_save = trimesh.points.PointCloud(_points, colors=mask_save) + point_save.export( + os.path.join( + part_mask_save_path, + f"mask_{i}_iou_{cluster_iou:.5f}_area_{cluster_area:.5f}_num_{cluster_num}.glb", + ) + ) + + # 过滤只有一个mask的cluster + with Timer("过滤只有一个mask的cluster"): + filtered_clusters = [] + other_clusters = [] + for i in clusters.keys(): + if len(clusters[i]) > 2: + filtered_clusters.append(i) + else: + other_clusters.append(i) + if show_info: + print( + f"过滤前:{len(clusters)} 个cluster," + f"过滤后:{len(filtered_clusters)} 个cluster" + ) + + # 再次合并 + with Timer("再次合并"): + filtered_clusters_num = len(filtered_clusters) + cluster2 = {} + is_union = [False] * filtered_clusters_num + for i in range(filtered_clusters_num): + if is_union[i]: + continue + cur_cluster = filtered_clusters[i] + cluster2[cur_cluster] = [cur_cluster] + for j in range(i + 1, filtered_clusters_num): + if is_union[j]: + continue + tar_cluster = filtered_clusters[j] + # if cal_single_iou(mask_sorted[tar_cluster], mask_sorted[cur_cluster]) > 0.9: + # if cal_iou(mask_sorted[tar_cluster], mask_sorted[cur_cluster]) > 0.5: + if ( + cal_bbox_iou( + _points, mask_sorted[tar_cluster], mask_sorted[cur_cluster] + ) + > 0.5 + ): + cluster2[cur_cluster].append(tar_cluster) + is_union[j] = True + if show_info: + print(f"再次合并,合并数量:{len(cluster2.keys())}") + + with Timer("计算没有mask的点"): + no_mask = np.ones(point_num) + for i in cluster2: + part_mask = mask_sorted[i] + no_mask[part_mask] = 0 + if show_info: + print( + f"{np.sum(no_mask == 1)} 个点没有mask," + f" 占比:{np.sum(no_mask == 1) / point_num:.4f}" + ) + + with Timer("修补遗漏mask"): + # 查询漏掉的mask + for i in tqdm(range(len(mask_sorted)), desc="漏掉mask", disable=not show_info): + if i in cluster2: + continue + part_mask = mask_sorted[i] + _iou = cal_single_iou(part_mask, no_mask) + if _iou > 0.7: + cluster2[i] = [i] + no_mask[part_mask] = 0 + if save_mid_res: + mask_save = mask_sorted[i] + mask_save = np.expand_dims(mask_save, axis=-1) + mask_save = np.repeat(mask_save, 3, axis=-1) + mask_save = (mask_save * 255).astype(np.uint8) + point_save = trimesh.points.PointCloud(_points, colors=mask_save) + cluster_iou = iou_sorted[i] + cluster_area = int(np.sum(mask_sorted[i])) + cluster_num = 1 + point_save.export( + os.path.join( + part_mask_save_path, + f"mask_{i}_iou_{cluster_iou:.5f}_area_{cluster_area:.5f}_num_{cluster_num}.glb", + ) + ) + # print(cluster2) + # print(len(cluster2.keys())) + if show_info: + print(f"修补遗漏mask:{len(cluster2.keys())}") + + with Timer("计算点云最终mask"): + final_mask = list(cluster2.keys()) + final_mask_area = [int(np.sum(mask_sorted[i])) for i in final_mask] + final_mask_area = [ + [final_mask[i], final_mask_area[i]] for i in range(len(final_mask)) + ] + final_mask_area_sorted = sorted( + final_mask_area, key=lambda x: x[1], reverse=True + ) + final_mask_sorted = [ + final_mask_area_sorted[i][0] for i in range(len(final_mask_area)) + ] + final_mask_area_sorted = [ + final_mask_area_sorted[i][1] for i in range(len(final_mask_area)) + ] + # print(final_mask_sorted) + # print(final_mask_area_sorted) + if show_info: + print(f"最终mask数量:{len(final_mask_sorted)}") + + with Timer("点云上色"): + # 生成color map + color_map = {} + for i in final_mask_sorted: + part_color = np.random.rand(3) * 255 + color_map[i] = part_color + # print(color_map) + + result_mask = -np.ones(point_num, dtype=np.int64) + for i in final_mask_sorted: + part_mask = mask_sorted[i] + result_mask[part_mask] = i + if save_mid_res: + # 保存点云结果 + result_colors = np.zeros_like(_colors_pca) + for i in final_mask_sorted: + part_color = color_map[i] + part_mask = mask_sorted[i] + result_colors[part_mask, :3] = part_color + trimesh.points.PointCloud(_points, colors=result_colors).export( + os.path.join(save_path, "auto_mask_cluster.glb") + ) + trimesh.points.PointCloud(_points, colors=result_colors).export( + os.path.join(save_path, "auto_mask_cluster.ply") + ) + if show_info: + print("保存点云完成") + + with Timer("投影Mesh并统计label"): + # 保存mesh结果 + face_seg_res = {} + for i in final_mask_sorted: + _part_mask = result_mask == i + _face_idx = face_idx[_part_mask] + for k in _face_idx: + if k not in face_seg_res: + face_seg_res[k] = [] + face_seg_res[k].append(i) + _part_mask = result_mask == -1 + _face_idx = face_idx[_part_mask] + for k in _face_idx: + if k not in face_seg_res: + face_seg_res[k] = [] + face_seg_res[k].append(-1) + + face_ids = -np.ones(len(mesh.faces), dtype=np.int64) * 2 + for i in tqdm(face_seg_res, leave=False, disable=True): + _seg_ids = np.array(face_seg_res[i]) + # 获取最多的seg_id + _max_id = np.argmax(np.bincount(_seg_ids + 2)) - 2 + face_ids[i] = _max_id + face_ids_org = face_ids.copy() + if show_info: + print("生成face_ids完成") + + # 获取邻接面片 + # face_adjacency = mesh.face_adjacency + # adjacent_faces = {} + # for face1, face2 in face_adjacency: + # if face1 not in adjacent_faces: + # adjacent_faces[face1] = [] + # if face2 not in adjacent_faces: + # adjacent_faces[face2] = [] + # adjacent_faces[face1].append(face2) + # adjacent_faces[face2].append(face1) + + with Timer("第一次修复face_ids"): + face_ids += 1 + # face_ids = fix_label(face_ids, adjacent_faces, use_aabb=True, mesh=mesh, show_info=show_info) + face_ids = fix_label(face_ids, adjacent_faces, mesh=mesh, show_info=show_info) + face_ids -= 1 + if show_info: + print("修复face_ids完成") + + color_map[-1] = np.array([255, 0, 0], dtype=np.uint8) + + if save_mid_res: + save_mesh( + os.path.join(save_path, "auto_mask_mesh.glb"), mesh, face_ids, color_map + ) + save_mesh( + os.path.join(save_path, "auto_mask_mesh_org.glb"), + mesh, + face_ids_org, + color_map, + ) + if show_info: + print("保存mesh结果完成") + + with Timer("计算连通区域"): + face_areas = calculate_face_areas(mesh) + mesh_total_area = np.sum(face_areas) + parts = get_connected_region(face_ids, adjacent_faces) + connected_parts, _face_connected_parts_ids = get_connected_region( + np.ones_like(face_ids), adjacent_faces, return_face_part_ids=True + ) + if show_info: + print(f"共{len(parts)}个mesh") + with Timer("排序连通区域"): + parts_cp_idx = [] + for x in parts: + _face_idx = x[0] + parts_cp_idx.append(_face_connected_parts_ids[_face_idx]) + parts_cp_idx = np.array(parts_cp_idx) + parts_areas = [float(np.sum(face_areas[x])) for x in parts] + connected_parts_areas = [float(np.sum(face_areas[x])) for x in connected_parts] + parts_cp_areas = [connected_parts_areas[x] for x in parts_cp_idx] + parts_sorted, parts_areas_sorted, parts_cp_areas_sorted = sort_multi_list( + [parts, parts_areas, parts_cp_areas], key=lambda x: x[1], reverse=True + ) + + with Timer("去除面积过小的区域"): + filtered_parts = [] + other_parts = [] + for i in range(len(parts_sorted)): + parts = parts_sorted[i] + area = parts_areas_sorted[i] + cp_area = parts_cp_areas_sorted[i] + if area / (cp_area + 1e-7) > 0.001: + filtered_parts.append(i) + else: + other_parts.append(i) + if show_info: + print(f"保留{len(filtered_parts)}个mesh, 其他{len(other_parts)}个mesh") + + with Timer("去除面积过小区域的label"): + face_ids_2 = face_ids.copy() + part_num = len(cluster2.keys()) + for j in other_parts: + parts = parts_sorted[j] + for i in parts: + face_ids_2[i] = -1 + + with Timer("第二次修复face_ids"): + face_ids_3 = face_ids_2.copy() + # face_ids_3 = fix_label(face_ids_3, adjacent_faces, use_aabb=True, mesh=mesh, show_info=show_info) + face_ids_3 = fix_label( + face_ids_3, adjacent_faces, mesh=mesh, show_info=show_info + ) + + if save_mid_res: + save_mesh( + os.path.join(save_path, "auto_mask_mesh_filtered_2.glb"), + mesh, + face_ids_3, + color_map, + ) + if show_info: + print("保存mesh结果完成") + + with Timer("第二次计算连通区域"): + parts_2 = get_connected_region(face_ids_3, adjacent_faces) + parts_areas_2 = [float(np.sum(face_areas[x])) for x in parts_2] + parts_ids_2 = [face_ids_3[x[0]] for x in parts_2] + + with Timer("添加过大的缺失part"): + color_map_2 = copy.deepcopy(color_map) + max_id = np.max(parts_ids_2) + for i in range(len(parts_2)): + _parts = parts_2[i] + _area = parts_areas_2[i] + _parts_id = face_ids_3[_parts[0]] + if _area / mesh_total_area > 0.001: + if _parts_id == -1 or _parts_id == -2: + parts_ids_2[i] = max_id + 1 + max_id += 1 + color_map_2[max_id] = np.random.rand(3) * 255 + if show_info: + print(f"新增part {max_id}") + # else: + # parts_ids_2[i] = -1 + + with Timer("赋值新的face_ids"): + face_ids_4 = face_ids_3.copy() + for i in range(len(parts_2)): + _parts = parts_2[i] + _parts_id = parts_ids_2[i] + for j in _parts: + face_ids_4[j] = _parts_id + with Timer("计算part和label的aabb"): + ids_aabb = {} + unique_ids = np.unique(face_ids_4) + for i in unique_ids: + if i < 0: + continue + _part_mask = face_ids_4 == i + _faces = mesh.faces[_part_mask] + _faces = np.reshape(_faces, (-1)) + _points = mesh.vertices[_faces] + min_xyz = np.min(_points, axis=0) + max_xyz = np.max(_points, axis=0) + ids_aabb[i] = [min_xyz, max_xyz] + + parts_2_aabb = [] + for i in range(len(parts_2)): + _parts = parts_2[i] + _faces = mesh.faces[_parts] + _faces = np.reshape(_faces, (-1)) + _points = mesh.vertices[_faces] + min_xyz = np.min(_points, axis=0) + max_xyz = np.max(_points, axis=0) + parts_2_aabb.append([min_xyz, max_xyz]) + + with Timer("计算part的邻居"): + parts_2_neighbor = find_neighbor_part( + parts_2, adjacent_faces, parts_2_aabb, parts_ids_2 + ) + with Timer("合并无mask区域"): + for i in range(len(parts_2)): + _parts = parts_2[i] + _ids = parts_ids_2[i] + if _ids == -1 or _ids == -2: + _cur_aabb = parts_2_aabb[i] + _min_aabb_increase = 1e10 + _min_id = -1 + for j in parts_2_neighbor[i]: + if parts_ids_2[j] == -1 or parts_ids_2[j] == -2: + continue + _tar_id = parts_ids_2[j] + _tar_aabb = ids_aabb[_tar_id] + _min_increase, _max_increase = aabb_increase(_tar_aabb, _cur_aabb) + _increase = max(np.max(_min_increase), np.max(_max_increase)) + if _min_aabb_increase > _increase: + _min_aabb_increase = _increase + _min_id = _tar_id + if _min_id >= 0: + parts_ids_2[i] = _min_id + + with Timer("再次赋值新的face_ids"): + face_ids_4 = face_ids_3.copy() + for i in range(len(parts_2)): + _parts = parts_2[i] + _parts_id = parts_ids_2[i] + for j in _parts: + face_ids_4[j] = _parts_id + + final_face_ids = face_ids_4 + if save_mid_res: + save_mesh( + os.path.join(save_path, "auto_mask_mesh_final.glb"), + mesh, + face_ids_4, + color_map_2, + ) + + if post_process: + parts = get_connected_region(final_face_ids, adjacent_faces) + final_face_ids = do_no_mask_process(parts, final_face_ids) + face_ids_5 = do_post_process( + face_areas, + parts, + adjacent_faces, + face_ids_4, + threshold, + show_info=show_info, + ) + if save_mid_res: + save_mesh( + os.path.join(save_path, "auto_mask_mesh_final_post.glb"), + mesh, + face_ids_5, + color_map_2, + ) + final_face_ids = face_ids_5 + with Timer("计算最后的aabb"): + aabb = get_aabb_from_face_ids(mesh, final_face_ids) + return aabb, final_face_ids, mesh + + +class AutoMask: + def __init__( + self, + ckpt_path, + point_num=100000, + prompt_num=400, + threshold=0.95, + post_process=True, + ): + """ + ckpt_path: str, 模型路径 + point_num: int, 采样点数量 + prompt_num: int, 提示数量 + threshold: float, 阈值 + post_process: bool, 是否后处理 + """ + self.model = YSAM() + self.model.load_state_dict( + state_dict=torch.load(ckpt_path, map_location="cpu")["state_dict"] + ) + self.model.eval() + self.model_parallel = torch.nn.DataParallel(self.model) + self.model.cuda() + self.model_parallel.cuda() + self.point_num = point_num + self.prompt_num = prompt_num + self.threshold = threshold + self.post_process = post_process + + def predict_aabb( + self, + mesh, + point_num=None, + prompt_num=None, + threshold=None, + post_process=None, + save_path=None, + save_mid_res=False, + show_info=True, + clean_mesh_flag=True, + seed=42, + ): + """ + Parameters: + mesh: trimesh.Trimesh, 输入网格 + point_num: int, 采样点数量 + prompt_num: int, 提示数量 + threshold: float, 阈值 + post_process: bool, 是否后处理 + Returns: + aabb: np.ndarray, 包围盒 + face_ids: np.ndarray, 面id + """ + point_num = point_num if point_num is not None else self.point_num + prompt_num = prompt_num if prompt_num is not None else self.prompt_num + threshold = threshold if threshold is not None else self.threshold + post_process = post_process if post_process is not None else self.post_process + return mesh_sam( + [self.model, self.model_parallel], + mesh, + save_path=save_path, + point_num=point_num, + prompt_num=prompt_num, + threshold=threshold, + post_process=post_process, + show_info=show_info, + save_mid_res=save_mid_res, + clean_mesh_flag=clean_mesh_flag, + seed=seed, + ) diff --git a/XPart/partgen/config/infer.yaml b/XPart/partgen/config/infer.yaml new file mode 100755 index 0000000000000000000000000000000000000000..f553e7fe88bd9e442d561061d4e3aa9f4fecb188 --- /dev/null +++ b/XPart/partgen/config/infer.yaml @@ -0,0 +1,122 @@ +name: "Xpart Pipeline release" + +ckpt_path: checkpoints/xpart.pt + +shapevae: + target: partgen.models.autoencoders.VolumeDecoderShapeVAE + params: + num_latents: &num_latents 1024 + embed_dim: 64 + num_freqs: 8 + include_pi: false + heads: 16 + width: 1024 + num_encoder_layers: 8 + num_decoder_layers: 16 + qkv_bias: false + qk_norm: true + scale_factor: &z_scale_factor 1.0039506158752403 + geo_decoder_mlp_expand_ratio: 4 + geo_decoder_downsample_ratio: 1 + geo_decoder_ln_post: true + point_feats: 4 + pc_size: &pc_size 81920 + pc_sharpedge_size: &pc_sharpedge_size 0 + +bbox_predictor: + target: partgen.bbox_estimator.auto_mask_api.AutoMask + params: + ckpt_path: checkpoints/p3sam.ckpt +conditioner: + target: partgen.models.conditioner.condioner_release.Conditioner + params: + use_geo: true + use_obj: true + use_seg_feat: true + geo_cfg: + target: partgen.models.conditioner.part_encoders.PartEncoder + output_dim: &cross2_output_dim 1024 + params: + use_local: true + local_feat_type: latents_shape # [latents,miche-point-query-structural-vae] + num_tokens_cond: &num_tokens_cond 4096 # num_tokens :2048 for holopart conditioner + local_geo_cfg: + target: partgen.models.autoencoders.VolumeDecoderShapeVAE + params: + num_latents: *num_tokens_cond + embed_dim: 64 + num_freqs: 8 + include_pi: false + heads: 16 + width: 1024 + num_encoder_layers: 8 + num_decoder_layers: 16 + qkv_bias: false + qk_norm: true + scale_factor: *z_scale_factor + geo_decoder_mlp_expand_ratio: 4 + geo_decoder_downsample_ratio: 1 + geo_decoder_ln_post: true + point_feats: 4 + pc_size: &pc_size_bbox 81920 + pc_sharpedge_size: &pc_sharpedge_size_bbox 0 + + obj_encoder_cfg: + target: partgen.models.autoencoders.VolumeDecoderShapeVAE + output_dim: &cross1_output_dim 1024 + params: + num_latents: 4096 + embed_dim: 64 + num_freqs: 8 + include_pi: false + heads: 16 + width: 1024 + num_encoder_layers: 8 + num_decoder_layers: 16 + qkv_bias: false + qk_norm: true + scale_factor: 1.0039506158752403 + geo_decoder_mlp_expand_ratio: 4 + geo_decoder_downsample_ratio: 1 + geo_decoder_ln_post: true + point_feats: 4 + pc_size: *pc_size + pc_sharpedge_size: *pc_sharpedge_size + seg_feat_cfg: + target: partgen.models.conditioner.sonata_extractor.SonataFeatureExtractor + +model: + target: partgen.models.partformer_dit.PartFormerDITPlain + params: + use_self_attention: true + use_cross_attention: true + use_cross_attention_2: true + # cond + use_bbox_cond: false + num_freqs: 8 + use_part_embed: true + valid_num: 50 #*valid_num + # para + input_size: *num_latents + in_channels: 64 + hidden_size: 2048 + encoder_hidden_dim: *cross1_output_dim # for object mesh + encoder_hidden2_dim: *cross2_output_dim # for part in bbox + depth: 21 + num_heads: 16 + qk_norm: true + qkv_bias: false + qk_norm_type: 'rms' + with_decoupled_ca: false + decoupled_ca_dim: *num_tokens_cond + decoupled_ca_weight: 1.0 + use_attention_pooling: false + use_pos_emb: false + num_moe_layers: 6 + num_experts: 8 + moe_top_k: 2 + +scheduler: + target: partgen.models.diffusion.schedulers.FlowMatchEulerDiscreteScheduler + params: + num_train_timesteps: 1000 \ No newline at end of file diff --git a/XPart/partgen/config/sonata.json b/XPart/partgen/config/sonata.json new file mode 100755 index 0000000000000000000000000000000000000000..3070793d6f38ac68aeba9e27f6b64989fcd51df3 --- /dev/null +++ b/XPart/partgen/config/sonata.json @@ -0,0 +1,58 @@ +{ + "in_channels": 9, + "order": [ + "z", + "z-trans", + "hilbert", + "hilbert-trans" + ], + "stride": [ + 2, + 2, + 2, + 2 + ], + "enc_depths": [ + 3, + 3, + 3, + 12, + 3 + ], + "enc_channels": [ + 48, + 96, + 192, + 384, + 512 + ], + "enc_num_head": [ + 3, + 6, + 12, + 24, + 32 + ], + "enc_patch_size": [ + 1024, + 1024, + 1024, + 1024, + 1024 + ], + "mlp_ratio": 4, + "qkv_bias": true, + "qk_scale": null, + "attn_drop": 0.0, + "proj_drop": 0.0, + "drop_path": 0.3, + "shuffle_orders": true, + "pre_norm": true, + "enable_rpe": false, + "enable_flash": true, + "upcast_attention": false, + "upcast_softmax": false, + "traceable": true, + "enc_mode": true, + "mask_token": true +} \ No newline at end of file diff --git a/XPart/partgen/models/autoencoders/__init__.py b/XPart/partgen/models/autoencoders/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..91c43c3ea1e46c0dabf41492e586f63b035d5b6c --- /dev/null +++ b/XPart/partgen/models/autoencoders/__init__.py @@ -0,0 +1,29 @@ +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +from .attention_blocks import CrossAttentionDecoder +from .attention_processors import ( + CrossAttentionProcessor, +) +from .model import VectsetVAE, VolumeDecoderShapeVAE + +from .surface_extractors import ( + SurfaceExtractors, + MCSurfaceExtractor, + DMCSurfaceExtractor, + Latent2MeshOutput, +) +from .volume_decoders import ( + VanillaVolumeDecoder, +) diff --git a/XPart/partgen/models/autoencoders/attention_blocks.py b/XPart/partgen/models/autoencoders/attention_blocks.py new file mode 100755 index 0000000000000000000000000000000000000000..cf11b3201a3420ea2799c0c089239bc06008d494 --- /dev/null +++ b/XPart/partgen/models/autoencoders/attention_blocks.py @@ -0,0 +1,770 @@ +# Open Source Model Licensed under the Apache License Version 2.0 +# and Other Licenses of the Third-Party Components therein: +# The below Model in this distribution may have been modified by THL A29 Limited +# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited. + +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# The below software and/or models in this distribution may have been +# modified by THL A29 Limited ("Tencent Modifications"). +# All Tencent Modifications are Copyright (C) THL A29 Limited. + +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + + +import os +from typing import Optional, Union, List + +import torch +import torch.nn as nn +from einops import rearrange +from torch import Tensor + +from .attention_processors import CrossAttentionProcessor +from ...utils.misc import logger + +scaled_dot_product_attention = nn.functional.scaled_dot_product_attention + +if os.environ.get("USE_SAGEATTN", "0") == "1": + try: + from sageattention import sageattn + except ImportError: + raise ImportError( + 'Please install the package "sageattention" to use this USE_SAGEATTN.' + ) + scaled_dot_product_attention = sageattn + + +class FourierEmbedder(nn.Module): + """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts + each feature dimension of `x[..., i]` into: + [ + sin(x[..., i]), + sin(f_1*x[..., i]), + sin(f_2*x[..., i]), + ... + sin(f_N * x[..., i]), + cos(x[..., i]), + cos(f_1*x[..., i]), + cos(f_2*x[..., i]), + ... + cos(f_N * x[..., i]), + x[..., i] # only present if include_input is True. + ], here f_i is the frequency. + + Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. + If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; + Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. + + Args: + num_freqs (int): the number of frequencies, default is 6; + logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; + input_dim (int): the input dimension, default is 3; + include_input (bool): include the input tensor or not, default is True. + + Attributes: + frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); + + out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), + otherwise, it is input_dim * num_freqs * 2. + + """ + + def __init__( + self, + num_freqs: int = 6, + logspace: bool = True, + input_dim: int = 3, + include_input: bool = True, + include_pi: bool = True, + ) -> None: + """The initialization""" + + super().__init__() + + if logspace: + frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32) + else: + frequencies = torch.linspace( + 1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32 + ) + + if include_pi: + frequencies *= torch.pi + + self.register_buffer("frequencies", frequencies, persistent=False) + self.include_input = include_input + self.num_freqs = num_freqs + + self.out_dim = self.get_dims(input_dim) + + def get_dims(self, input_dim): + temp = 1 if self.include_input or self.num_freqs == 0 else 0 + out_dim = input_dim * (self.num_freqs * 2 + temp) + + return out_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward process. + + Args: + x: tensor of shape [..., dim] + + Returns: + embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] + where temp is 1 if include_input is True and 0 otherwise. + """ + + if self.num_freqs > 0: + embed = (x[..., None].contiguous() * self.frequencies).view( + *x.shape[:-1], -1 + ) + if self.include_input: + return torch.cat((x, embed.sin(), embed.cos()), dim=-1) + else: + return torch.cat((embed.sin(), embed.cos()), dim=-1) + else: + return x + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + def extra_repr(self): + return f"drop_prob={round(self.drop_prob, 3):0.3f}" + + +class MLP(nn.Module): + def __init__( + self, + *, + width: int, + expand_ratio: int = 4, + output_width: int = None, + drop_path_rate: float = 0.0, + ): + super().__init__() + self.width = width + self.c_fc = nn.Linear(width, width * expand_ratio) + self.c_proj = nn.Linear( + width * expand_ratio, output_width if output_width is not None else width + ) + self.gelu = nn.GELU() + self.drop_path = ( + DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + ) + + def forward(self, x): + return self.drop_path(self.c_proj(self.gelu(self.c_fc(x)))) + + +class QKVMultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + heads: int, + width=None, + qk_norm=False, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.heads = heads + self.q_norm = ( + norm_layer(width // heads, elementwise_affine=True, eps=1e-6) + if qk_norm + else nn.Identity() + ) + self.k_norm = ( + norm_layer(width // heads, elementwise_affine=True, eps=1e-6) + if qk_norm + else nn.Identity() + ) + + self.attn_processor = CrossAttentionProcessor() + + def forward(self, q, kv): + _, n_ctx, _ = q.shape + bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 + q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) + k, v = torch.split(kv, attn_ch, dim=-1) + + q = self.q_norm(q) + k = self.k_norm(k) + q, k, v = map( + lambda t: rearrange(t, "b n h d -> b h n d", h=self.heads), (q, k, v) + ) + out = self.attn_processor(self, q, k, v) + out = out.transpose(1, 2).reshape(bs, n_ctx, -1) + return out + + +class MultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + qkv_bias: bool = True, + data_width: Optional[int] = None, + norm_layer=nn.LayerNorm, + qk_norm: bool = False, + kv_cache: bool = False, + ): + super().__init__() + self.width = width + self.heads = heads + self.data_width = width if data_width is None else data_width + self.c_q = nn.Linear(width, width, bias=qkv_bias) + self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias) + self.c_proj = nn.Linear(width, width) + self.attention = QKVMultiheadCrossAttention( + heads=heads, + width=width, + norm_layer=norm_layer, + qk_norm=qk_norm, + ) + self.kv_cache = kv_cache + self.data = None + + def forward(self, x, data): + x = self.c_q(x) + if self.kv_cache: + if self.data is None: + self.data = self.c_kv(data) + logger.info( + "Save kv cache,this should be called only once for one mesh" + ) + data = self.data + else: + data = self.c_kv(data) + x = self.attention(x, data) + x = self.c_proj(x) + return x + + +class ResidualCrossAttentionBlock(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + mlp_expand_ratio: int = 4, + data_width: Optional[int] = None, + qkv_bias: bool = True, + norm_layer=nn.LayerNorm, + qk_norm: bool = False, + ): + super().__init__() + + if data_width is None: + data_width = width + + self.attn = MultiheadCrossAttention( + width=width, + heads=heads, + data_width=data_width, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm, + ) + self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) + self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6) + self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6) + self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio) + + def forward(self, x: torch.Tensor, data: torch.Tensor): + x = x + self.attn(self.ln_1(x), self.ln_2(data)) + x = x + self.mlp(self.ln_3(x)) + return x + + +class QKVMultiheadAttention(nn.Module): + def __init__( + self, *, heads: int, width=None, qk_norm=False, norm_layer=nn.LayerNorm + ): + super().__init__() + self.heads = heads + self.q_norm = ( + norm_layer(width // heads, elementwise_affine=True, eps=1e-6) + if qk_norm + else nn.Identity() + ) + self.k_norm = ( + norm_layer(width // heads, elementwise_affine=True, eps=1e-6) + if qk_norm + else nn.Identity() + ) + + def forward(self, qkv): + bs, n_ctx, width = qkv.shape + attn_ch = width // self.heads // 3 + qkv = qkv.view(bs, n_ctx, self.heads, -1) + q, k, v = torch.split(qkv, attn_ch, dim=-1) + + q = self.q_norm(q) + k = self.k_norm(k) + + q, k, v = map( + lambda t: rearrange(t, "b n h d -> b h n d", h=self.heads), (q, k, v) + ) + out = ( + scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1) + ) + return out + + +class MultiheadAttention(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + qkv_bias: bool, + norm_layer=nn.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0, + ): + super().__init__() + self.width = width + self.heads = heads + self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias) + self.c_proj = nn.Linear(width, width) + self.attention = QKVMultiheadAttention( + heads=heads, + width=width, + norm_layer=norm_layer, + qk_norm=qk_norm, + ) + self.drop_path = ( + DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + ) + + def forward(self, x): + x = self.c_qkv(x) + x = self.attention(x) + x = self.drop_path(self.c_proj(x)) + return x + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + qkv_bias: bool = True, + norm_layer=nn.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0, + ): + super().__init__() + self.attn = MultiheadAttention( + width=width, + heads=heads, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm, + drop_path_rate=drop_path_rate, + ) + self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) + self.mlp = MLP(width=width, drop_path_rate=drop_path_rate) + self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6) + + def forward(self, x: torch.Tensor): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, + *, + width: int, + layers: int, + heads: int, + qkv_bias: bool = True, + norm_layer=nn.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0, + ): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock( + width=width, + heads=heads, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm, + drop_path_rate=drop_path_rate, + ) + for _ in range(layers) + ]) + + def forward(self, x: torch.Tensor): + for block in self.resblocks: + x = block(x) + return x + + +class CrossAttentionDecoder(nn.Module): + + def __init__( + self, + *, + out_channels: int, + fourier_embedder: FourierEmbedder, + width: int, + heads: int, + mlp_expand_ratio: int = 4, + downsample_ratio: int = 1, + enable_ln_post: bool = True, + qkv_bias: bool = True, + qk_norm: bool = False, + label_type: str = "binary", + ): + super().__init__() + + self.enable_ln_post = enable_ln_post + self.fourier_embedder = fourier_embedder + self.downsample_ratio = downsample_ratio + self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width) + if self.downsample_ratio != 1: + self.latents_proj = nn.Linear(width * downsample_ratio, width) + if self.enable_ln_post == False: + qk_norm = False + self.cross_attn_decoder = ResidualCrossAttentionBlock( + width=width, + mlp_expand_ratio=mlp_expand_ratio, + heads=heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + ) + + if self.enable_ln_post: + self.ln_post = nn.LayerNorm(width) + self.output_proj = nn.Linear(width, out_channels) + self.label_type = label_type + self.count = 0 + + def set_cross_attention_processor(self, processor): + self.cross_attn_decoder.attn.attention.attn_processor = processor + + # def set_default_cross_attention_processor(self): + # self.cross_attn_decoder.attn.attention.attn_processor = CrossAttentionProcessor + + def forward(self, queries=None, query_embeddings=None, latents=None): + if query_embeddings is None: + query_embeddings = self.query_proj( + self.fourier_embedder(queries).to(latents.dtype) + ) + self.count += query_embeddings.shape[1] + if self.downsample_ratio != 1: + latents = self.latents_proj(latents) + x = self.cross_attn_decoder(query_embeddings, latents) + if self.enable_ln_post: + x = self.ln_post(x) + occ = self.output_proj(x) + return occ + + +def fps( + src: torch.Tensor, + batch: Optional[Tensor] = None, + ratio: Optional[Union[Tensor, float]] = None, + random_start: bool = True, + batch_size: Optional[int] = None, + ptr: Optional[Union[Tensor, List[int]]] = None, +): + src = src.float() + from torch_cluster import fps as fps_fn + + output = fps_fn(src, batch, ratio, random_start, batch_size, ptr) + return output + + +class PointCrossAttentionEncoder(nn.Module): + + def __init__( + self, + *, + num_latents: int, + downsample_ratio: float, + pc_size: int, + pc_sharpedge_size: int, + fourier_embedder: FourierEmbedder, + point_feats: int, + width: int, + heads: int, + layers: int, + normal_pe: bool = False, + qkv_bias: bool = True, + use_ln_post: bool = False, + use_checkpoint: bool = False, + qk_norm: bool = False, + ): + + super().__init__() + + self.use_checkpoint = use_checkpoint + self.num_latents = num_latents + self.downsample_ratio = downsample_ratio + self.point_feats = point_feats + self.normal_pe = normal_pe + + if pc_sharpedge_size == 0: + print( + f"PointCrossAttentionEncoder INFO: pc_sharpedge_size is not given," + f" using pc_size as pc_sharpedge_size" + ) + else: + print( + "PointCrossAttentionEncoder INFO: pc_sharpedge_size is given, using" + f" pc_size={pc_size}, pc_sharpedge_size={pc_sharpedge_size}" + ) + + self.pc_size = pc_size + self.pc_sharpedge_size = pc_sharpedge_size + + self.fourier_embedder = fourier_embedder + + self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width) + self.cross_attn = ResidualCrossAttentionBlock( + width=width, heads=heads, qkv_bias=qkv_bias, qk_norm=qk_norm + ) + + self.self_attn = None + if layers > 0: + self.self_attn = Transformer( + width=width, + layers=layers, + heads=heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + ) + + if use_ln_post: + self.ln_post = nn.LayerNorm(width) + else: + self.ln_post = None + + def sample_points_and_latents( + self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None + ): + B, N, D = pc.shape + num_pts = self.num_latents * self.downsample_ratio + + # Compute number of latents + num_latents = int(num_pts / self.downsample_ratio) + + # Compute the number of random and sharpedge latents + num_random_query = ( + self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents + ) + num_sharpedge_query = num_latents - num_random_query + + # Split random and sharpedge surface points + random_pc, sharpedge_pc = torch.split( + pc, [self.pc_size, self.pc_sharpedge_size], dim=1 + ) + assert ( + random_pc.shape[1] <= self.pc_size + ), "Random surface points size must be less than or equal to pc_size" + assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, ( + "Sharpedge surface points size must be less than or equal to" + " pc_sharpedge_size" + ) + + # Randomly select random surface points and random query points + input_random_pc_size = int(num_random_query * self.downsample_ratio) + random_query_ratio = num_random_query / input_random_pc_size + idx_random_pc = torch.randperm(random_pc.shape[1], device=random_pc.device)[ + :input_random_pc_size + ] + input_random_pc = random_pc[:, idx_random_pc, :] + flatten_input_random_pc = input_random_pc.view(B * input_random_pc_size, D) + N_down = int(flatten_input_random_pc.shape[0] / B) + batch_down = torch.arange(B).to(pc.device) + batch_down = torch.repeat_interleave(batch_down, N_down) + idx_query_random = fps( + flatten_input_random_pc, batch_down, ratio=random_query_ratio + ) + query_random_pc = flatten_input_random_pc[idx_query_random].view(B, -1, D) + + # Randomly select sharpedge surface points and sharpedge query points + input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio) + if input_sharpedge_pc_size == 0: + input_sharpedge_pc = torch.zeros(B, 0, D, dtype=input_random_pc.dtype).to( + pc.device + ) + query_sharpedge_pc = torch.zeros(B, 0, D, dtype=query_random_pc.dtype).to( + pc.device + ) + else: + sharpedge_query_ratio = num_sharpedge_query / input_sharpedge_pc_size + idx_sharpedge_pc = torch.randperm( + sharpedge_pc.shape[1], device=sharpedge_pc.device + )[:input_sharpedge_pc_size] + input_sharpedge_pc = sharpedge_pc[:, idx_sharpedge_pc, :] + flatten_input_sharpedge_surface_points = input_sharpedge_pc.view( + B * input_sharpedge_pc_size, D + ) + N_down = int(flatten_input_sharpedge_surface_points.shape[0] / B) + batch_down = torch.arange(B).to(pc.device) + batch_down = torch.repeat_interleave(batch_down, N_down) + idx_query_sharpedge = fps( + flatten_input_sharpedge_surface_points, + batch_down, + ratio=sharpedge_query_ratio, + ) + query_sharpedge_pc = flatten_input_sharpedge_surface_points[ + idx_query_sharpedge + ].view(B, -1, D) + + # Concatenate random and sharpedge surface points and query points + query_pc = torch.cat([query_random_pc, query_sharpedge_pc], dim=1) + input_pc = torch.cat([input_random_pc, input_sharpedge_pc], dim=1) + + # PE + query = self.fourier_embedder(query_pc) + data = self.fourier_embedder(input_pc) + + # Concat normal if given + if self.point_feats != 0: + + random_surface_feats, sharpedge_surface_feats = torch.split( + feats, [self.pc_size, self.pc_sharpedge_size], dim=1 + ) + input_random_surface_feats = random_surface_feats[:, idx_random_pc, :] + flatten_input_random_surface_feats = input_random_surface_feats.view( + B * input_random_pc_size, -1 + ) + query_random_feats = flatten_input_random_surface_feats[ + idx_query_random + ].view(B, -1, flatten_input_random_surface_feats.shape[-1]) + + if input_sharpedge_pc_size == 0: + input_sharpedge_surface_feats = torch.zeros( + B, 0, self.point_feats, dtype=input_random_surface_feats.dtype + ).to(pc.device) + query_sharpedge_feats = torch.zeros( + B, 0, self.point_feats, dtype=query_random_feats.dtype + ).to(pc.device) + else: + input_sharpedge_surface_feats = sharpedge_surface_feats[ + :, idx_sharpedge_pc, : + ] + flatten_input_sharpedge_surface_feats = ( + input_sharpedge_surface_feats.view(B * input_sharpedge_pc_size, -1) + ) + query_sharpedge_feats = flatten_input_sharpedge_surface_feats[ + idx_query_sharpedge + ].view(B, -1, flatten_input_sharpedge_surface_feats.shape[-1]) + + query_feats = torch.cat([query_random_feats, query_sharpedge_feats], dim=1) + input_feats = torch.cat( + [input_random_surface_feats, input_sharpedge_surface_feats], dim=1 + ) + + if self.normal_pe: + query_normal_pe = self.fourier_embedder(query_feats[..., :3]) + input_normal_pe = self.fourier_embedder(input_feats[..., :3]) + query_feats = torch.cat([query_normal_pe, query_feats[..., 3:]], dim=-1) + input_feats = torch.cat([input_normal_pe, input_feats[..., 3:]], dim=-1) + + query = torch.cat([query, query_feats], dim=-1) + data = torch.cat([data, input_feats], dim=-1) + + if input_sharpedge_pc_size == 0: + query_sharpedge_pc = torch.zeros(B, 1, D).to(pc.device) + input_sharpedge_pc = torch.zeros(B, 1, D).to(pc.device) + + # print(f'query_pc: {query_pc.shape}') + # print(f'input_pc: {input_pc.shape}') + # print(f'query_random_pc: {query_random_pc.shape}') + # print(f'input_random_pc: {input_random_pc.shape}') + # print(f'query_sharpedge_pc: {query_sharpedge_pc.shape}') + # print(f'input_sharpedge_pc: {input_sharpedge_pc.shape}') + + return ( + query.view(B, -1, query.shape[-1]), + data.view(B, -1, data.shape[-1]), + [ + query_pc, + input_pc, + query_random_pc, + input_random_pc, + query_sharpedge_pc, + input_sharpedge_pc, + ], + ) + + def forward(self, pc, feats): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + + Returns: + + """ + + query, data, pc_infos = self.sample_points_and_latents(pc, feats) + + query = self.input_proj(query) + query = query + data = self.input_proj(data) + data = data + + latents = self.cross_attn(query, data) + if self.self_attn is not None: + latents = self.self_attn(latents) + + if self.ln_post is not None: + latents = self.ln_post(latents) + + return latents, pc_infos diff --git a/XPart/partgen/models/autoencoders/attention_processors.py b/XPart/partgen/models/autoencoders/attention_processors.py new file mode 100755 index 0000000000000000000000000000000000000000..2f24054e84137ab40f584a5e0111cb50fc93fd99 --- /dev/null +++ b/XPart/partgen/models/autoencoders/attention_processors.py @@ -0,0 +1,32 @@ +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +import os + +import torch +import torch.nn.functional as F + +scaled_dot_product_attention = F.scaled_dot_product_attention +if os.environ.get('CA_USE_SAGEATTN', '0') == '1': + try: + from sageattention import sageattn + except ImportError: + raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.') + scaled_dot_product_attention = sageattn + + +class CrossAttentionProcessor: + def __call__(self, attn, q, k, v): + out = scaled_dot_product_attention(q, k, v) + return out diff --git a/XPart/partgen/models/autoencoders/model.py b/XPart/partgen/models/autoencoders/model.py new file mode 100755 index 0000000000000000000000000000000000000000..ae47c3aa731fbf88b49ee2b0018e54185253835a --- /dev/null +++ b/XPart/partgen/models/autoencoders/model.py @@ -0,0 +1,452 @@ +# Open Source Model Licensed under the Apache License Version 2.0 +# and Other Licenses of the Third-Party Components therein: +# The below Model in this distribution may have been modified by THL A29 Limited +# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited. + +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# The below software and/or models in this distribution may have been +# modified by THL A29 Limited ("Tencent Modifications"). +# All Tencent Modifications are Copyright (C) THL A29 Limited. + +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + + +import os +from typing import Tuple, List, Union + +from functools import partial + +import copy +import numpy as np +import torch +import torch.nn as nn +import yaml + +from .attention_blocks import ( + FourierEmbedder, + Transformer, + CrossAttentionDecoder, + PointCrossAttentionEncoder, +) +from .surface_extractors import MCSurfaceExtractor, SurfaceExtractors, Latent2MeshOutput +from .volume_decoders import ( + VanillaVolumeDecoder, +) +from ...utils.misc import logger, synchronize_timer, smart_load_model +from ...utils.mesh_utils import extract_geometry_fast + + +class DiagonalGaussianDistribution(object): + def __init__( + self, + parameters: Union[torch.Tensor, List[torch.Tensor]], + deterministic=False, + feat_dim=1, + ): + """ + Initialize a diagonal Gaussian distribution with mean and log-variance parameters. + + Args: + parameters (Union[torch.Tensor, List[torch.Tensor]]): + Either a single tensor containing concatenated mean and log-variance along `feat_dim`, + or a list of two tensors [mean, logvar]. + deterministic (bool, optional): If True, the distribution is deterministic (zero variance). + Default is False. feat_dim (int, optional): Dimension along which mean and logvar are + concatenated if parameters is a single tensor. Default is 1. + """ + self.feat_dim = feat_dim + self.parameters = parameters + + if isinstance(parameters, list): + self.mean = parameters[0] + self.logvar = parameters[1] + else: + self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim) + + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean) + + def sample(self): + """ + Sample from the diagonal Gaussian distribution. + + Returns: + torch.Tensor: A sample tensor with the same shape as the mean. + """ + x = self.mean + self.std * torch.randn_like(self.mean) + return x + + def kl(self, other=None, dims=(1, 2, 3)): + """ + Compute the Kullback-Leibler (KL) divergence between this distribution and another. + + If `other` is None, compute KL divergence to a standard normal distribution N(0, I). + + Args: + other (DiagonalGaussianDistribution, optional): Another diagonal Gaussian distribution. + dims (tuple, optional): Dimensions along which to compute the mean KL divergence. + Default is (1, 2, 3). + + Returns: + torch.Tensor: The mean KL divergence value. + """ + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.mean( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=dims + ) + else: + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=dims, + ) + + def nll(self, sample, dims=(1, 2, 3)): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + + +class VectsetVAE(nn.Module): + + @classmethod + @synchronize_timer("VectsetVAE Model Loading") + def from_single_file( + cls, + ckpt_path, + config_path, + device="cuda", + dtype=torch.float16, + use_safetensors=None, + **kwargs, + ): + # load config + with open(config_path, "r") as f: + config = yaml.safe_load(f) + + # load ckpt + if use_safetensors: + ckpt_path = ckpt_path.replace(".ckpt", ".safetensors") + if not os.path.exists(ckpt_path): + raise FileNotFoundError(f"Model file {ckpt_path} not found") + + logger.info(f"Loading model from {ckpt_path}") + if use_safetensors: + import safetensors.torch + + ckpt = safetensors.torch.load_file(ckpt_path, device="cpu") + else: + ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) + + model_kwargs = config["params"] + model_kwargs.update(kwargs) + + model = cls(**model_kwargs) + model.load_state_dict(ckpt) + model.to(device=device, dtype=dtype) + return model + + @classmethod + def from_pretrained( + cls, + model_path, + device="cuda", + dtype=torch.float16, + use_safetensors=False, + variant="fp16", + subfolder="hunyuan3d-vae-v2-1", + **kwargs, + ): + config_path, ckpt_path = smart_load_model( + model_path, + subfolder=subfolder, + use_safetensors=use_safetensors, + variant=variant, + ) + + return cls.from_single_file( + ckpt_path, + config_path, + device=device, + dtype=dtype, + use_safetensors=use_safetensors, + **kwargs, + ) + + def init_from_ckpt(self, path, ignore_keys=()): + state_dict = torch.load(path, map_location="cpu") + state_dict = state_dict.get("state_dict", state_dict) + keys = list(state_dict.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del state_dict[k] + missing, unexpected = self.load_state_dict(state_dict, strict=False) + print( + f"Restored from {path} with {len(missing)} missing and" + f" {len(unexpected)} unexpected keys" + ) + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def __init__(self, volume_decoder=None, surface_extractor=None): + super().__init__() + if volume_decoder is None: + volume_decoder = VanillaVolumeDecoder() + if surface_extractor is None: + surface_extractor = MCSurfaceExtractor() + self.volume_decoder = volume_decoder + self.surface_extractor = surface_extractor + + def latents2mesh(self, latents: torch.FloatTensor, **kwargs): + with synchronize_timer("Volume decoding"): + grid_logits = self.volume_decoder(latents, self.geo_decoder, **kwargs) + with synchronize_timer("Surface extraction"): + outputs = self.surface_extractor(grid_logits, **kwargs) + return outputs + + +class VolumeDecoderShapeVAE(VectsetVAE): + def __init__( + self, + *, + num_latents: int, + embed_dim: int, + width: int, + heads: int, + num_decoder_layers: int, + num_encoder_layers: int = 8, + pc_size: int = 5120, + pc_sharpedge_size: int = 5120, + point_feats: int = 3, + downsample_ratio: int = 20, + geo_decoder_downsample_ratio: int = 1, + geo_decoder_mlp_expand_ratio: int = 4, + geo_decoder_ln_post: bool = True, + num_freqs: int = 8, + include_pi: bool = True, + qkv_bias: bool = True, + qk_norm: bool = False, + label_type: str = "binary", + drop_path_rate: float = 0.0, + scale_factor: float = 1.0, + use_ln_post: bool = True, + ckpt_path=None, + volume_decoder=None, + surface_extractor=None, + ): + super().__init__(volume_decoder, surface_extractor) + self.geo_decoder_ln_post = geo_decoder_ln_post + self.downsample_ratio = downsample_ratio + + self.fourier_embedder = FourierEmbedder( + num_freqs=num_freqs, include_pi=include_pi + ) + + self.encoder = PointCrossAttentionEncoder( + fourier_embedder=self.fourier_embedder, + num_latents=num_latents, + downsample_ratio=self.downsample_ratio, + pc_size=pc_size, + pc_sharpedge_size=pc_sharpedge_size, + point_feats=point_feats, + width=width, + heads=heads, + layers=num_encoder_layers, + qkv_bias=qkv_bias, + use_ln_post=use_ln_post, + qk_norm=qk_norm, + ) + + self.pre_kl = nn.Linear(width, embed_dim * 2) + self.post_kl = nn.Linear(embed_dim, width) + + self.transformer = Transformer( + width=width, + layers=num_decoder_layers, + heads=heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + drop_path_rate=drop_path_rate, + ) + + self.geo_decoder = CrossAttentionDecoder( + fourier_embedder=self.fourier_embedder, + out_channels=1, + mlp_expand_ratio=geo_decoder_mlp_expand_ratio, + downsample_ratio=geo_decoder_downsample_ratio, + enable_ln_post=self.geo_decoder_ln_post, + width=width // geo_decoder_downsample_ratio, + heads=heads // geo_decoder_downsample_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + label_type=label_type, + ) + + self.scale_factor = scale_factor + self.latent_shape = (num_latents, embed_dim) + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path) + + def forward(self, latents): + latents = self.post_kl(latents) + latents = self.transformer(latents) + return latents + + def encode(self, surface, sample_posterior=True, return_pc_info=False): + pc, feats = surface[:, :, :3], surface[:, :, 3:] + latents, pc_infos = self.encoder(pc, feats) + # print(latents.shape, self.pre_kl.weight.shape) + moments = self.pre_kl(latents) + posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) + if sample_posterior: + latents = posterior.sample() + else: + latents = posterior.mode() + if return_pc_info: + return latents, pc_infos + else: + return latents + + def encode_shape(self, surface, return_pc_info=False): + pc, feats = surface[:, :, :3], surface[:, :, 3:] + latents, pc_infos = self.encoder(pc, feats) + if return_pc_info: + return latents, pc_infos + else: + return latents + + def decode(self, latents): + latents = self.post_kl(latents) + latents = self.transformer(latents) + return latents + + def query_geometry(self, queries: torch.FloatTensor, latents: torch.FloatTensor): + logits = self.geo_decoder(queries=queries, latents=latents).squeeze(-1) + return logits + + def latents2mesh(self, latents: torch.FloatTensor, **kwargs): + coarse_kwargs = copy.deepcopy(kwargs) + coarse_kwargs["octree_resolution"] = 256 + + with synchronize_timer("Coarse Volume decoding"): + coarse_grid_logits = self.volume_decoder( + latents, self.geo_decoder, **coarse_kwargs + ) + with synchronize_timer("Coarse Surface extraction"): + coarse_mesh = self.surface_extractor(coarse_grid_logits, **coarse_kwargs) + + assert len(coarse_mesh) == 1 + bbox_gen_by_coarse_matching_cube_mesh = np.stack( + [coarse_mesh[0].mesh_v.max(0), coarse_mesh[0].mesh_v.min(0)] + ) + bbox_gen_by_coarse_matching_cube_mesh_range = ( + bbox_gen_by_coarse_matching_cube_mesh[0] + - bbox_gen_by_coarse_matching_cube_mesh[1] + ) + + # extend by 10% + bbox_gen_by_coarse_matching_cube_mesh[0] += ( + bbox_gen_by_coarse_matching_cube_mesh_range * 0.1 + ) + bbox_gen_by_coarse_matching_cube_mesh[1] -= ( + bbox_gen_by_coarse_matching_cube_mesh_range * 0.1 + ) + with synchronize_timer("Fine-grained Volume decoding"): + grid_logits = self.volume_decoder( + latents, + self.geo_decoder, + bbox_corner=bbox_gen_by_coarse_matching_cube_mesh[None], + **kwargs, + ) + with synchronize_timer("Fine-grained Surface extraction"): + outputs = self.surface_extractor( + grid_logits, + bbox_corner=bbox_gen_by_coarse_matching_cube_mesh[None], + **kwargs, + ) + + return outputs + + def latent2mesh_2( + self, + latents: torch.FloatTensor, + bounds: Union[Tuple[float], List[float], float] = 1.1, + octree_depth: int = 7, + num_chunks: int = 10000, + mc_level: float = -1 / 512, + octree_resolution: int = None, + mc_mode: str = "mc", + ) -> List[Latent2MeshOutput]: + """ + Args: + latents: [bs, num_latents, dim] + bounds: + octree_depth: + num_chunks: + Returns: + mesh_outputs (List[MeshOutput]): the mesh outputs list. + """ + outputs = [] + geometric_func = partial(self.query_geometry, latents=latents) + # 2. decode geometry + device = latents.device + if mc_mode == "dmc" and not hasattr(self, "diffdmc"): + from diso import DiffDMC + + self.diffdmc = DiffDMC(dtype=torch.float32).to(device) + mesh_v_f, has_surface = extract_geometry_fast( + geometric_func=geometric_func, + device=device, + batch_size=len(latents), + bounds=bounds, + octree_depth=octree_depth, + num_chunks=num_chunks, + disable=False, + mc_level=mc_level, + octree_resolution=octree_resolution, + diffdmc=self.diffdmc if mc_mode == "dmc" else None, + mc_mode=mc_mode, + ) + # 3. decode texture + for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): + if not is_surface: + outputs.append(None) + continue + out = Latent2MeshOutput() + out.mesh_v = mesh_v + out.mesh_f = mesh_f + outputs.append(out) + return outputs diff --git a/XPart/partgen/models/autoencoders/surface_extractors.py b/XPart/partgen/models/autoencoders/surface_extractors.py new file mode 100755 index 0000000000000000000000000000000000000000..8282aa2dc4e7b2b18b5605edf5dea7f32880e696 --- /dev/null +++ b/XPart/partgen/models/autoencoders/surface_extractors.py @@ -0,0 +1,164 @@ +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +from typing import Union, Tuple, List + +import numpy as np +import torch +from skimage import measure + + +class Latent2MeshOutput: + def __init__(self, mesh_v=None, mesh_f=None): + self.mesh_v = mesh_v + self.mesh_f = mesh_f + + +def center_vertices(vertices): + """Translate the vertices so that bounding box is centered at zero.""" + vert_min = vertices.min(dim=0)[0] + vert_max = vertices.max(dim=0)[0] + vert_center = 0.5 * (vert_min + vert_max) + return vertices - vert_center + + +class SurfaceExtractor: + def _compute_box_stat(self, bounds: Union[Tuple[float], List[float], float], octree_resolution: int): + """ + Compute grid size, bounding box minimum coordinates, and bounding box size based on input + bounds and resolution. + + Args: + bounds (Union[Tuple[float], List[float], float]): Bounding box coordinates or a single + float representing half side length. + If float, bounds are assumed symmetric around zero in all axes. + Expected format if list/tuple: [xmin, ymin, zmin, xmax, ymax, zmax]. + octree_resolution (int): Resolution of the octree grid. + + Returns: + grid_size (List[int]): Grid size along each axis (x, y, z), each equal to octree_resolution + 1. + bbox_min (np.ndarray): Minimum coordinates of the bounding box (xmin, ymin, zmin). + bbox_size (np.ndarray): Size of the bounding box along each axis (xmax - xmin, etc.). + """ + if isinstance(bounds, float): + bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] + + bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) + bbox_size = bbox_max - bbox_min + grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1] + return grid_size, bbox_min, bbox_size + + def run(self, *args, **kwargs): + """ + Abstract method to extract surface mesh from grid logits. + + This method should be implemented by subclasses. + + Raises: + NotImplementedError: Always, since this is an abstract method. + """ + return NotImplementedError + + def __call__(self, grid_logits, **kwargs): + """ + Process a batch of grid logits to extract surface meshes. + + Args: + grid_logits (torch.Tensor): Batch of grid logits with shape (batch_size, ...). + **kwargs: Additional keyword arguments passed to the `run` method. + + Returns: + List[Optional[Latent2MeshOutput]]: List of mesh outputs for each grid in the batch. + If extraction fails for a grid, None is appended at that position. + """ + outputs = [] + for i in range(grid_logits.shape[0]): + try: + vertices, faces = self.run(grid_logits[i], **kwargs) + vertices = vertices.astype(np.float32) + faces = np.ascontiguousarray(faces) + outputs.append(Latent2MeshOutput(mesh_v=vertices, mesh_f=faces)) + + except Exception: + import traceback + traceback.print_exc() + outputs.append(None) + + return outputs + + +class MCSurfaceExtractor(SurfaceExtractor): + def run(self, grid_logit, *, mc_level, bounds, octree_resolution, **kwargs): + """ + Extract surface mesh using the Marching Cubes algorithm. + + Args: + grid_logit (torch.Tensor): 3D grid logits tensor representing the scalar field. + mc_level (float): The level (iso-value) at which to extract the surface. + bounds (Union[Tuple[float], List[float], float]): Bounding box coordinates or half side length. + octree_resolution (int): Resolution of the octree grid. + **kwargs: Additional keyword arguments (ignored). + + Returns: + Tuple[np.ndarray, np.ndarray]: Tuple containing: + - vertices (np.ndarray): Extracted mesh vertices, scaled and translated to bounding + box coordinates. + - faces (np.ndarray): Extracted mesh faces (triangles). + """ + vertices, faces, normals, _ = measure.marching_cubes(grid_logit.cpu().numpy(), + mc_level, + method="lewiner") + grid_size, bbox_min, bbox_size = self._compute_box_stat(bounds, octree_resolution) + vertices = vertices / grid_size * bbox_size + bbox_min + return vertices, faces + + +class DMCSurfaceExtractor(SurfaceExtractor): + def run(self, grid_logit, *, octree_resolution, **kwargs): + """ + Extract surface mesh using Differentiable Marching Cubes (DMC) algorithm. + + Args: + grid_logit (torch.Tensor): 3D grid logits tensor representing the scalar field. + octree_resolution (int): Resolution of the octree grid. + **kwargs: Additional keyword arguments (ignored). + + Returns: + Tuple[np.ndarray, np.ndarray]: Tuple containing: + - vertices (np.ndarray): Extracted mesh vertices, centered and converted to numpy. + - faces (np.ndarray): Extracted mesh faces (triangles), with reversed vertex order. + + Raises: + ImportError: If the 'diso' package is not installed. + """ + device = grid_logit.device + if not hasattr(self, 'dmc'): + try: + from diso import DiffDMC + self.dmc = DiffDMC(dtype=torch.float32).to(device) + except: + raise ImportError("Please install diso via `pip install diso`, or set mc_algo to 'mc'") + sdf = -grid_logit / octree_resolution + sdf = sdf.to(torch.float32).contiguous() + verts, faces = self.dmc(sdf, deform=None, return_quads=False, normalize=True) + verts = center_vertices(verts) + vertices = verts.detach().cpu().numpy() + faces = faces.detach().cpu().numpy()[:, ::-1] + return vertices, faces + + +SurfaceExtractors = { + 'mc': MCSurfaceExtractor, + 'dmc': DMCSurfaceExtractor, +} diff --git a/XPart/partgen/models/autoencoders/volume_decoders.py b/XPart/partgen/models/autoencoders/volume_decoders.py new file mode 100755 index 0000000000000000000000000000000000000000..8c9dc3a5d7ecd3acdec5d418b6be5a92a2dba731 --- /dev/null +++ b/XPart/partgen/models/autoencoders/volume_decoders.py @@ -0,0 +1,107 @@ +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + + +from typing import Union, Tuple, List, Callable + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import repeat +from tqdm import tqdm + +from .attention_blocks import CrossAttentionDecoder +from ...utils.misc import logger +from ...utils.mesh_utils import ( + extract_near_surface_volume_fn, + generate_dense_grid_points, +) + + +class VanillaVolumeDecoder: + @torch.no_grad() + def __call__( + self, + latents: torch.FloatTensor, + geo_decoder: Callable, + bounds: Union[Tuple[float], List[float], float] = 1.01, + num_chunks: int = 10000, + octree_resolution: int = None, + enable_pbar: bool = True, + **kwargs, + ): + + """ + Perform volume decoding with a vanilla decoder + Args: + latents (torch.FloatTensor): Latent vectors to decode. + geo_decoder (Callable): The geometry decoder function. + bounds (Union[Tuple[float], List[float], float]): Bounding box for the volume. + num_chunks (int): Number of chunks to process at a time. + octree_resolution (int): Resolution of the octree for sampling points. + enable_pbar (bool): Whether to enable progress bar. + Returns: + grid_logits (torch.FloatTensor): Decoded 3D volume logits. + """ + device = latents.device + dtype = latents.dtype + batch_size = latents.shape[0] + + # 1. generate query points + if isinstance(bounds, float): + bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] + + bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) + xyz_samples, grid_size, length = generate_dense_grid_points( + bbox_min=bbox_min, + bbox_max=bbox_max, + octree_resolution=octree_resolution, + indexing="ij", + ) + xyz_samples = ( + torch.from_numpy(xyz_samples) + .to(device, dtype=dtype) + .contiguous() + .reshape(-1, 3) + ) + + # 2. latents to 3d volume + batch_logits = [] + for start in tqdm( + range(0, xyz_samples.shape[0], num_chunks), + desc=f"Volume Decoding", + disable=not enable_pbar, + ): + chunk_queries = xyz_samples[start : start + num_chunks, :] + chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size) + logits = geo_decoder(queries=chunk_queries, latents=latents) + batch_logits.append(logits) + + grid_logits = torch.cat(batch_logits, dim=1) + grid_logits = grid_logits.view((batch_size, *grid_size)).float() + + return grid_logits diff --git a/XPart/partgen/models/conditioner/condioner_release.py b/XPart/partgen/models/conditioner/condioner_release.py new file mode 100755 index 0000000000000000000000000000000000000000..edb569cc2ae75cdeb67c8758a79e15f2f9368587 --- /dev/null +++ b/XPart/partgen/models/conditioner/condioner_release.py @@ -0,0 +1,170 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from .part_encoders import PartEncoder +from ..autoencoders import VolumeDecoderShapeVAE +from ...utils.misc import ( + instantiate_from_config, + instantiate_non_trainable_model, +) +from .sonata_extractor import SonataFeatureExtractor +from .part_encoders import PartEncoder + + +def debug_sonata_feat(points, feats): + from sklearn.decomposition import PCA + import numpy as np + import trimesh + import os + + point_num = points.shape[0] + feat_save = feats.float().detach().cpu().numpy() + data_scaled = feat_save / np.linalg.norm(feat_save, axis=-1, keepdims=True) + pca = PCA(n_components=3) + data_reduced = pca.fit_transform(data_scaled) + data_reduced = (data_reduced - data_reduced.min()) / ( + data_reduced.max() - data_reduced.min() + ) + colors_255 = (data_reduced * 255).astype(np.uint8) + colors_255 = np.concatenate( + [colors_255, np.ones((point_num, 1), dtype=np.uint8) * 255], axis=-1 + ) + pc_save = trimesh.points.PointCloud(points, colors=colors_255) + return pc_save + # pc_save.export(os.path.join("debug", "point_pca.glb")) + + +class Conditioner(torch.nn.Module): + + def __init__( + self, + use_image=False, + use_geo=True, + use_obj=True, + use_seg_feat=False, + geo_cfg=None, + obj_encoder_cfg=None, + seg_feat_cfg=None, + **kwargs + ): + super().__init__() + self.use_image = use_image + self.use_obj = use_obj + self.use_geo = use_geo + self.use_seg_feat = use_seg_feat + self.geo_cfg = geo_cfg + self.obj_encoder_cfg = obj_encoder_cfg + self.seg_feat_cfg = seg_feat_cfg + if use_geo and geo_cfg is not None: + self.geo_encoder: PartEncoder = instantiate_from_config(geo_cfg) + if hasattr(geo_cfg, "output_dim"): + self.geo_out_proj = torch.nn.Linear(1024 + 512, geo_cfg.output_dim) + + if use_obj and obj_encoder_cfg is not None: + self.obj_encoder: VolumeDecoderShapeVAE = instantiate_non_trainable_model( + obj_encoder_cfg + ) + if hasattr(obj_encoder_cfg, "output_dim"): + self.obj_out_proj = torch.nn.Linear( + 1024 + 512, obj_encoder_cfg.output_dim + ) + if use_seg_feat and seg_feat_cfg is not None: + self.seg_feat_encoder: SonataFeatureExtractor = ( + instantiate_non_trainable_model(seg_feat_cfg) + ) + if hasattr(seg_feat_cfg, "output_dim"): + self.seg_feat_outproj = torch.nn.Linear(512, seg_feat_cfg.output_dim) + + def forward(self, part_surface_inbbox, object_surface): + bz = part_surface_inbbox.shape[0] + context = {} + # geo_cond + if self.use_geo: + context["geo_cond"], local_pc_infos = self.geo_encoder( + part_surface_inbbox, + object_surface, + return_local_pc_info=True, + ) + # obj cond + if self.use_obj: + with torch.no_grad(): + context["obj_cond"], global_pc_infos = self.obj_encoder.encode_shape( + object_surface, return_pc_info=True + ) + + # seg feat cond + if self.use_seg_feat: + # TODO: batchsize must be One + num_parts = part_surface_inbbox.shape[0] + with torch.autocast(device_type="cuda", dtype=torch.float32): + # encode sonata feature + # with torch.cuda.amp.autocast(enabled=False): + with torch.no_grad(): + point, normal = ( + object_surface[:1, ..., :3].float(), + object_surface[:1, ..., 3:6].float(), + ) + point_feat = self.seg_feat_encoder(point, normal) + # local feat + if self.use_obj: + nearest_global_matches = torch.argmin( + torch.cdist(global_pc_infos[0], object_surface[..., :3]), dim=-1 + ) + # global feat + global_point_feats = point_feat.expand(num_parts, -1, -1).gather( + 1, + nearest_global_matches.unsqueeze(-1).expand( + -1, -1, point_feat.size(-1) + ), + ) + context["obj_cond"] = torch.concat( + [context["obj_cond"], global_point_feats], dim=-1 + ).to(dtype=self.obj_out_proj.weight.dtype) + if hasattr(self, "obj_out_proj"): + context["obj_cond"] = self.obj_out_proj( + context["obj_cond"] + ) # .float() + if self.use_geo: + nearest_local_matches = torch.argmin( + torch.cdist(local_pc_infos[0], object_surface[..., :3]), dim=-1 + ) + local_point_feats = point_feat.expand(num_parts, -1, -1).gather( + 1, + nearest_local_matches.unsqueeze(-1).expand( + -1, -1, point_feat.size(-1) + ), + ) + context["geo_cond"] = torch.concat( + [context["geo_cond"], local_point_feats], + dim=-1, + ).to(dtype=self.geo_out_proj.weight.dtype) + if hasattr(self, "geo_out_proj"): + context["geo_cond"] = self.geo_out_proj( + context["geo_cond"] + ) # .float() + return context diff --git a/XPart/partgen/models/conditioner/part_encoders.py b/XPart/partgen/models/conditioner/part_encoders.py new file mode 100755 index 0000000000000000000000000000000000000000..cd41761059141407801fd144c653ac15596f409f --- /dev/null +++ b/XPart/partgen/models/conditioner/part_encoders.py @@ -0,0 +1,89 @@ +import torch.nn as nn +from ...utils.misc import ( + instantiate_from_config, + instantiate_non_trainable_model, +) +from ..autoencoders.model import ( + VolumeDecoderShapeVAE, +) + + +class PartEncoder(nn.Module): + def __init__( + self, + use_local=True, + local_global_feat_dim=None, + local_geo_cfg=None, + local_feat_type="latents", + num_tokens_cond=2048, + ): + super().__init__() + self.local_global_feat_dim = local_global_feat_dim + self.local_feat_type = local_feat_type + self.num_tokens_cond = num_tokens_cond + # local + self.use_local = use_local + if use_local: + if local_geo_cfg is None: + raise ValueError( + "local_geo_cfg must be provided when use_local is True" + ) + assert ( + "ShapeVAE" in local_geo_cfg.get("target").split(".")[-1] + ), "local_geo_cfg must be a ShapeVAE config" + self.local_encoder: VolumeDecoderShapeVAE = instantiate_from_config( + local_geo_cfg + ) + if self.local_global_feat_dim is not None: + self.local_out_layer = nn.Linear( + ( + local_geo_cfg.params.embed_dim + if self.local_feat_type == "latents" + else local_geo_cfg.params.width + ), + self.local_global_feat_dim, + bias=True, + ) + + def forward(self, part_surface_inbbox, object_surface, return_local_pc_info=False): + """ + Args: + aabb: (B, 2, 3) tensor representing the axis-aligned bounding box + object_surface: (B, N, 3) tensor representing the surface points of the object + Returns: + local_features: (B, num_tokens_cond, C) tensor of local features + global_features: (B,num_tokens_cond, C) tensor of global features + """ + # random selection if more than num_tokens_cond points + if self.use_local: + # with torch.autocast( + # device_type=part_surface_inbbox.device.type, + # dtype=torch.float16, + # ): + # with torch.no_grad(): + if self.local_feat_type == "latents": + local_features, local_pc_infos = self.local_encoder.encode( + part_surface_inbbox, sample_posterior=True, return_pc_info=True + ) # (B, num_tokens_cond, C) + elif self.local_feat_type == "latents_shape": + local_features, local_pc_infos = self.local_encoder.encode_shape( + part_surface_inbbox, return_pc_info=True + ) # (B, num_tokens_cond, C) + elif self.local_feat_type == "miche-point-query-structural-vae": + local_features, local_pc_infos = self.local_encoder.encode( + part_surface_inbbox, sample_posterior=True, return_pc_info=True + ) + local_features = self.local_encoder(local_features) + else: + raise ValueError( + f"local_feat_type {self.local_feat_type} not supported" + ) + # ouput layer + geo_features = ( + self.local_out_layer(local_features) + if hasattr(self, "local_out_layer") + else local_features + ) + if return_local_pc_info: + return geo_features, local_pc_infos + return geo_features diff --git a/XPart/partgen/models/conditioner/sonata_extractor.py b/XPart/partgen/models/conditioner/sonata_extractor.py new file mode 100755 index 0000000000000000000000000000000000000000..e09cc00a2a8f7928e5bb75caa16b35d361b75bf8 --- /dev/null +++ b/XPart/partgen/models/conditioner/sonata_extractor.py @@ -0,0 +1,315 @@ +import torch +import torch.nn as nn +from .. import sonata + +from typing import Dict, Union, Optional +from pathlib import Path + + +class SonataFeatureExtractor(nn.Module): + """ + Feature extractor using Sonata backbone with MLP projection. + Supports batch processing and gradient computation. + """ + + def __init__( + self, + ckpt_path: Optional[str] = "", + ): + super().__init__() + + # Load Sonata model + self.sonata = sonata.load_by_config( + str(Path(__file__).parent.parent.parent / "config" / "sonata.json") + ) + + # Store original dtype for later reference + # self._original_dtype = next(self.parameters()).dtype + + # Define MLP projection head (same as in train-sonata.py) + self.mlp = nn.Sequential( + nn.Linear(1232, 512), + nn.GELU(), + nn.Linear(512, 512), + nn.GELU(), + nn.Linear(512, 512), + ) + + # Define transform + self.transform = sonata.transform.default() + + # Load checkpoint if provided + if ckpt_path: + self.load_checkpoint(ckpt_path) + + def load_checkpoint(self, checkpoint_path: str): + """Load model weights from checkpoint.""" + checkpoint = torch.load(checkpoint_path, map_location="cpu") + + # Extract state dict from Lightning checkpoint + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + # Remove 'model.' prefix if present from Lightning + state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} + else: + state_dict = checkpoint + + # Debug: Show all keys in checkpoint + print("\n=== Checkpoint Keys ===") + print(f"Total keys in checkpoint: {len(state_dict)}") + print("\nSample keys:") + for i, key in enumerate(list(state_dict.keys())[:10]): + print(f" {key}") + if len(state_dict) > 10: + print(f" ... and {len(state_dict) - 10} more keys") + + # Load only the relevant weights + sonata_dict = { + k.replace("sonata.", ""): v + for k, v in state_dict.items() + if k.startswith("sonata.") + } + mlp_dict = { + k.replace("mlp.", ""): v + for k, v in state_dict.items() + if k.startswith("mlp.") + } + + print(f"\nFound {len(sonata_dict)} Sonata keys") + print(f"Found {len(mlp_dict)} MLP keys") + + # Load Sonata weights and show missing/unexpected keys + if sonata_dict: + print("\n=== Loading Sonata Weights ===") + result = self.sonata.load_state_dict(sonata_dict, strict=False) + if result.missing_keys: + print(f"\nMissing keys ({len(result.missing_keys)}):") + for key in result.missing_keys[:20]: # Show first 20 + print(f" - {key}") + if len(result.missing_keys) > 20: + print(f" ... and {len(result.missing_keys) - 20} more") + else: + print("No missing keys!") + + if result.unexpected_keys: + print(f"\nUnexpected keys ({len(result.unexpected_keys)}):") + for key in result.unexpected_keys[:20]: # Show first 20 + print(f" - {key}") + if len(result.unexpected_keys) > 20: + print(f" ... and {len(result.unexpected_keys) - 20} more") + else: + print("No unexpected keys!") + + # Load MLP weights + if mlp_dict: + print("\n=== Loading MLP Weights ===") + result = self.mlp.load_state_dict(mlp_dict, strict=False) + if result.missing_keys: + print(f"\nMissing keys: {result.missing_keys}") + if result.unexpected_keys: + print(f"Unexpected keys: {result.unexpected_keys}") + print("MLP weights loaded successfully!") + + print(f"\n✓ Loaded checkpoint from {checkpoint_path}") + + def prepare_batch_data( + self, points: torch.Tensor, normals: Optional[torch.Tensor] = None + ) -> Dict: + """ + Prepare batch data for Sonata model. + + Args: + points: [B, N, 3] or [N, 3] tensor of point coordinates + normals: [B, N, 3] or [N, 3] tensor of normals (optional) + + Returns: + Dictionary formatted for Sonata input + """ + # Handle single batch case + if points.dim() == 2: + points = points.unsqueeze(0) + if normals is not None: + normals = normals.unsqueeze(0) + # print('Sonata points shape: ', points.shape) + B, N, _ = points.shape + + # Prepare batch indices + batch_idx = torch.arange(B).view(-1, 1).repeat(1, N).reshape(-1) + + # Flatten points for Sonata format + coord = points.reshape(B * N, 3) + + if normals is not None: + normal = normals.reshape(B * N, 3) + else: + # Generate dummy normals if not provided + normal = torch.ones_like(coord) + + # Generate dummy colors + color = torch.ones_like(coord) + + # Function to convert tensor to numpy array, handling BFloat16 + def to_numpy(tensor): + # First convert to CPU if needed + if tensor.is_cuda: + tensor = tensor.cpu() + # Convert BFloat16 or other unsupported dtypes to float32 + if tensor.dtype not in [ + torch.float32, + torch.float64, + torch.int32, + torch.int64, + torch.uint8, + torch.int8, + torch.int16, + ]: + tensor = tensor.to(torch.float32) + # Then convert to numpy + return tensor.numpy() + + # Create data dict + data_dict = { + "coord": to_numpy(coord), + "normal": to_numpy(normal), + "color": to_numpy(color), + "batch": to_numpy(batch_idx), + } + + # Apply transform + data_dict = self.transform(data_dict) + + return data_dict, B, N + + def forward( + self, points: torch.Tensor, normals: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Extract features from point clouds. + + Args: + points: [B, N, 3] or [N, 3] tensor of point coordinates + normals: [B, N, 3] or [N, 3] tensor of normals (optional) + + Returns: + features: [B, N, 512] or [N, 512] tensor of features + """ + # Store original shape + original_shape = points.shape + single_batch = points.dim() == 2 + + # Prepare data for Sonata + data_dict, B, N = self.prepare_batch_data(points, normals) + + # Move to GPU if needed and convert to appropriate dtype + device = points.device + dtype = points.dtype + + # Make sure the entire model is in the correct dtype + # if dtype != self._original_dtype: + # self.to(dtype) + # self._original_dtype = dtype + + for key in data_dict.keys(): + if isinstance(data_dict[key], torch.Tensor): + # Convert tensors to the right device and dtype if they're floating point + if data_dict[key].is_floating_point(): + data_dict[key] = data_dict[key].to(device=device, dtype=dtype) + else: + # For integer tensors, just move to device without changing dtype + data_dict[key] = data_dict[key].to(device) + + # Extract Sonata features + point = self.sonata(data_dict) + + # Handle pooling layers (same as in train-sonata.py) + while "pooling_parent" in point.keys(): + assert "pooling_inverse" in point.keys() + parent = point.pop("pooling_parent") + inverse = point.pop("pooling_inverse") + parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1) + point = parent + + # Get features and apply MLP + feat = point.feat # [M, 1232] + feat = self.mlp(feat) # [M, 512] + + # Map back to original points + feat = feat[point.inverse] # [B*N, 512] + + # Reshape to batch format + feat = feat.reshape(B, -1, feat.shape[-1]) # [B, N, 512] + + # Return in original format + if single_batch: + feat = feat.squeeze(0) # [N, 512] + + return feat + + def extract_features_batch( + self, + points_list: list, + normals_list: Optional[list] = None, + batch_size: int = 8, + ) -> list: + """ + Extract features for multiple point clouds in batches. + + Args: + points_list: List of [N_i, 3] tensors + normals_list: List of [N_i, 3] tensors (optional) + batch_size: Batch size for processing + + Returns: + List of [N_i, 512] feature tensors + """ + features_list = [] + + # Process in batches + for i in range(0, len(points_list), batch_size): + batch_points = points_list[i : i + batch_size] + batch_normals = normals_list[i : i + batch_size] if normals_list else None + + # Find max points in batch + max_n = max(p.shape[0] for p in batch_points) + + # Pad to same size + padded_points = [] + masks = [] + for points in batch_points: + n = points.shape[0] + if n < max_n: + padding = torch.zeros(max_n - n, 3, device=points.device) + points = torch.cat([points, padding], dim=0) + padded_points.append(points) + mask = torch.zeros(max_n, dtype=torch.bool, device=points.device) + mask[:n] = True + masks.append(mask) + + # Stack batch + batch_tensor = torch.stack(padded_points) # [B, max_n, 3] + + # Handle normals similarly if provided + if batch_normals: + padded_normals = [] + for j, normals in enumerate(batch_normals): + n = normals.shape[0] + if n < max_n: + padding = torch.ones(max_n - n, 3, device=normals.device) + normals = torch.cat([normals, padding], dim=0) + padded_normals.append(normals) + normals_tensor = torch.stack(padded_normals) + else: + normals_tensor = None + + # Extract features + with torch.cuda.amp.autocast(enabled=True): + batch_features = self.forward( + batch_tensor, normals_tensor + ) # [B, max_n, 512] + + # Unpad and add to results + for j, (feat, mask) in enumerate(zip(batch_features, masks)): + features_list.append(feat[mask]) + + return features_list + diff --git a/XPart/partgen/models/diffusion/schedulers.py b/XPart/partgen/models/diffusion/schedulers.py new file mode 100755 index 0000000000000000000000000000000000000000..25d6adf44ca4e4cff283025586b378b9027a157a --- /dev/null +++ b/XPart/partgen/models/diffusion/schedulers.py @@ -0,0 +1,329 @@ +# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT +# except for the third-party components listed below. +# Hunyuan 3D does not impose any additional limitations beyond what is outlined +# in the repsective licenses of these third-party components. +# Users must comply with all terms and conditions of original licenses of these third-party +# components and must ensure that the usage of the third party components adheres to +# all relevant laws and regulations. + +# For avoidance of doubts, Hunyuan 3D means the large language models and +# their software and algorithms, including trained model weights, parameters (including +# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, +# fine-tuning enabling code and other elements of the foregoing made publicly available +# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput, logging + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + NOTE: this is very similar to diffusers.FlowMatchEulerDiscreteScheduler. Except our timesteps are reversed + + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting=False, + ): + timesteps = np.linspace( + 1, num_train_timesteps, num_train_timesteps, dtype=np.float32 + ).copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + + self._step_index = None + self._begin_index = None + + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_noise( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Forward process in flow-matching + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) + + if sample.device.type == "mps" and torch.is_floating_point(timestep): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) + timestep = timestep.to(sample.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(sample.device) + timestep = timestep.to(sample.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) for t in timestep + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timestep.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timestep.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(sample.shape): + sigma = sigma.unsqueeze(-1) + + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + " you have a pass a value for `mu` when `use_dynamic_shifting` is set" + " to be `True`" + ) + + if sigmas is None: + self.num_inference_steps = num_inference_steps + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), + self._sigma_to_t(self.sigma_min), + num_inference_steps, + ) + + sigmas = timesteps / self.config.num_train_timesteps + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + timesteps = sigmas * self.config.num_train_timesteps + + self.timesteps = timesteps.to(device=device) + self.sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)]) + + self._step_index = None + self._begin_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as" + " timesteps to `EulerDiscreteScheduler.step()` is not supported. Make" + " sure to pass one of the `scheduler.timesteps` as a timestep.", + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] + + prev_sample = sample + (sigma_next - sigma) * model_output + + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + + def __len__(self): + return self.config.num_train_timesteps + diff --git a/XPart/partgen/models/diffusion/transport/__init__.py b/XPart/partgen/models/diffusion/transport/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..71164eaf5c447f9fc7bf3bc6ab4e8d6e0c67a05c --- /dev/null +++ b/XPart/partgen/models/diffusion/transport/__init__.py @@ -0,0 +1,97 @@ +# This file includes code derived from the SiT project (https://github.com/willisma/SiT), +# which is licensed under the MIT License. +# +# MIT License +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from .transport import Transport, ModelType, WeightType, PathType, Sampler + + +def create_transport( + path_type='Linear', + prediction="velocity", + loss_weight=None, + train_eps=None, + sample_eps=None, + train_sample_type="uniform", + mean = 0.0, + std = 1.0, + shift_scale = 1.0, +): + """function for creating Transport object + **Note**: model prediction defaults to velocity + Args: + - path_type: type of path to use; default to linear + - learn_score: set model prediction to score + - learn_noise: set model prediction to noise + - velocity_weighted: weight loss by velocity weight + - likelihood_weighted: weight loss by likelihood weight + - train_eps: small epsilon for avoiding instability during training + - sample_eps: small epsilon for avoiding instability during sampling + """ + + if prediction == "noise": + model_type = ModelType.NOISE + elif prediction == "score": + model_type = ModelType.SCORE + else: + model_type = ModelType.VELOCITY + + if loss_weight == "velocity": + loss_type = WeightType.VELOCITY + elif loss_weight == "likelihood": + loss_type = WeightType.LIKELIHOOD + else: + loss_type = WeightType.NONE + + path_choice = { + "Linear": PathType.LINEAR, + "GVP": PathType.GVP, + "VP": PathType.VP, + } + + path_type = path_choice[path_type] + + if (path_type in [PathType.VP]): + train_eps = 1e-5 if train_eps is None else train_eps + sample_eps = 1e-3 if train_eps is None else sample_eps + elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY): + train_eps = 1e-3 if train_eps is None else train_eps + sample_eps = 1e-3 if train_eps is None else sample_eps + else: # velocity & [GVP, LINEAR] is stable everywhere + train_eps = 0 + sample_eps = 0 + + # create flow state + state = Transport( + model_type=model_type, + path_type=path_type, + loss_type=loss_type, + train_eps=train_eps, + sample_eps=sample_eps, + train_sample_type=train_sample_type, + mean=mean, + std=std, + shift_scale =shift_scale, + ) + + return state diff --git a/XPart/partgen/models/diffusion/transport/integrators.py b/XPart/partgen/models/diffusion/transport/integrators.py new file mode 100755 index 0000000000000000000000000000000000000000..e2e011492ce8d7acc296df9bfe57d45ae01893e7 --- /dev/null +++ b/XPart/partgen/models/diffusion/transport/integrators.py @@ -0,0 +1,142 @@ +# This file includes code derived from the SiT project (https://github.com/willisma/SiT), +# which is licensed under the MIT License. +# +# MIT License +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import numpy as np +import torch as th +import torch.nn as nn +from torchdiffeq import odeint +from functools import partial +from tqdm import tqdm + +class sde: + """SDE solver class""" + def __init__( + self, + drift, + diffusion, + *, + t0, + t1, + num_steps, + sampler_type, + ): + assert t0 < t1, "SDE sampler has to be in forward time" + + self.num_timesteps = num_steps + self.t = th.linspace(t0, t1, num_steps) + self.dt = self.t[1] - self.t[0] + self.drift = drift + self.diffusion = diffusion + self.sampler_type = sampler_type + + def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs): + w_cur = th.randn(x.size()).to(x) + t = th.ones(x.size(0)).to(x) * t + dw = w_cur * th.sqrt(self.dt) + drift = self.drift(x, t, model, **model_kwargs) + diffusion = self.diffusion(x, t) + mean_x = x + drift * self.dt + x = mean_x + th.sqrt(2 * diffusion) * dw + return x, mean_x + + def __Heun_step(self, x, _, t, model, **model_kwargs): + w_cur = th.randn(x.size()).to(x) + dw = w_cur * th.sqrt(self.dt) + t_cur = th.ones(x.size(0)).to(x) * t + diffusion = self.diffusion(x, t_cur) + xhat = x + th.sqrt(2 * diffusion) * dw + K1 = self.drift(xhat, t_cur, model, **model_kwargs) + xp = xhat + self.dt * K1 + K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs) + return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step + + def __forward_fn(self): + """TODO: generalize here by adding all private functions ending with steps to it""" + sampler_dict = { + "Euler": self.__Euler_Maruyama_step, + "Heun": self.__Heun_step, + } + + try: + sampler = sampler_dict[self.sampler_type] + except: + raise NotImplementedError("Smapler type not implemented.") + + return sampler + + def sample(self, init, model, **model_kwargs): + """forward loop of sde""" + x = init + mean_x = init + samples = [] + sampler = self.__forward_fn() + for ti in self.t[:-1]: + with th.no_grad(): + x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs) + samples.append(x) + + return samples + +class ode: + """ODE solver class""" + def __init__( + self, + drift, + *, + t0, + t1, + sampler_type, + num_steps, + atol, + rtol, + ): + assert t0 < t1, "ODE sampler has to be in forward time" + + self.drift = drift + self.t = th.linspace(t0, t1, num_steps) + self.atol = atol + self.rtol = rtol + self.sampler_type = sampler_type + + def sample(self, x, model, **model_kwargs): + + device = x[0].device if isinstance(x, tuple) else x.device + def _fn(t, x): + t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t + model_output = self.drift(x, t, model, **model_kwargs) + return model_output + + t = self.t.to(device) + atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol] + rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol] + samples = odeint( + _fn, + x, + t, + method=self.sampler_type, + atol=atol, + rtol=rtol + ) + return samples diff --git a/XPart/partgen/models/diffusion/transport/path.py b/XPart/partgen/models/diffusion/transport/path.py new file mode 100755 index 0000000000000000000000000000000000000000..9c867929e9c570a5465c6bfd337cce27b0447041 --- /dev/null +++ b/XPart/partgen/models/diffusion/transport/path.py @@ -0,0 +1,220 @@ +# This file includes code derived from the SiT project (https://github.com/willisma/SiT), +# which is licensed under the MIT License. +# +# MIT License +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch as th +import numpy as np +from functools import partial + +def expand_t_like_x(t, x): + """Function to reshape time t to broadcastable dimension of x + Args: + t: [batch_dim,], time vector + x: [batch_dim,...], data point + """ + dims = [1] * (len(x.size()) - 1) + t = t.view(t.size(0), *dims) + return t + + +#################### Coupling Plans #################### + +class ICPlan: + """Linear Coupling Plan""" + def __init__(self, sigma=0.0): + self.sigma = sigma + + def compute_alpha_t(self, t): + """Compute the data coefficient along the path""" + return t, 1 + + def compute_sigma_t(self, t): + """Compute the noise coefficient along the path""" + return 1 - t, -1 + + def compute_d_alpha_alpha_ratio_t(self, t): + """Compute the ratio between d_alpha and alpha""" + return 1 / t + + def compute_drift(self, x, t): + """We always output sde according to score parametrization; """ + t = expand_t_like_x(t, x) + alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t) + sigma_t, d_sigma_t = self.compute_sigma_t(t) + drift = alpha_ratio * x + diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t + + return -drift, diffusion + + def compute_diffusion(self, x, t, form="constant", norm=1.0): + """Compute the diffusion term of the SDE + Args: + x: [batch_dim, ...], data point + t: [batch_dim,], time vector + form: str, form of the diffusion term + norm: float, norm of the diffusion term + """ + t = expand_t_like_x(t, x) + choices = { + "constant": norm, + "SBDM": norm * self.compute_drift(x, t)[1], + "sigma": norm * self.compute_sigma_t(t)[0], + "linear": norm * (1 - t), + "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2, + "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2, + } + + try: + diffusion = choices[form] + except KeyError: + raise NotImplementedError(f"Diffusion form {form} not implemented") + + return diffusion + + def get_score_from_velocity(self, velocity, x, t): + """Wrapper function: transfrom velocity prediction model to score + Args: + velocity: [batch_dim, ...] shaped tensor; velocity model output + x: [batch_dim, ...] shaped tensor; x_t data point + t: [batch_dim,] time tensor + """ + t = expand_t_like_x(t, x) + alpha_t, d_alpha_t = self.compute_alpha_t(t) + sigma_t, d_sigma_t = self.compute_sigma_t(t) + mean = x + reverse_alpha_ratio = alpha_t / d_alpha_t + var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t + score = (reverse_alpha_ratio * velocity - mean) / var + return score + + def get_noise_from_velocity(self, velocity, x, t): + """Wrapper function: transfrom velocity prediction model to denoiser + Args: + velocity: [batch_dim, ...] shaped tensor; velocity model output + x: [batch_dim, ...] shaped tensor; x_t data point + t: [batch_dim,] time tensor + """ + t = expand_t_like_x(t, x) + alpha_t, d_alpha_t = self.compute_alpha_t(t) + sigma_t, d_sigma_t = self.compute_sigma_t(t) + mean = x + reverse_alpha_ratio = alpha_t / d_alpha_t + var = reverse_alpha_ratio * d_sigma_t - sigma_t + noise = (reverse_alpha_ratio * velocity - mean) / var + return noise + + def get_velocity_from_score(self, score, x, t): + """Wrapper function: transfrom score prediction model to velocity + Args: + score: [batch_dim, ...] shaped tensor; score model output + x: [batch_dim, ...] shaped tensor; x_t data point + t: [batch_dim,] time tensor + """ + t = expand_t_like_x(t, x) + drift, var = self.compute_drift(x, t) + velocity = var * score - drift + return velocity + + def compute_mu_t(self, t, x0, x1): + """Compute the mean of time-dependent density p_t""" + t = expand_t_like_x(t, x1) + alpha_t, _ = self.compute_alpha_t(t) + sigma_t, _ = self.compute_sigma_t(t) + # t*x1 + (1-t)*x0 ; t=0 x0; t=1 x1 + return alpha_t * x1 + sigma_t * x0 + + def compute_xt(self, t, x0, x1): + """Sample xt from time-dependent density p_t; rng is required""" + xt = self.compute_mu_t(t, x0, x1) + return xt + + def compute_ut(self, t, x0, x1, xt): + """Compute the vector field corresponding to p_t""" + t = expand_t_like_x(t, x1) + _, d_alpha_t = self.compute_alpha_t(t) + _, d_sigma_t = self.compute_sigma_t(t) + return d_alpha_t * x1 + d_sigma_t * x0 + + def plan(self, t, x0, x1): + xt = self.compute_xt(t, x0, x1) + ut = self.compute_ut(t, x0, x1, xt) + return t, xt, ut + + +class VPCPlan(ICPlan): + """class for VP path flow matching""" + + def __init__(self, sigma_min=0.1, sigma_max=20.0): + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * \ + (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min + self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * \ + (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min + + + def compute_alpha_t(self, t): + """Compute coefficient of x1""" + alpha_t = self.log_mean_coeff(t) + alpha_t = th.exp(alpha_t) + d_alpha_t = alpha_t * self.d_log_mean_coeff(t) + return alpha_t, d_alpha_t + + def compute_sigma_t(self, t): + """Compute coefficient of x0""" + p_sigma_t = 2 * self.log_mean_coeff(t) + sigma_t = th.sqrt(1 - th.exp(p_sigma_t)) + d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t) + return sigma_t, d_sigma_t + + def compute_d_alpha_alpha_ratio_t(self, t): + """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" + return self.d_log_mean_coeff(t) + + def compute_drift(self, x, t): + """Compute the drift term of the SDE""" + t = expand_t_like_x(t, x) + beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min) + return -0.5 * beta_t * x, beta_t / 2 + + +class GVPCPlan(ICPlan): + def __init__(self, sigma=0.0): + super().__init__(sigma) + + def compute_alpha_t(self, t): + """Compute coefficient of x1""" + alpha_t = th.sin(t * np.pi / 2) + d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2) + return alpha_t, d_alpha_t + + def compute_sigma_t(self, t): + """Compute coefficient of x0""" + sigma_t = th.cos(t * np.pi / 2) + d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2) + return sigma_t, d_sigma_t + + def compute_d_alpha_alpha_ratio_t(self, t): + """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" + return np.pi / (2 * th.tan(t * np.pi / 2)) diff --git a/XPart/partgen/models/diffusion/transport/transport.py b/XPart/partgen/models/diffusion/transport/transport.py new file mode 100755 index 0000000000000000000000000000000000000000..ea566fd1ca6d845097939f8c43e279d9000982ba --- /dev/null +++ b/XPart/partgen/models/diffusion/transport/transport.py @@ -0,0 +1,506 @@ +# This file includes code derived from the SiT project (https://github.com/willisma/SiT), +# which is licensed under the MIT License. +# +# MIT License +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch as th +import numpy as np +import logging + +import enum + +from . import path +from .utils import EasyDict, log_state, mean_flat +from .integrators import ode, sde + + +class ModelType(enum.Enum): + """ + Which type of output the model predicts. + """ + + NOISE = enum.auto() # the model predicts epsilon + SCORE = enum.auto() # the model predicts \nabla \log p(x) + VELOCITY = enum.auto() # the model predicts v(x) + + +class PathType(enum.Enum): + """ + Which type of path to use. + """ + + LINEAR = enum.auto() + GVP = enum.auto() + VP = enum.auto() + + +class WeightType(enum.Enum): + """ + Which type of weighting to use. + """ + + NONE = enum.auto() + VELOCITY = enum.auto() + LIKELIHOOD = enum.auto() + + +class Transport: + + def __init__( + self, + *, + model_type, + path_type, + loss_type, + train_eps, + sample_eps, + train_sample_type="uniform", + **kwargs, + ): + path_options = { + PathType.LINEAR: path.ICPlan, + PathType.GVP: path.GVPCPlan, + PathType.VP: path.VPCPlan, + } + + self.loss_type = loss_type + self.model_type = model_type + self.path_sampler = path_options[path_type]() + self.train_eps = train_eps + self.sample_eps = sample_eps + self.train_sample_type = train_sample_type + if self.train_sample_type == "logit_normal": + self.mean = kwargs["mean"] + self.std = kwargs["std"] + self.shift_scale = kwargs["shift_scale"] + print(f"using logit normal sample, shift scale is {self.shift_scale}") + + def prior_logp(self, z): + """ + Standard multivariate normal prior + Assume z is batched + """ + shape = th.tensor(z.size()) + N = th.prod(shape[1:]) + _fn = lambda x: -N / 2.0 * np.log(2 * np.pi) - th.sum(x**2) / 2.0 + return th.vmap(_fn)(z) + + def check_interval( + self, + train_eps, + sample_eps, + *, + diffusion_form="SBDM", + sde=False, + reverse=False, + eval=False, + last_step_size=0.0, + ): + t0 = 0 + t1 = 1 + eps = train_eps if not eval else sample_eps + if type(self.path_sampler) in [path.VPCPlan]: + + t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size + + elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) and ( + self.model_type != ModelType.VELOCITY or sde + ): # avoid numerical issue by taking a first semi-implicit step + + t0 = ( + eps + if (diffusion_form == "SBDM" and sde) + or self.model_type != ModelType.VELOCITY + else 0 + ) + t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size + + if reverse: + t0, t1 = 1 - t0, 1 - t1 + + return t0, t1 + + def sample(self, x1): + """Sampling x0 & t based on shape of x1 (if needed) + Args: + x1 - data point; [batch, *dim] + """ + + x0 = th.randn_like(x1) + if self.train_sample_type == "uniform": + t0, t1 = self.check_interval(self.train_eps, self.sample_eps) + t = th.rand((x1.shape[0],)) * (t1 - t0) + t0 + t = t.to(x1) + elif self.train_sample_type == "logit_normal": + t = th.randn((x1.shape[0],)) * self.std + self.mean + t = t.to(x1) + t = 1 / (1 + th.exp(-t)) + + t = ( + np.sqrt(self.shift_scale) + * t + / (1 + (np.sqrt(self.shift_scale) - 1) * t) + ) + + return t, x0, x1 + + def training_losses(self, model, x1, model_kwargs=None): + """Loss for training the score model + Args: + - model: backbone model; could be score, noise, or velocity + - x1: datapoint + - model_kwargs: additional arguments for the model + """ + if model_kwargs == None: + model_kwargs = {} + + t, x0, x1 = self.sample(x1) + t, xt, ut = self.path_sampler.plan(t, x0, x1) + model_output = model(xt, t, **model_kwargs) + B, *_, C = xt.shape + assert model_output.size() == (B, *xt.size()[1:-1], C) + + terms = {} + terms["pred"] = model_output + if self.model_type == ModelType.VELOCITY: + terms["loss"] = mean_flat(((model_output - ut) ** 2)) + else: + _, drift_var = self.path_sampler.compute_drift(xt, t) + sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt)) + if self.loss_type in [WeightType.VELOCITY]: + weight = (drift_var / sigma_t) ** 2 + elif self.loss_type in [WeightType.LIKELIHOOD]: + weight = drift_var / (sigma_t**2) + elif self.loss_type in [WeightType.NONE]: + weight = 1 + else: + raise NotImplementedError() + + if self.model_type == ModelType.NOISE: + terms["loss"] = mean_flat(weight * ((model_output - x0) ** 2)) + else: + terms["loss"] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2)) + + return terms + + def get_drift(self): + """member function for obtaining the drift of the probability flow ODE""" + + def score_ode(x, t, model, **model_kwargs): + drift_mean, drift_var = self.path_sampler.compute_drift(x, t) + model_output = model(x, t, **model_kwargs) + return -drift_mean + drift_var * model_output # by change of variable + + def noise_ode(x, t, model, **model_kwargs): + drift_mean, drift_var = self.path_sampler.compute_drift(x, t) + sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x)) + model_output = model(x, t, **model_kwargs) + score = model_output / -sigma_t + return -drift_mean + drift_var * score + + def velocity_ode(x, t, model, **model_kwargs): + model_output = model(x, t, **model_kwargs) + return model_output + + if self.model_type == ModelType.NOISE: + drift_fn = noise_ode + elif self.model_type == ModelType.SCORE: + drift_fn = score_ode + else: + drift_fn = velocity_ode + + def body_fn(x, t, model, **model_kwargs): + model_output = drift_fn(x, t, model, **model_kwargs) + assert ( + model_output.shape == x.shape + ), "Output shape from ODE solver must match input shape" + return model_output + + return body_fn + + def get_score( + self, + ): + """member function for obtaining score of + x_t = alpha_t * x + sigma_t * eps""" + if self.model_type == ModelType.NOISE: + score_fn = ( + lambda x, t, model, **kwargs: model(x, t, **kwargs) + / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0] + ) + elif self.model_type == ModelType.SCORE: + score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs) + elif self.model_type == ModelType.VELOCITY: + score_fn = ( + lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity( + model(x, t, **kwargs), x, t + ) + ) + else: + raise NotImplementedError() + + return score_fn + + +class Sampler: + """Sampler class for the transport model""" + + def __init__( + self, + transport, + ): + """Constructor for a general sampler; supporting different sampling methods + Args: + - transport: an tranport object specify model prediction & interpolant type + """ + + self.transport = transport + self.drift = self.transport.get_drift() + self.score = self.transport.get_score() + + def __get_sde_diffusion_and_drift( + self, + *, + diffusion_form="SBDM", + diffusion_norm=1.0, + ): + + def diffusion_fn(x, t): + diffusion = self.transport.path_sampler.compute_diffusion( + x, t, form=diffusion_form, norm=diffusion_norm + ) + return diffusion + + sde_drift = lambda x, t, model, **kwargs: self.drift( + x, t, model, **kwargs + ) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs) + + sde_diffusion = diffusion_fn + + return sde_drift, sde_diffusion + + def __get_last_step( + self, + sde_drift, + *, + last_step, + last_step_size, + ): + """Get the last step function of the SDE solver""" + + if last_step is None: + last_step_fn = lambda x, t, model, **model_kwargs: x + elif last_step == "Mean": + last_step_fn = ( + lambda x, t, model, **model_kwargs: x + + sde_drift(x, t, model, **model_kwargs) * last_step_size + ) + elif last_step == "Tweedie": + alpha = ( + self.transport.path_sampler.compute_alpha_t + ) # simple aliasing; the original name was too long + sigma = self.transport.path_sampler.compute_sigma_t + last_step_fn = lambda x, t, model, **model_kwargs: x / alpha(t)[0][0] + ( + sigma(t)[0][0] ** 2 + ) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs) + elif last_step == "Euler": + last_step_fn = ( + lambda x, t, model, **model_kwargs: x + + self.drift(x, t, model, **model_kwargs) * last_step_size + ) + else: + raise NotImplementedError() + + return last_step_fn + + def sample_sde( + self, + *, + sampling_method="Euler", + diffusion_form="SBDM", + diffusion_norm=1.0, + last_step="Mean", + last_step_size=0.04, + num_steps=250, + ): + """returns a sampling function with given SDE settings + Args: + - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama + - diffusion_form: function form of diffusion coefficient; default to be matching SBDM + - diffusion_norm: function magnitude of diffusion coefficient; default to 1 + - last_step: type of the last step; default to identity + - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1] + - num_steps: total integration step of SDE + """ + + if last_step is None: + last_step_size = 0.0 + + sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift( + diffusion_form=diffusion_form, + diffusion_norm=diffusion_norm, + ) + + t0, t1 = self.transport.check_interval( + self.transport.train_eps, + self.transport.sample_eps, + diffusion_form=diffusion_form, + sde=True, + eval=True, + reverse=False, + last_step_size=last_step_size, + ) + + _sde = sde( + sde_drift, + sde_diffusion, + t0=t0, + t1=t1, + num_steps=num_steps, + sampler_type=sampling_method, + ) + + last_step_fn = self.__get_last_step( + sde_drift, last_step=last_step, last_step_size=last_step_size + ) + + def _sample(init, model, **model_kwargs): + xs = _sde.sample(init, model, **model_kwargs) + ts = th.ones(init.size(0), device=init.device) * t1 + x = last_step_fn(xs[-1], ts, model, **model_kwargs) + xs.append(x) + + assert len(xs) == num_steps, "Samples does not match the number of steps" + + return xs + + return _sample + + def sample_ode( + self, + *, + sampling_method="dopri5", + num_steps=50, + atol=1e-6, + rtol=1e-3, + reverse=False, + ): + """returns a sampling function with given ODE settings + Args: + - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 + - num_steps: + - fixed solver (Euler, Heun): the actual number of integration steps performed + - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation + - atol: absolute error tolerance for the solver + - rtol: relative error tolerance for the solver + - reverse: whether solving the ODE in reverse (data to noise); default to False + """ + if reverse: + drift = lambda x, t, model, **kwargs: self.drift( + x, th.ones_like(t) * (1 - t), model, **kwargs + ) + else: + drift = self.drift + + t0, t1 = self.transport.check_interval( + self.transport.train_eps, + self.transport.sample_eps, + sde=False, + eval=True, + reverse=reverse, + last_step_size=0.0, + ) + + _ode = ode( + drift=drift, + t0=t0, + t1=t1, + sampler_type=sampling_method, + num_steps=num_steps, + atol=atol, + rtol=rtol, + ) + + return _ode.sample + + + def sample_ode_likelihood( + self, + *, + sampling_method="dopri5", + num_steps=50, + atol=1e-6, + rtol=1e-3, + ): + """returns a sampling function for calculating likelihood with given ODE settings + Args: + - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 + - num_steps: + - fixed solver (Euler, Heun): the actual number of integration steps performed + - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation + - atol: absolute error tolerance for the solver + - rtol: relative error tolerance for the solver + """ + + def _likelihood_drift(x, t, model, **model_kwargs): + x, _ = x + eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1 + t = th.ones_like(t) * (1 - t) + with th.enable_grad(): + x.requires_grad = True + grad = th.autograd.grad( + th.sum(self.drift(x, t, model, **model_kwargs) * eps), x + )[0] + logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size())))) + drift = self.drift(x, t, model, **model_kwargs) + return (-drift, logp_grad) + + t0, t1 = self.transport.check_interval( + self.transport.train_eps, + self.transport.sample_eps, + sde=False, + eval=True, + reverse=False, + last_step_size=0.0, + ) + + _ode = ode( + drift=_likelihood_drift, + t0=t0, + t1=t1, + sampler_type=sampling_method, + num_steps=num_steps, + atol=atol, + rtol=rtol, + ) + + def _sample_fn(x, model, **model_kwargs): + init_logp = th.zeros(x.size(0)).to(x) + input = (x, init_logp) + drift, delta_logp = _ode.sample(input, model, **model_kwargs) + drift, delta_logp = drift[-1], delta_logp[-1] + prior_logp = self.transport.prior_logp(drift) + logp = prior_logp - delta_logp + return logp, drift + + return _sample_fn diff --git a/XPart/partgen/models/diffusion/transport/utils.py b/XPart/partgen/models/diffusion/transport/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..5830aab015f9186c3e02d462af7580cfb223cfae --- /dev/null +++ b/XPart/partgen/models/diffusion/transport/utils.py @@ -0,0 +1,54 @@ +# This file includes code derived from the SiT project (https://github.com/willisma/SiT), +# which is licensed under the MIT License. +# +# MIT License +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch as th + +class EasyDict: + + def __init__(self, sub_dict): + for k, v in sub_dict.items(): + setattr(self, k, v) + + def __getitem__(self, key): + return getattr(self, key) + +def mean_flat(x): + """ + Take the mean over all non-batch dimensions. + """ + return th.mean(x, dim=list(range(1, len(x.size())))) + +def log_state(state): + result = [] + + sorted_state = dict(sorted(state.items())) + for key, value in sorted_state.items(): + # Check if the value is an instance of a class + if " None: + import torch.nn.init as init + + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + # print(bsz, seq_len, h) + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear(hidden_states, self.weight, None) + if self.scoring_func == "softmax": + scores = logits.softmax(dim=-1) + else: + raise NotImplementedError( + f"insupportable scoring function for MoE gating: {self.scoring_func}" + ) + + ### select top-k experts + topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) + + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + + ### expert-level computation auxiliary loss + if self.training and self.alpha > 0.0: + scores_for_aux = scores + aux_topk = self.top_k + # always compute aux loss based on the naive greedy topk method + topk_idx_for_aux_loss = topk_idx.view(bsz, -1) + if self.seq_aux: + scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) + ce = torch.zeros( + bsz, self.n_routed_experts, device=hidden_states.device + ) + ce.scatter_add_( + 1, + topk_idx_for_aux_loss, + torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device), + ).div_(seq_len * aux_topk / self.n_routed_experts) + aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() + aux_loss = aux_loss * self.alpha + else: + mask_ce = F.one_hot( + topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts + ) + ce = mask_ce.float().mean(0) + Pi = scores_for_aux.mean(0) + fi = ce * self.n_routed_experts + aux_loss = (Pi * fi).sum() * self.alpha + else: + aux_loss = None + return topk_idx, topk_weight, aux_loss + + +class MoEBlock(nn.Module): + def __init__( + self, + dim, + num_experts=8, + moe_top_k=2, + activation_fn="gelu", + dropout=0.0, + final_dropout=False, + ff_inner_dim=None, + ff_bias=True, + ): + super().__init__() + self.moe_top_k = moe_top_k + self.experts = nn.ModuleList([ + FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + for i in range(num_experts) + ]) + self.gate = MoEGate( + embed_dim=dim, num_experts=num_experts, num_experts_per_tok=moe_top_k + ) + + self.shared_experts = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + def initialize_weight(self): + pass + + def forward(self, hidden_states): + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight, aux_loss = self.gate(hidden_states) + + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + if self.training: + hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim=0) + y = torch.empty_like(hidden_states, dtype=hidden_states.dtype) + for i, expert in enumerate(self.experts): + tmp = expert(hidden_states[flat_topk_idx == i]) + y[flat_topk_idx == i] = tmp.to(hidden_states.dtype) + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + y = y.view(*orig_shape) + y = AddAuxiliaryLoss.apply(y, aux_loss) + else: + y = self.moe_infer( + hidden_states, flat_topk_idx, topk_weight.view(-1, 1) + ).view(*orig_shape) + y = y + self.shared_experts(identity) + return y + + @torch.no_grad() + def moe_infer(self, x, flat_expert_indices, flat_expert_weights): + expert_cache = torch.zeros_like(x) + idxs = flat_expert_indices.argsort() + tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) + token_idxs = idxs // self.moe_top_k + for i, end_idx in enumerate(tokens_per_expert): + start_idx = 0 if i == 0 else tokens_per_expert[i - 1] + if start_idx == end_idx: + continue + expert = self.experts[i] + exp_token_idx = token_idxs[start_idx:end_idx] + expert_tokens = x[exp_token_idx] + expert_out = expert(expert_tokens) + expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) + + # for fp16 and other dtype + expert_cache = expert_cache.to(expert_out.dtype) + expert_cache.scatter_reduce_( + 0, + exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), + expert_out, + reduce="sum", + ) + return expert_cache diff --git a/XPart/partgen/models/partformer_dit.py b/XPart/partgen/models/partformer_dit.py new file mode 100755 index 0000000000000000000000000000000000000000..3fef8403f5c66a101954f39bf6c9f7459f96771d --- /dev/null +++ b/XPart/partgen/models/partformer_dit.py @@ -0,0 +1,756 @@ +# Newest version: add local&global context (cross-attn), and local&global attn (self-attn) +import math + +import torch.nn.functional as F + +import torch.nn as nn +import torch +from typing import Optional +from einops import rearrange +from .moe_layers import MoEBlock +import numpy as np + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + return np.concatenate([emb_sin, emb_cos], axis=1) + + +class Timesteps(nn.Module): + def __init__( + self, + num_channels: int, + downscale_freq_shift: float = 0.0, + scale: int = 1, + max_period: int = 10000, + ): + super().__init__() + self.num_channels = num_channels + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + self.max_period = max_period + + def forward(self, timesteps): + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + embedding_dim = self.num_channels + half_dim = embedding_dim // 2 + exponent = -math.log(self.max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - self.downscale_freq_shift) + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + emb = self.scale * emb + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__( + self, + hidden_size, + frequency_embedding_size=256, + cond_proj_dim=None, + out_size=None, + ): + super().__init__() + if out_size is None: + out_size = hidden_size + self.mlp = nn.Sequential( + nn.Linear(hidden_size, frequency_embedding_size, bias=True), + nn.GELU(), + nn.Linear(frequency_embedding_size, out_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear( + cond_proj_dim, frequency_embedding_size, bias=False + ) + + self.time_embed = Timesteps(hidden_size) + + def forward(self, t, condition): + + t_freq = self.time_embed(t).type(self.mlp[0].weight.dtype) + + # t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype) + if condition is not None: + t_freq = t_freq + self.cond_proj(condition) + + t = self.mlp(t_freq) + t = t.unsqueeze(dim=1) + return t + + +class MLP(nn.Module): + def __init__(self, *, width: int): + super().__init__() + self.width = width + self.fc1 = nn.Linear(width, width * 4) + self.fc2 = nn.Linear(width * 4, width) + self.gelu = nn.GELU() + + def forward(self, x): + return self.fc2(self.gelu(self.fc1(x))) + + +class CrossAttention(nn.Module): + def __init__( + self, + qdim, + kdim, + num_heads, + qkv_bias=True, + qk_norm=False, + norm_layer=nn.LayerNorm, + with_decoupled_ca=False, + decoupled_ca_dim=16, + decoupled_ca_weight=1.0, + **kwargs, + ): + super().__init__() + self.qdim = qdim + self.kdim = kdim + self.num_heads = num_heads + assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads" + self.head_dim = self.qdim // num_heads + assert ( + self.head_dim % 8 == 0 and self.head_dim <= 128 + ), "Only support head_dim <= 128 and divisible by 8" + self.scale = self.head_dim**-0.5 + + self.to_q = nn.Linear(qdim, qdim, bias=qkv_bias) + self.to_k = nn.Linear(kdim, qdim, bias=qkv_bias) + self.to_v = nn.Linear(kdim, qdim, bias=qkv_bias) + + # TODO: eps should be 1 / 65530 if using fp16 + self.q_norm = ( + norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) + if qk_norm + else nn.Identity() + ) + self.k_norm = ( + norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) + if qk_norm + else nn.Identity() + ) + self.out_proj = nn.Linear(qdim, qdim, bias=True) + + self.with_dca = with_decoupled_ca + if self.with_dca: + self.kv_proj_dca = nn.Linear(kdim, 2 * qdim, bias=qkv_bias) + self.k_norm_dca = ( + norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) + if qk_norm + else nn.Identity() + ) + self.dca_dim = decoupled_ca_dim + self.dca_weight = decoupled_ca_weight + # zero init + nn.init.zeros_(self.out_proj.weight) + nn.init.zeros_(self.out_proj.bias) + + def forward(self, x, y): + """ + Parameters + ---------- + x: torch.Tensor + (batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim) + y: torch.Tensor + (batch, seqlen2, hidden_dim2) + freqs_cis_img: torch.Tensor + (batch, hidden_dim // 2), RoPE for image + """ + b, s1, c = x.shape # [b, s1, D] + + if self.with_dca: + token_len = y.shape[1] + context_dca = y[:, -self.dca_dim :, :] + kv_dca = self.kv_proj_dca(context_dca).view( + b, self.dca_dim, 2, self.num_heads, self.head_dim + ) + k_dca, v_dca = kv_dca.unbind(dim=2) # [b, s, h, d] + k_dca = self.k_norm_dca(k_dca) + y = y[:, : (token_len - self.dca_dim), :] + + _, s2, c = y.shape # [b, s2, 1024] + q = self.to_q(x) + k = self.to_k(y) + v = self.to_v(y) + + kv = torch.cat((k, v), dim=-1) + split_size = kv.shape[-1] // self.num_heads // 2 + kv = kv.view(1, -1, self.num_heads, split_size * 2) + k, v = torch.split(kv, split_size, dim=-1) + + q = q.view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d] + k = k.view(b, s2, self.num_heads, self.head_dim) # [b, s2, h, d] + v = v.view(b, s2, self.num_heads, self.head_dim) # [b, s2, h, d] + + q = self.q_norm(q) + k = self.k_norm(k) + + with torch.backends.cuda.sdp_kernel( + enable_flash=True, enable_math=False, enable_mem_efficient=True + ): + q, k, v = map( + lambda t: rearrange(t, "b n h d -> b h n d", h=self.num_heads), + (q, k, v), + ) + context = ( + F.scaled_dot_product_attention(q, k, v) + .transpose(1, 2) + .reshape(b, s1, -1) + ) + + if self.with_dca: + with torch.backends.cuda.sdp_kernel( + enable_flash=True, enable_math=False, enable_mem_efficient=True + ): + k_dca, v_dca = map( + lambda t: rearrange(t, "b n h d -> b h n d", h=self.num_heads), + (k_dca, v_dca), + ) + context_dca = ( + F.scaled_dot_product_attention(q, k_dca, v_dca) + .transpose(1, 2) + .reshape(b, s1, -1) + ) + + context = context + self.dca_weight * context_dca + + out = self.out_proj(context) # context.reshape - B, L1, -1 + + return out + + +class Attention(nn.Module): + """ + We rename some layer names to align with flash attention + """ + + def __init__( + self, + dim, + num_heads, + qkv_bias=True, + qk_norm=False, + norm_layer=nn.LayerNorm, + use_global_processor=False, + ): + super().__init__() + self.use_global_processor = use_global_processor + self.dim = dim + self.num_heads = num_heads + assert self.dim % num_heads == 0, "dim should be divisible by num_heads" + self.head_dim = self.dim // num_heads + # This assertion is aligned with flash attention + assert ( + self.head_dim % 8 == 0 and self.head_dim <= 128 + ), "Only support head_dim <= 128 and divisible by 8" + self.scale = self.head_dim**-0.5 + + self.to_q = nn.Linear(dim, dim, bias=qkv_bias) + self.to_k = nn.Linear(dim, dim, bias=qkv_bias) + self.to_v = nn.Linear(dim, dim, bias=qkv_bias) + # TODO: eps should be 1 / 65530 if using fp16 + self.q_norm = ( + norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) + if qk_norm + else nn.Identity() + ) + self.k_norm = ( + norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) + if qk_norm + else nn.Identity() + ) + self.out_proj = nn.Linear(dim, dim) + + # set processor + self.processor = LocalGlobalProcessor(use_global=use_global_processor) + + def forward(self, x): + return self.processor(self, x) + + +class AttentionPool(nn.Module): + def __init__( + self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None + ): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5 + ) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x, attention_mask=None): + x = x.permute(1, 0, 2) # NLC -> LNC + if attention_mask is not None: + attention_mask = attention_mask.unsqueeze(-1).permute(1, 0, 2) + global_emb = (x * attention_mask).sum(dim=0) / attention_mask.sum(dim=0) + x = torch.cat([global_emb[None,], x], dim=0) + + else: + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] + ), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + return x.squeeze(0) + + +class LocalGlobalProcessor: + def __init__(self, use_global=False): + self.use_global = use_global + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + ): + """ + hidden_states: [B, L, C] + """ + if self.use_global: + B_old, N_old, C_old = hidden_states.shape + hidden_states = hidden_states.reshape(1, -1, C_old) + B, N, C = hidden_states.shape + + q = attn.to_q(hidden_states) + k = attn.to_k(hidden_states) + v = attn.to_v(hidden_states) + + qkv = torch.cat((q, k, v), dim=-1) + split_size = qkv.shape[-1] // attn.num_heads // 3 + qkv = qkv.view(1, -1, attn.num_heads, split_size * 3) + q, k, v = torch.split(qkv, split_size, dim=-1) + + q = q.reshape(B, N, attn.num_heads, attn.head_dim).transpose( + 1, 2 + ) # [b, h, s, d] + k = k.reshape(B, N, attn.num_heads, attn.head_dim).transpose( + 1, 2 + ) # [b, h, s, d] + v = v.reshape(B, N, attn.num_heads, attn.head_dim).transpose(1, 2) + + q = attn.q_norm(q) # [b, h, s, d] + k = attn.k_norm(k) # [b, h, s, d] + + with torch.backends.cuda.sdp_kernel( + enable_flash=True, enable_math=False, enable_mem_efficient=True + ): + hidden_states = F.scaled_dot_product_attention(q, k, v) + hidden_states = hidden_states.transpose(1, 2).reshape(B, N, -1) + + hidden_states = attn.out_proj(hidden_states) + if self.use_global: + hidden_states = hidden_states.reshape(B_old, N_old, -1) + return hidden_states + + +class PartFormerDitBlock(nn.Module): + + def __init__( + self, + hidden_size, + num_heads, + use_self_attention: bool = True, + use_cross_attention: bool = False, + use_cross_attention_2: bool = False, + encoder_hidden_dim=1024, # cross-attn encoder_hidden_states dim + encoder_hidden2_dim=1024, # cross-attn 2 encoder_hidden_states dim + # cross_attn2_weight=0.0, + qkv_bias=True, + qk_norm=False, + norm_layer=nn.LayerNorm, + qk_norm_layer=nn.RMSNorm, + with_decoupled_ca=False, + decoupled_ca_dim=16, + decoupled_ca_weight=1.0, + skip_connection=False, + timested_modulate=False, + c_emb_size=0, # time embedding size + use_moe: bool = False, + num_experts: int = 8, + moe_top_k: int = 2, + ): + super().__init__() + # self.cross_attn2_weight = cross_attn2_weight + use_ele_affine = True + # ========================= Self-Attention ========================= + self.use_self_attention = use_self_attention + if self.use_self_attention: + self.norm1 = norm_layer( + hidden_size, elementwise_affine=use_ele_affine, eps=1e-6 + ) + self.attn1 = Attention( + hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + norm_layer=qk_norm_layer, + ) + + # ========================= Add ========================= + # Simply use add like SDXL. + self.timested_modulate = timested_modulate + if self.timested_modulate: + self.default_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(c_emb_size, hidden_size, bias=True) + ) + # ========================= Cross-Attention ========================= + self.use_cross_attention = use_cross_attention + if self.use_cross_attention: + self.norm2 = norm_layer( + hidden_size, elementwise_affine=use_ele_affine, eps=1e-6 + ) + self.attn2 = CrossAttention( + hidden_size, + encoder_hidden_dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + norm_layer=qk_norm_layer, + with_decoupled_ca=False, + ) + self.use_cross_attention_2 = use_cross_attention_2 + if self.use_cross_attention_2: + self.norm2_2 = norm_layer( + hidden_size, elementwise_affine=use_ele_affine, eps=1e-6 + ) + self.attn2_2 = CrossAttention( + hidden_size, + encoder_hidden2_dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + norm_layer=qk_norm_layer, + with_decoupled_ca=with_decoupled_ca, + decoupled_ca_dim=decoupled_ca_dim, + decoupled_ca_weight=decoupled_ca_weight, + ) + # ========================= FFN ========================= + self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6) + self.use_moe = use_moe + if self.use_moe: + print("using moe") + self.moe = MoEBlock( + hidden_size, + num_experts=num_experts, + moe_top_k=moe_top_k, + dropout=0.0, + activation_fn="gelu", + final_dropout=False, + ff_inner_dim=int(hidden_size * 4.0), + ff_bias=True, + ) + else: + self.mlp = MLP(width=hidden_size) + # ========================= skip FFN ========================= + if skip_connection: + self.skip_norm = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6) + self.skip_linear = nn.Linear(2 * hidden_size, hidden_size) + else: + self.skip_linear = None + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_hidden_states_2: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + skip_value: torch.Tensor = None, + ): + # skip connection + if self.skip_linear is not None: + cat = torch.cat([skip_value, hidden_states], dim=-1) + hidden_states = self.skip_linear(cat) + hidden_states = self.skip_norm(hidden_states) + # local global attn (self-attn) + if self.timested_modulate: + shift_msa = self.default_modulation(temb).unsqueeze(dim=1) + hidden_states = hidden_states + shift_msa + if self.use_self_attention: + attn_output = self.attn1(self.norm1(hidden_states)) + hidden_states = hidden_states + attn_output + # image cross attn + if self.use_cross_attention: + original_cross_out = self.attn2( + self.norm2(hidden_states), + encoder_hidden_states, + ) + # added local-global cross attn + # 2. Cross-Attention + if self.use_cross_attention_2: + cross_out_2 = self.attn2_2( + self.norm2_2(hidden_states), + encoder_hidden_states_2, + ) + hidden_states = ( + hidden_states + + (original_cross_out if self.use_cross_attention else 0) + + (cross_out_2 if self.use_cross_attention_2 else 0) + ) + + # FFN Layer + mlp_inputs = self.norm3(hidden_states) + + if self.use_moe: + hidden_states = hidden_states + self.moe(mlp_inputs) + else: + hidden_states = hidden_states + self.mlp(mlp_inputs) + + return hidden_states + + +class FinalLayer(nn.Module): + """ + The final layer of HunYuanDiT. + """ + + def __init__(self, final_hidden_size, out_channels): + super().__init__() + self.final_hidden_size = final_hidden_size + self.norm_final = nn.LayerNorm( + final_hidden_size, elementwise_affine=True, eps=1e-6 + ) + self.linear = nn.Linear(final_hidden_size, out_channels, bias=True) + + def forward(self, x): + x = self.norm_final(x) + x = x[:, 1:] + x = self.linear(x) + return x + + +class PartFormerDITPlain(nn.Module): + + def __init__( + self, + input_size=1024, + in_channels=4, + hidden_size=1024, + use_self_attention=True, + use_cross_attention=True, + use_cross_attention_2=True, + encoder_hidden_dim=1024, # cross-attn encoder_hidden_states dim + encoder_hidden2_dim=1024, # cross-attn 2 encoder_hidden_states dim + depth=24, + num_heads=16, + qk_norm=False, + qkv_bias=True, + norm_type="layer", + qk_norm_type="rms", + with_decoupled_ca=False, + decoupled_ca_dim=16, + decoupled_ca_weight=1.0, + use_pos_emb=False, + # use_attention_pooling=True, + guidance_cond_proj_dim=None, + num_moe_layers: int = 6, + num_experts: int = 8, + moe_top_k: int = 2, + **kwargs, + ): + super().__init__() + + self.input_size = input_size + self.depth = depth + self.in_channels = in_channels + self.out_channels = in_channels + self.num_heads = num_heads + + self.hidden_size = hidden_size + self.norm = nn.LayerNorm if norm_type == "layer" else nn.RMSNorm + self.qk_norm = nn.RMSNorm if qk_norm_type == "rms" else nn.LayerNorm + # embedding + self.x_embedder = nn.Linear(in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder( + hidden_size, hidden_size * 4, cond_proj_dim=guidance_cond_proj_dim + ) + # Will use fixed sin-cos embedding: + self.use_pos_emb = use_pos_emb + if self.use_pos_emb: + self.register_buffer("pos_embed", torch.zeros(1, input_size, hidden_size)) + pos = np.arange(self.input_size, dtype=np.float32) + pos_embed = get_1d_sincos_pos_embed_from_grid(self.pos_embed.shape[-1], pos) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # self.use_attention_pooling = use_attention_pooling + # if use_attention_pooling: + + # self.pooler = AttentionPool( + # self.text_len, encoder_hidden_dim, num_heads=8, output_dim=1024 + # ) + # self.extra_embedder = nn.Sequential( + # nn.Linear(1024, hidden_size * 4), + # nn.SiLU(), + # nn.Linear(hidden_size * 4, hidden_size, bias=True), + # ) + # for part embedding + self.use_bbox_cond = kwargs.get("use_bbox_cond", False) + if self.use_bbox_cond: + self.bbox_conditioner = BboxEmbedder( + out_size=hidden_size, + num_freqs=kwargs.get("num_freqs", 8), + ) + self.use_part_embed = kwargs.get("use_part_embed", False) + if self.use_part_embed: + self.valid_num = kwargs.get("valid_num", 50) + self.part_embed = nn.Parameter(torch.randn(self.valid_num, hidden_size)) + # zero init part_embed + self.part_embed.data.zero_() + # transformer blocks + self.blocks = nn.ModuleList([ + PartFormerDitBlock( + hidden_size, + num_heads, + use_self_attention=use_self_attention, + use_cross_attention=use_cross_attention, + use_cross_attention_2=use_cross_attention_2, + encoder_hidden_dim=encoder_hidden_dim, # cross-attn encoder_hidden_states dim + encoder_hidden2_dim=encoder_hidden2_dim, # cross-attn 2 encoder_hidden_states dim + # cross_attn2_weight=cross_attn2_weight, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + norm_layer=self.norm, + qk_norm_layer=self.qk_norm, + with_decoupled_ca=with_decoupled_ca, + decoupled_ca_dim=decoupled_ca_dim, + decoupled_ca_weight=decoupled_ca_weight, + skip_connection=layer > depth // 2, + use_moe=True if depth - layer <= num_moe_layers else False, + num_experts=num_experts, + moe_top_k=moe_top_k, + ) + for layer in range(depth) + ]) + # set local-global processor + for layer, block in enumerate(self.blocks): + if hasattr(block, "attn1") and (layer + 1) % 2 == 0: + block.attn1.processor = LocalGlobalProcessor(use_global=True) + + self.depth = depth + + self.final_layer = FinalLayer(hidden_size, self.out_channels) + + def forward(self, x, t, contexts: dict, **kwargs): + """ + + x: [B, N, C] + t: [B] + contexts: dict + image_context: [B, K*ni, C] + geo_context: [B, K*ng, C] or [B, K*ng, C*2] + aabb: [B, K, 2, 3] + num_tokens: [B, N] + + N = K * num_tokens + + For parts pretrain : K = 1 + """ + # prepare input + aabb: torch.Tensor = kwargs.get("aabb", None) + # image_context = contexts.get("image_un_cond", None) + object_context = contexts.get("obj_cond", None) + geo_context = contexts.get("geo_cond", None) + num_tokens: torch.Tensor = kwargs.get("num_tokens", None) + # timeembedding and input projection + t = self.t_embedder(t, condition=kwargs.get("guidance_cond")) + x = self.x_embedder(x) + + if self.use_pos_emb: + pos_embed = self.pos_embed.to(x.dtype) + x = x + pos_embed + + # c is time embedding (adding pooling context or not) + # if self.use_attention_pooling: + # # TODO: attention_pooling for all contexts + # extra_vec = self.pooler(image_context, None) + # c = t + self.extra_embedder(extra_vec) # [B, D] + # else: + # c = t + c = t + # bounding box + if self.use_bbox_cond: + center_extent = torch.cat( + [torch.mean(aabb, dim=-2), aabb[..., 1, :] - aabb[..., 0, :]], dim=-1 + ) + bbox_embeds = self.bbox_conditioner(center_extent) + # TODO: now only support batch_size=1 + bbox_embeds = torch.repeat_interleave( + bbox_embeds, repeats=num_tokens[0], dim=1 + ) + x = x + bbox_embeds + # part id embedding + if self.use_part_embed: + num_parts = aabb.shape[1] + random_idx = torch.randperm(self.valid_num)[:num_parts] + part_embeds = self.part_embed[random_idx].unsqueeze(1) + # import pdb + + # pdb.set_trace() + x = x + part_embeds + x = torch.cat([c, x], dim=1) + skip_value_list = [] + for layer, block in enumerate(self.blocks): + skip_value = None if layer <= self.depth // 2 else skip_value_list.pop() + x = block( + hidden_states=x, + # encoder_hidden_states=image_context, + encoder_hidden_states=object_context, + encoder_hidden_states_2=geo_context, + temb=c, + skip_value=skip_value, + ) + if layer < self.depth // 2: + skip_value_list.append(x) + + x = self.final_layer(x) + return x diff --git a/XPart/partgen/models/sonata/__init__.py b/XPart/partgen/models/sonata/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..ba273918b811a926b395b4cac40463e22b2c1081 --- /dev/null +++ b/XPart/partgen/models/sonata/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .model import load, load_by_config + +from . import model +from . import module +from . import structure +from . import data +from . import transform +from . import utils +from . import registry + +__all__ = [ + "load", + "load_by_config", + "model", + "module", + "structure", + "transform", + "registry", + "utils", +] diff --git a/XPart/partgen/models/sonata/data.py b/XPart/partgen/models/sonata/data.py new file mode 100755 index 0000000000000000000000000000000000000000..43ad271965d99f3ae5fa9453feddb7bfb6ae3356 --- /dev/null +++ b/XPart/partgen/models/sonata/data.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import numpy as np +import torch +from collections.abc import Mapping, Sequence +from huggingface_hub import hf_hub_download + + +DATAS = ["sample1", "sample1_high_res", "sample1_dino"] + + +def load( + name: str = "sonata", + download_root: str = None, +): + if name in DATAS: + print(f"Loading data from HuggingFace: {name} ...") + data_path = hf_hub_download( + repo_id="pointcept/demo", + filename=f"{name}.npz", + repo_type="dataset", + revision="main", + local_dir=download_root or os.path.expanduser("~/.cache/sonata/data"), + ) + elif os.path.isfile(name): + print(f"Loading data in local path: {name} ...") + data_path = name + else: + raise RuntimeError(f"Data {name} not found; available models = {DATAS}") + return dict(np.load(data_path)) + + +from torch.utils.data.dataloader import default_collate + + +def collate_fn(batch): + """ + collate function for point cloud which support dict and list, + 'coord' is necessary to determine 'offset' + """ + if not isinstance(batch, Sequence): + raise TypeError(f"{batch.dtype} is not supported.") + + if isinstance(batch[0], torch.Tensor): + return torch.cat(list(batch)) + elif isinstance(batch[0], str): + # str is also a kind of Sequence, judgement should before Sequence + return list(batch) + elif isinstance(batch[0], Sequence): + for data in batch: + data.append(torch.tensor([data[0].shape[0]])) + batch = [collate_fn(samples) for samples in zip(*batch)] + batch[-1] = torch.cumsum(batch[-1], dim=0).int() + return batch + elif isinstance(batch[0], Mapping): + batch = { + key: ( + collate_fn([d[key] for d in batch]) + if "offset" not in key + # offset -> bincount -> concat bincount-> concat offset + else torch.cumsum( + collate_fn([d[key].diff(prepend=torch.tensor([0])) for d in batch]), + dim=0, + ) + ) + for key in batch[0] + } + return batch + else: + return default_collate(batch) diff --git a/XPart/partgen/models/sonata/model.py b/XPart/partgen/models/sonata/model.py new file mode 100755 index 0000000000000000000000000000000000000000..9219844f2ab40fdfbb0829bf5687867d9a858b36 --- /dev/null +++ b/XPart/partgen/models/sonata/model.py @@ -0,0 +1,874 @@ +""" +Point Transformer - V3 Mode2 - Sonata +Pointcept detached version + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from packaging import version +from huggingface_hub import hf_hub_download, PyTorchModelHubMixin +from addict import Dict +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +import spconv.pytorch as spconv +import torch_scatter +from timm.layers import DropPath +import json + +try: + import flash_attn +except ImportError: + flash_attn = None + +from .structure import Point +from .module import PointSequential, PointModule +from .utils import offset2bincount + +MODELS = [ + "sonata", + "sonata_small", + "sonata_linear_prob_head_sc", +] + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class RPE(torch.nn.Module): + def __init__(self, patch_size, num_heads): + super().__init__() + self.patch_size = patch_size + self.num_heads = num_heads + self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2) + self.rpe_num = 2 * self.pos_bnd + 1 + self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads)) + torch.nn.init.trunc_normal_(self.rpe_table, std=0.02) + + def forward(self, coord): + idx = ( + coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd + + self.pos_bnd # relative position to positive index + + torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride + ) + out = self.rpe_table.index_select(0, idx.reshape(-1)) + out = out.view(idx.shape + (-1,)).sum(3) + out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K) + return out + + +class SerializedAttention(PointModule): + def __init__( + self, + channels, + num_heads, + patch_size, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + order_index=0, + enable_rpe=False, + enable_flash=True, + upcast_attention=True, + upcast_softmax=True, + ): + super().__init__() + assert channels % num_heads == 0 + self.channels = channels + self.num_heads = num_heads + self.scale = qk_scale or (channels // num_heads) ** -0.5 + self.order_index = order_index + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.enable_rpe = enable_rpe + self.enable_flash = enable_flash + if enable_flash: + assert ( + enable_rpe is False + ), "Set enable_rpe to False when enable Flash Attention" + assert ( + upcast_attention is False + ), "Set upcast_attention to False when enable Flash Attention" + assert ( + upcast_softmax is False + ), "Set upcast_softmax to False when enable Flash Attention" + assert flash_attn is not None, "Make sure flash_attn is installed." + self.patch_size = patch_size + self.attn_drop = attn_drop + else: + # when disable flash attention, we still don't want to use mask + # consequently, patch size will auto set to the + # min number of patch_size_max and number of points + self.patch_size_max = patch_size + self.patch_size = 0 + self.attn_drop = torch.nn.Dropout(attn_drop) + + self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias) + self.proj = torch.nn.Linear(channels, channels) + self.proj_drop = torch.nn.Dropout(proj_drop) + self.softmax = torch.nn.Softmax(dim=-1) + self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None + + @torch.no_grad() + def get_rel_pos(self, point, order): + K = self.patch_size + rel_pos_key = f"rel_pos_{self.order_index}" + if rel_pos_key not in point.keys(): + grid_coord = point.grid_coord[order] + grid_coord = grid_coord.reshape(-1, K, 3) + point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1) + return point[rel_pos_key] + + @torch.no_grad() + def get_padding_and_inverse(self, point): + pad_key = "pad" + unpad_key = "unpad" + cu_seqlens_key = "cu_seqlens_key" + if ( + pad_key not in point.keys() + or unpad_key not in point.keys() + or cu_seqlens_key not in point.keys() + ): + offset = point.offset + bincount = offset2bincount(offset) + bincount_pad = ( + torch.div( + bincount + self.patch_size - 1, + self.patch_size, + rounding_mode="trunc", + ) + * self.patch_size + ) + # only pad point when num of points larger than patch_size + mask_pad = bincount > self.patch_size + bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad + _offset = nn.functional.pad(offset, (1, 0)) + _offset_pad = nn.functional.pad(torch.cumsum(bincount_pad, dim=0), (1, 0)) + pad = torch.arange(_offset_pad[-1], device=offset.device) + unpad = torch.arange(_offset[-1], device=offset.device) + cu_seqlens = [] + for i in range(len(offset)): + unpad[_offset[i] : _offset[i + 1]] += _offset_pad[i] - _offset[i] + if bincount[i] != bincount_pad[i]: + pad[ + _offset_pad[i + 1] + - self.patch_size + + (bincount[i] % self.patch_size) : _offset_pad[i + 1] + ] = pad[ + _offset_pad[i + 1] + - 2 * self.patch_size + + (bincount[i] % self.patch_size) : _offset_pad[i + 1] + - self.patch_size + ] + pad[_offset_pad[i] : _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i] + cu_seqlens.append( + torch.arange( + _offset_pad[i], + _offset_pad[i + 1], + step=self.patch_size, + dtype=torch.int32, + device=offset.device, + ) + ) + point[pad_key] = pad + point[unpad_key] = unpad + point[cu_seqlens_key] = nn.functional.pad( + torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1] + ) + return point[pad_key], point[unpad_key], point[cu_seqlens_key] + + def forward(self, point): + if not self.enable_flash: + self.patch_size = min( + offset2bincount(point.offset).min().tolist(), self.patch_size_max + ) + + H = self.num_heads + K = self.patch_size + C = self.channels + + pad, unpad, cu_seqlens = self.get_padding_and_inverse(point) + + order = point.serialized_order[self.order_index][pad] + inverse = unpad[point.serialized_inverse[self.order_index]] + + # padding and reshape feat and batch for serialized point patch + qkv = self.qkv(point.feat)[order] + + if not self.enable_flash: + # encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C') + q, k, v = ( + qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0) + ) + # attn + if self.upcast_attention: + q = q.float() + k = k.float() + attn = (q * self.scale) @ k.transpose(-2, -1) # (N', H, K, K) + if self.enable_rpe: + attn = attn + self.rpe(self.get_rel_pos(point, order)) + if self.upcast_softmax: + attn = attn.float() + attn = self.softmax(attn) + attn = self.attn_drop(attn).to(qkv.dtype) + feat = (attn @ v).transpose(1, 2).reshape(-1, C) + else: + feat = flash_attn.flash_attn_varlen_qkvpacked_func( + qkv.half().reshape(-1, 3, H, C // H), + cu_seqlens, + max_seqlen=self.patch_size, + dropout_p=self.attn_drop if self.training else 0, + softmax_scale=self.scale, + ).reshape(-1, C) + feat = feat.to(qkv.dtype) + feat = feat[inverse] + + # ffn + feat = self.proj(feat) + feat = self.proj_drop(feat) + point.feat = feat + return point + + +class MLP(nn.Module): + def __init__( + self, + in_channels, + hidden_channels=None, + out_channels=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_channels = out_channels or in_channels + hidden_channels = hidden_channels or in_channels + self.fc1 = nn.Linear(in_channels, hidden_channels) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_channels, out_channels) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Block(PointModule): + def __init__( + self, + channels, + num_heads, + patch_size=48, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + drop_path=0.0, + layer_scale=None, + norm_layer=nn.LayerNorm, + act_layer=nn.GELU, + pre_norm=True, + order_index=0, + cpe_indice_key=None, + enable_rpe=False, + enable_flash=True, + upcast_attention=True, + upcast_softmax=True, + ): + super().__init__() + self.channels = channels + self.pre_norm = pre_norm + + self.cpe = PointSequential( + spconv.SubMConv3d( + channels, + channels, + kernel_size=3, + bias=True, + indice_key=cpe_indice_key, + ), + nn.Linear(channels, channels), + norm_layer(channels), + ) + + self.norm1 = PointSequential(norm_layer(channels)) + self.ls1 = PointSequential( + LayerScale(channels, init_values=layer_scale) + if layer_scale is not None + else nn.Identity() + ) + self.attn = SerializedAttention( + channels=channels, + patch_size=patch_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=proj_drop, + order_index=order_index, + enable_rpe=enable_rpe, + enable_flash=enable_flash, + upcast_attention=upcast_attention, + upcast_softmax=upcast_softmax, + ) + self.norm2 = PointSequential(norm_layer(channels)) + self.ls2 = PointSequential( + LayerScale(channels, init_values=layer_scale) + if layer_scale is not None + else nn.Identity() + ) + self.mlp = PointSequential( + MLP( + in_channels=channels, + hidden_channels=int(channels * mlp_ratio), + out_channels=channels, + act_layer=act_layer, + drop=proj_drop, + ) + ) + self.drop_path = PointSequential( + DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + ) + + def forward(self, point: Point): + shortcut = point.feat + point = self.cpe(point) + point.feat = shortcut + point.feat + shortcut = point.feat + if self.pre_norm: + point = self.norm1(point) + point = self.drop_path(self.ls1(self.attn(point))) + point.feat = shortcut + point.feat + if not self.pre_norm: + point = self.norm1(point) + + shortcut = point.feat + if self.pre_norm: + point = self.norm2(point) + point = self.drop_path(self.ls2(self.mlp(point))) + point.feat = shortcut + point.feat + if not self.pre_norm: + point = self.norm2(point) + point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat) + return point + + +class GridPooling(PointModule): + def __init__( + self, + in_channels, + out_channels, + stride=2, + norm_layer=None, + act_layer=None, + reduce="max", + shuffle_orders=True, + traceable=True, # record parent and cluster + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.stride = stride + assert reduce in ["sum", "mean", "min", "max"] + self.reduce = reduce + self.shuffle_orders = shuffle_orders + self.traceable = traceable + + self.proj = nn.Linear(in_channels, out_channels) + if norm_layer is not None: + self.norm = PointSequential(norm_layer(out_channels)) + if act_layer is not None: + self.act = PointSequential(act_layer()) + + def forward(self, point: Point): + if "grid_coord" in point.keys(): + grid_coord = point.grid_coord + elif {"coord", "grid_size"}.issubset(point.keys()): + grid_coord = torch.div( + point.coord - point.coord.min(0)[0], + point.grid_size, + rounding_mode="trunc", + ).int() + else: + raise AssertionError( + "[gird_coord] or [coord, grid_size] should be include in the Point" + ) + grid_coord = torch.div(grid_coord, self.stride, rounding_mode="trunc") + grid_coord = grid_coord | point.batch.view(-1, 1) << 48 + grid_coord, cluster, counts = torch.unique( + grid_coord, + sorted=True, + return_inverse=True, + return_counts=True, + dim=0, + ) + grid_coord = grid_coord & ((1 << 48) - 1) + # indices of point sorted by cluster, for torch_scatter.segment_csr + _, indices = torch.sort(cluster) + # index pointer for sorted point, for torch_scatter.segment_csr + idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)]) + # head_indices of each cluster, for reduce attr e.g. code, batch + head_indices = indices[idx_ptr[:-1]] + point_dict = Dict( + feat=torch_scatter.segment_csr( + self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce + ), + coord=torch_scatter.segment_csr( + point.coord[indices], idx_ptr, reduce="mean" + ), + grid_coord=grid_coord, + batch=point.batch[head_indices], + ) + if "origin_coord" in point.keys(): + point_dict["origin_coord"] = torch_scatter.segment_csr( + point.origin_coord[indices], idx_ptr, reduce="mean" + ) + if "condition" in point.keys(): + point_dict["condition"] = point.condition + if "context" in point.keys(): + point_dict["context"] = point.context + if "name" in point.keys(): + point_dict["name"] = point.name + if "split" in point.keys(): + point_dict["split"] = point.split + if "color" in point.keys(): + point_dict["color"] = torch_scatter.segment_csr( + point.color[indices], idx_ptr, reduce="mean" + ) + if "grid_size" in point.keys(): + point_dict["grid_size"] = point.grid_size * self.stride + + if self.traceable: + point_dict["pooling_inverse"] = cluster + point_dict["pooling_parent"] = point + order = point.order + point = Point(point_dict) + if self.norm is not None: + point = self.norm(point) + if self.act is not None: + point = self.act(point) + point.serialization(order=order, shuffle_orders=self.shuffle_orders) + point.sparsify() + return point + + +class GridUnpooling(PointModule): + def __init__( + self, + in_channels, + skip_channels, + out_channels, + norm_layer=None, + act_layer=None, + traceable=False, # record parent and cluster + ): + super().__init__() + self.proj = PointSequential(nn.Linear(in_channels, out_channels)) + self.proj_skip = PointSequential(nn.Linear(skip_channels, out_channels)) + + if norm_layer is not None: + self.proj.add(norm_layer(out_channels)) + self.proj_skip.add(norm_layer(out_channels)) + + if act_layer is not None: + self.proj.add(act_layer()) + self.proj_skip.add(act_layer()) + + self.traceable = traceable + + def forward(self, point): + assert "pooling_parent" in point.keys() + assert "pooling_inverse" in point.keys() + parent = point.pop("pooling_parent") + inverse = point.pooling_inverse + feat = point.feat + + parent = self.proj_skip(parent) + parent.feat = parent.feat + self.proj(point).feat[inverse] + parent.sparse_conv_feat = parent.sparse_conv_feat.replace_feature(parent.feat) + + if self.traceable: + point.feat = feat + parent["unpooling_parent"] = point + return parent + + +class Embedding(PointModule): + def __init__( + self, + in_channels, + embed_channels, + norm_layer=None, + act_layer=None, + mask_token=False, + ): + super().__init__() + self.in_channels = in_channels + self.embed_channels = embed_channels + + self.stem = PointSequential(linear=nn.Linear(in_channels, embed_channels)) + if norm_layer is not None: + self.stem.add(norm_layer(embed_channels), name="norm") + if act_layer is not None: + self.stem.add(act_layer(), name="act") + + if mask_token: + self.mask_token = nn.Parameter(torch.zeros(1, embed_channels)) + else: + self.mask_token = None + + def forward(self, point: Point): + point = self.stem(point) + if "mask" in point.keys(): + point.feat = torch.where( + point.mask.unsqueeze(-1), + self.mask_token.to(point.feat.dtype), + point.feat, + ) + return point + + +class PointTransformerV3(PointModule, PyTorchModelHubMixin): + def __init__( + self, + in_channels=6, + order=("z", "z-trans"), + stride=(2, 2, 2, 2), + enc_depths=(3, 3, 3, 12, 3), + enc_channels=(48, 96, 192, 384, 512), + enc_num_head=(3, 6, 12, 24, 32), + enc_patch_size=(1024, 1024, 1024, 1024, 1024), + dec_depths=(3, 3, 3, 3), + dec_channels=(96, 96, 192, 384), + dec_num_head=(6, 6, 12, 32), + dec_patch_size=(1024, 1024, 1024, 1024), + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + drop_path=0.3, + layer_scale=None, + pre_norm=True, + shuffle_orders=True, + enable_rpe=False, + enable_flash=True, + upcast_attention=False, + upcast_softmax=False, + traceable=False, + mask_token=False, + enc_mode=False, + freeze_encoder=False, + ): + super().__init__() + self.num_stages = len(enc_depths) + self.order = [order] if isinstance(order, str) else order + self.enc_mode = enc_mode + self.shuffle_orders = shuffle_orders + self.freeze_encoder = freeze_encoder + + assert self.num_stages == len(stride) + 1 + assert self.num_stages == len(enc_depths) + assert self.num_stages == len(enc_channels) + assert self.num_stages == len(enc_num_head) + assert self.num_stages == len(enc_patch_size) + assert self.enc_mode or self.num_stages == len(dec_depths) + 1 + assert self.enc_mode or self.num_stages == len(dec_channels) + 1 + assert self.enc_mode or self.num_stages == len(dec_num_head) + 1 + assert self.enc_mode or self.num_stages == len(dec_patch_size) + 1 + + print(f"flash attention: {enable_flash}") + + # normalization layer + ln_layer = nn.LayerNorm + # activation layers + act_layer = nn.GELU + + self.embedding = Embedding( + in_channels=in_channels, + embed_channels=enc_channels[0], + norm_layer=ln_layer, + act_layer=act_layer, + mask_token=mask_token, + ) + + # encoder + enc_drop_path = [ + x.item() for x in torch.linspace(0, drop_path, sum(enc_depths)) + ] + self.enc = PointSequential() + for s in range(self.num_stages): + enc_drop_path_ = enc_drop_path[ + sum(enc_depths[:s]) : sum(enc_depths[: s + 1]) + ] + enc = PointSequential() + if s > 0: + enc.add( + GridPooling( + in_channels=enc_channels[s - 1], + out_channels=enc_channels[s], + stride=stride[s - 1], + norm_layer=ln_layer, + act_layer=act_layer, + ), + name="down", + ) + for i in range(enc_depths[s]): + enc.add( + Block( + channels=enc_channels[s], + num_heads=enc_num_head[s], + patch_size=enc_patch_size[s], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=proj_drop, + drop_path=enc_drop_path_[i], + layer_scale=layer_scale, + norm_layer=ln_layer, + act_layer=act_layer, + pre_norm=pre_norm, + order_index=i % len(self.order), + cpe_indice_key=f"stage{s}", + enable_rpe=enable_rpe, + enable_flash=enable_flash, + upcast_attention=upcast_attention, + upcast_softmax=upcast_softmax, + ), + name=f"block{i}", + ) + if len(enc) != 0: + self.enc.add(module=enc, name=f"enc{s}") + + # decoder + if not self.enc_mode: + dec_drop_path = [ + x.item() for x in torch.linspace(0, drop_path, sum(dec_depths)) + ] + self.dec = PointSequential() + dec_channels = list(dec_channels) + [enc_channels[-1]] + for s in reversed(range(self.num_stages - 1)): + dec_drop_path_ = dec_drop_path[ + sum(dec_depths[:s]) : sum(dec_depths[: s + 1]) + ] + dec_drop_path_.reverse() + dec = PointSequential() + dec.add( + GridUnpooling( + in_channels=dec_channels[s + 1], + skip_channels=enc_channels[s], + out_channels=dec_channels[s], + norm_layer=ln_layer, + act_layer=act_layer, + traceable=traceable, + ), + name="up", + ) + for i in range(dec_depths[s]): + dec.add( + Block( + channels=dec_channels[s], + num_heads=dec_num_head[s], + patch_size=dec_patch_size[s], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=proj_drop, + drop_path=dec_drop_path_[i], + layer_scale=layer_scale, + norm_layer=ln_layer, + act_layer=act_layer, + pre_norm=pre_norm, + order_index=i % len(self.order), + cpe_indice_key=f"stage{s}", + enable_rpe=enable_rpe, + enable_flash=enable_flash, + upcast_attention=upcast_attention, + upcast_softmax=upcast_softmax, + ), + name=f"block{i}", + ) + self.dec.add(module=dec, name=f"dec{s}") + if self.freeze_encoder: + for p in self.embedding.parameters(): + p.requires_grad = False + for p in self.enc.parameters(): + p.requires_grad = False + self.apply(self._init_weights) + + @staticmethod + def _init_weights(module): + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, spconv.SubMConv3d): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + def forward(self, data_dict): + point = Point(data_dict) + point = self.embedding(point) + + point.serialization(order=self.order, shuffle_orders=self.shuffle_orders) + point.sparsify() + + point = self.enc(point) + if not self.enc_mode: + point = self.dec(point) + return point + + +def load( + name: str = "sonata", + repo_id="facebook/sonata", + download_root: str = None, + custom_config: dict = None, + ckpt_only: bool = False, +): + if name in MODELS: + print(f"Loading checkpoint from HuggingFace: {name} ...") + ckpt_path = hf_hub_download( + repo_id=repo_id, + filename=f"{name}.pth", + repo_type="model", + revision="main", + local_dir=download_root or os.path.expanduser("~/.cache/sonata/ckpt"), + ) + elif os.path.isfile(name): + print(f"Loading checkpoint in local path: {name} ...") + ckpt_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {MODELS}") + + if version.parse(torch.__version__) >= version.parse("2.4"): + ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) + else: + ckpt = torch.load(ckpt_path, map_location="cpu") + if custom_config is not None: + for key, value in custom_config.items(): + ckpt["config"][key] = value + + if ckpt_only: + return ckpt + + # 关闭flash attention + # ckpt["config"]['enable_flash'] = False + + model = PointTransformerV3(**ckpt["config"]) + model.load_state_dict(ckpt["state_dict"]) + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"Model params: {n_parameters / 1e6:.2f}M {n_parameters}") + return model + + +def load_by_config(config_path: str): + with open(config_path, "r") as f: + config = json.load(f) + model = PointTransformerV3(**config) + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"Model params: {n_parameters / 1e6:.2f}M {n_parameters}") + return model diff --git a/XPart/partgen/models/sonata/module.py b/XPart/partgen/models/sonata/module.py new file mode 100755 index 0000000000000000000000000000000000000000..7fe25c28c82d82276c15a9dffcc127512c108e26 --- /dev/null +++ b/XPart/partgen/models/sonata/module.py @@ -0,0 +1,107 @@ +""" +Point Modules +Pointcept detached version + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import sys +import torch.nn as nn +import spconv.pytorch as spconv +from collections import OrderedDict + +from .structure import Point + + +class PointModule(nn.Module): + r"""PointModule + placeholder, all module subclass from this will take Point in PointSequential. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class PointSequential(PointModule): + r"""A sequential container. + Modules will be added to it in the order they are passed in the constructor. + Alternatively, an ordered dict of modules can also be passed in. + """ + + def __init__(self, *args, **kwargs): + super().__init__() + if len(args) == 1 and isinstance(args[0], OrderedDict): + for key, module in args[0].items(): + self.add_module(key, module) + else: + for idx, module in enumerate(args): + self.add_module(str(idx), module) + for name, module in kwargs.items(): + if sys.version_info < (3, 6): + raise ValueError("kwargs only supported in py36+") + if name in self._modules: + raise ValueError("name exists.") + self.add_module(name, module) + + def __getitem__(self, idx): + if not (-len(self) <= idx < len(self)): + raise IndexError("index {} is out of range".format(idx)) + if idx < 0: + idx += len(self) + it = iter(self._modules.values()) + for i in range(idx): + next(it) + return next(it) + + def __len__(self): + return len(self._modules) + + def add(self, module, name=None): + if name is None: + name = str(len(self._modules)) + if name in self._modules: + raise KeyError("name exists") + self.add_module(name, module) + + def forward(self, input): + for k, module in self._modules.items(): + # Point module + if isinstance(module, PointModule): + input = module(input) + # Spconv module + elif spconv.modules.is_spconv_module(module): + if isinstance(input, Point): + input.sparse_conv_feat = module(input.sparse_conv_feat) + input.feat = input.sparse_conv_feat.features + else: + input = module(input) + # PyTorch module + else: + if isinstance(input, Point): + input.feat = module(input.feat) + if "sparse_conv_feat" in input.keys(): + input.sparse_conv_feat = input.sparse_conv_feat.replace_feature( + input.feat + ) + elif isinstance(input, spconv.SparseConvTensor): + if input.indices.shape[0] != 0: + input = input.replace_feature(module(input.features)) + else: + input = module(input) + return input diff --git a/XPart/partgen/models/sonata/registry.py b/XPart/partgen/models/sonata/registry.py new file mode 100755 index 0000000000000000000000000000000000000000..f91d1b804bdc2f919abdc7df94da6f3b5db804fb --- /dev/null +++ b/XPart/partgen/models/sonata/registry.py @@ -0,0 +1,340 @@ +# @lint-ignore-every LICENSELINT +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +import warnings +from functools import partial +from collections import abc + + +def is_seq_of(seq, expected_type, seq_type=None): + """Check whether it is a sequence of some type. + + Args: + seq (Sequence): The sequence to be checked. + expected_type (type): Expected type of sequence items. + seq_type (type, optional): Expected sequence type. + + Returns: + bool: Whether the sequence is valid. + """ + if seq_type is None: + exp_seq_type = abc.Sequence + else: + assert isinstance(seq_type, type) + exp_seq_type = seq_type + if not isinstance(seq, exp_seq_type): + return False + for item in seq: + if not isinstance(item, expected_type): + return False + return True + + +def build_from_cfg(cfg, registry, default_args=None): + """Build a module from configs dict. + + Args: + cfg (dict): Config dict. It should at least contain the key "type". + registry (:obj:`Registry`): The registry to search the type from. + default_args (dict, optional): Default initialization arguments. + + Returns: + object: The constructed object. + """ + if not isinstance(cfg, dict): + raise TypeError(f"cfg must be a dict, but got {type(cfg)}") + if "type" not in cfg: + if default_args is None or "type" not in default_args: + raise KeyError( + '`cfg` or `default_args` must contain the key "type", ' + f"but got {cfg}\n{default_args}" + ) + if not isinstance(registry, Registry): + raise TypeError( + "registry must be an mmcv.Registry object, " f"but got {type(registry)}" + ) + if not (isinstance(default_args, dict) or default_args is None): + raise TypeError( + "default_args must be a dict or None, " f"but got {type(default_args)}" + ) + + args = cfg.copy() + + if default_args is not None: + for name, value in default_args.items(): + args.setdefault(name, value) + + obj_type = args.pop("type") + if isinstance(obj_type, str): + obj_cls = registry.get(obj_type) + if obj_cls is None: + raise KeyError(f"{obj_type} is not in the {registry.name} registry") + elif inspect.isclass(obj_type): + obj_cls = obj_type + else: + raise TypeError(f"type must be a str or valid type, but got {type(obj_type)}") + try: + return obj_cls(**args) + except Exception as e: + # Normal TypeError does not print class name. + raise type(e)(f"{obj_cls.__name__}: {e}") + + +class Registry: + """A registry to map strings to classes. + + Registered object could be built from registry. + Example: + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + >>> resnet = MODELS.build(dict(type='ResNet')) + + Please refer to + https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for + advanced usage. + + Args: + name (str): Registry name. + build_func(func, optional): Build function to construct instance from + Registry, func:`build_from_cfg` is used if neither ``parent`` or + ``build_func`` is specified. If ``parent`` is specified and + ``build_func`` is not given, ``build_func`` will be inherited + from ``parent``. Default: None. + parent (Registry, optional): Parent registry. The class registered in + children registry could be built from parent. Default: None. + scope (str, optional): The scope of registry. It is the key to search + for children registry. If not specified, scope will be the name of + the package where class is defined, e.g. mmdet, mmcls, mmseg. + Default: None. + """ + + def __init__(self, name, build_func=None, parent=None, scope=None): + self._name = name + self._module_dict = dict() + self._children = dict() + self._scope = self.infer_scope() if scope is None else scope + + # self.build_func will be set with the following priority: + # 1. build_func + # 2. parent.build_func + # 3. build_from_cfg + if build_func is None: + if parent is not None: + self.build_func = parent.build_func + else: + self.build_func = build_from_cfg + else: + self.build_func = build_func + if parent is not None: + assert isinstance(parent, Registry) + parent._add_children(self) + self.parent = parent + else: + self.parent = None + + def __len__(self): + return len(self._module_dict) + + def __contains__(self, key): + return self.get(key) is not None + + def __repr__(self): + format_str = ( + self.__class__.__name__ + f"(name={self._name}, " + f"items={self._module_dict})" + ) + return format_str + + @staticmethod + def infer_scope(): + """Infer the scope of registry. + + The name of the package where registry is defined will be returned. + + Example: + # in mmdet/models/backbone/resnet.py + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + The scope of ``ResNet`` will be ``mmdet``. + + + Returns: + scope (str): The inferred scope name. + """ + # inspect.stack() trace where this function is called, the index-2 + # indicates the frame where `infer_scope()` is called + filename = inspect.getmodule(inspect.stack()[2][0]).__name__ + split_filename = filename.split(".") + return split_filename[0] + + @staticmethod + def split_scope_key(key): + """Split scope and key. + + The first scope will be split from key. + + Examples: + >>> Registry.split_scope_key('mmdet.ResNet') + 'mmdet', 'ResNet' + >>> Registry.split_scope_key('ResNet') + None, 'ResNet' + + Return: + scope (str, None): The first scope. + key (str): The remaining key. + """ + split_index = key.find(".") + if split_index != -1: + return key[:split_index], key[split_index + 1 :] + else: + return None, key + + @property + def name(self): + return self._name + + @property + def scope(self): + return self._scope + + @property + def module_dict(self): + return self._module_dict + + @property + def children(self): + return self._children + + def get(self, key): + """Get the registry record. + + Args: + key (str): The class name in string format. + + Returns: + class: The corresponding class. + """ + scope, real_key = self.split_scope_key(key) + if scope is None or scope == self._scope: + # get from self + if real_key in self._module_dict: + return self._module_dict[real_key] + else: + # get from self._children + if scope in self._children: + return self._children[scope].get(real_key) + else: + # goto root + parent = self.parent + while parent.parent is not None: + parent = parent.parent + return parent.get(key) + + def build(self, *args, **kwargs): + return self.build_func(*args, **kwargs, registry=self) + + def _add_children(self, registry): + """Add children for a registry. + + The ``registry`` will be added as children based on its scope. + The parent registry could build objects from children registry. + + Example: + >>> models = Registry('models') + >>> mmdet_models = Registry('models', parent=models) + >>> @mmdet_models.register_module() + >>> class ResNet: + >>> pass + >>> resnet = models.build(dict(type='mmdet.ResNet')) + """ + + assert isinstance(registry, Registry) + assert registry.scope is not None + assert ( + registry.scope not in self.children + ), f"scope {registry.scope} exists in {self.name} registry" + self.children[registry.scope] = registry + + def _register_module(self, module_class, module_name=None, force=False): + if not inspect.isclass(module_class): + raise TypeError("module must be a class, " f"but got {type(module_class)}") + + if module_name is None: + module_name = module_class.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and name in self._module_dict: + raise KeyError(f"{name} is already registered " f"in {self.name}") + self._module_dict[name] = module_class + + def deprecated_register_module(self, cls=None, force=False): + warnings.warn( + "The old API of register_module(module, force=False) " + "is deprecated and will be removed, please use the new API " + "register_module(name=None, force=False, module=None) instead." + ) + if cls is None: + return partial(self.deprecated_register_module, force=force) + self._register_module(cls, force=force) + return cls + + def register_module(self, name=None, force=False, module=None): + """Register a module. + + A record will be added to `self._module_dict`, whose key is the class + name or the specified name, and value is the class itself. + It can be used as a decorator or a normal function. + + Example: + >>> backbones = Registry('backbone') + >>> @backbones.register_module() + >>> class ResNet: + >>> pass + + >>> backbones = Registry('backbone') + >>> @backbones.register_module(name='mnet') + >>> class MobileNet: + >>> pass + + >>> backbones = Registry('backbone') + >>> class ResNet: + >>> pass + >>> backbones.register_module(ResNet) + + Args: + name (str | None): The module name to be registered. If not + specified, the class name will be used. + force (bool, optional): Whether to override an existing class with + the same name. Default: False. + module (type): Module class to be registered. + """ + if not isinstance(force, bool): + raise TypeError(f"force must be a boolean, but got {type(force)}") + # NOTE: This is a walkaround to be compatible with the old api, + # while it may introduce unexpected bugs. + if isinstance(name, type): + return self.deprecated_register_module(name, force=force) + + # raise the error ahead of time + if not (name is None or isinstance(name, str) or is_seq_of(name, str)): + raise TypeError( + "name must be either of None, an instance of str or a sequence" + f" of str, but got {type(name)}" + ) + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + self._register_module(module_class=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(cls): + self._register_module(module_class=cls, module_name=name, force=force) + return cls + + return _register diff --git a/XPart/partgen/models/sonata/serialization/__init__.py b/XPart/partgen/models/sonata/serialization/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..46e6c72e9657bb6ccdfd90f3387973ee9b658eb7 --- /dev/null +++ b/XPart/partgen/models/sonata/serialization/__init__.py @@ -0,0 +1,9 @@ +#init.py +from .default import ( + encode, + decode, + z_order_encode, + z_order_decode, + hilbert_encode, + hilbert_decode, +) diff --git a/XPart/partgen/models/sonata/serialization/default.py b/XPart/partgen/models/sonata/serialization/default.py new file mode 100755 index 0000000000000000000000000000000000000000..73346305d321770706796f7bec302a10c038540d --- /dev/null +++ b/XPart/partgen/models/sonata/serialization/default.py @@ -0,0 +1,82 @@ +""" +Serialization Encoding +Pointcept detached version + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from .z_order import xyz2key as z_order_encode_ +from .z_order import key2xyz as z_order_decode_ +from .hilbert import encode as hilbert_encode_ +from .hilbert import decode as hilbert_decode_ + + +@torch.inference_mode() +def encode(grid_coord, batch=None, depth=16, order="z"): + assert order in {"z", "z-trans", "hilbert", "hilbert-trans"} + if order == "z": + code = z_order_encode(grid_coord, depth=depth) + elif order == "z-trans": + code = z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth) + elif order == "hilbert": + code = hilbert_encode(grid_coord, depth=depth) + elif order == "hilbert-trans": + code = hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth) + else: + raise NotImplementedError + if batch is not None: + batch = batch.long() + code = batch << depth * 3 | code + return code + + +@torch.inference_mode() +def decode(code, depth=16, order="z"): + assert order in {"z", "hilbert"} + batch = code >> depth * 3 + code = code & ((1 << depth * 3) - 1) + if order == "z": + grid_coord = z_order_decode(code, depth=depth) + elif order == "hilbert": + grid_coord = hilbert_decode(code, depth=depth) + else: + raise NotImplementedError + return grid_coord, batch + + +def z_order_encode(grid_coord: torch.Tensor, depth: int = 16): + x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long() + # we block the support to batch, maintain batched code in Point class + code = z_order_encode_(x, y, z, b=None, depth=depth) + return code + + +def z_order_decode(code: torch.Tensor, depth): + x, y, z = z_order_decode_(code, depth=depth) + grid_coord = torch.stack([x, y, z], dim=-1) # (N, 3) + return grid_coord + + +def hilbert_encode(grid_coord: torch.Tensor, depth: int = 16): + return hilbert_encode_(grid_coord, num_dims=3, num_bits=depth) + + +def hilbert_decode(code: torch.Tensor, depth: int = 16): + return hilbert_decode_(code, num_dims=3, num_bits=depth) diff --git a/XPart/partgen/models/sonata/serialization/hilbert.py b/XPart/partgen/models/sonata/serialization/hilbert.py new file mode 100755 index 0000000000000000000000000000000000000000..a6f0de1006b35a469ca87936c707702371064349 --- /dev/null +++ b/XPart/partgen/models/sonata/serialization/hilbert.py @@ -0,0 +1,318 @@ +""" +Hilbert Order +Modified from https://github.com/PrincetonLIPS/numpy-hilbert-curve + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com), Kaixin Xu +Please cite our work if the code is helpful to you. +""" + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + + +def right_shift(binary, k=1, axis=-1): + """Right shift an array of binary values. + + Parameters: + ----------- + binary: An ndarray of binary values. + + k: The number of bits to shift. Default 1. + + axis: The axis along which to shift. Default -1. + + Returns: + -------- + Returns an ndarray with zero prepended and the ends truncated, along + whatever axis was specified.""" + + # If we're shifting the whole thing, just return zeros. + if binary.shape[axis] <= k: + return torch.zeros_like(binary) + + # Determine the padding pattern. + # padding = [(0,0)] * len(binary.shape) + # padding[axis] = (k,0) + + # Determine the slicing pattern to eliminate just the last one. + slicing = [slice(None)] * len(binary.shape) + slicing[axis] = slice(None, -k) + shifted = torch.nn.functional.pad( + binary[tuple(slicing)], (k, 0), mode="constant", value=0 + ) + + return shifted + + +def binary2gray(binary, axis=-1): + """Convert an array of binary values into Gray codes. + + This uses the classic X ^ (X >> 1) trick to compute the Gray code. + + Parameters: + ----------- + binary: An ndarray of binary values. + + axis: The axis along which to compute the gray code. Default=-1. + + Returns: + -------- + Returns an ndarray of Gray codes. + """ + shifted = right_shift(binary, axis=axis) + + # Do the X ^ (X >> 1) trick. + gray = torch.logical_xor(binary, shifted) + + return gray + + +def gray2binary(gray, axis=-1): + """Convert an array of Gray codes back into binary values. + + Parameters: + ----------- + gray: An ndarray of gray codes. + + axis: The axis along which to perform Gray decoding. Default=-1. + + Returns: + -------- + Returns an ndarray of binary values. + """ + + # Loop the log2(bits) number of times necessary, with shift and xor. + shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1) + while shift > 0: + gray = torch.logical_xor(gray, right_shift(gray, shift)) + shift = torch.div(shift, 2, rounding_mode="floor") + return gray + + +def encode(locs, num_dims, num_bits): + """Decode an array of locations in a hypercube into a Hilbert integer. + + This is a vectorized-ish version of the Hilbert curve implementation by John + Skilling as described in: + + Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference + Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics. + + Params: + ------- + locs - An ndarray of locations in a hypercube of num_dims dimensions, in + which each dimension runs from 0 to 2**num_bits-1. The shape can + be arbitrary, as long as the last dimension of the same has size + num_dims. + + num_dims - The dimensionality of the hypercube. Integer. + + num_bits - The number of bits for each dimension. Integer. + + Returns: + -------- + The output is an ndarray of uint64 integers with the same shape as the + input, excluding the last dimension, which needs to be num_dims. + """ + + # Keep around the original shape for later. + orig_shape = locs.shape + bitpack_mask = 1 << torch.arange(0, 8).to(locs.device) + bitpack_mask_rev = bitpack_mask.flip(-1) + + if orig_shape[-1] != num_dims: + raise ValueError( + """ + The shape of locs was surprising in that the last dimension was of size + %d, but num_dims=%d. These need to be equal. + """ + % (orig_shape[-1], num_dims) + ) + + if num_dims * num_bits > 63: + raise ValueError( + """ + num_dims=%d and num_bits=%d for %d bits total, which can't be encoded + into a int64. Are you sure you need that many points on your Hilbert + curve? + """ + % (num_dims, num_bits, num_dims * num_bits) + ) + + # Treat the location integers as 64-bit unsigned and then split them up into + # a sequence of uint8s. Preserve the association by dimension. + locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1) + + # Now turn these into bits and truncate to num_bits. + gray = ( + locs_uint8.unsqueeze(-1) + .bitwise_and(bitpack_mask_rev) + .ne(0) + .byte() + .flatten(-2, -1)[..., -num_bits:] + ) + + # Run the decoding process the other way. + # Iterate forwards through the bits. + for bit in range(0, num_bits): + # Iterate forwards through the dimensions. + for dim in range(0, num_dims): + # Identify which ones have this bit active. + mask = gray[:, dim, bit] + + # Where this bit is on, invert the 0 dimension for lower bits. + gray[:, 0, bit + 1 :] = torch.logical_xor( + gray[:, 0, bit + 1 :], mask[:, None] + ) + + # Where the bit is off, exchange the lower bits with the 0 dimension. + to_flip = torch.logical_and( + torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1), + torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]), + ) + gray[:, dim, bit + 1 :] = torch.logical_xor( + gray[:, dim, bit + 1 :], to_flip + ) + gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip) + + # Now flatten out. + gray = gray.swapaxes(1, 2).reshape((-1, num_bits * num_dims)) + + # Convert Gray back to binary. + hh_bin = gray2binary(gray) + + # Pad back out to 64 bits. + extra_dims = 64 - num_bits * num_dims + padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0) + + # Convert binary values into uint8s. + hh_uint8 = ( + (padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask) + .sum(2) + .squeeze() + .type(torch.uint8) + ) + + # Convert uint8s into uint64s. + hh_uint64 = hh_uint8.view(torch.int64).squeeze() + + return hh_uint64 + + +def decode(hilberts, num_dims, num_bits): + """Decode an array of Hilbert integers into locations in a hypercube. + + This is a vectorized-ish version of the Hilbert curve implementation by John + Skilling as described in: + + Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference + Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics. + + Params: + ------- + hilberts - An ndarray of Hilbert integers. Must be an integer dtype and + cannot have fewer bits than num_dims * num_bits. + + num_dims - The dimensionality of the hypercube. Integer. + + num_bits - The number of bits for each dimension. Integer. + + Returns: + -------- + The output is an ndarray of unsigned integers with the same shape as hilberts + but with an additional dimension of size num_dims. + """ + + if num_dims * num_bits > 64: + raise ValueError( + """ + num_dims=%d and num_bits=%d for %d bits total, which can't be encoded + into a uint64. Are you sure you need that many points on your Hilbert + curve? + """ + % (num_dims, num_bits) + ) + + # Handle the case where we got handed a naked integer. + hilberts = torch.atleast_1d(hilberts) + + # Keep around the shape for later. + orig_shape = hilberts.shape + bitpack_mask = 2 ** torch.arange(0, 8).to(hilberts.device) + bitpack_mask_rev = bitpack_mask.flip(-1) + + # Treat each of the hilberts as a s equence of eight uint8. + # This treats all of the inputs as uint64 and makes things uniform. + hh_uint8 = ( + hilberts.ravel().type(torch.int64).view(torch.uint8).reshape((-1, 8)).flip(-1) + ) + + # Turn these lists of uints into lists of bits and then truncate to the size + # we actually need for using Skilling's procedure. + hh_bits = ( + hh_uint8.unsqueeze(-1) + .bitwise_and(bitpack_mask_rev) + .ne(0) + .byte() + .flatten(-2, -1)[:, -num_dims * num_bits :] + ) + + # Take the sequence of bits and Gray-code it. + gray = binary2gray(hh_bits) + + # There has got to be a better way to do this. + # I could index them differently, but the eventual packbits likes it this way. + gray = gray.reshape((-1, num_bits, num_dims)).swapaxes(1, 2) + + # Iterate backwards through the bits. + for bit in range(num_bits - 1, -1, -1): + # Iterate backwards through the dimensions. + for dim in range(num_dims - 1, -1, -1): + # Identify which ones have this bit active. + mask = gray[:, dim, bit] + + # Where this bit is on, invert the 0 dimension for lower bits. + gray[:, 0, bit + 1 :] = torch.logical_xor( + gray[:, 0, bit + 1 :], mask[:, None] + ) + + # Where the bit is off, exchange the lower bits with the 0 dimension. + to_flip = torch.logical_and( + torch.logical_not(mask[:, None]), + torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]), + ) + gray[:, dim, bit + 1 :] = torch.logical_xor( + gray[:, dim, bit + 1 :], to_flip + ) + gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip) + + # Pad back out to 64 bits. + extra_dims = 64 - num_bits + padded = torch.nn.functional.pad(gray, (extra_dims, 0), "constant", 0) + + # Now chop these up into blocks of 8. + locs_chopped = padded.flip(-1).reshape((-1, num_dims, 8, 8)) + + # Take those blocks and turn them unto uint8s. + # from IPython import embed; embed() + locs_uint8 = (locs_chopped * bitpack_mask).sum(3).squeeze().type(torch.uint8) + + # Finally, treat these as uint64s. + flat_locs = locs_uint8.view(torch.int64) + + # Return them in the expected shape. + return flat_locs.reshape((*orig_shape, num_dims)) diff --git a/XPart/partgen/models/sonata/serialization/z_order.py b/XPart/partgen/models/sonata/serialization/z_order.py new file mode 100755 index 0000000000000000000000000000000000000000..f45f6a39219ba0d6c945c2e588649c3376410e15 --- /dev/null +++ b/XPart/partgen/models/sonata/serialization/z_order.py @@ -0,0 +1,145 @@ +# @lint-ignore-every LICENSELINT +# -------------------------------------------------------- +# Octree-based Sparse Convolutional Neural Networks +# Copyright (c) 2022 Peng-Shuai Wang +# Licensed under The MIT License [see LICENSE for details] +# Written by Peng-Shuai Wang +# -------------------------------------------------------- +# -------------------------------------------------------- +# Octree-based Sparse Convolutional Neural Networks +# Copyright (c) 2022 Peng-Shuai Wang +# Licensed under The MIT License [see LICENSE for details] +# Written by Peng-Shuai Wang +# -------------------------------------------------------- +# -------------------------------------------------------- +# Octree-based Sparse Convolutional Neural Networks +# Copyright (c) 2022 Peng-Shuai Wang +# Licensed under The MIT License [see LICENSE for details] +# Written by Peng-Shuai Wang +# -------------------------------------------------------- +# -------------------------------------------------------- +# Octree-based Sparse Convolutional Neural Networks +# Copyright (c) 2022 Peng-Shuai Wang +# Licensed under The MIT License [see LICENSE for details] +# Written by Peng-Shuai Wang +# -------------------------------------------------------- + +import torch +from typing import Optional, Union + + +class KeyLUT: + def __init__(self): + r256 = torch.arange(256, dtype=torch.int64) + r512 = torch.arange(512, dtype=torch.int64) + zero = torch.zeros(256, dtype=torch.int64) + device = torch.device("cpu") + + self._encode = { + device: ( + self.xyz2key(r256, zero, zero, 8), + self.xyz2key(zero, r256, zero, 8), + self.xyz2key(zero, zero, r256, 8), + ) + } + self._decode = {device: self.key2xyz(r512, 9)} + + def encode_lut(self, device=torch.device("cpu")): + if device not in self._encode: + cpu = torch.device("cpu") + self._encode[device] = tuple(e.to(device) for e in self._encode[cpu]) + return self._encode[device] + + def decode_lut(self, device=torch.device("cpu")): + if device not in self._decode: + cpu = torch.device("cpu") + self._decode[device] = tuple(e.to(device) for e in self._decode[cpu]) + return self._decode[device] + + def xyz2key(self, x, y, z, depth): + key = torch.zeros_like(x) + for i in range(depth): + mask = 1 << i + key = ( + key + | ((x & mask) << (2 * i + 2)) + | ((y & mask) << (2 * i + 1)) + | ((z & mask) << (2 * i + 0)) + ) + return key + + def key2xyz(self, key, depth): + x = torch.zeros_like(key) + y = torch.zeros_like(key) + z = torch.zeros_like(key) + for i in range(depth): + x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2)) + y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1)) + z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0)) + return x, y, z + + +_key_lut = KeyLUT() + + +def xyz2key( + x: torch.Tensor, + y: torch.Tensor, + z: torch.Tensor, + b: Optional[Union[torch.Tensor, int]] = None, + depth: int = 16, +): + """Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys + based on pre-computed look up tables. The speed of this function is much + faster than the method based on for-loop. + + Args: + x (torch.Tensor): The x coordinate. + y (torch.Tensor): The y coordinate. + z (torch.Tensor): The z coordinate. + b (torch.Tensor or int): The batch index of the coordinates, and should be + smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of + :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`. + depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). + """ + + EX, EY, EZ = _key_lut.encode_lut(x.device) + x, y, z = x.long(), y.long(), z.long() + + mask = 255 if depth > 8 else (1 << depth) - 1 + key = EX[x & mask] | EY[y & mask] | EZ[z & mask] + if depth > 8: + mask = (1 << (depth - 8)) - 1 + key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask] + key = key16 << 24 | key + + if b is not None: + b = b.long() + key = b << 48 | key + + return key + + +def key2xyz(key: torch.Tensor, depth: int = 16): + r"""Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates + and the batch index based on pre-computed look up tables. + + Args: + key (torch.Tensor): The shuffled key. + depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). + """ + + DX, DY, DZ = _key_lut.decode_lut(key.device) + x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key) + + b = key >> 48 + key = key & ((1 << 48) - 1) + + n = (depth + 2) // 3 + for i in range(n): + k = key >> (i * 9) & 511 + x = x | (DX[k] << (i * 3)) + y = y | (DY[k] << (i * 3)) + z = z | (DZ[k] << (i * 3)) + + return x, y, z, b diff --git a/XPart/partgen/models/sonata/structure.py b/XPart/partgen/models/sonata/structure.py new file mode 100755 index 0000000000000000000000000000000000000000..a462955758b620577881be5c5fd5f8f48e7c290e --- /dev/null +++ b/XPart/partgen/models/sonata/structure.py @@ -0,0 +1,159 @@ +""" +Data structure for 3D point cloud + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import spconv.pytorch as spconv +from addict import Dict + +from .serialization import encode +from .utils import offset2batch, batch2offset + + +class Point(Dict): + """ + Point Structure of Pointcept + + A Point (point cloud) in Pointcept is a dictionary that contains various properties of + a batched point cloud. The property with the following names have a specific definition + as follows: + + - "coord": original coordinate of point cloud; + - "grid_coord": grid coordinate for specific grid size (related to GridSampling); + Point also support the following optional attributes: + - "offset": if not exist, initialized as batch size is 1; + - "batch": if not exist, initialized as batch size is 1; + - "feat": feature of point cloud, default input of model; + - "grid_size": Grid size of point cloud (related to GridSampling); + (related to Serialization) + - "serialized_depth": depth of serialization, 2 ** depth * grid_size describe the maximum of point cloud range; + - "serialized_code": a list of serialization codes; + - "serialized_order": a list of serialization order determined by code; + - "serialized_inverse": a list of inverse mapping determined by code; + (related to Sparsify: SpConv) + - "sparse_shape": Sparse shape for Sparse Conv Tensor; + - "sparse_conv_feat": SparseConvTensor init with information provide by Point; + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # If one of "offset" or "batch" do not exist, generate by the existing one + if "batch" not in self.keys() and "offset" in self.keys(): + self["batch"] = offset2batch(self.offset) + elif "offset" not in self.keys() and "batch" in self.keys(): + self["offset"] = batch2offset(self.batch) + + def serialization(self, order="z", depth=None, shuffle_orders=False): + """ + Point Cloud Serialization + + relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"] + """ + self["order"] = order + assert "batch" in self.keys() + if "grid_coord" not in self.keys(): + # if you don't want to operate GridSampling in data augmentation, + # please add the following augmentation into your pipeline: + # dict(type="Copy", keys_dict={"grid_size": 0.01}), + # (adjust `grid_size` to what your want) + assert {"grid_size", "coord"}.issubset(self.keys()) + + self["grid_coord"] = torch.div( + self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc" + ).int() + + if depth is None: + # Adaptive measure the depth of serialization cube (length = 2 ^ depth) + depth = int(self.grid_coord.max() + 1).bit_length() + self["serialized_depth"] = depth + # Maximum bit length for serialization code is 63 (int64) + assert depth * 3 + len(self.offset).bit_length() <= 63 + # Here we follow OCNN and set the depth limitation to 16 (48bit) for the point position. + # Although depth is limited to less than 16, we can encode a 655.36^3 (2^16 * 0.01) meter^3 + # cube with a grid size of 0.01 meter. We consider it is enough for the current stage. + # We can unlock the limitation by optimizing the z-order encoding function if necessary. + assert depth <= 16 + + # The serialization codes are arranged as following structures: + # [Order1 ([n]), + # Order2 ([n]), + # ... + # OrderN ([n])] (k, n) + code = [ + encode(self.grid_coord, self.batch, depth, order=order_) for order_ in order + ] + code = torch.stack(code) + order = torch.argsort(code) + inverse = torch.zeros_like(order).scatter_( + dim=1, + index=order, + src=torch.arange(0, code.shape[1], device=order.device).repeat( + code.shape[0], 1 + ), + ) + + if shuffle_orders: + perm = torch.randperm(code.shape[0]) + code = code[perm] + order = order[perm] + inverse = inverse[perm] + + self["serialized_code"] = code + self["serialized_order"] = order + self["serialized_inverse"] = inverse + + def sparsify(self, pad=96): + """ + Point Cloud Serialization + + Point cloud is sparse, here we use "sparsify" to specifically refer to + preparing "spconv.SparseConvTensor" for SpConv. + + relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"] + + pad: padding sparse for sparse shape. + """ + assert {"feat", "batch"}.issubset(self.keys()) + if "grid_coord" not in self.keys(): + # if you don't want to operate GridSampling in data augmentation, + # please add the following augmentation into your pipeline: + # dict(type="Copy", keys_dict={"grid_size": 0.01}), + # (adjust `grid_size` to what your want) + assert {"grid_size", "coord"}.issubset(self.keys()) + self["grid_coord"] = torch.div( + self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc" + ).int() + if "sparse_shape" in self.keys(): + sparse_shape = self.sparse_shape + else: + sparse_shape = torch.add( + torch.max(self.grid_coord, dim=0).values, pad + ).tolist() + sparse_conv_feat = spconv.SparseConvTensor( + features=self.feat, + indices=torch.cat( + [self.batch.unsqueeze(-1).int(), self.grid_coord.int()], dim=1 + ).contiguous(), + spatial_shape=sparse_shape, + batch_size=self.batch[-1].tolist() + 1, + ) + self["sparse_shape"] = sparse_shape + self["sparse_conv_feat"] = sparse_conv_feat diff --git a/XPart/partgen/models/sonata/transform.py b/XPart/partgen/models/sonata/transform.py new file mode 100755 index 0000000000000000000000000000000000000000..e9d4c6612dabddec50536983a7ff7eea185a815c --- /dev/null +++ b/XPart/partgen/models/sonata/transform.py @@ -0,0 +1,1330 @@ +""" +3D point cloud augmentation + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + +import random +import numbers +import scipy +import scipy.ndimage +import scipy.interpolate +import scipy.stats +import numpy as np +import torch +import copy +from collections.abc import Sequence, Mapping + +from .registry import Registry + +TRANSFORMS = Registry("transforms") + + +def index_operator(data_dict, index, duplicate=False): + # index selection operator for keys in "index_valid_keys" + # custom these keys by "Update" transform in config + if "index_valid_keys" not in data_dict: + data_dict["index_valid_keys"] = [ + "coord", + "color", + "normal", + "strength", + "segment", + "instance", + ] + if not duplicate: + for key in data_dict["index_valid_keys"]: + if key in data_dict: + data_dict[key] = data_dict[key][index] + return data_dict + else: + data_dict_ = dict() + for key in data_dict.keys(): + if key in data_dict["index_valid_keys"]: + data_dict_[key] = data_dict[key][index] + else: + data_dict_[key] = data_dict[key] + return data_dict_ + + +@TRANSFORMS.register_module() +class Collect(object): + def __init__(self, keys, offset_keys_dict=None, **kwargs): + """ + e.g. Collect(keys=[coord], feat_keys=[coord, color]) + """ + if offset_keys_dict is None: + offset_keys_dict = dict(offset="coord") + self.keys = keys + self.offset_keys = offset_keys_dict + self.kwargs = kwargs + + def __call__(self, data_dict): + data = dict() + if isinstance(self.keys, str): + self.keys = [self.keys] + for key in self.keys: + data[key] = data_dict[key] + for key, value in self.offset_keys.items(): + data[key] = torch.tensor([data_dict[value].shape[0]]) + for name, keys in self.kwargs.items(): + name = name.replace("_keys", "") + assert isinstance(keys, Sequence) + data[name] = torch.cat([data_dict[key].float() for key in keys], dim=1) + return data + + +@TRANSFORMS.register_module() +class Copy(object): + def __init__(self, keys_dict=None): + if keys_dict is None: + keys_dict = dict(coord="origin_coord", segment="origin_segment") + self.keys_dict = keys_dict + + def __call__(self, data_dict): + for key, value in self.keys_dict.items(): + if isinstance(data_dict[key], np.ndarray): + data_dict[value] = data_dict[key].copy() + elif isinstance(data_dict[key], torch.Tensor): + data_dict[value] = data_dict[key].clone().detach() + else: + data_dict[value] = copy.deepcopy(data_dict[key]) + return data_dict + + +@TRANSFORMS.register_module() +class Update(object): + def __init__(self, keys_dict=None): + if keys_dict is None: + keys_dict = dict() + self.keys_dict = keys_dict + + def __call__(self, data_dict): + for key, value in self.keys_dict.items(): + data_dict[key] = value + return data_dict + + +@TRANSFORMS.register_module() +class ToTensor(object): + def __call__(self, data): + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, str): + # note that str is also a kind of sequence, judgement should before sequence + return data + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, bool): + return torch.from_numpy(data) + elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, np.integer): + return torch.from_numpy(data).long() + elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, np.floating): + return torch.from_numpy(data).float() + elif isinstance(data, Mapping): + result = {sub_key: self(item) for sub_key, item in data.items()} + return result + elif isinstance(data, Sequence): + result = [self(item) for item in data] + return result + else: + raise TypeError(f"type {type(data)} cannot be converted to tensor.") + + +@TRANSFORMS.register_module() +class NormalizeColor(object): + def __call__(self, data_dict): + if "color" in data_dict.keys(): + data_dict["color"] = data_dict["color"] / 255 + return data_dict + + +@TRANSFORMS.register_module() +class NormalizeCoord(object): + def __call__(self, data_dict): + if "coord" in data_dict.keys(): + # modified from pointnet2 + centroid = np.mean(data_dict["coord"], axis=0) + data_dict["coord"] -= centroid + m = np.max(np.sqrt(np.sum(data_dict["coord"] ** 2, axis=1))) + data_dict["coord"] = data_dict["coord"] / m + return data_dict + + +@TRANSFORMS.register_module() +class PositiveShift(object): + def __call__(self, data_dict): + if "coord" in data_dict.keys(): + coord_min = np.min(data_dict["coord"], 0) + data_dict["coord"] -= coord_min + return data_dict + + +@TRANSFORMS.register_module() +class CenterShift(object): + def __init__(self, apply_z=True): + self.apply_z = apply_z + + def __call__(self, data_dict): + if "coord" in data_dict.keys(): + x_min, y_min, z_min = data_dict["coord"].min(axis=0) + x_max, y_max, _ = data_dict["coord"].max(axis=0) + if self.apply_z: + shift = [(x_min + x_max) / 2, (y_min + y_max) / 2, z_min] + else: + shift = [(x_min + x_max) / 2, (y_min + y_max) / 2, 0] + data_dict["coord"] -= shift + return data_dict + + +@TRANSFORMS.register_module() +class RandomShift(object): + def __init__(self, shift=((-0.2, 0.2), (-0.2, 0.2), (0, 0))): + self.shift = shift + + def __call__(self, data_dict): + if "coord" in data_dict.keys(): + shift_x = np.random.uniform(self.shift[0][0], self.shift[0][1]) + shift_y = np.random.uniform(self.shift[1][0], self.shift[1][1]) + shift_z = np.random.uniform(self.shift[2][0], self.shift[2][1]) + data_dict["coord"] += [shift_x, shift_y, shift_z] + return data_dict + + +@TRANSFORMS.register_module() +class PointClip(object): + def __init__(self, point_cloud_range=(-80, -80, -3, 80, 80, 1)): + self.point_cloud_range = point_cloud_range + + def __call__(self, data_dict): + if "coord" in data_dict.keys(): + data_dict["coord"] = np.clip( + data_dict["coord"], + a_min=self.point_cloud_range[:3], + a_max=self.point_cloud_range[3:], + ) + return data_dict + + +@TRANSFORMS.register_module() +class RandomDropout(object): + def __init__(self, dropout_ratio=0.2, dropout_application_ratio=0.5): + """ + upright_axis: axis index among x,y,z, i.e. 2 for z + """ + self.dropout_ratio = dropout_ratio + self.dropout_application_ratio = dropout_application_ratio + + def __call__(self, data_dict): + if random.random() < self.dropout_application_ratio: + n = len(data_dict["coord"]) + idx = np.random.choice(n, int(n * (1 - self.dropout_ratio)), replace=False) + if "sampled_index" in data_dict: + # for ScanNet data efficient, we need to make sure labeled point is sampled. + idx = np.unique(np.append(idx, data_dict["sampled_index"])) + mask = np.zeros_like(data_dict["segment"]).astype(bool) + mask[data_dict["sampled_index"]] = True + data_dict["sampled_index"] = np.where(mask[idx])[0] + data_dict = index_operator(data_dict, idx) + return data_dict + + +@TRANSFORMS.register_module() +class RandomRotate(object): + def __init__(self, angle=None, center=None, axis="z", always_apply=False, p=0.5): + self.angle = [-1, 1] if angle is None else angle + self.axis = axis + self.always_apply = always_apply + self.p = p if not self.always_apply else 1 + self.center = center + + def __call__(self, data_dict): + if random.random() > self.p: + return data_dict + angle = np.random.uniform(self.angle[0], self.angle[1]) * np.pi + rot_cos, rot_sin = np.cos(angle), np.sin(angle) + if self.axis == "x": + rot_t = np.array([[1, 0, 0], [0, rot_cos, -rot_sin], [0, rot_sin, rot_cos]]) + elif self.axis == "y": + rot_t = np.array([[rot_cos, 0, rot_sin], [0, 1, 0], [-rot_sin, 0, rot_cos]]) + elif self.axis == "z": + rot_t = np.array([[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]]) + else: + raise NotImplementedError + if "coord" in data_dict.keys(): + if self.center is None: + x_min, y_min, z_min = data_dict["coord"].min(axis=0) + x_max, y_max, z_max = data_dict["coord"].max(axis=0) + center = [(x_min + x_max) / 2, (y_min + y_max) / 2, (z_min + z_max) / 2] + else: + center = self.center + data_dict["coord"] -= center + data_dict["coord"] = np.dot(data_dict["coord"], np.transpose(rot_t)) + data_dict["coord"] += center + if "normal" in data_dict.keys(): + data_dict["normal"] = np.dot(data_dict["normal"], np.transpose(rot_t)) + return data_dict + + +@TRANSFORMS.register_module() +class RandomRotateTargetAngle(object): + def __init__( + self, angle=(1 / 2, 1, 3 / 2), center=None, axis="z", always_apply=False, p=0.75 + ): + self.angle = angle + self.axis = axis + self.always_apply = always_apply + self.p = p if not self.always_apply else 1 + self.center = center + + def __call__(self, data_dict): + if random.random() > self.p: + return data_dict + angle = np.random.choice(self.angle) * np.pi + rot_cos, rot_sin = np.cos(angle), np.sin(angle) + if self.axis == "x": + rot_t = np.array([[1, 0, 0], [0, rot_cos, -rot_sin], [0, rot_sin, rot_cos]]) + elif self.axis == "y": + rot_t = np.array([[rot_cos, 0, rot_sin], [0, 1, 0], [-rot_sin, 0, rot_cos]]) + elif self.axis == "z": + rot_t = np.array([[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]]) + else: + raise NotImplementedError + if "coord" in data_dict.keys(): + if self.center is None: + x_min, y_min, z_min = data_dict["coord"].min(axis=0) + x_max, y_max, z_max = data_dict["coord"].max(axis=0) + center = [(x_min + x_max) / 2, (y_min + y_max) / 2, (z_min + z_max) / 2] + else: + center = self.center + data_dict["coord"] -= center + data_dict["coord"] = np.dot(data_dict["coord"], np.transpose(rot_t)) + data_dict["coord"] += center + if "normal" in data_dict.keys(): + data_dict["normal"] = np.dot(data_dict["normal"], np.transpose(rot_t)) + return data_dict + + +@TRANSFORMS.register_module() +class RandomScale(object): + def __init__(self, scale=None, anisotropic=False): + self.scale = scale if scale is not None else [0.95, 1.05] + self.anisotropic = anisotropic + + def __call__(self, data_dict): + if "coord" in data_dict.keys(): + scale = np.random.uniform( + self.scale[0], self.scale[1], 3 if self.anisotropic else 1 + ) + data_dict["coord"] *= scale + return data_dict + + +@TRANSFORMS.register_module() +class RandomFlip(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, data_dict): + if np.random.rand() < self.p: + if "coord" in data_dict.keys(): + data_dict["coord"][:, 0] = -data_dict["coord"][:, 0] + if "normal" in data_dict.keys(): + data_dict["normal"][:, 0] = -data_dict["normal"][:, 0] + if np.random.rand() < self.p: + if "coord" in data_dict.keys(): + data_dict["coord"][:, 1] = -data_dict["coord"][:, 1] + if "normal" in data_dict.keys(): + data_dict["normal"][:, 1] = -data_dict["normal"][:, 1] + return data_dict + + +@TRANSFORMS.register_module() +class RandomJitter(object): + def __init__(self, sigma=0.01, clip=0.05): + assert clip > 0 + self.sigma = sigma + self.clip = clip + + def __call__(self, data_dict): + if "coord" in data_dict.keys(): + jitter = np.clip( + self.sigma * np.random.randn(data_dict["coord"].shape[0], 3), + -self.clip, + self.clip, + ) + data_dict["coord"] += jitter + return data_dict + + +@TRANSFORMS.register_module() +class ClipGaussianJitter(object): + def __init__(self, scalar=0.02, store_jitter=False): + self.scalar = scalar + self.mean = np.mean(3) + self.cov = np.identity(3) + self.quantile = 1.96 + self.store_jitter = store_jitter + + def __call__(self, data_dict): + if "coord" in data_dict.keys(): + jitter = np.random.multivariate_normal( + self.mean, self.cov, data_dict["coord"].shape[0] + ) + jitter = self.scalar * np.clip(jitter / 1.96, -1, 1) + data_dict["coord"] += jitter + if self.store_jitter: + data_dict["jitter"] = jitter + return data_dict + + +@TRANSFORMS.register_module() +class ChromaticAutoContrast(object): + def __init__(self, p=0.2, blend_factor=None): + self.p = p + self.blend_factor = blend_factor + + def __call__(self, data_dict): + if "color" in data_dict.keys() and np.random.rand() < self.p: + lo = np.min(data_dict["color"], 0, keepdims=True) + hi = np.max(data_dict["color"], 0, keepdims=True) + scale = 255 / (hi - lo) + contrast_feat = (data_dict["color"][:, :3] - lo) * scale + blend_factor = ( + np.random.rand() if self.blend_factor is None else self.blend_factor + ) + data_dict["color"][:, :3] = (1 - blend_factor) * data_dict["color"][ + :, :3 + ] + blend_factor * contrast_feat + return data_dict + + +@TRANSFORMS.register_module() +class ChromaticTranslation(object): + def __init__(self, p=0.95, ratio=0.05): + self.p = p + self.ratio = ratio + + def __call__(self, data_dict): + if "color" in data_dict.keys() and np.random.rand() < self.p: + tr = (np.random.rand(1, 3) - 0.5) * 255 * 2 * self.ratio + data_dict["color"][:, :3] = np.clip(tr + data_dict["color"][:, :3], 0, 255) + return data_dict + + +@TRANSFORMS.register_module() +class ChromaticJitter(object): + def __init__(self, p=0.95, std=0.005): + self.p = p + self.std = std + + def __call__(self, data_dict): + if "color" in data_dict.keys() and np.random.rand() < self.p: + noise = np.random.randn(data_dict["color"].shape[0], 3) + noise *= self.std * 255 + data_dict["color"][:, :3] = np.clip( + noise + data_dict["color"][:, :3], 0, 255 + ) + return data_dict + + +@TRANSFORMS.register_module() +class RandomColorGrayScale(object): + def __init__(self, p): + self.p = p + + @staticmethod + def rgb_to_grayscale(color, num_output_channels=1): + if color.shape[-1] < 3: + raise TypeError( + "Input color should have at least 3 dimensions, but found {}".format( + color.shape[-1] + ) + ) + + if num_output_channels not in (1, 3): + raise ValueError("num_output_channels should be either 1 or 3") + + r, g, b = color[..., 0], color[..., 1], color[..., 2] + gray = (0.2989 * r + 0.587 * g + 0.114 * b).astype(color.dtype) + gray = np.expand_dims(gray, axis=-1) + + if num_output_channels == 3: + gray = np.broadcast_to(gray, color.shape) + + return gray + + def __call__(self, data_dict): + if np.random.rand() < self.p: + data_dict["color"] = self.rgb_to_grayscale(data_dict["color"], 3) + return data_dict + + +@TRANSFORMS.register_module() +class RandomColorJitter(object): + """ + Random Color Jitter for 3D point cloud (refer torchvision) + """ + + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, p=0.95): + self.brightness = self._check_input(brightness, "brightness") + self.contrast = self._check_input(contrast, "contrast") + self.saturation = self._check_input(saturation, "saturation") + self.hue = self._check_input( + hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False + ) + self.p = p + + @staticmethod + def _check_input( + value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True + ): + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError( + "If {} is a single number, it must be non negative.".format(name) + ) + value = [center - float(value), center + float(value)] + if clip_first_on_zero: + value[0] = max(value[0], 0.0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError("{} values should be between {}".format(name, bound)) + else: + raise TypeError( + "{} should be a single number or a list/tuple with length 2.".format( + name + ) + ) + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + value = None + return value + + @staticmethod + def blend(color1, color2, ratio): + ratio = float(ratio) + bound = 255.0 + return ( + (ratio * color1 + (1.0 - ratio) * color2) + .clip(0, bound) + .astype(color1.dtype) + ) + + @staticmethod + def rgb2hsv(rgb): + r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2] + maxc = np.max(rgb, axis=-1) + minc = np.min(rgb, axis=-1) + eqc = maxc == minc + cr = maxc - minc + s = cr / (np.ones_like(maxc) * eqc + maxc * (1 - eqc)) + cr_divisor = np.ones_like(maxc) * eqc + cr * (1 - eqc) + rc = (maxc - r) / cr_divisor + gc = (maxc - g) / cr_divisor + bc = (maxc - b) / cr_divisor + + hr = (maxc == r) * (bc - gc) + hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc) + hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc) + h = hr + hg + hb + h = (h / 6.0 + 1.0) % 1.0 + return np.stack((h, s, maxc), axis=-1) + + @staticmethod + def hsv2rgb(hsv): + h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2] + i = np.floor(h * 6.0) + f = (h * 6.0) - i + i = i.astype(np.int32) + + p = np.clip((v * (1.0 - s)), 0.0, 1.0) + q = np.clip((v * (1.0 - s * f)), 0.0, 1.0) + t = np.clip((v * (1.0 - s * (1.0 - f))), 0.0, 1.0) + i = i % 6 + mask = np.expand_dims(i, axis=-1) == np.arange(6) + + a1 = np.stack((v, q, p, p, t, v), axis=-1) + a2 = np.stack((t, v, v, q, p, p), axis=-1) + a3 = np.stack((p, p, t, v, v, q), axis=-1) + a4 = np.stack((a1, a2, a3), axis=-1) + + return np.einsum("...na, ...nab -> ...nb", mask.astype(hsv.dtype), a4) + + def adjust_brightness(self, color, brightness_factor): + if brightness_factor < 0: + raise ValueError( + "brightness_factor ({}) is not non-negative.".format(brightness_factor) + ) + + return self.blend(color, np.zeros_like(color), brightness_factor) + + def adjust_contrast(self, color, contrast_factor): + if contrast_factor < 0: + raise ValueError( + "contrast_factor ({}) is not non-negative.".format(contrast_factor) + ) + mean = np.mean(RandomColorGrayScale.rgb_to_grayscale(color)) + return self.blend(color, mean, contrast_factor) + + def adjust_saturation(self, color, saturation_factor): + if saturation_factor < 0: + raise ValueError( + "saturation_factor ({}) is not non-negative.".format(saturation_factor) + ) + gray = RandomColorGrayScale.rgb_to_grayscale(color) + return self.blend(color, gray, saturation_factor) + + def adjust_hue(self, color, hue_factor): + if not (-0.5 <= hue_factor <= 0.5): + raise ValueError( + "hue_factor ({}) is not in [-0.5, 0.5].".format(hue_factor) + ) + orig_dtype = color.dtype + hsv = self.rgb2hsv(color / 255.0) + h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2] + h = (h + hue_factor) % 1.0 + hsv = np.stack((h, s, v), axis=-1) + color_hue_adj = (self.hsv2rgb(hsv) * 255.0).astype(orig_dtype) + return color_hue_adj + + @staticmethod + def get_params(brightness, contrast, saturation, hue): + fn_idx = torch.randperm(4) + b = ( + None + if brightness is None + else np.random.uniform(brightness[0], brightness[1]) + ) + c = None if contrast is None else np.random.uniform(contrast[0], contrast[1]) + s = ( + None + if saturation is None + else np.random.uniform(saturation[0], saturation[1]) + ) + h = None if hue is None else np.random.uniform(hue[0], hue[1]) + return fn_idx, b, c, s, h + + def __call__(self, data_dict): + ( + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ) = self.get_params(self.brightness, self.contrast, self.saturation, self.hue) + + for fn_id in fn_idx: + if ( + fn_id == 0 + and brightness_factor is not None + and np.random.rand() < self.p + ): + data_dict["color"] = self.adjust_brightness( + data_dict["color"], brightness_factor + ) + elif ( + fn_id == 1 and contrast_factor is not None and np.random.rand() < self.p + ): + data_dict["color"] = self.adjust_contrast( + data_dict["color"], contrast_factor + ) + elif ( + fn_id == 2 + and saturation_factor is not None + and np.random.rand() < self.p + ): + data_dict["color"] = self.adjust_saturation( + data_dict["color"], saturation_factor + ) + elif fn_id == 3 and hue_factor is not None and np.random.rand() < self.p: + data_dict["color"] = self.adjust_hue(data_dict["color"], hue_factor) + return data_dict + + +@TRANSFORMS.register_module() +class HueSaturationTranslation(object): + @staticmethod + def rgb_to_hsv(rgb): + # Translated from source of colorsys.rgb_to_hsv + # r,g,b should be a numpy arrays with values between 0 and 255 + # rgb_to_hsv returns an array of floats between 0.0 and 1.0. + rgb = rgb.astype("float") + hsv = np.zeros_like(rgb) + # in case an RGBA array was passed, just copy the A channel + hsv[..., 3:] = rgb[..., 3:] + r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2] + maxc = np.max(rgb[..., :3], axis=-1) + minc = np.min(rgb[..., :3], axis=-1) + hsv[..., 2] = maxc + mask = maxc != minc + hsv[mask, 1] = (maxc - minc)[mask] / maxc[mask] + rc = np.zeros_like(r) + gc = np.zeros_like(g) + bc = np.zeros_like(b) + rc[mask] = (maxc - r)[mask] / (maxc - minc)[mask] + gc[mask] = (maxc - g)[mask] / (maxc - minc)[mask] + bc[mask] = (maxc - b)[mask] / (maxc - minc)[mask] + hsv[..., 0] = np.select( + [r == maxc, g == maxc], [bc - gc, 2.0 + rc - bc], default=4.0 + gc - rc + ) + hsv[..., 0] = (hsv[..., 0] / 6.0) % 1.0 + return hsv + + @staticmethod + def hsv_to_rgb(hsv): + # Translated from source of colorsys.hsv_to_rgb + # h,s should be a numpy arrays with values between 0.0 and 1.0 + # v should be a numpy array with values between 0.0 and 255.0 + # hsv_to_rgb returns an array of uints between 0 and 255. + rgb = np.empty_like(hsv) + rgb[..., 3:] = hsv[..., 3:] + h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2] + i = (h * 6.0).astype("uint8") + f = (h * 6.0) - i + p = v * (1.0 - s) + q = v * (1.0 - s * f) + t = v * (1.0 - s * (1.0 - f)) + i = i % 6 + conditions = [s == 0.0, i == 1, i == 2, i == 3, i == 4, i == 5] + rgb[..., 0] = np.select(conditions, [v, q, p, p, t, v], default=v) + rgb[..., 1] = np.select(conditions, [v, v, v, q, p, p], default=t) + rgb[..., 2] = np.select(conditions, [v, p, t, v, v, q], default=p) + return rgb.astype("uint8") + + def __init__(self, hue_max=0.5, saturation_max=0.2): + self.hue_max = hue_max + self.saturation_max = saturation_max + + def __call__(self, data_dict): + if "color" in data_dict.keys(): + # Assume color[:, :3] is rgb + hsv = HueSaturationTranslation.rgb_to_hsv(data_dict["color"][:, :3]) + hue_val = (np.random.rand() - 0.5) * 2 * self.hue_max + sat_ratio = 1 + (np.random.rand() - 0.5) * 2 * self.saturation_max + hsv[..., 0] = np.remainder(hue_val + hsv[..., 0] + 1, 1) + hsv[..., 1] = np.clip(sat_ratio * hsv[..., 1], 0, 1) + data_dict["color"][:, :3] = np.clip( + HueSaturationTranslation.hsv_to_rgb(hsv), 0, 255 + ) + return data_dict + + +@TRANSFORMS.register_module() +class RandomColorDrop(object): + def __init__(self, p=0.2, color_augment=0.0): + self.p = p + self.color_augment = color_augment + + def __call__(self, data_dict): + if "color" in data_dict.keys() and np.random.rand() < self.p: + data_dict["color"] *= self.color_augment + return data_dict + + def __repr__(self): + return "RandomColorDrop(color_augment: {}, p: {})".format( + self.color_augment, self.p + ) + + +@TRANSFORMS.register_module() +class ElasticDistortion(object): + def __init__(self, distortion_params=None): + self.distortion_params = ( + [[0.2, 0.4], [0.8, 1.6]] if distortion_params is None else distortion_params + ) + + @staticmethod + def elastic_distortion(coords, granularity, magnitude): + """ + Apply elastic distortion on sparse coordinate space. + pointcloud: numpy array of (number of points, at least 3 spatial dims) + granularity: size of the noise grid (in same scale[m/cm] as the voxel grid) + magnitude: noise multiplier + """ + blurx = np.ones((3, 1, 1, 1)).astype("float32") / 3 + blury = np.ones((1, 3, 1, 1)).astype("float32") / 3 + blurz = np.ones((1, 1, 3, 1)).astype("float32") / 3 + coords_min = coords.min(0) + + # Create Gaussian noise tensor of the size given by granularity. + noise_dim = ((coords - coords_min).max(0) // granularity).astype(int) + 3 + noise = np.random.randn(*noise_dim, 3).astype(np.float32) + + # Smoothing. + for _ in range(2): + noise = scipy.ndimage.filters.convolve( + noise, blurx, mode="constant", cval=0 + ) + noise = scipy.ndimage.filters.convolve( + noise, blury, mode="constant", cval=0 + ) + noise = scipy.ndimage.filters.convolve( + noise, blurz, mode="constant", cval=0 + ) + + # Trilinear interpolate noise filters for each spatial dimensions. + ax = [ + np.linspace(d_min, d_max, d) + for d_min, d_max, d in zip( + coords_min - granularity, + coords_min + granularity * (noise_dim - 2), + noise_dim, + ) + ] + interp = scipy.interpolate.RegularGridInterpolator( + ax, noise, bounds_error=False, fill_value=0 + ) + coords += interp(coords) * magnitude + return coords + + def __call__(self, data_dict): + if "coord" in data_dict.keys() and self.distortion_params is not None: + if random.random() < 0.95: + for granularity, magnitude in self.distortion_params: + data_dict["coord"] = self.elastic_distortion( + data_dict["coord"], granularity, magnitude + ) + return data_dict + + +@TRANSFORMS.register_module() +class GridSample(object): + def __init__( + self, + grid_size=0.05, + hash_type="fnv", + mode="train", + return_inverse=False, + return_grid_coord=False, + return_min_coord=False, + return_displacement=False, + project_displacement=False, + ): + self.grid_size = grid_size + self.hash = self.fnv_hash_vec if hash_type == "fnv" else self.ravel_hash_vec + assert mode in ["train", "test"] + self.mode = mode + self.return_inverse = return_inverse + self.return_grid_coord = return_grid_coord + self.return_min_coord = return_min_coord + self.return_displacement = return_displacement + self.project_displacement = project_displacement + + def __call__(self, data_dict): + assert "coord" in data_dict.keys() + scaled_coord = data_dict["coord"] / np.array(self.grid_size) + grid_coord = np.floor(scaled_coord).astype(int) + min_coord = grid_coord.min(0) + grid_coord -= min_coord + scaled_coord -= min_coord + min_coord = min_coord * np.array(self.grid_size) + key = self.hash(grid_coord) + idx_sort = np.argsort(key) + key_sort = key[idx_sort] + _, inverse, count = np.unique(key_sort, return_inverse=True, return_counts=True) + if self.mode == "train": # train mode + idx_select = ( + np.cumsum(np.insert(count, 0, 0)[0:-1]) + + np.random.randint(0, count.max(), count.size) % count + ) + idx_unique = idx_sort[idx_select] + if "sampled_index" in data_dict: + # for ScanNet data efficient, we need to make sure labeled point is sampled. + idx_unique = np.unique( + np.append(idx_unique, data_dict["sampled_index"]) + ) + mask = np.zeros_like(data_dict["segment"]).astype(bool) + mask[data_dict["sampled_index"]] = True + data_dict["sampled_index"] = np.where(mask[idx_unique])[0] + data_dict = index_operator(data_dict, idx_unique) + if self.return_inverse: + data_dict["inverse"] = np.zeros_like(inverse) + data_dict["inverse"][idx_sort] = inverse + if self.return_grid_coord: + data_dict["grid_coord"] = grid_coord[idx_unique] + data_dict["index_valid_keys"].append("grid_coord") + if self.return_min_coord: + data_dict["min_coord"] = min_coord.reshape([1, 3]) + if self.return_displacement: + displacement = ( + scaled_coord - grid_coord - 0.5 + ) # [0, 1] -> [-0.5, 0.5] displacement to center + if self.project_displacement: + displacement = np.sum( + displacement * data_dict["normal"], axis=-1, keepdims=True + ) + data_dict["displacement"] = displacement[idx_unique] + data_dict["index_valid_keys"].append("displacement") + return data_dict + + elif self.mode == "test": # test mode + data_part_list = [] + for i in range(count.max()): + idx_select = np.cumsum(np.insert(count, 0, 0)[0:-1]) + i % count + idx_part = idx_sort[idx_select] + data_part = index_operator(data_dict, idx_part, duplicate=True) + data_part["index"] = idx_part + if self.return_inverse: + data_part["inverse"] = np.zeros_like(inverse) + data_part["inverse"][idx_sort] = inverse + if self.return_grid_coord: + data_part["grid_coord"] = grid_coord[idx_part] + data_dict["index_valid_keys"].append("grid_coord") + if self.return_min_coord: + data_part["min_coord"] = min_coord.reshape([1, 3]) + if self.return_displacement: + displacement = ( + scaled_coord - grid_coord - 0.5 + ) # [0, 1] -> [-0.5, 0.5] displacement to center + if self.project_displacement: + displacement = np.sum( + displacement * data_dict["normal"], axis=-1, keepdims=True + ) + data_dict["displacement"] = displacement[idx_part] + data_dict["index_valid_keys"].append("displacement") + data_part_list.append(data_part) + return data_part_list + else: + raise NotImplementedError + + @staticmethod + def ravel_hash_vec(arr): + """ + Ravel the coordinates after subtracting the min coordinates. + """ + assert arr.ndim == 2 + arr = arr.copy() + arr -= arr.min(0) + arr = arr.astype(np.uint64, copy=False) + arr_max = arr.max(0).astype(np.uint64) + 1 + + keys = np.zeros(arr.shape[0], dtype=np.uint64) + # Fortran style indexing + for j in range(arr.shape[1] - 1): + keys += arr[:, j] + keys *= arr_max[j + 1] + keys += arr[:, -1] + return keys + + @staticmethod + def fnv_hash_vec(arr): + """ + FNV64-1A + """ + assert arr.ndim == 2 + # Floor first for negative coordinates + arr = arr.copy() + arr = arr.astype(np.uint64, copy=False) + hashed_arr = np.uint64(14695981039346656037) * np.ones( + arr.shape[0], dtype=np.uint64 + ) + for j in range(arr.shape[1]): + hashed_arr *= np.uint64(1099511628211) + hashed_arr = np.bitwise_xor(hashed_arr, arr[:, j]) + return hashed_arr + + +@TRANSFORMS.register_module() +class SphereCrop(object): + def __init__(self, point_max=80000, sample_rate=None, mode="random"): + self.point_max = point_max + self.sample_rate = sample_rate + assert mode in ["random", "center", "all"] + self.mode = mode + + def __call__(self, data_dict): + point_max = ( + int(self.sample_rate * data_dict["coord"].shape[0]) + if self.sample_rate is not None + else self.point_max + ) + + assert "coord" in data_dict.keys() + if data_dict["coord"].shape[0] > point_max: + if self.mode == "random": + center = data_dict["coord"][ + np.random.randint(data_dict["coord"].shape[0]) + ] + elif self.mode == "center": + center = data_dict["coord"][data_dict["coord"].shape[0] // 2] + else: + raise NotImplementedError + idx_crop = np.argsort(np.sum(np.square(data_dict["coord"] - center), 1))[ + :point_max + ] + data_dict = index_operator(data_dict, idx_crop) + return data_dict + + +@TRANSFORMS.register_module() +class ShufflePoint(object): + def __call__(self, data_dict): + assert "coord" in data_dict.keys() + shuffle_index = np.arange(data_dict["coord"].shape[0]) + np.random.shuffle(shuffle_index) + data_dict = index_operator(data_dict, shuffle_index) + return data_dict + + +@TRANSFORMS.register_module() +class CropBoundary(object): + def __call__(self, data_dict): + assert "segment" in data_dict + segment = data_dict["segment"].flatten() + mask = (segment != 0) * (segment != 1) + data_dict = index_operator(data_dict, mask) + return data_dict + + +@TRANSFORMS.register_module() +class ContrastiveViewsGenerator(object): + def __init__( + self, + view_keys=("coord", "color", "normal", "origin_coord"), + view_trans_cfg=None, + ): + self.view_keys = view_keys + self.view_trans = Compose(view_trans_cfg) + + def __call__(self, data_dict): + view1_dict = dict() + view2_dict = dict() + for key in self.view_keys: + view1_dict[key] = data_dict[key].copy() + view2_dict[key] = data_dict[key].copy() + view1_dict = self.view_trans(view1_dict) + view2_dict = self.view_trans(view2_dict) + for key, value in view1_dict.items(): + data_dict["view1_" + key] = value + for key, value in view2_dict.items(): + data_dict["view2_" + key] = value + return data_dict + + +@TRANSFORMS.register_module() +class MultiViewGenerator(object): + def __init__( + self, + global_view_num=2, + global_view_scale=(0.4, 1.0), + local_view_num=4, + local_view_scale=(0.1, 0.4), + global_shared_transform=None, + global_transform=None, + local_transform=None, + max_size=65536, + center_height_scale=(0, 1), + shared_global_view=False, + view_keys=("coord", "origin_coord", "color", "normal"), + ): + self.global_view_num = global_view_num + self.global_view_scale = global_view_scale + self.local_view_num = local_view_num + self.local_view_scale = local_view_scale + self.global_shared_transform = Compose(global_shared_transform) + self.global_transform = Compose(global_transform) + self.local_transform = Compose(local_transform) + self.max_size = max_size + self.center_height_scale = center_height_scale + self.shared_global_view = shared_global_view + self.view_keys = view_keys + assert "coord" in view_keys + + def get_view(self, point, center, scale): + coord = point["coord"] + max_size = min(self.max_size, coord.shape[0]) + size = int(np.random.uniform(*scale) * max_size) + index = np.argsort(np.sum(np.square(coord - center), axis=-1))[:size] + view = dict(index=index) + for key in point.keys(): + if key in self.view_keys: + view[key] = point[key][index] + + if "index_valid_keys" in point.keys(): + # inherit index_valid_keys from point + view["index_valid_keys"] = point["index_valid_keys"] + return view + + def __call__(self, data_dict): + coord = data_dict["coord"] + point = self.global_shared_transform(copy.deepcopy(data_dict)) + z_min = coord[:, 2].min() + z_max = coord[:, 2].max() + z_min_ = z_min + (z_max - z_min) * self.center_height_scale[0] + z_max_ = z_min + (z_max - z_min) * self.center_height_scale[1] + center_mask = np.logical_and(coord[:, 2] >= z_min_, coord[:, 2] <= z_max_) + # get major global view + major_center = coord[np.random.choice(np.where(center_mask)[0])] + major_view = self.get_view(point, major_center, self.global_view_scale) + major_coord = major_view["coord"] + # get global views: restrict the center of left global view within the major global view + if not self.shared_global_view: + global_views = [ + self.get_view( + point=point, + center=major_coord[np.random.randint(major_coord.shape[0])], + scale=self.global_view_scale, + ) + for _ in range(self.global_view_num - 1) + ] + else: + global_views = [ + {key: value.copy() for key, value in major_view.items()} + for _ in range(self.global_view_num - 1) + ] + + global_views = [major_view] + global_views + + # get local views: restrict the center of local view within the major global view + cover_mask = np.zeros_like(major_view["index"], dtype=bool) + local_views = [] + for i in range(self.local_view_num): + if sum(~cover_mask) == 0: + # reset cover mask if all points are sampled + cover_mask[:] = False + local_view = self.get_view( + point=data_dict, + center=major_coord[np.random.choice(np.where(~cover_mask)[0])], + scale=self.local_view_scale, + ) + local_views.append(local_view) + cover_mask[np.isin(major_view["index"], local_view["index"])] = True + + # augmentation and concat + view_dict = {} + for global_view in global_views: + global_view.pop("index") + global_view = self.global_transform(global_view) + for key in self.view_keys: + if f"global_{key}" in view_dict.keys(): + view_dict[f"global_{key}"].append(global_view[key]) + else: + view_dict[f"global_{key}"] = [global_view[key]] + view_dict["global_offset"] = np.cumsum( + [data.shape[0] for data in view_dict["global_coord"]] + ) + for local_view in local_views: + local_view.pop("index") + local_view = self.local_transform(local_view) + for key in self.view_keys: + if f"local_{key}" in view_dict.keys(): + view_dict[f"local_{key}"].append(local_view[key]) + else: + view_dict[f"local_{key}"] = [local_view[key]] + view_dict["local_offset"] = np.cumsum( + [data.shape[0] for data in view_dict["local_coord"]] + ) + for key in view_dict.keys(): + if "offset" not in key: + view_dict[key] = np.concatenate(view_dict[key], axis=0) + data_dict.update(view_dict) + return data_dict + + +@TRANSFORMS.register_module() +class InstanceParser(object): + def __init__(self, segment_ignore_index=(-1, 0, 1), instance_ignore_index=-1): + self.segment_ignore_index = segment_ignore_index + self.instance_ignore_index = instance_ignore_index + + def __call__(self, data_dict): + coord = data_dict["coord"] + segment = data_dict["segment"] + instance = data_dict["instance"] + mask = ~np.in1d(segment, self.segment_ignore_index) + # mapping ignored instance to ignore index + instance[~mask] = self.instance_ignore_index + # reorder left instance + unique, inverse = np.unique(instance[mask], return_inverse=True) + instance_num = len(unique) + instance[mask] = inverse + # init instance information + centroid = np.ones((coord.shape[0], 3)) * self.instance_ignore_index + bbox = np.ones((instance_num, 8)) * self.instance_ignore_index + vacancy = [ + index for index in self.segment_ignore_index if index >= 0 + ] # vacate class index + + for instance_id in range(instance_num): + mask_ = instance == instance_id + coord_ = coord[mask_] + bbox_min = coord_.min(0) + bbox_max = coord_.max(0) + bbox_centroid = coord_.mean(0) + bbox_center = (bbox_max + bbox_min) / 2 + bbox_size = bbox_max - bbox_min + bbox_theta = np.zeros(1, dtype=coord_.dtype) + bbox_class = np.array([segment[mask_][0]], dtype=coord_.dtype) + # shift class index to fill vacate class index caused by segment ignore index + bbox_class -= np.greater(bbox_class, vacancy).sum() + + centroid[mask_] = bbox_centroid + bbox[instance_id] = np.concatenate( + [bbox_center, bbox_size, bbox_theta, bbox_class] + ) # 3 + 3 + 1 + 1 = 8 + data_dict["instance"] = instance + data_dict["instance_centroid"] = centroid + data_dict["bbox"] = bbox + return data_dict + + +class Compose(object): + def __init__(self, cfg=None): + self.cfg = cfg if cfg is not None else [] + self.transforms = [] + for t_cfg in self.cfg: + self.transforms.append(TRANSFORMS.build(t_cfg)) + + def __call__(self, data_dict): + for t in self.transforms: + data_dict = t(data_dict) + return data_dict + + +def default(): + config = [ + dict(type="CenterShift", apply_z=True), + dict( + type="GridSample", + # grid_size=0.02, + # grid_size=0.01, + grid_size=0.005, + # grid_size=0.0025, + hash_type="fnv", + mode="train", + return_grid_coord=True, + return_inverse=True, + ), + dict(type="NormalizeColor"), + dict(type="ToTensor"), + dict( + type="Collect", + keys=("coord", "grid_coord", "color", "inverse"), + feat_keys=("coord", "color", "normal"), + ), + ] + return Compose(config) diff --git a/XPart/partgen/models/sonata/utils.py b/XPart/partgen/models/sonata/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..e268d443f57918a96f1bc97aecab2220e86a352e --- /dev/null +++ b/XPart/partgen/models/sonata/utils.py @@ -0,0 +1,75 @@ +""" +General utils + +Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) +Please cite our work if the code is helpful to you. +""" + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import random +import numpy as np +import torch +import torch.backends.cudnn as cudnn +from datetime import datetime + + +@torch.no_grad() +def offset2bincount(offset): + return torch.diff( + offset, prepend=torch.tensor([0], device=offset.device, dtype=torch.long) + ) + + +@torch.no_grad() +def bincount2offset(bincount): + return torch.cumsum(bincount, dim=0) + + +@torch.no_grad() +def offset2batch(offset): + bincount = offset2bincount(offset) + return torch.arange( + len(bincount), device=offset.device, dtype=torch.long + ).repeat_interleave(bincount) + + +@torch.no_grad() +def batch2offset(batch): + return torch.cumsum(batch.bincount(), dim=0).long() + + +def get_random_seed(): + seed = ( + os.getpid() + + int(datetime.now().strftime("%S%f")) + + int.from_bytes(os.urandom(2), "big") + ) + return seed + + +def set_seed(seed=None): + if seed is None: + seed = get_random_seed() + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + cudnn.benchmark = False + cudnn.deterministic = True + os.environ["PYTHONHASHSEED"] = str(seed) diff --git a/XPart/partgen/partformer_pipeline.py b/XPart/partgen/partformer_pipeline.py new file mode 100755 index 0000000000000000000000000000000000000000..5bfecfbff6dcff1404b5366353214b9a1af213d0 --- /dev/null +++ b/XPart/partgen/partformer_pipeline.py @@ -0,0 +1,734 @@ +import torch +from .utils.misc import logger, synchronize_timer +import inspect +from typing import List, Optional +import trimesh +import numpy as np +from tqdm import tqdm +import copy +from typing import List, Optional, Union +import os +from .utils.mesh_utils import ( + SampleMesh, + load_surface_points, + sample_bbox_points_from_trimesh, + explode_mesh, + fix_mesh, +) +from .utils.misc import ( + init_from_ckpt, + instantiate_from_config, + get_config_from_file, + smart_load_model, +) + +from diffusers.utils.torch_utils import randn_tensor +from pathlib import Path + + +@synchronize_timer("Export to trimesh") +def export_to_trimesh(mesh_output): + if isinstance(mesh_output, list): + outputs = [] + for mesh in mesh_output: + if mesh is None: + outputs.append(None) + else: + mesh.mesh_f = mesh.mesh_f[:, ::-1] + mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f) + mesh_output = fix_mesh(mesh_output) + outputs.append(mesh_output) + return outputs + else: + mesh_output.mesh_f = mesh_output.mesh_f[:, ::-1] + mesh_output = trimesh.Trimesh(mesh_output.mesh_v, mesh_output.mesh_f) + mesh_output = fix_mesh(mesh_output) + return mesh_output + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[Union[List[float], np.ndarray]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to" + " set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps`" + " does not support custom timestep schedules. Please check whether you" + " are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps`" + " does not support custom sigmas schedules. Please check whether you" + " are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class TokenAllocMixin: + + def allocate_tokens(self, bboxes, num_latents=512): + return np.array([num_latents] * bboxes.shape[0]) + + +class PartFormerPipeline(TokenAllocMixin): + + def __init__( + self, + vae, + model, + scheduler, + conditioner, + bbox_predictor=None, + verbose=False, + **kwargs, + ): + self.vae = vae + self.model = model + self.scheduler = scheduler + self.conditioner = conditioner + self.kwargs = kwargs + self.bbox_predictor = bbox_predictor + self.verbose = verbose + self.kwargs = kwargs + + @classmethod + @synchronize_timer("Hunyuan3D PartGen Pipeline Model Loading") + def from_single_file( + cls, + ckpt_path=None, + config=None, + device="cuda", + dtype=torch.float32, + use_safetensors=None, + ignore_keys=(), + **kwargs, + ): + # prepare config + if config is None: + config = get_config_from_file( + str( + Path(__file__).parent.parent + / "config" + / "partformer_full_pipeline_512_with_sonata.yaml" + ) + ) + # TODO: + if ckpt_path is None: + ckpt_path = str( + Path(__file__).parent + / "ckpts" + / "partformer_full_pipeline_512_with_sonata.ckpt" + ) + # load ckpt + if use_safetensors: + ckpt_path = ckpt_path.replace(".ckpt", ".safetensors") + if not os.path.exists(ckpt_path): + raise FileNotFoundError(f"Model file {ckpt_path} not found") + logger.info(f"Loading model from {ckpt_path}") + + if use_safetensors: + # parse safetensors + import safetensors.torch + + safetensors_ckpt = safetensors.torch.load_file(ckpt_path, device="cpu") + ckpt = {} + for key, value in safetensors_ckpt.items(): + model_name = key.split(".")[0] + new_key = key[len(model_name) + 1 :] + if model_name not in ckpt: + ckpt[model_name] = {} + ckpt[model_name][new_key] = value + else: + # ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) + ckpt = torch.load(ckpt_path, map_location="cpu") + # load model + model = instantiate_from_config(config["model"]) + # model.load_state_dict(ckpt["model"]) + init_from_ckpt(model, ckpt, prefix="model", ignore_keys=ignore_keys) + vae = instantiate_from_config(config["shapevae"]) + # vae.load_state_dict(ckpt["shapevae"], strict=False) + init_from_ckpt(vae, ckpt, prefix="shapevae", ignore_keys=ignore_keys) + if config.get("conditioner", None) is not None: + conditioner = instantiate_from_config(config["conditioner"]) + init_from_ckpt( + conditioner, ckpt, prefix="conditioner", ignore_keys=ignore_keys + ) + else: + conditioner = vae + scheduler = instantiate_from_config(config["scheduler"]) + bbox_predictor = instantiate_from_config(config.get("bbox_predictor", None)) + model_kwargs = dict( + vae=vae, + model=model, + scheduler=scheduler, + conditioner=conditioner, + bbox_predictor=bbox_predictor, # TODO: add bbox predictor + device=device, + dtype=dtype, + ) + model_kwargs.update(kwargs) + return cls(**model_kwargs) + + @classmethod + def from_pretrained( + cls, + config=None, + dtype=torch.float32, + ignore_keys=(), + device="cuda", + **kwargs, + ): + if config is None: + config = get_config_from_file( + str( + Path(__file__).parent.parent + / "config" + / "partformer_full_pipeline_512_with_sonata.yaml" + ) + ) + ckpt_path = smart_load_model( + model_path="tencent/Hunyuan3D-Part", + ) + ckpt = torch.load(os.path.join(ckpt_path, "xpart.pt"), map_location="cpu") + # load model + model = instantiate_from_config(config["model"]) + # model.load_state_dict(ckpt["model"]) + init_from_ckpt(model, ckpt, prefix="model", ignore_keys=ignore_keys) + vae = instantiate_from_config(config["shapevae"]) + # vae.load_state_dict(ckpt["shapevae"], strict=False) + init_from_ckpt(vae, ckpt, prefix="shapevae", ignore_keys=ignore_keys) + if config.get("conditioner", None) is not None: + conditioner = instantiate_from_config(config["conditioner"]) + init_from_ckpt( + conditioner, ckpt, prefix="conditioner", ignore_keys=ignore_keys + ) + else: + conditioner = vae + scheduler = instantiate_from_config(config["scheduler"]) + config["bbox_predictor"]["params"]["ckpt_path"] = os.path.join( + ckpt_path, "p3sam.ckpt" + ) + bbox_predictor = instantiate_from_config(config.get("bbox_predictor", None)) + model_kwargs = dict( + vae=vae, + model=model, + scheduler=scheduler, + conditioner=conditioner, + bbox_predictor=bbox_predictor, # TODO: add bbox predictor + device=device, + dtype=dtype, + ) + model_kwargs.update(kwargs) + return cls(**model_kwargs) + + def compile(self): + self.vae = torch.compile(self.vae) + self.model = torch.compile(self.model) + self.conditioner = torch.compile(self.conditioner) + + def to(self, device=None, dtype=None): + if dtype is not None: + self.dtype = dtype + self.vae.to(dtype=dtype) + self.model.to(dtype=dtype) + self.conditioner.to(dtype=dtype) + if device is not None: + self.device = torch.device(device) + self.vae.to(device) + self.model.to(device) + self.conditioner.to(device) + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def predict_bbox( + self, mesh: trimesh.Trimesh, scale_box=1.0, drop_normal=True, seed=42 + ): + """ + Predict the bounding box of the object surface. + Args: + obj_surface (`torch.Tensor`): [B, N, 3] + Returns: + `torch.Tensor`: [B, K, 2, 3] where K is the number of bounding boxes + """ + if self.bbox_predictor is None: + raise ValueError("bbox_predictor is not set.") + aabb, face_ids, mesh = self.bbox_predictor.predict_aabb( + mesh, post_process=True, seed=seed + ) + # aabb, face_ids, mesh = self.bbox_predictor.predict_aabb(mesh, post_process=False) + aabb = torch.from_numpy(aabb) + return aabb + + def prepare_latents( + self, batch_size, latent_shape, dtype, device, generator, latents=None + ): + # prepare latents for different parts + shape = (batch_size, *latent_shape) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but" + f" requested an effective batch size of {batch_size}. Make sure the" + " batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype + ) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * getattr(self.scheduler, "init_noise_sigma", 1.0) + return latents + + @synchronize_timer("Encode cond") + def encode_cond( + self, + part_surface_inbbox, + object_surface, + do_classifier_free_guidance, + ): + bsz = object_surface.shape[0] + cond = self.conditioner(part_surface_inbbox, object_surface) + + if do_classifier_free_guidance: + # TODO: do_classifier_free_guidance, un_cond + un_cond = {k: torch.zeros_like(v) for k, v in cond.items()} + + def cat_recursive(a, b): + if isinstance(a, torch.Tensor): + return torch.cat([a, b], dim=0).to(self.dtype) + out = {} + for k in a.keys(): + out[k] = cat_recursive(a[k], b[k]) + return out + + cond = cat_recursive(cond, un_cond) + return cond + + def normalize_mesh(self, mesh): + vertices = mesh.vertices + min_xyz = np.min(vertices, axis=0) + max_xyz = np.max(vertices, axis=0) + center = (min_xyz + max_xyz) / 2.0 + # scale = np.max(np.linalg.norm(vertices - center, axis=1)) + scale = np.max(max_xyz - min_xyz) / 2 / 0.8 + vertices = (vertices - center) / scale + mesh.vertices = vertices + return mesh, center, scale + + def check_inputs( + self, + obj_surface=None, + obj_surface_raw=None, + mesh_path=None, + mesh=None, + aabb=None, + part_surface_inbbox=None, + seed=42, + ): + """ + Check the inputs of the pipeline. + Args: + obj_surface (`torch.Tensor`): [B, N, 3+3+1] + mesh_path (`str`): path to the mesh file + mesh (`trimesh.Trimesh`): mesh object + aabb (`torch.Tensor`): [B, K, 2, 3] + part_surface_inbbox (`torch.Tensor`): [B, K,N, 3+3+1] + """ + if obj_surface is None: + if mesh_path is None and (mesh is None and obj_surface_raw is None): + raise ValueError( + "obj_surface or mesh_path/mesh/obj_surface_raw must be provided." + ) + elif aabb is None or part_surface_inbbox is None: + raise ValueError( + "aabb and part_surface_inbbox must be provided if obj_surface is" + " provided." + ) + else: + assert aabb.shape[0] == part_surface_inbbox.shape[0], "Batch size mismatch." + center = np.zeros(3) + scale = 1.0 + # 1. Load object surface and sample + if obj_surface is None: + if obj_surface_raw is None: + if mesh is not None: + obj_surface_raw = SampleMesh( + mesh.vertices, mesh.faces, -1, seed=seed + ) + elif mesh_path is not None: + mesh = trimesh.load(mesh_path, force="mesh") + mesh, center, scale = self.normalize_mesh(mesh) + print(f"Normalize mesh: {center}, {scale}") + obj_surface_raw = SampleMesh( + mesh.vertices, mesh.faces, -1, seed=seed + ) + else: + raise ValueError("obj_surface or mesh_path/mesh must be provided.") + rng = np.random.default_rng(seed=seed) + obj_surface, _ = load_surface_points( + rng, + obj_surface_raw["random_surface"], + obj_surface_raw["sharp_surface"], + pc_size=81920, + pc_sharpedge_size=0, + return_sharpedge_label=True, + return_normal=True, + ) + obj_surface = obj_surface.unsqueeze(0) + # 2. load aabb + if aabb is None: + aabb = self.predict_bbox(mesh, seed=seed) + print(f"Get bbox from bbox_predictor: {aabb.shape}") + else: + if isinstance(aabb, np.ndarray): + aabb = torch.from_numpy(aabb) + # normalize aabb by mesh scale and center + aabb = aabb.float() + aabb = (aabb - torch.from_numpy(center).float()) / scale + + # 3. load part surface in bbox + if part_surface_inbbox is None: + part_surface_inbbox, valid_parts_mask = sample_bbox_points_from_trimesh( + mesh, aabb, num_points=81920, seed=seed + ) + aabb = aabb[valid_parts_mask] + aabb = aabb.unsqueeze(0) + part_surface_inbbox = part_surface_inbbox.unsqueeze(0) + return ( + obj_surface, + aabb, + part_surface_inbbox, + mesh, + center, + scale, + ) + + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + def _export( + self, + latents, + output_type="trimesh", + box_v=1.01, + mc_level=0.0, + num_chunks=20000, + octree_resolution=256, + mc_algo="mc", + enable_pbar=True, + **kwargs, + ): + if not output_type == "latent": + latents = 1.0 / self.vae.scale_factor * latents + latents = self.vae(latents) + outputs = self.vae.latent2mesh_2( + # outputs = self.vae.latents2mesh( + latents, + bounds=box_v, + mc_level=mc_level, + octree_depth=8, + num_chunks=num_chunks, + octree_resolution=octree_resolution, + mc_mode=mc_algo, + # enable_pbar=enable_pbar, + **kwargs, + ) + else: + outputs = latents + + if output_type == "trimesh": + outputs = export_to_trimesh(outputs) + + return outputs + + @torch.no_grad() + @torch.autocast("cuda", dtype=torch.bfloat16) + def __call__( + self, + obj_surface=None, + obj_surface_raw=None, + mesh_path=None, + mesh=None, + aabb=None, + part_surface_inbbox=None, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + eta: float = 0.0, + # guidance_scale: float = 7.5, + guidance_scale: float = -1.0, + dual_guidance_scale: float = 10.5, + dual_guidance: bool = True, + generator=None, + seed=42, + # marching cubes + box_v=1.01, + octree_resolution=512, + mc_level=-1 / 512, + num_chunks=400000, + mc_algo="mc", + output_type: Optional[str] = "trimesh", + enable_pbar=True, + **kwargs, + ): + """ + Args: + obj_surface (`torch.Tensor`): [B, N, 3+3+1] + aabb (`torch.Tensor`): [B, K, 2, 3] + part_surface_inbbox (`torch.Tensor`): [B, K,N, 3+3+1] + Returns: + `trimesh.Scene` : single object composed of multiple parts + """ + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + do_classifier_free_guidance = guidance_scale >= 0 and not ( + hasattr(self.model, "guidance_embed") and self.model.guidance_embed is True + ) + # 1. Check inputs and predict bbox if not provided + device = self.device + dtype = self.dtype + obj_surface, aabb, part_surface_inbbox, mesh, center, scale = self.check_inputs( + obj_surface, + obj_surface_raw, + mesh_path, + mesh, + aabb, + part_surface_inbbox, + seed=seed, + ) + if self.verbose: + # return gt mesh with bbox + mesh_bbox = trimesh.Scene() + if mesh is not None: + mesh_bbox.add_geometry(mesh) + else: + mesh = trimesh.points.PointCloud( + obj_surface[0, :, :3].float().cpu().numpy() + ) + mesh_bbox.add_geometry(mesh) + for bbox in aabb[0]: + box = trimesh.path.creation.box_outline() + box.vertices *= (bbox[1] - bbox[0]).float().cpu().numpy() + box.vertices += (bbox[0] + bbox[1]).float().cpu().numpy() / 2 + mesh_bbox.add_geometry(box) + # Convert to device and dtype + obj_surface = obj_surface.to(device=device, dtype=dtype) + aabb = aabb.to(device=device, dtype=dtype) + part_surface_inbbox = part_surface_inbbox.to(device=device, dtype=dtype) + batch_size, num_parts, N, dim = part_surface_inbbox.shape + # TODO: support batch size > 1 + assert batch_size == 1, "Batch size > 1 is not supported yet." + # 2. Prepare latent variables + # TODO:allocate tokens for each parts + num_tokens = torch.tensor( + [self.allocate_tokens(x, self.vae.latent_shape[0]) for x in aabb], + device=device, + ) + latent_shape = self.vae.latent_shape + latents = self.prepare_latents( + num_parts, latent_shape, dtype, device, generator + ) + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # 3. condition + cond = self.encode_cond( + part_surface_inbbox.reshape(batch_size * num_parts, N, dim), + obj_surface.expand(batch_size * num_parts, -1, -1), + do_classifier_free_guidance, + ) + # 4. guidance_cond for controling sampling + guidance_cond = None + if getattr(self.model, "guidance_cond_proj_dim", None) is not None: + logger.info("Using lcm guidance scale") + guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(batch_size) + guidance_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.model.guidance_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 5. Prepare timesteps + # NOTE: this is slightly different from common usage, we start from 0. + sigmas = np.linspace(0, 1, num_inference_steps) if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + ) + + torch.cuda.empty_cache() + + # 6. Denoising loop + with synchronize_timer("Diffusion Sampling"): + for i, t in enumerate( + tqdm(timesteps, disable=not enable_pbar, desc="Diffusion Sampling:") + ): + # expand the latents if we are doing classifier free guidance + if do_classifier_free_guidance: + latent_model_input = torch.cat([latents] * 2) + aabb = torch.repeat_interleave(aabb, 2, dim=0) + else: + latent_model_input = latents + + # NOTE: we assume model get timesteps ranged from 0 to 1 + timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + timestep = timestep / self.scheduler.config.num_train_timesteps + noise_pred = self.model( + latent_model_input, + timestep, + cond, + aabb=aabb, + num_tokens=num_tokens, + guidance_cond=guidance_cond, + ) + + if do_classifier_free_guidance: + noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_cond - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + outputs = self.scheduler.step(noise_pred, t, latents) + latents = outputs.prev_sample + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, outputs) + + # latents2mesh + # part_latents = torch.split(latents, num_tokens[0].tolist(), dim=1) + out = trimesh.Scene() + for i, part_latent in enumerate(latents): + try: + part_mesh = self._export( + latents=part_latent.unsqueeze(0), + output_type=output_type, + box_v=box_v, + mc_level=mc_level, + num_chunks=num_chunks, + octree_resolution=octree_resolution, + mc_algo=mc_algo, + enable_pbar=enable_pbar, + )[0] + out.add_geometry(part_mesh) + random_color = np.random.randint(0, 255, size=3) + part_mesh.visual.face_colors = random_color + except Exception as e: + logger.error(f"Failed to export part {i} with error {e}") + print(f"Denormalize mesh: {center}, {scale}") + for key in out.geometry.keys(): + _v = out.geometry[key].vertices + _v = _v * scale + center + out.geometry[key].vertices = _v + + if self.verbose: + explode_object = explode_mesh(copy.deepcopy(out), explosion_scale=0.2) + # add bbox into out + out_bbox = trimesh.Scene() + out_bbox.add_geometry(out) + for bbox in aabb[0]: + box = trimesh.path.creation.box_outline() + box.vertices *= (bbox[1] - bbox[0]).float().cpu().numpy() + box.vertices += (bbox[0] + bbox[1]).float().cpu().numpy() / 2 + box.vertices = box.vertices * scale + center + out_bbox.add_geometry(box) + return out, (out_bbox, mesh_bbox, explode_object) + else: + # return only the generated mesh + return out, None diff --git a/XPart/partgen/utils/mesh_utils.py b/XPart/partgen/utils/mesh_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..315f122c822607aef3117259c0eb2cf2b42d4140 --- /dev/null +++ b/XPart/partgen/utils/mesh_utils.py @@ -0,0 +1,794 @@ +import numpy as np +import trimesh +import torch +import torch.nn.functional as F +from skimage import measure +from typing import Callable, Tuple, List, Union +from torch import nn +from tqdm import tqdm +from einops import repeat +import traceback +import pymeshlab +import tempfile + + +def random_sample_pointcloud(mesh: trimesh.Trimesh, num=30000, seed=42): + # points, face_idx = mesh.sample(num, return_index=True) + points, face_idx = trimesh.sample.sample_surface(mesh, num, seed=seed) + normals = mesh.face_normals[face_idx] + rng = np.random.default_rng(seed=seed) + index = rng.choice(num, num, replace=False) + return points[index], normals[index] + + +def sharp_sample_pointcloud(mesh: trimesh.Trimesh, num=16384): + V = mesh.vertices + N = mesh.face_normals + VN = mesh.vertex_normals + F = mesh.faces + VN2 = np.ones(V.shape[0]) + for i in range(3): + dot = np.stack((VN2[F[:, i]], np.sum(VN[F[:, i]] * N, axis=-1)), axis=-1) + VN2[F[:, i]] = np.min(dot, axis=-1) + + sharp_mask = VN2 < 0.985 + # collect edge + edge_a = np.concatenate((F[:, 0], F[:, 1], F[:, 2])) + edge_b = np.concatenate((F[:, 1], F[:, 2], F[:, 0])) + sharp_edge = sharp_mask[edge_a] * sharp_mask[edge_b] + edge_a = edge_a[sharp_edge > 0] + edge_b = edge_b[sharp_edge > 0] + + sharp_verts_a = V[edge_a] + sharp_verts_b = V[edge_b] + sharp_verts_an = VN[edge_a] + sharp_verts_bn = VN[edge_b] + + weights = np.linalg.norm(sharp_verts_b - sharp_verts_a, axis=-1) + weights /= np.sum(weights) + + random_number = np.random.rand(num) + w = np.random.rand(num, 1) + index = np.searchsorted(weights.cumsum(), random_number) + samples = w * sharp_verts_a[index] + (1 - w) * sharp_verts_b[index] + normals = w * sharp_verts_an[index] + (1 - w) * sharp_verts_bn[index] + return samples, normals + + +def SampleMesh(V, F, origin_num, seed=42): + """Sample a mesh to get random points and normals. + Args: + V (np.ndarray): Vertices of the mesh. + F (np.ndarray): Faces of the mesh. + origin_num (int): Number of original faces to sample from. + Returns: + surface_data (dict): Dictionary containing sampled points and normals. + The dictionary contains: + - "random_surface": Sampled points and normals from the mesh. + - "random_surface_fill": Boolean array indicating whether the points are from the fill region. + - "sharp_surface": Sampled points and normals from the sharp edges of the mesh. + """ + + mesh = trimesh.Trimesh(vertices=V, faces=F[:origin_num]) + mesh_fill = trimesh.Trimesh(vertices=V, faces=F[origin_num:]) + + area = mesh.area + area_fill = mesh_fill.area + sample_num = 499712 // 2 + num_fill = int(sample_num * (area_fill / (area + area_fill))) + num = sample_num - num_fill + # if not mesh.is_watertight: + # raise ValueError + random_surface, random_normal = random_sample_pointcloud(mesh, num=num, seed=seed) + if num_fill == 0: + random_surface_fill, random_normal_fill = np.zeros((0, 3)), np.zeros((0, 3)) + else: + random_surface_fill, random_normal_fill = random_sample_pointcloud( + mesh_fill, num=num_fill, seed=seed + ) + random_sharp_surface, sharp_normal = sharp_sample_pointcloud(mesh, num=sample_num) + + # save_surface + surface = np.concatenate((random_surface, random_normal), axis=1).astype(np.float16) + surface_fill = np.concatenate( + (random_surface_fill, random_normal_fill), axis=1 + ).astype(np.float16) + sharp_surface = np.concatenate((random_sharp_surface, sharp_normal), axis=1).astype( + np.float16 + ) + + a, b = np.ones(num), np.zeros(num_fill) + + surface_data = { + "random_surface": np.concatenate((surface, surface_fill), axis=0), + "random_surface_fill": np.concatenate((a, b)).astype(bool), + "sharp_surface": sharp_surface, + } + + return surface_data + + +def load_surface_points( + rng, + random_surface, + sharpedge_surface, + pc_size, + pc_sharpedge_size, + return_sharpedge_label=True, + return_normal=True, +): + """ + sample surface points based on pc_size and pc_sharpedge_size + Args: + rng: Random number generator + random_surface: Array of random surface points + sharpedge_surface: Array of sharp edge surface points + Returns: + surface: Array of surface points and normals + geo_points: Array of geo points + """ + + surface_normal = [] + if pc_size > 0: + ind = rng.choice(random_surface.shape[0], pc_size, replace=False) + random_surface = random_surface[ind] + if return_sharpedge_label: + sharpedge_label = np.zeros((pc_size, 1)) + random_surface = np.concatenate((random_surface, sharpedge_label), axis=1) + surface_normal.append(random_surface) + + if pc_sharpedge_size > 0: + ind_sharpedge = rng.choice( + sharpedge_surface.shape[0], pc_sharpedge_size, replace=False + ) + sharpedge_surface = sharpedge_surface[ind_sharpedge] + if return_sharpedge_label: + sharpedge_label = np.ones((pc_sharpedge_size, 1)) + sharpedge_surface = np.concatenate( + (sharpedge_surface, sharpedge_label), axis=1 + ) + surface_normal.append(sharpedge_surface) + + surface_normal = np.concatenate(surface_normal, axis=0) + surface_normal = torch.FloatTensor(surface_normal) + surface = surface_normal[:, 0:3] + normal = surface_normal[:, 3:6] + assert surface.shape[0] == pc_size + pc_sharpedge_size + + geo_points = 0.0 + normal = torch.nn.functional.normalize(normal, p=2, dim=1) + if return_normal: + surface = torch.cat([surface, normal], dim=-1) + if return_sharpedge_label: + surface = torch.cat([surface, surface_normal[:, -1:]], dim=-1) + return surface, geo_points + + +def sample_bbox_points_from_trimesh(mesh, aabb, num_points, seed=42): + _faces = mesh.faces + _vertices = mesh.vertices + _faces = np.reshape(_faces, (-1)) + num_parts = aabb.shape[0] + _points = _points = torch.from_numpy(_vertices[_faces]) + _part_mask = torch.all( + (_points[None, :, :3] >= aabb[:, :1]) & (_points[None, :, :3] <= aabb[:, 1:]), + dim=-1, + ) + _part_mask = torch.any(torch.reshape(_part_mask, (num_parts, -1, 3)), dim=-1) + faces_idx_in_bbox = [torch.nonzero(x).squeeze(-1).numpy() for x in _part_mask] + # in case some parts are empty(inside surface) + valid_parts_mask = torch.tensor( + [len(x) > 0 for x in faces_idx_in_bbox], dtype=torch.bool, device=_points.device + ) + aabb = aabb[valid_parts_mask] + # print(len(faces_idx_in_bbox), len(aabb)) + faces_idx_in_bbox = [x for x in faces_idx_in_bbox if len(x) > 0] + num_valid_parts = len(faces_idx_in_bbox) + # process valid parts + mesh_in_bbox = mesh.submesh(faces_idx_in_bbox, append=False) + points, normals = [], [] + for part in mesh_in_bbox: + # part_points, face_idx = part.sample(num_points, return_index=True) + part_points, face_idx = trimesh.sample.sample_surface( + part, num_points, seed=seed + ) + part_normals = part.face_normals[face_idx] + points.append(torch.from_numpy(part_points)) + normals.append(torch.from_numpy(part_normals)) + out = torch.concat( + [torch.stack(points, dim=0), torch.stack(normals, dim=0)], dim=-1 + ) + out = torch.concat( + [ + out, + torch.zeros( + [num_valid_parts, num_points, 1], dtype=out.dtype, device=out.device + ), + ], + dim=-1, + ) # add sharp edge label + return out, valid_parts_mask + + +def sample_surface_inbbox( + rng, + object_surface_raw, + aabb, + pc_size_bbox, + return_normal=True, + return_sharpedge_label=True, +): + """ + Sample surface points within the bounding box defined by aabb. + Args: + object_surface_raw: Raw surface points from the object + aabb: [K,2,3] Axis-aligned bounding box defined by min and max corners + pc_size_bbox: Number of points to sample within the bounding box + Returns: + part_surface_inbbox: Sampled surface points within the bounding box + """ + num_parts = aabb.shape[0] + object_all_surface = torch.from_numpy( + np.concatenate( + [ + object_surface_raw["random_surface"], + object_surface_raw["sharp_surface"], + ], + axis=0, + ) + ) # [N,6] + sharpedge_labels = torch.concat( + [ + torch.zeros(object_surface_raw["random_surface"].shape[0], 1), + torch.ones(object_surface_raw["sharp_surface"].shape[0], 1), + ], + dim=0, + ) + sampled_masks = torch.all( + (object_all_surface[None, :, :3] >= aabb[:, :1]) + & (object_all_surface[None, :, :3] <= aabb[:, 1:]), + dim=-1, + ) + surfaces = [] + valid_index = [] + for idx, sampled_mask in enumerate(sampled_masks): + part_surface_inbbox = object_all_surface[sampled_mask] + sharpedge_label = sharpedge_labels[sampled_mask] + # TODO: drop inside parts + if part_surface_inbbox.shape[0] == 0: + continue + try: + ind = rng.choice(part_surface_inbbox.shape[0], pc_size_bbox, replace=False) + except ValueError: + ind = np.arange(part_surface_inbbox.shape[0]) + ind = np.concatenate([ + ind, + rng.choice( + part_surface_inbbox.shape[0], + pc_size_bbox - part_surface_inbbox.shape[0], + replace=True, + ), + ]) + part_surface_inbbox = part_surface_inbbox[ind] + sharpedge_label = sharpedge_label[ind] + # point feat + surface = part_surface_inbbox[:, 0:3] + normal = part_surface_inbbox[:, 3:6] + # TODO: check normal + # normal = torch.nn.functional.normalize(normal, p=2, dim=1) + if return_normal: + surface = torch.cat([surface, normal], dim=-1) + if return_sharpedge_label: + surface = torch.cat( + [surface, sharpedge_label], + dim=-1, + ) + surfaces.append(surface) + valid_index.append(idx) + surface = torch.stack(surfaces, dim=0) + return surface, torch.tensor(valid_index) + + +def explode_mesh(mesh, explosion_scale=0.4): + + if isinstance(mesh, trimesh.Scene): + scene = mesh + elif isinstance(mesh, trimesh.Trimesh): + print("Warning: Single mesh provided, can't create exploded view") + scene = trimesh.Scene(mesh) + return scene + else: + print(f"Warning: Unexpected mesh type: {type(mesh)}") + scene = mesh + + if len(scene.geometry) <= 1: + print("Only one geometry found - nothing to explode") + return scene + + print(f"[EXPLODE_MESH] Starting mesh explosion with scale {explosion_scale}") + print(f"[EXPLODE_MESH] Processing {len(scene.geometry)} parts") + + exploded_scene = trimesh.Scene() + + part_centers = [] + geometry_names = [] + + for geometry_name, geometry in scene.geometry.items(): + if hasattr(geometry, "vertices"): + transform = scene.graph[geometry_name][0] + vertices_global = trimesh.transformations.transform_points( + geometry.vertices, transform + ) + center = np.mean(vertices_global, axis=0) + part_centers.append(center) + geometry_names.append(geometry_name) + print(f"[EXPLODE_MESH] Part {geometry_name}: center = {center}") + + if not part_centers: + print("No valid geometries with vertices found") + return scene + + part_centers = np.array(part_centers) + global_center = np.mean(part_centers, axis=0) + + print(f"[EXPLODE_MESH] Global center: {global_center}") + + for i, (geometry_name, geometry) in enumerate(scene.geometry.items()): + if hasattr(geometry, "vertices"): + if i < len(part_centers): + part_center = part_centers[i] + direction = part_center - global_center + + direction_norm = np.linalg.norm(direction) + if direction_norm > 1e-6: + direction = direction / direction_norm + else: + direction = np.random.randn(3) + direction = direction / np.linalg.norm(direction) + + offset = direction * explosion_scale + else: + offset = np.zeros(3) + + original_transform = scene.graph[geometry_name][0].copy() + + new_transform = original_transform.copy() + new_transform[:3, 3] = new_transform[:3, 3] + offset + + exploded_scene.add_geometry( + geometry, transform=new_transform, geom_name=geometry_name + ) + + print( + f"[EXPLODE_MESH] Part {geometry_name}: moved by" + f" {np.linalg.norm(offset):.4f}" + ) + + print("[EXPLODE_MESH] Mesh explosion complete") + return exploded_scene + + +def generate_dense_grid_points( + bbox_min: np.ndarray, + bbox_max: np.ndarray, + octree_depth: int = 7, + indexing: str = "ij", + octree_resolution: int = None, +): + length = bbox_max - bbox_min + num_cells = octree_resolution + if octree_resolution is None: + length = bbox_max - bbox_min + num_cells = np.exp2(octree_depth) + + x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) + y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) + z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) + [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) + xyz = np.stack((xs, ys, zs), axis=-1) + xyz = xyz.reshape(-1, 3) + grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] + + return xyz, grid_size, length + + +def extract_near_surface_volume_fn(input_tensor: torch.Tensor, alpha: float): + """ + 修复维度问题的PyTorch实现 + Args: + input_tensor: shape [D, D, D], torch.float16 + alpha: 标量偏移值 + Returns: + mask: shape [D, D, D], torch.int32 表面掩码 + """ + device = input_tensor.device + D = input_tensor.shape[0] + signed_val = 0.0 + + # 添加偏移并处理无效值 + val = input_tensor + alpha + valid_mask = val > -9000 # 假设-9000是无效值 + + # 改进的邻居获取函数(保持维度一致) + def get_neighbor(t, shift, axis): + """根据指定轴进行位移并保持维度一致""" + if shift == 0: + return t.clone() + + # 确定填充轴(输入为[D, D, D]对应z,y,x轴) + pad_dims = [0, 0, 0, 0, 0, 0] # 格式:[x前,x后,y前,y后,z前,z后] + + # 根据轴类型设置填充 + if axis == 0: # x轴(最后一个维度) + pad_idx = 0 if shift > 0 else 1 + pad_dims[pad_idx] = abs(shift) + elif axis == 1: # y轴(中间维度) + pad_idx = 2 if shift > 0 else 3 + pad_dims[pad_idx] = abs(shift) + elif axis == 2: # z轴(第一个维度) + pad_idx = 4 if shift > 0 else 5 + pad_dims[pad_idx] = abs(shift) + + # 执行填充(添加batch和channel维度适配F.pad) + padded = F.pad( + t.unsqueeze(0).unsqueeze(0), pad_dims[::-1], mode="replicate" + ) # 反转顺序适配F.pad + + # 构建动态切片索引 + slice_dims = [slice(None)] * 3 # 初始化为全切片 + if axis == 0: # x轴(dim=2) + if shift > 0: + slice_dims[0] = slice(shift, None) + else: + slice_dims[0] = slice(None, shift) + elif axis == 1: # y轴(dim=1) + if shift > 0: + slice_dims[1] = slice(shift, None) + else: + slice_dims[1] = slice(None, shift) + elif axis == 2: # z轴(dim=0) + if shift > 0: + slice_dims[2] = slice(shift, None) + else: + slice_dims[2] = slice(None, shift) + + # 应用切片并恢复维度 + padded = padded.squeeze(0).squeeze(0) + sliced = padded[slice_dims] + return sliced + + # 获取各方向邻居(确保维度一致) + left = get_neighbor(val, 1, axis=0) # x方向 + right = get_neighbor(val, -1, axis=0) + back = get_neighbor(val, 1, axis=1) # y方向 + front = get_neighbor(val, -1, axis=1) + down = get_neighbor(val, 1, axis=2) # z方向 + up = get_neighbor(val, -1, axis=2) + + # 处理边界无效值(使用where保持维度一致) + def safe_where(neighbor): + return torch.where(neighbor > -9000, neighbor, val) + + left = safe_where(left) + right = safe_where(right) + back = safe_where(back) + front = safe_where(front) + down = safe_where(down) + up = safe_where(up) + + # 计算符号一致性(转换为float32确保精度) + sign = torch.sign(val.to(torch.float32)) + neighbors_sign = torch.stack( + [ + torch.sign(left.to(torch.float32)), + torch.sign(right.to(torch.float32)), + torch.sign(back.to(torch.float32)), + torch.sign(front.to(torch.float32)), + torch.sign(down.to(torch.float32)), + torch.sign(up.to(torch.float32)), + ], + dim=0, + ) + + # 检查所有符号是否一致 + same_sign = torch.all(neighbors_sign == sign, dim=0) + + # 生成最终掩码 + mask = (~same_sign).to(torch.int32) + return mask * valid_mask.to(torch.int32) + + +@torch.no_grad() +def extract_geometry_fast( + geometric_func: Callable, + device: torch.device, + batch_size: int = 1, + bounds: Union[Tuple[float], List[float], float] = ( + -1.25, + -1.25, + -1.25, + 1.25, + 1.25, + 1.25, + ), + octree_depth: int = 7, + num_chunks: int = 10000, + disable: bool = True, + mc_level: float = -1 / 512, + octree_resolution: int = None, + diffdmc=None, + rotation_matrix=None, + mc_mode="mc", + dtype=torch.float16, + min_resolution: int = 95, +): + """ + + Args: + geometric_func: + device: + bounds: + octree_depth: + batch_size: + num_chunks: + disable: + + Returns: + + """ + + if isinstance(bounds, float): + bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] + if octree_resolution is None: + octree_resolution = 2**octree_depth + + assert ( + octree_resolution >= 256 + ), "octree resolution must be at least 256 for fast inference" + + resolutions = [] + if octree_resolution < min_resolution: + resolutions.append(octree_resolution) + while octree_resolution >= min_resolution: + resolutions.append(octree_resolution) + octree_resolution = octree_resolution // 2 + resolutions.reverse() + bbox_min = np.array(bounds[0:3]) + bbox_max = np.array(bounds[3:6]) + bbox_size = bbox_max - bbox_min + + dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype) + dilate.weight = torch.nn.Parameter( + torch.ones(dilate.weight.shape, dtype=dtype, device=device) + ) + + xyz_samples, grid_size, length = generate_dense_grid_points( + bbox_min=bbox_min, + bbox_max=bbox_max, + octree_resolution=resolutions[0], + indexing="ij", + ) + + grid_size = np.array(grid_size) + xyz_samples = torch.FloatTensor(xyz_samples).to(device).half() + + if mc_level == -1: + print( + f"Training with soft labels, inference with sigmoid and marching cubes" + f" level 0." + ) + elif mc_level == 0: + print(f"VAE Trained with TSDF, inference with marching cubes level 0.") + else: + print( + "VAE Trained with Occupancy, inference with marching cubes level" + f" {mc_level}." + ) + batch_logits = [] + for start in tqdm( + range(0, xyz_samples.shape[0], num_chunks), + desc=f"MC Level {mc_level} Implicit Function:", + disable=disable, + leave=False, + ): + queries = xyz_samples[start : start + num_chunks, :] + batch_queries = repeat(queries, "p c -> b p c", b=batch_size) + logits = geometric_func(batch_queries) + if mc_level == -1: + mc_level = 0 + print( + f"Training with soft labels, inference with sigmoid and marching cubes" + f" level 0." + ) + logits = torch.sigmoid(logits) * 2 - 1 + batch_logits.append(logits) + + grid_logits = ( + torch.cat(batch_logits, dim=1) + .view((batch_size, grid_size[0], grid_size[1], grid_size[2])) + .half() + ) + + for octree_depth_now in resolutions[1:]: + grid_size = np.array([octree_depth_now + 1] * 3) + resolution = bbox_size / octree_depth_now + next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device) + if octree_depth_now == resolutions[-1]: + next_logits = torch.full( + next_index.shape, float("nan"), dtype=dtype, device=device + ) + else: + next_logits = torch.full( + next_index.shape, -10000.0, dtype=dtype, device=device + ) + + FN = extract_near_surface_volume_fn + curr_points = FN(grid_logits.squeeze(0), mc_level) + curr_points += grid_logits.squeeze(0).abs() < min( + 0.95, 0.95 * 128 * 4 / octree_depth_now + ) + if octree_depth_now > 510: + expand_num = 0 + else: + expand_num = 1 + for i in range(expand_num): + curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0) + (cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0) + next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1 + for i in range(1): + next_index = dilate(next_index.unsqueeze(0)).squeeze(0) + nidx = torch.where(next_index > 0) + next_points = torch.stack(nidx, dim=1) + next_points = next_points * torch.tensor( + resolution, device=device + ) + torch.tensor(bbox_min, device=device) + batch_logits = [] + for start in tqdm( + range(0, next_points.shape[0], num_chunks), + desc=f"MC Level {octree_depth_now + 1} Implicit Function:", + disable=disable, + leave=False, + ): + queries = next_points[start : start + num_chunks, :] + batch_queries = repeat(queries, "p c -> b p c", b=batch_size) + logits = geometric_func(batch_queries) + if mc_level == -1: + mc_level = 0 + print( + f"Training with soft labels, inference with sigmoid and marching" + f" cubes level 0." + ) + logits = torch.sigmoid(logits) * 2 - 1 + batch_logits.append(logits) + grid_logits = torch.cat(batch_logits, dim=1).half() + next_logits[nidx] = grid_logits[0] + grid_logits = next_logits.unsqueeze(0) + # s_mc = time.time() + mesh_v_f = [] + has_surface = np.zeros((batch_size,), dtype=np.bool_) + for i in range(batch_size): + try: + if mc_mode == "mc": + if len(resolutions) > 1: + mask = (next_index > 0).cpu().numpy() + grid_logits = grid_logits.cpu().numpy() + vertices, faces, normals, _ = measure.marching_cubes( + grid_logits[i], mc_level, method="lewiner", mask=mask + ) + else: + vertices, faces, normals, _ = measure.marching_cubes( + grid_logits[i].cpu().numpy(), mc_level, method="lewiner" + ) + vertices = vertices / (grid_size - 1) * bbox_size + bbox_min + # vertices[:, [0, 1]] = vertices[:, [1, 0]] + elif mc_mode == "dmc": + torch.cuda.empty_cache() + grid_logits = -grid_logits[i] + grid_logits = grid_logits.to(torch.float32).contiguous() + verts, faces = diffdmc( + grid_logits, deform=None, return_quads=False, normalize=False + ) + verts = verts * torch.tensor(resolution, device=device) + torch.tensor( + bbox_min, device=device + ) + vertices = verts.detach().cpu().numpy() + faces = faces.detach().cpu().numpy()[:, ::-1] + elif mc_mode == "odc": + # https://github.com/KAIST-Visual-AI-Group/ODC + from .occupancy_dual_contouring import occupancy_dual_contouring + import torch.nn.functional as F + + odc = occupancy_dual_contouring("cuda") + + size = grid_logits.shape[-1] + grid_logits = grid_logits.reshape(1, 1, size, size, size) + + def implicit_function(xyz): + xyz = xyz.reshape(1, -1, 1, 1, 3).float() + # print(grid_logits.dtype, xyz.dtype) + outputs = F.grid_sample(grid_logits.float(), xyz) + outputs = -outputs.reshape(-1) + return outputs + + num_cells = ( + octree_resolution + if octree_resolution is not None + else np.exp2(octree_depth) + ) + vertices, triangles = odc.extract_mesh( + implicit_function, + isolevel=mc_level, + min_coord=bbox_min, + max_coord=bbox_max, + num_grid=1024, + ) + vertices = vertices.detach().cpu().numpy() + faces = triangles.detach().cpu().numpy()[:, ::-1] + else: + raise ValueError(f"Unknown marching cubes mode: {mc_mode}") + mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces))) + has_surface[i] = True + + except ValueError: + traceback.print_exc() + mesh_v_f.append((None, None)) + has_surface[i] = False + + except RuntimeError: + traceback.print_exc() + mesh_v_f.append((None, None)) + has_surface[i] = False + return mesh_v_f, has_surface + + +def pymeshlab2trimesh(mesh: pymeshlab.MeshSet): + with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as temp_file: + mesh.save_current_mesh(temp_file.name) + mesh = trimesh.load(temp_file.name) + # 检查加载的对象类型 + if isinstance(mesh, trimesh.Scene): + combined_mesh = trimesh.Trimesh() + # 如果是Scene,遍历所有的geometry并合并 + for geom in mesh.geometry.values(): + combined_mesh = trimesh.util.concatenate([combined_mesh, geom]) + mesh = combined_mesh + return mesh + + +def trimesh2pymeshlab(mesh: trimesh.Trimesh): + with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as temp_file: + if isinstance(mesh, trimesh.scene.Scene): + for idx, obj in enumerate(mesh.geometry.values()): + if idx == 0: + temp_mesh = obj + else: + temp_mesh = temp_mesh + obj + mesh = temp_mesh + mesh.export(temp_file.name) + mesh = pymeshlab.MeshSet() + mesh.load_new_mesh(temp_file.name) + return mesh + + +def remove_overlength_edge(mesh: pymeshlab.MeshSet, max_length: float): + mesh.apply_filter("compute_selection_by_edge_length", threshold=max_length) + mesh.apply_filter("compute_selection_transfer_face_to_vertex", inclusive=False) + mesh.apply_filter("meshing_remove_selected_vertices_and_faces") + return mesh + + +def remove_floater(mesh: pymeshlab.MeshSet): + mesh.apply_filter( + "compute_selection_by_small_disconnected_components_per_face", nbfaceratio=0.005 + ) + mesh.apply_filter("compute_selection_transfer_face_to_vertex", inclusive=False) + mesh.apply_filter("meshing_remove_selected_vertices_and_faces") + return mesh + + +def fix_mesh(mesh: trimesh.Trimesh): + ms = trimesh2pymeshlab(mesh) + ms = remove_overlength_edge(ms, max_length=8 / 512) + ms = remove_floater(ms) + mesh = pymeshlab2trimesh(ms) + return mesh diff --git a/XPart/partgen/utils/misc.py b/XPart/partgen/utils/misc.py new file mode 100755 index 0000000000000000000000000000000000000000..6c85aa18141e7f64f316048141fdb02dacb69f78 --- /dev/null +++ b/XPart/partgen/utils/misc.py @@ -0,0 +1,206 @@ +import os +import torch +import logging +import importlib +from typing import Union +from functools import wraps + +from omegaconf import OmegaConf, DictConfig, ListConfig + + +def get_logger(name): + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + return logger + + +logger = get_logger("hy3dgen.partgen") + + +class synchronize_timer: + """Synchronized timer to count the inference time of `nn.Module.forward`. + + Supports both context manager and decorator usage. + + Example as context manager: + ```python + with synchronize_timer('name') as t: + run() + ``` + + Example as decorator: + ```python + @synchronize_timer('Export to trimesh') + def export_to_trimesh(mesh_output): + pass + ``` + """ + + def __init__(self, name=None): + self.name = name + + def __enter__(self): + """Context manager entry: start timing.""" + if os.environ.get("HY3DGEN_DEBUG", "0") == "1": + self.start = torch.cuda.Event(enable_timing=True) + self.end = torch.cuda.Event(enable_timing=True) + self.start.record() + return lambda: self.time + + def __exit__(self, exc_type, exc_value, exc_tb): + """Context manager exit: stop timing and log results.""" + if os.environ.get("HY3DGEN_DEBUG", "0") == "1": + self.end.record() + torch.cuda.synchronize() + self.time = self.start.elapsed_time(self.end) + if self.name is not None: + logger.info(f"{self.name} takes {self.time} ms") + + def __call__(self, func): + """Decorator: wrap the function to time its execution.""" + + @wraps(func) + def wrapper(*args, **kwargs): + with self: + result = func(*args, **kwargs) + return result + + return wrapper + + +def get_config_from_file(config_file: str) -> Union[DictConfig, ListConfig]: + config_file = OmegaConf.load(config_file) + + if "base_config" in config_file.keys(): + if config_file["base_config"] == "default_base": + base_config = OmegaConf.create() + # base_config = get_default_config() + elif config_file["base_config"].endswith(".yaml"): + base_config = get_config_from_file(config_file["base_config"]) + else: + raise ValueError( + f"{config_file} must be `.yaml` file or it contains `base_config` key." + ) + + config_file = {key: value for key, value in config_file if key != "base_config"} + + return OmegaConf.merge(base_config, config_file) + + return config_file + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def instantiate_from_config(config, **kwargs): + if "target" not in config: + raise KeyError("Expected key `target` to instantiate.") + + cls = get_obj_from_str(config["target"]) + + if config.get("from_pretrained", None): + return cls.from_pretrained( + config["from_pretrained"], + use_safetensors=config.get("use_safetensors", False), + variant=config.get("variant", "fp16"), + ) + + params = config.get("params", dict()) + # params.update(kwargs) + # instance = cls(**params) + kwargs.update(params) + instance = cls(**kwargs) + + return instance + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def instantiate_non_trainable_model(config): + model = instantiate_from_config(config) + model = model.eval() + model.train = disabled_train + for param in model.parameters(): + param.requires_grad = False + + return model + + +def smart_load_model( + model_path, +): + original_model_path = model_path + # try local path + base_dir = os.environ.get("HY3DGEN_MODELS", "~/.cache/xpart") + model_fld = os.path.expanduser(os.path.join(base_dir, model_path)) + logger.info(f"Try to load model from local path: {model_path}") + if not os.path.exists(model_path): + logger.info("Model path not exists, try to download from huggingface") + try: + from huggingface_hub import snapshot_download + + # 只下载指定子目录 + path = snapshot_download( + repo_id=original_model_path, + # allow_patterns=[f"{subfolder}/*"], # 关键修改:模式匹配子文件夹 + local_dir=model_fld, + ) + model_path = path # os.path.join(path, subfolder) # 保持路径拼接逻辑不变 + except ImportError: + logger.warning( + "You need to install HuggingFace Hub to load models from the hub." + ) + raise RuntimeError(f"Model path {model_path} not found") + except Exception as e: + raise e + + if not os.path.exists(model_path): + raise FileNotFoundError(f"Model path {original_model_path} not found") + + return model_path + + +def init_from_ckpt(model, ckpt, prefix="model", ignore_keys=()): + if "state_dict" not in ckpt: + # deepspeed ckpt + state_dict = {} + ckpt = ckpt["module"] if "module" in ckpt else ckpt + for k in ckpt.keys(): + new_k = k.replace("_forward_module.", "") + state_dict[new_k] = ckpt[k] + else: + state_dict = ckpt["state_dict"] + keys = list(state_dict.keys()) + for k in keys: + for ik in ignore_keys: + if ik in k: + print("Deleting key {} from state_dict.".format(k)) + del state_dict[k] + state_dict = { + k.replace(prefix + ".", ""): v + for k, v in state_dict.items() + if k.startswith(prefix) + } + missing, unexpected = model.load_state_dict(state_dict, strict=False) + print(f"Restored with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") diff --git a/app.py b/app.py index 5cc265e4965b21fd51f9accf7aafbebd89067761..19186df7b40686d451ca053e74d38e104d17a2a1 100644 --- a/app.py +++ b/app.py @@ -1,14 +1,191 @@ import gradio as gr -import spaces +import os +import sys +import argparse +import numpy as np +import trimesh +from pathlib import Path import torch +import pytorch_lightning as pl +import spaces + +sys.path.append('P3-SAM') +from demo.auto_mask import AutoMask +from demo.auto_mask_no_postprocess import AutoMask as AutoMaskNoPostProcess +sys.path.append('XPart') +from partgen.partformer_pipeline import PartFormerPipeline +from partgen.utils.misc import get_config_from_file + +automask = AutoMask() +automask_no_postprocess = AutoMaskNoPostProcess(automask_instance=automask) + +def _load_pipeline(): + pl.seed_everything(2026, workers=True) + cfg_path = str(Path(__file__).parent / "XPart/partgen/config" / "infer.yaml") + config = get_config_from_file(cfg_path) + assert hasattr(config, "ckpt") or hasattr( + config, "ckpt_path" + ), "ckpt or ckpt_path must be specified in config" + pipeline = PartFormerPipeline.from_pretrained( + config=config, + verbose=True, + ignore_keys=config.get("ignore_keys", []), + ) + + device = "cuda" + pipeline.to(device=device, dtype=torch.float32) + return pipeline + +_PIPELINE = _load_pipeline() + +output_path = 'P3-SAM/results/gradio' +os.makedirs(output_path, exist_ok=True) -zero = torch.Tensor([0]).cuda() -print(zero.device) # <-- 'cpu' 🤔 +@spaces.GPU +def segment(mesh_path, connectivity=True, postprocess=True, postprocess_threshold=0.95, seed=42, gr_state=None): + if mesh_path is None: + gr.Warning("No Input Mesh") + gr_state[0] = (None, None) + return None, None + mesh = trimesh.load(mesh_path, force='mesh', process=False) + if connectivity: + aabb, face_ids, mesh = automask.predict_aabb(mesh, seed=seed, is_parallel=False, post_process=postprocess, threshold=postprocess_threshold) + else: + aabb, face_ids, mesh = automask_no_postprocess.predict_aabb(mesh, seed=seed, is_parallel=False, post_process=False) + color_map = {} + unique_ids = np.unique(face_ids) + for i in unique_ids: + if i == -1: + continue + part_color = np.random.rand(3) * 255 + color_map[i] = part_color + face_colors = [] + for i in face_ids: + if i == -1: + face_colors.append([0, 0, 0]) + else: + face_colors.append(color_map[i]) + face_colors = np.array(face_colors).astype(np.uint8) + mesh_save = mesh.copy() + mesh_save.visual.face_colors = face_colors + + file_path = os.path.join(output_path, 'segment_mesh.glb') + mesh_save.export(file_path) + face_id_save_path = os.path.join(output_path, 'face_id.npy') + np.save(face_id_save_path, face_ids) + gr_state[0] = (aabb, mesh_path) + return file_path, face_id_save_path @spaces.GPU -def greet(n): - print(zero.device) # <-- 'cuda:0' 🤗 - return f"Hello {zero + n} Tensor" +def generate(mesh_path, seed=42, gr_state=None): + if mesh_path is None: + gr.Warning("No Input Mesh") + gr_state[0] = (None, None) + return None, None, None + if gr_state[0][0] is None or mesh_path != gr_state[0][1]: + gr.Warning("Please segment the mesh first") + return None, None, None + + aabb = gr_state[0][0] + # Ensure deterministic behavior per request + try: + pl.seed_everything(int(seed), workers=True) + except Exception: + pl.seed_everything(2026, workers=True) + additional_params = {"output_type": "trimesh"} + obj_mesh, (out_bbox, mesh_gt_bbox, explode_object) = _PIPELINE( + mesh_path=mesh_path, + aabb=aabb, + octree_resolution=512, + **additional_params, + ) + # Export all results to temporary files for Gradio Model3D + obj_path = os.path.join(output_path, 'obj_mesh.glb') + out_bbox_path = os.path.join(output_path, 'out_bbox.glb') + explode_path = os.path.join(output_path, 'explode.glb') + obj_mesh.export(obj_path) + out_bbox.export(out_bbox_path) + explode_object.export(explode_path) + return obj_path, out_bbox_path, explode_path + +with gr.Blocks() as demo: + gr.Markdown( +''' +# ☯️ Hunyuan3D Part:P3-SAM&XPart +This demo allows you to generate parts given a 3D model using Hunyuan3D-Part. +First segment the 3D model using P3-SAM and then generate parts using XPart. +''' + ) + with gr.Row(): + with gr.Column(): + # P3-SAM + gr.Markdown( +''' +## P3-SAM: Native 3D Part Segmentation + +[Paper](https://arxiv.org/abs/2509.06784) | [Project Page](https://murcherful.github.io/P3-SAM/) | [Code](https://github.com/Tencent-Hunyuan/Hunyuan3D-Part/P3-SAM/) | [Model](https://huggingface.co/tencent/Hunyuan3D-Part) + +This is a demo of P3-SAM, a native 3D part segmentation method that can segment a mesh into different parts. +Input a mesh and push the "Segment" button to get the segmentation results. +''' + ) + p3sam_button = gr.Button("Segment") + p3sam_input = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="Input Mesh") + p3sam_output = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="Segmentation Result") + p3sam_face_id_output = gr.File(label='Face ID') + p3sam_conectivity = gr.Checkbox(value=True, label="Connectivity") + p3sam_postprocess = gr.Checkbox(value=True, label="Post-processing") + p3sam_postprocess_threshold = gr.Number(value=0.95, label="Post-processing Threshold") + p3sam_seed = gr.Number(value=42, label="Random Seed") + gr.Markdown( +''' +P3-SAM will clean your mesh. To get face-aligned labels, you can download the "Segmentation Result" and "Face ID". +You can also use the "Connectivity" and "Post-processing" options to control the behavior of the algorithm. +The "Post-processing" will merge the small parts according to the threshold. The smaller the threshold, the more parts will be merged. +''' + ) + gr.Examples(examples=[ + 'P3-SAM/demo/assets/1.glb', + 'P3-SAM/demo/assets/2.glb', + 'P3-SAM/demo/assets/4.glb', + 'XPart/data/000.glb', + 'XPart/data/001.glb', + 'XPart/data/002.glb', + 'XPart/data/003.glb', + 'XPart/data/004.glb', + ], + inputs = [p3sam_input], + example_labels=[ + 'Female Warrior', + 'Suspended Island', + 'Beetle Car', + 'Koi Fish', + 'Motorcycle', + 'Gundam', + 'Computer Desk', + 'Coffee Machine' + ] + ) + with gr.Column(): + # XPart + gr.Markdown( +''' +## XPart: High-fidelity and Structure-coherent Shapede Composition + +[Paper](https://arxiv.org/abs/2509.08643) | [Project Page](https://yanxinhao.github.io/Projects/X-Part/) | [Code](https://github.com/Tencent-Hunyuan/Hunyuan3D-Part/XPart/) | [Model](https://huggingface.co/tencent/Hunyuan3D-Part) + +This is a demo of XPart, a high-fidelity and structure-coherent shape-decomposition method that can generate parts from a 3D model. +Input a mesh, segment it using P3-SAM on the left, and push the "Generate" button to get the generated parts. +''' ) + xpart_button = gr.Button("Generate") + xpart_output = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="Generated Parts") + xpart_output_bbox = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="Gnerated Parts with BBox") + xpart_output_exploded = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="Exploded Object") + xpart_seed = gr.Number(value=42, label="Random Seed") + gr_state = gr.State(value=[(None, None)]) + p3sam_button.click(segment, inputs=[p3sam_input, p3sam_conectivity, p3sam_postprocess, p3sam_postprocess_threshold, p3sam_seed, gr_state], outputs=[p3sam_output, p3sam_face_id_output]) + xpart_button.click(generate, inputs=[p3sam_input, xpart_seed, gr_state], outputs=[xpart_output, xpart_output_bbox, xpart_output_exploded]) + -demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text()) -demo.launch() +if __name__ == '__main__': + demo.launch(server_name='0.0.0.0', server_port=8080) diff --git a/app_test.py b/app_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5cc265e4965b21fd51f9accf7aafbebd89067761 --- /dev/null +++ b/app_test.py @@ -0,0 +1,14 @@ +import gradio as gr +import spaces +import torch + +zero = torch.Tensor([0]).cuda() +print(zero.device) # <-- 'cpu' 🤔 + +@spaces.GPU +def greet(n): + print(zero.device) # <-- 'cuda:0' 🤗 + return f"Hello {zero + n} Tensor" + +demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text()) +demo.launch() diff --git a/requirements.txt b/requirements.txt index 17b616310eced7b4f8c29ca737e17a5646f86e95..5463b23d2260767e1b34e9f1efd2654df9919848 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,6 +17,7 @@ pandas==2.2.2 # 3D Mesh Processing trimesh==4.4.7 +pymeshlab==2022.2.post3 # Configuration Management omegaconf==2.3.0 @@ -34,4 +35,10 @@ onnxruntime==1.16.3 torchmetrics==1.6.0 timm +numba +fpsample +# sonata +spconv-cu126 +torch-scatter -f https://data.pyg.org/whl/torch-2.8.0+cu126.html +git+https://github.com/Dao-AILab/flash-attention.git@master \ No newline at end of file