Spaces:
Runtime error
Runtime error
| import torch | |
| import tops | |
| import cv2 | |
| import torchvision.transforms.functional as F | |
| from typing import Optional, List, Union, Tuple | |
| from .cse import from_E_to_vertex | |
| import numpy as np | |
| from tops import download_file | |
| from .torch_utils import ( | |
| denormalize_img, binary_dilation, binary_erosion, | |
| remove_pad, crop_box) | |
| from torchvision.utils import _generate_color_palette | |
| from PIL import Image, ImageColor, ImageDraw | |
| def get_coco_keypoints(): | |
| # From: https://github.com/facebookresearch/Detectron/blob/main/detectron/utils/keypoints.py | |
| keypoints = [ | |
| 'nose', | |
| 'left_eye', | |
| 'right_eye', | |
| 'left_ear', | |
| 'right_ear', | |
| 'left_shoulder', | |
| 'right_shoulder', | |
| 'left_elbow', | |
| 'right_elbow', | |
| 'left_wrist', | |
| 'right_wrist', | |
| 'left_hip', | |
| 'right_hip', | |
| 'left_knee', | |
| 'right_knee', | |
| 'left_ankle', | |
| 'right_ankle' | |
| ] | |
| keypoint_flip_map = { | |
| 'left_eye': 'right_eye', | |
| 'left_ear': 'right_ear', | |
| 'left_shoulder': 'right_shoulder', | |
| 'left_elbow': 'right_elbow', | |
| 'left_wrist': 'right_wrist', | |
| 'left_hip': 'right_hip', | |
| 'left_knee': 'right_knee', | |
| 'left_ankle': 'right_ankle' | |
| } | |
| connectivity = { | |
| "nose": "left_eye", | |
| "left_eye": "right_eye", | |
| "right_eye": "nose", | |
| "left_ear": "left_eye", | |
| "right_ear": "right_eye", | |
| "left_shoulder": "nose", | |
| "right_shoulder": "nose", | |
| "left_elbow": "left_shoulder", | |
| "right_elbow": "right_shoulder", | |
| "left_wrist": "left_elbow", | |
| "right_wrist": "right_elbow", | |
| "left_hip": "left_shoulder", | |
| "right_hip": "right_shoulder", | |
| "left_knee": "left_hip", | |
| "right_knee": "right_hip", | |
| "left_ankle": "left_knee", | |
| "right_ankle": "right_knee" | |
| } | |
| connectivity_indices = [ | |
| (sidx, keypoints.index(connectivity[kp])) | |
| for sidx, kp in enumerate(keypoints) | |
| ] | |
| return keypoints, keypoint_flip_map, connectivity_indices | |
| def get_coco_colors(): | |
| return [ | |
| *["red"]*5, | |
| "blue", | |
| "green", | |
| "blue", | |
| "green", | |
| "blue", | |
| "green", | |
| "purple", | |
| "orange", | |
| "purple", | |
| "orange", | |
| "purple", | |
| "orange", | |
| ] | |
| def draw_keypoints( | |
| image: torch.Tensor, | |
| keypoints: torch.Tensor, | |
| connectivity: Optional[List[Tuple[int, int]]] = None, | |
| visible: Optional[List[List[bool]]] = None, | |
| colors: Optional[Union[str, Tuple[int, int, int]]] = None, | |
| radius: int = None, | |
| width: int = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Function taken from torchvision source code. Added in torchvision 0.12 | |
| Draws Keypoints on given RGB image. | |
| The values of the input image should be uint8 between 0 and 255. | |
| Args: | |
| image (Tensor): Tensor of shape (3, H, W) and dtype uint8. | |
| keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, | |
| in the format [x, y]. | |
| connectivity (List[Tuple[int, int]]]): A List of tuple where, | |
| each tuple contains pair of keypoints to be connected. | |
| colors (str, Tuple): The color can be represented as | |
| PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. | |
| radius (int): Integer denoting radius of keypoint. | |
| width (int): Integer denoting width of line connecting keypoints. | |
| Returns: | |
| img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. | |
| """ | |
| if not isinstance(image, torch.Tensor): | |
| raise TypeError(f"The image must be a tensor, got {type(image)}") | |
| elif image.dtype != torch.uint8: | |
| raise ValueError(f"The image dtype must be uint8, got {image.dtype}") | |
| elif image.dim() != 3: | |
| raise ValueError("Pass individual images, not batches") | |
| elif image.size()[0] != 3: | |
| raise ValueError("Pass an RGB image. Other Image formats are not supported") | |
| if keypoints.ndim != 3: | |
| raise ValueError("keypoints must be of shape (num_instances, K, 2)") | |
| if width is None: | |
| width = int(max(max(image.shape[-2:]) * 0.01, 1)) | |
| if radius is None: | |
| radius = int(max(max(image.shape[-2:]) * 0.01, 1)) | |
| ndarr = image.permute(1, 2, 0).cpu().numpy() | |
| img_to_draw = Image.fromarray(ndarr) | |
| draw = ImageDraw.Draw(img_to_draw) | |
| if isinstance(keypoints, torch.Tensor): | |
| img_kpts = keypoints.to(torch.int64).tolist() | |
| else: | |
| assert isinstance(keypoints, np.ndarray) | |
| img_kpts = keypoints.astype(int).tolist() | |
| colors = get_coco_colors() | |
| for inst_id, kpt_inst in enumerate(img_kpts): | |
| for kpt_id, kpt in enumerate(kpt_inst): | |
| if visible is not None and int(visible[inst_id][kpt_id]) == 0: | |
| continue | |
| x1 = kpt[0] - radius | |
| x2 = kpt[0] + radius | |
| y1 = kpt[1] - radius | |
| y2 = kpt[1] + radius | |
| draw.ellipse([x1, y1, x2, y2], fill=colors[kpt_id], outline=None, width=0) | |
| if connectivity is not None: | |
| for connection in connectivity: | |
| if connection[1] >= len(kpt_inst) or connection[0] >= len(kpt_inst): | |
| continue | |
| if visible is not None and int(visible[inst_id][connection[1]]) == 0 or int(visible[inst_id][connection[0]]) == 0: | |
| continue | |
| start_pt_x = kpt_inst[connection[0]][0] | |
| start_pt_y = kpt_inst[connection[0]][1] | |
| end_pt_x = kpt_inst[connection[1]][0] | |
| end_pt_y = kpt_inst[connection[1]][1] | |
| draw.line( | |
| ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)), | |
| width=width, fill=colors[connection[1]] | |
| ) | |
| return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) | |
| def visualize_keypoints(img, keypoints): | |
| img = img.clone() | |
| keypoints = keypoints.clone() | |
| keypoints[:, :, 0] *= img.shape[-1] | |
| keypoints[:, :, 1] *= img.shape[-2] | |
| _, _, connectivity = get_coco_keypoints() | |
| connectivity = np.array(connectivity) | |
| visible = None | |
| if keypoints.shape[-1] == 3: | |
| visible = keypoints[:, :, 2] > 0 | |
| for idx in range(img.shape[0]): | |
| img[idx] = draw_keypoints( | |
| img[idx], keypoints[idx:idx+1].long(), colors="red", | |
| connectivity=connectivity, visible=visible[idx:idx+1]) | |
| return img | |
| def visualize_batch( | |
| img: torch.Tensor, mask: torch.Tensor, | |
| vertices: torch.Tensor = None, | |
| E_mask: torch.Tensor = None, | |
| embed_map: torch.Tensor = None, | |
| semantic_mask: torch.Tensor = None, | |
| embedding: torch.Tensor = None, | |
| keypoints: torch.Tensor = None, | |
| maskrcnn_mask: torch.Tensor = None, | |
| **kwargs) -> torch.ByteTensor: | |
| img = denormalize_img(img).mul(255).round().clamp(0, 255).byte() | |
| img = draw_mask(img, mask) | |
| if maskrcnn_mask is not None and maskrcnn_mask.shape == mask.shape: | |
| img = draw_mask(img, maskrcnn_mask) | |
| if vertices is not None or embedding is not None: | |
| assert E_mask is not None | |
| assert embed_map is not None | |
| img, E_mask, embedding, embed_map, vertices = tops.to_cpu([ | |
| img, E_mask, embedding, embed_map, vertices | |
| ]) | |
| img = draw_cse(img, E_mask, embedding, embed_map, vertices) | |
| elif semantic_mask is not None: | |
| img = draw_segmentation_masks(img, semantic_mask) | |
| if keypoints is not None: | |
| img = visualize_keypoints(img, keypoints) | |
| return img | |
| def draw_cse( | |
| img: torch.Tensor, E_seg: torch.Tensor, | |
| embedding: torch.Tensor = None, | |
| embed_map: torch.Tensor = None, | |
| vertices: torch.Tensor = None, t=0.7 | |
| ): | |
| """ | |
| E_seg: 1 for areas with embedding | |
| """ | |
| assert img.dtype == torch.uint8 | |
| img = img.view(-1, *img.shape[-3:]) | |
| E_seg = E_seg.view(-1, 1, *E_seg.shape[-2:]) | |
| if vertices is None: | |
| assert embedding is not None | |
| assert embed_map is not None | |
| embedding = embedding.view(-1, *embedding.shape[-3:]) | |
| vertices = torch.stack( | |
| [from_E_to_vertex(e[None], e_seg[None].logical_not().float(), embed_map) | |
| for e, e_seg in zip(embedding, E_seg)]) | |
| i = np.arange(0, 256, dtype=np.uint8).reshape(1, -1) | |
| colormap_JET = torch.from_numpy(cv2.applyColorMap(i, cv2.COLORMAP_JET)[0]) | |
| color_embed_map, _ = np.load(download_file( | |
| "https://dl.fbaipublicfiles.com/densepose/data/cse/mds_d=256.npy"), allow_pickle=True) | |
| color_embed_map = torch.from_numpy(color_embed_map).float()[:, 0] | |
| color_embed_map -= color_embed_map.min() | |
| color_embed_map /= color_embed_map.max() | |
| vertx2idx = (color_embed_map*255).long() | |
| vertx2colormap = colormap_JET[vertx2idx] | |
| vertices = vertices.view(-1, *vertices.shape[-2:]) | |
| E_seg = E_seg.view(-1, 1, *E_seg.shape[-2:]) | |
| # This operation might be good to do on cpu... | |
| E_color = vertx2colormap[vertices.long()] | |
| E_color = E_color.to(E_seg.device) | |
| E_color = E_color.permute(0, 3, 1, 2) | |
| E_color = E_color*E_seg.byte() | |
| m = E_seg.bool().repeat(1, 3, 1, 1) | |
| img[m] = (img[m] * (1-t) + t * E_color[m]).byte() | |
| return img | |
| def draw_cse_all( | |
| embedding: List[torch.Tensor], E_mask: List[torch.Tensor], | |
| im: torch.Tensor, boxes_XYXY: list, embed_map: torch.Tensor, t=0.7): | |
| """ | |
| E_seg: 1 for areas with embedding | |
| """ | |
| assert len(im.shape) == 3, im.shape | |
| assert im.dtype == torch.uint8 | |
| N = len(E_mask) | |
| im = im.clone() | |
| for i in range(N): | |
| assert len(E_mask[i].shape) == 2 | |
| assert len(embedding[i].shape) == 3 | |
| assert embed_map.shape[1] == embedding[i].shape[0] | |
| assert len(boxes_XYXY[i]) == 4 | |
| E = embedding[i] | |
| x0, y0, x1, y1 = boxes_XYXY[i] | |
| E = F.resize(E, (y1-y0, x1-x0), antialias=True) | |
| s = E_mask[i].float() | |
| s = (F.resize(s.squeeze()[None], (y1-y0, x1-x0), antialias=True) > 0).float() | |
| box = boxes_XYXY[i] | |
| im_ = crop_box(im, box) | |
| s = remove_pad(s, box, im.shape[1:]) | |
| E = remove_pad(E, box, im.shape[1:]) | |
| E_color = draw_cse(img=im_, E_seg=s[None], embedding=E[None], embed_map=embed_map)[0] | |
| E_color = E_color.to(im.device) | |
| s = s.bool().repeat(3, 1, 1) | |
| crop_box(im, box)[s] = (im_[s] * (1-t) + t * E_color[s]).byte() | |
| return im | |
| def draw_segmentation_masks( | |
| image: torch.Tensor, | |
| masks: torch.Tensor, | |
| alpha: float = 0.8, | |
| colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Draws segmentation masks on given RGB image. | |
| The values of the input image should be uint8 between 0 and 255. | |
| Args: | |
| image (Tensor): Tensor of shape (3, H, W) and dtype uint8. | |
| masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. | |
| alpha (float): Float number between 0 and 1 denoting the transparency of the masks. | |
| 0 means full transparency, 1 means no transparency. | |
| colors (list or None): List containing the colors of the masks. The colors can | |
| be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. | |
| When ``masks`` has a single entry of shape (H, W), you can pass a single color instead of a list | |
| with one element. By default, random colors are generated for each mask. | |
| Returns: | |
| img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top. | |
| """ | |
| if not isinstance(image, torch.Tensor): | |
| raise TypeError(f"The image must be a tensor, got {type(image)}") | |
| elif image.dtype != torch.uint8: | |
| raise ValueError(f"The image dtype must be uint8, got {image.dtype}") | |
| elif image.dim() != 3: | |
| raise ValueError("Pass individual images, not batches") | |
| elif image.size()[0] != 3: | |
| raise ValueError("Pass an RGB image. Other Image formats are not supported") | |
| if masks.ndim == 2: | |
| masks = masks[None, :, :] | |
| if masks.ndim != 3: | |
| raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)") | |
| if masks.dtype != torch.bool: | |
| raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}") | |
| if masks.shape[-2:] != image.shape[-2:]: | |
| raise ValueError("The image and the masks must have the same height and width") | |
| num_masks = masks.size()[0] | |
| if num_masks == 0: | |
| return image | |
| if colors is None: | |
| colors = _generate_color_palette(num_masks) | |
| if not isinstance(colors[0], (Tuple, List)): | |
| colors = [colors for i in range(num_masks)] | |
| if colors is not None and num_masks > len(colors): | |
| raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") | |
| if not isinstance(colors, list): | |
| colors = [colors] | |
| if not isinstance(colors[0], (tuple, str)): | |
| raise ValueError("colors must be a tuple or a string, or a list thereof") | |
| if isinstance(colors[0], tuple) and len(colors[0]) != 3: | |
| raise ValueError("It seems that you passed a tuple of colors instead of a list of colors") | |
| out_dtype = torch.uint8 | |
| colors_ = [] | |
| for color in colors: | |
| if isinstance(color, str): | |
| color = ImageColor.getrgb(color) | |
| color = torch.tensor(color, dtype=out_dtype, device=masks.device) | |
| colors_.append(color) | |
| img_to_draw = image.detach().clone() | |
| # TODO: There might be a way to vectorize this | |
| for mask, color in zip(masks, colors_): | |
| img_to_draw[:, mask] = color[:, None] | |
| out = image * (1 - alpha) + img_to_draw * alpha | |
| return out.to(out_dtype) | |
| def draw_mask(im: torch.Tensor, mask: torch.Tensor, t=0.2, color=(255, 255, 255), visualize_instances=True): | |
| """ | |
| Visualize mask where mask = 0. | |
| Supports multiple instances. | |
| mask shape: [N, C, H, W], where C is different instances in same image. | |
| """ | |
| orig_imshape = im.shape | |
| if mask.numel() == 0: | |
| return im | |
| assert len(mask.shape) in (3, 4), mask.shape | |
| mask = mask.view(-1, *mask.shape[-3:]) | |
| im = im.view(-1, *im.shape[-3:]) | |
| assert im.dtype == torch.uint8, im.dtype | |
| assert 0 <= t <= 1 | |
| if not visualize_instances: | |
| mask = mask.any(dim=1, keepdim=True) | |
| mask = mask.bool() | |
| kernel = torch.ones((3, 3), dtype=mask.dtype, device=mask.device) | |
| outer_border = binary_dilation(mask, kernel).logical_xor(mask) | |
| outer_border = outer_border.any(dim=1, keepdim=True).repeat(1, 3, 1, 1) > 0 | |
| inner_border = binary_erosion(mask, kernel).logical_xor(mask) | |
| inner_border = inner_border.any(dim=1, keepdim=True).repeat(1, 3, 1, 1) > 0 | |
| mask = (mask == 0).any(dim=1, keepdim=True).repeat(1, 3, 1, 1) | |
| color = torch.tensor(color).to(im.device).byte().view(1, 3, 1, 1) # .repeat(1, *im.shape[1:]) | |
| color = color.repeat(im.shape[0], 1, *im.shape[-2:]) | |
| im[mask] = (im[mask] * (1-t) + t * color[mask]).byte() | |
| im[outer_border] = 255 | |
| im[inner_border] = 0 | |
| return im.view(*orig_imshape) | |
| def draw_cropped_masks(im: torch.Tensor, mask: torch.Tensor, boxes: torch.Tensor, **kwargs): | |
| for i, box in enumerate(boxes): | |
| x0, y0, x1, y1 = boxes[i] | |
| orig_shape = (y1-y0, x1-x0) | |
| m = F.resize(mask[i], orig_shape, F.InterpolationMode.NEAREST).squeeze()[None] | |
| m = remove_pad(m, boxes[i], im.shape[-2:]) | |
| crop_box(im, boxes[i]).set_(draw_mask(crop_box(im, boxes[i]), m)) | |
| return im | |
| def draw_cropped_keypoints(im: torch.Tensor, all_keypoints: torch.Tensor, boxes: torch.Tensor, **kwargs): | |
| n_boxes = boxes.shape[0] | |
| tops.assert_shape(all_keypoints, (n_boxes, 17, 3)) | |
| im = im.clone() | |
| for i, box in enumerate(boxes): | |
| x0, y0, x1, y1 = boxes[i] | |
| orig_shape = (y1-y0, x1-x0) | |
| keypoints = all_keypoints[i].clone() | |
| keypoints[:, 0] *= orig_shape[1] | |
| keypoints[:, 1] *= orig_shape[0] | |
| keypoints = keypoints.long() | |
| _, _, connectivity = get_coco_keypoints() | |
| connectivity = np.array(connectivity) | |
| visible = (keypoints[:, 2] > .5) | |
| # Remove padding from keypoints before visualization | |
| keypoints[:, 0] += min(x0, 0) | |
| keypoints[:, 1] += min(y0, 0) | |
| im_with_kp = draw_keypoints( | |
| crop_box(im, box), keypoints[None], colors="red", connectivity=connectivity, visible=visible[None]) | |
| crop_box(im, box).copy_(im_with_kp) | |
| return im | |