# ------------------------------------------------------------------------------ # Copyright 2025 2toINF (https://github.com/2toINF) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ------------------------------------------------------------------------------ from __future__ import annotations from typing import Iterable, Tuple, Dict, Type import torch import torch.nn as nn # ============================================================================= # Registry # ============================================================================= ACTION_REGISTRY: Dict[str, Type["BaseActionSpace"]] = {} def register_action(name: str): """Decorator for registering a new action space.""" def _wrap(cls): key = name.lower() if key in ACTION_REGISTRY: raise KeyError(f"ActionSpace '{key}' already registered -> {ACTION_REGISTRY[key]}") ACTION_REGISTRY[key] = cls cls.name = key return cls return _wrap def build_action_space(name: str, **kwargs) -> "BaseActionSpace": """Instantiate a registered action space by name.""" key = name.lower() if key not in ACTION_REGISTRY: raise KeyError(f"Unknown action space '{name}'. Available: {list(ACTION_REGISTRY.keys())}") return ACTION_REGISTRY[key](**kwargs) # ============================================================================= # Base class # ============================================================================= class BaseActionSpace(nn.Module): """ Abstract base class for all action-space definitions. Each subclass defines: - `dim_action`: dimension of the action vector. - `gripper_idx`: indices of gripper channels. - `compute_loss(pred, target)`: supervised loss for this space. - `preprocess(proprio, action, mode)`: pre-step modifications. - `postprocess(action)`: post-step corrections (e.g. apply sigmoid). """ name: str = "base" dim_action: int = 0 gripper_idx: Tuple[int, ...] = () def __init__(self): super().__init__() # --------------------------------------------------------------------- # Core supervised loss # --------------------------------------------------------------------- def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]: raise NotImplementedError def forward(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]: """Alias for compute_loss.""" return self.compute_loss(pred, target) # --------------------------------------------------------------------- # Space-level hooks # --------------------------------------------------------------------- def preprocess( self, proprio: torch.Tensor, action: torch.Tensor, mode: str = "train", ) -> Tuple[torch.Tensor, torch.Tensor]: """Default: return unchanged.""" return proprio, action def postprocess(self, action: torch.Tensor) -> torch.Tensor: """Default: return unchanged.""" return action # ============================================================================= # Utilities # ============================================================================= def _ensure_indices_valid(D: int, idx: Iterable[int], name: str) -> None: bad = [i for i in idx if i < 0 or i >= D] if bad: raise IndexError(f"{name} contains out-of-range indices {bad} for action dim D={D}") # ============================================================================= # Implementations # ============================================================================= @register_action("ee6d") class EE6DActionSpace(BaseActionSpace): """End-effector layout with xyz, 6D rotation, and gripper channels.""" dim_action = 20 gripper_idx = (9, 19) GRIPPER_SCALE = 1.0 XYZ_SCALE = 500.0 ROT_SCALE = 10.0 POS_IDX_1 = (0, 1, 2) POS_IDX_2 = (10, 11, 12) ROT_IDX_1 = (3, 4, 5, 6, 7, 8) ROT_IDX_2 = (13, 14, 15, 16, 17, 18) def __init__(self): super().__init__() self.mse = nn.MSELoss() self.bce = nn.BCEWithLogitsLoss() def compute_loss(self, pred, target): assert pred.shape == target.shape, "pred/target shapes must match" B, T, D = pred.shape _ensure_indices_valid(D, self.gripper_idx, "gripper_idx") # Gripper BCE g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx] gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE # XYZ position pos_loss = ( self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) + self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2]) ) * self.XYZ_SCALE # Rotation 6D rot_loss = ( self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) + self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2]) ) * self.ROT_SCALE return { "position_loss": pos_loss, "rotate6D_loss": rot_loss, "gripper_loss": gripper_loss, } def preprocess(self, proprio, action, mode="train"): """Zero-out gripper channels in proprio/action.""" proprio_m = proprio.clone() action_m = action.clone() proprio_m[..., self.gripper_idx] = 0.0 action_m[..., self.gripper_idx] = 0.0 return proprio_m, action_m def postprocess(self, action: torch.Tensor) -> torch.Tensor: """Apply sigmoid to gripper logits.""" if action.size(-1) > max(self.gripper_idx): action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx]) return action @register_action("joint") class JointActionSpace(BaseActionSpace): """Joint-space layout with joints + gripper only.""" dim_action = 14 gripper_idx = (6, 13) GRIPPER_SCALE = 0.1 JOINTS_SCALE = 1.0 def __init__(self): super().__init__() self.mse = nn.MSELoss() self.bce = nn.BCEWithLogitsLoss() def compute_loss(self, pred, target): assert pred.shape == target.shape B, T, D = pred.shape _ensure_indices_valid(D, self.gripper_idx, "gripper_idx") g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx] gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE joints_idx = tuple(i for i in range(D) if i not in set(self.gripper_idx)) joints_loss = self.mse(pred[:, :, joints_idx], target[:, :, joints_idx]) * self.JOINTS_SCALE return { "joints_loss": joints_loss, "gripper_loss": gripper_loss, } def preprocess(self, proprio, action, mode="train"): """Zero-out gripper channels in proprio/action.""" proprio_m = proprio.clone() action_m = action.clone() proprio_m[..., self.gripper_idx] = 0.0 action_m[..., self.gripper_idx] = 0.0 return proprio_m, action_m def postprocess(self, action: torch.Tensor) -> torch.Tensor: """Apply sigmoid to gripper logits.""" if action.size(-1) > max(self.gripper_idx): action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx]) return action @register_action("agibot_ee6d") class AGIBOTEE6DActionSpace(BaseActionSpace): """AGI-bot variant of EE6DActionSpace using MSE for all components.""" dim_action = 20 gripper_idx = (9, 19) GRIPPER_SCALE = 10.0 XYZ_SCALE = 500.0 ROT_SCALE = 10.0 POS_IDX_1 = (0, 1, 2) POS_IDX_2 = (10, 11, 12) ROT_IDX_1 = (3, 4, 5, 6, 7, 8) ROT_IDX_2 = (13, 14, 15, 16, 17, 18) def __init__(self): super().__init__() self.mse = nn.MSELoss() def compute_loss(self, pred, target): assert pred.shape == target.shape B, T, D = pred.shape _ensure_indices_valid(D, self.gripper_idx, "gripper_idx") gripper_loss = self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE pos_loss = ( self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) + self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2]) ) * self.XYZ_SCALE rot_loss = ( self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) + self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2]) ) * self.ROT_SCALE return { "position_loss": pos_loss, "rotate6D_loss": rot_loss, "gripper_loss": gripper_loss, } def preprocess(self, proprio, action, mode="train"): """No preprocessing applied in AGIBOT variant.""" return proprio, action def postprocess(self, action: torch.Tensor) -> torch.Tensor: """AGIBOT does not postprocess.""" return action # ============================================================================= # Exports # ============================================================================= __all__ = [ "BaseActionSpace", "build_action_space", "register_action", "EE6DActionSpace", "JointActionSpace", "AGIBOTEE6DActionSpace", "ACTION_REGISTRY", ]