Spaces:
Runtime error
Runtime error
| import torch | |
| import lzma | |
| from dp2.detection.base import BaseDetector | |
| from .utils import combine_cse_maskrcnn_dets | |
| from .models.cse import CSEDetector | |
| from .models.mask_rcnn import MaskRCNNDetector | |
| from .models.keypoint_maskrcnn import KeypointMaskRCNN | |
| from .structures import CSEPersonDetection, PersonDetection | |
| from pathlib import Path | |
| class CSEPersonDetector(BaseDetector): | |
| def __init__( | |
| self, | |
| score_threshold: float, | |
| mask_rcnn_cfg: dict, | |
| cse_cfg: dict, | |
| cse_post_process_cfg: dict, | |
| **kwargs | |
| ) -> None: | |
| super().__init__(**kwargs) | |
| self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold) | |
| self.cse_detector = CSEDetector(**cse_cfg, score_thres=score_threshold) | |
| self.post_process_cfg = cse_post_process_cfg | |
| self.iou_combine_threshold = self.post_process_cfg.pop("iou_combine_threshold") | |
| def __call__(self, *args, **kwargs): | |
| return self.forward(*args, **kwargs) | |
| def load_from_cache(self, cache_path: Path): | |
| with lzma.open(cache_path, "rb") as fp: | |
| state_dict = torch.load(fp) | |
| kwargs = dict( | |
| post_process_cfg=self.post_process_cfg, | |
| embed_map=self.cse_detector.embed_map, | |
| ) | |
| return [ | |
| state["cls"].from_state_dict(**kwargs, state_dict=state) | |
| for state in state_dict | |
| ] | |
| def forward(self, im: torch.Tensor, cse_dets=None): | |
| mask_dets = self.mask_rcnn(im) | |
| if cse_dets is None: | |
| cse_dets = self.cse_detector(im) | |
| segmentation = mask_dets["segmentation"] | |
| segmentation, cse_dets, _ = combine_cse_maskrcnn_dets( | |
| segmentation, cse_dets, self.iou_combine_threshold | |
| ) | |
| det = CSEPersonDetection( | |
| segmentation=segmentation, | |
| cse_dets=cse_dets, | |
| embed_map=self.cse_detector.embed_map, | |
| orig_imshape_CHW=im.shape, | |
| **self.post_process_cfg | |
| ) | |
| return [det] | |
| class MaskRCNNPersonDetector(BaseDetector): | |
| def __init__( | |
| self, | |
| score_threshold: float, | |
| mask_rcnn_cfg: dict, | |
| cse_post_process_cfg: dict, | |
| **kwargs | |
| ) -> None: | |
| super().__init__(**kwargs) | |
| self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold) | |
| self.post_process_cfg = cse_post_process_cfg | |
| def __call__(self, *args, **kwargs): | |
| return self.forward(*args, **kwargs) | |
| def load_from_cache(self, cache_path: Path): | |
| with lzma.open(cache_path, "rb") as fp: | |
| state_dict = torch.load(fp) | |
| kwargs = dict( | |
| post_process_cfg=self.post_process_cfg, | |
| ) | |
| return [ | |
| state["cls"].from_state_dict(**kwargs, state_dict=state) | |
| for state in state_dict | |
| ] | |
| def forward(self, im: torch.Tensor): | |
| mask_dets = self.mask_rcnn(im) | |
| segmentation = mask_dets["segmentation"] | |
| det = PersonDetection( | |
| segmentation, **self.post_process_cfg, orig_imshape_CHW=im.shape | |
| ) | |
| return [det] | |
| class KeypointMaskRCNNPersonDetector(BaseDetector): | |
| def __init__( | |
| self, | |
| score_threshold: float, | |
| mask_rcnn_cfg: dict, | |
| cse_post_process_cfg: dict, | |
| **kwargs | |
| ) -> None: | |
| super().__init__(**kwargs) | |
| self.mask_rcnn = KeypointMaskRCNN( | |
| **mask_rcnn_cfg, score_threshold=score_threshold | |
| ) | |
| self.post_process_cfg = cse_post_process_cfg | |
| def __call__(self, *args, **kwargs): | |
| return self.forward(*args, **kwargs) | |
| def load_from_cache(self, cache_path: Path): | |
| with lzma.open(cache_path, "rb") as fp: | |
| state_dict = torch.load(fp) | |
| kwargs = dict( | |
| post_process_cfg=self.post_process_cfg, | |
| ) | |
| return [ | |
| state["cls"].from_state_dict(**kwargs, state_dict=state) | |
| for state in state_dict | |
| ] | |
| def forward(self, im: torch.Tensor): | |
| mask_dets = self.mask_rcnn(im) | |
| segmentation = mask_dets["segmentation"] | |
| det = PersonDetection( | |
| segmentation, | |
| **self.post_process_cfg, | |
| orig_imshape_CHW=im.shape, | |
| keypoints=mask_dets["keypoints"] | |
| ) | |
| return [det] | |