Spaces:
Runtime error
Runtime error
| import json | |
| from pathlib import Path | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from torchvision import models, transforms | |
| from torchvision.models.feature_extraction import create_feature_extractor | |
| from transformers import ViTForImageClassification | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| labels = json.loads(Path("labels.json").read_text()) | |
| # Load ResNet-50 | |
| resnet50 = models.resnet50(pretrained=True).to(device) | |
| resnet50.eval() | |
| # Create ResNet feature extractor | |
| feature_extractor = create_feature_extractor(resnet50, return_nodes=["layer4", "fc"]) | |
| fc_layer_weights = resnet50.fc.weight | |
| # Load ViT | |
| vit = ViTForImageClassification.from_pretrained("raedinkhaled/vit-base-mri").to( | |
| device | |
| ) | |
| vit.eval() | |
| normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| preprocess = transforms.Compose( | |
| [transforms.Resize((224, 224)), transforms.ToTensor(), normalize] | |
| ) | |
| examples = sorted([f.as_posix() for f in Path("examples").glob("*")]) | |
| def get_cam(img_tensor): | |
| output = feature_extractor(img_tensor) | |
| cnn_features = output["layer4"].squeeze() | |
| class_id = output["fc"].argmax() | |
| cam = fc_layer_weights[class_id].matmul(cnn_features.flatten(1)) | |
| cam = cam.reshape(cnn_features.shape[1], cnn_features.shape[2]) | |
| return cam.cpu().numpy(), labels[class_id] | |
| def get_attention_mask(img_tensor): | |
| result = vit(img_tensor, output_attentions=True) | |
| class_id = result[0].argmax() | |
| attention_probs = torch.stack(result[1]).squeeze(1) | |
| # Average the attention at each layer over all heads | |
| attention_probs = torch.mean(attention_probs, dim=1) | |
| residual = torch.eye(attention_probs.size(-1)).to(device) | |
| attention_probs = 0.5 * attention_probs + 0.5 * residual | |
| # normalize by layer | |
| attention_probs = attention_probs / attention_probs.sum(dim=-1).unsqueeze(-1) | |
| attention_rollout = attention_probs[0] | |
| for i in range(1, attention_probs.size(0)): | |
| attention_rollout = torch.matmul(attention_probs[i], attention_rollout) | |
| # Attention between cls token and patches | |
| mask = attention_rollout[0, 1:] | |
| mask_size = np.sqrt(mask.size(0)).astype(int) | |
| mask = mask.reshape(mask_size, mask_size) | |
| return mask.cpu().numpy(), labels[class_id] | |
| def plot_mask_on_image(image, mask): | |
| # min-max normalization | |
| mask = (mask - mask.min()) / mask.max() | |
| mask = (255 * mask).astype(np.uint8) | |
| mask = cv2.resize(mask, image.size) | |
| heatmap = cv2.applyColorMap(mask, cv2.COLORMAP_JET) | |
| result = heatmap * 0.3 + np.array(image) * 0.5 | |
| return result.astype(np.uint8) | |
| def inference(img): | |
| img_tensor = preprocess(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| cam, resnet_label = get_cam(img_tensor) | |
| attention_mask, vit_label = get_attention_mask(img_tensor) | |
| cam_result = plot_mask_on_image(img, cam) | |
| rollout_result = plot_mask_on_image(img, attention_mask) | |
| return resnet_label, cam_result, vit_label, rollout_result | |
| if __name__ == "__main__": | |
| interface = gr.Interface( | |
| fn=inference, | |
| inputs=gr.inputs.Image(type="pil", label="Input Image"), | |
| outputs=[ | |
| gr.outputs.Label(num_top_classes=1, type="auto", label="ResNet Label"), | |
| gr.outputs.Image(type="auto", label="ResNet CAM"), | |
| gr.outputs.Label(num_top_classes=1, type="auto", label="ViT Label"), | |
| gr.outputs.Image(type="auto", label="raedinkhaled/vit-base-mri CAM"), | |
| ], | |
| examples=examples, | |
| title="Transformer Explainability On Our Pre Trained Model", | |
| live=True, | |
| ) | |
| interface.launch() |