|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
from typing import Iterable, Tuple, Dict, Type |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ACTION_REGISTRY: Dict[str, Type["BaseActionSpace"]] = {} |
|
|
|
|
|
|
|
|
def register_action(name: str): |
|
|
"""Decorator for registering a new action space.""" |
|
|
def _wrap(cls): |
|
|
key = name.lower() |
|
|
if key in ACTION_REGISTRY: |
|
|
raise KeyError(f"ActionSpace '{key}' already registered -> {ACTION_REGISTRY[key]}") |
|
|
ACTION_REGISTRY[key] = cls |
|
|
cls.name = key |
|
|
return cls |
|
|
return _wrap |
|
|
|
|
|
|
|
|
def build_action_space(name: str, **kwargs) -> "BaseActionSpace": |
|
|
"""Instantiate a registered action space by name.""" |
|
|
key = name.lower() |
|
|
if key not in ACTION_REGISTRY: |
|
|
raise KeyError(f"Unknown action space '{name}'. Available: {list(ACTION_REGISTRY.keys())}") |
|
|
return ACTION_REGISTRY[key](**kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseActionSpace(nn.Module): |
|
|
""" |
|
|
Abstract base class for all action-space definitions. |
|
|
|
|
|
Each subclass defines: |
|
|
- `dim_action`: dimension of the action vector. |
|
|
- `gripper_idx`: indices of gripper channels. |
|
|
- `compute_loss(pred, target)`: supervised loss for this space. |
|
|
- `preprocess(proprio, action, mode)`: pre-step modifications. |
|
|
- `postprocess(action)`: post-step corrections (e.g. apply sigmoid). |
|
|
""" |
|
|
|
|
|
name: str = "base" |
|
|
dim_action: int = 0 |
|
|
gripper_idx: Tuple[int, ...] = () |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]: |
|
|
raise NotImplementedError |
|
|
|
|
|
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]: |
|
|
"""Alias for compute_loss.""" |
|
|
return self.compute_loss(pred, target) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess( |
|
|
self, |
|
|
proprio: torch.Tensor, |
|
|
action: torch.Tensor, |
|
|
mode: str = "train", |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Default: return unchanged.""" |
|
|
return proprio, action |
|
|
|
|
|
def postprocess(self, action: torch.Tensor) -> torch.Tensor: |
|
|
"""Default: return unchanged.""" |
|
|
return action |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ensure_indices_valid(D: int, idx: Iterable[int], name: str) -> None: |
|
|
bad = [i for i in idx if i < 0 or i >= D] |
|
|
if bad: |
|
|
raise IndexError(f"{name} contains out-of-range indices {bad} for action dim D={D}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_action("ee6d") |
|
|
class EE6DActionSpace(BaseActionSpace): |
|
|
"""End-effector layout with xyz, 6D rotation, and gripper channels.""" |
|
|
|
|
|
dim_action = 20 |
|
|
gripper_idx = (9, 19) |
|
|
GRIPPER_SCALE = 1.0 |
|
|
XYZ_SCALE = 500.0 |
|
|
ROT_SCALE = 10.0 |
|
|
|
|
|
POS_IDX_1 = (0, 1, 2) |
|
|
POS_IDX_2 = (10, 11, 12) |
|
|
ROT_IDX_1 = (3, 4, 5, 6, 7, 8) |
|
|
ROT_IDX_2 = (13, 14, 15, 16, 17, 18) |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.mse = nn.MSELoss() |
|
|
self.bce = nn.BCEWithLogitsLoss() |
|
|
|
|
|
def compute_loss(self, pred, target): |
|
|
assert pred.shape == target.shape, "pred/target shapes must match" |
|
|
B, T, D = pred.shape |
|
|
_ensure_indices_valid(D, self.gripper_idx, "gripper_idx") |
|
|
|
|
|
|
|
|
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx] |
|
|
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE |
|
|
|
|
|
|
|
|
pos_loss = ( |
|
|
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) + |
|
|
self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2]) |
|
|
) * self.XYZ_SCALE |
|
|
|
|
|
|
|
|
rot_loss = ( |
|
|
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) + |
|
|
self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2]) |
|
|
) * self.ROT_SCALE |
|
|
|
|
|
return { |
|
|
"position_loss": pos_loss, |
|
|
"rotate6D_loss": rot_loss, |
|
|
"gripper_loss": gripper_loss, |
|
|
} |
|
|
|
|
|
def preprocess(self, proprio, action, mode="train"): |
|
|
"""Zero-out gripper channels in proprio/action.""" |
|
|
proprio_m = proprio.clone() |
|
|
action_m = action.clone() |
|
|
proprio_m[..., self.gripper_idx] = 0.0 |
|
|
action_m[..., self.gripper_idx] = 0.0 |
|
|
return proprio_m, action_m |
|
|
|
|
|
def postprocess(self, action: torch.Tensor) -> torch.Tensor: |
|
|
"""Apply sigmoid to gripper logits.""" |
|
|
if action.size(-1) > max(self.gripper_idx): |
|
|
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx]) |
|
|
return action |
|
|
|
|
|
|
|
|
@register_action("joint") |
|
|
class JointActionSpace(BaseActionSpace): |
|
|
"""Joint-space layout with joints + gripper only.""" |
|
|
|
|
|
dim_action = 14 |
|
|
gripper_idx = (6, 13) |
|
|
GRIPPER_SCALE = 0.1 |
|
|
JOINTS_SCALE = 1.0 |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.mse = nn.MSELoss() |
|
|
self.bce = nn.BCEWithLogitsLoss() |
|
|
|
|
|
def compute_loss(self, pred, target): |
|
|
assert pred.shape == target.shape |
|
|
B, T, D = pred.shape |
|
|
_ensure_indices_valid(D, self.gripper_idx, "gripper_idx") |
|
|
|
|
|
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx] |
|
|
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE |
|
|
|
|
|
joints_idx = tuple(i for i in range(D) if i not in set(self.gripper_idx)) |
|
|
joints_loss = self.mse(pred[:, :, joints_idx], target[:, :, joints_idx]) * self.JOINTS_SCALE |
|
|
|
|
|
return { |
|
|
"joints_loss": joints_loss, |
|
|
"gripper_loss": gripper_loss, |
|
|
} |
|
|
|
|
|
def preprocess(self, proprio, action, mode="train"): |
|
|
"""Zero-out gripper channels in proprio/action.""" |
|
|
proprio_m = proprio.clone() |
|
|
action_m = action.clone() |
|
|
proprio_m[..., self.gripper_idx] = 0.0 |
|
|
action_m[..., self.gripper_idx] = 0.0 |
|
|
return proprio_m, action_m |
|
|
|
|
|
def postprocess(self, action: torch.Tensor) -> torch.Tensor: |
|
|
"""Apply sigmoid to gripper logits.""" |
|
|
if action.size(-1) > max(self.gripper_idx): |
|
|
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx]) |
|
|
return action |
|
|
|
|
|
|
|
|
@register_action("agibot_ee6d") |
|
|
class AGIBOTEE6DActionSpace(BaseActionSpace): |
|
|
"""AGI-bot variant of EE6DActionSpace using MSE for all components.""" |
|
|
|
|
|
dim_action = 20 |
|
|
gripper_idx = (9, 19) |
|
|
GRIPPER_SCALE = 10.0 |
|
|
XYZ_SCALE = 500.0 |
|
|
ROT_SCALE = 10.0 |
|
|
POS_IDX_1 = (0, 1, 2) |
|
|
POS_IDX_2 = (10, 11, 12) |
|
|
ROT_IDX_1 = (3, 4, 5, 6, 7, 8) |
|
|
ROT_IDX_2 = (13, 14, 15, 16, 17, 18) |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.mse = nn.MSELoss() |
|
|
|
|
|
def compute_loss(self, pred, target): |
|
|
assert pred.shape == target.shape |
|
|
B, T, D = pred.shape |
|
|
_ensure_indices_valid(D, self.gripper_idx, "gripper_idx") |
|
|
|
|
|
gripper_loss = self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE |
|
|
pos_loss = ( |
|
|
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) + |
|
|
self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2]) |
|
|
) * self.XYZ_SCALE |
|
|
rot_loss = ( |
|
|
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) + |
|
|
self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2]) |
|
|
) * self.ROT_SCALE |
|
|
|
|
|
return { |
|
|
"position_loss": pos_loss, |
|
|
"rotate6D_loss": rot_loss, |
|
|
"gripper_loss": gripper_loss, |
|
|
} |
|
|
|
|
|
def preprocess(self, proprio, action, mode="train"): |
|
|
"""No preprocessing applied in AGIBOT variant.""" |
|
|
return proprio, action |
|
|
|
|
|
def postprocess(self, action: torch.Tensor) -> torch.Tensor: |
|
|
"""AGIBOT does not postprocess.""" |
|
|
return action |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"BaseActionSpace", |
|
|
"build_action_space", |
|
|
"register_action", |
|
|
"EE6DActionSpace", |
|
|
"JointActionSpace", |
|
|
"AGIBOTEE6DActionSpace", |
|
|
"ACTION_REGISTRY", |
|
|
] |
|
|
|