Spaces:
Runtime error
Runtime error
| import torch | |
| import tops | |
| import sys | |
| from contextlib import redirect_stdout | |
| from torch_fidelity.sample_similarity_lpips import NetLinLayer, URL_VGG16_LPIPS, VGG16features, normalize_tensor, spatial_average | |
| class SampleSimilarityLPIPS(torch.nn.Module): | |
| SUPPORTED_DTYPES = { | |
| 'uint8': torch.uint8, | |
| 'float32': torch.float32, | |
| } | |
| def __init__(self): | |
| super().__init__() | |
| self.chns = [64, 128, 256, 512, 512] | |
| self.L = len(self.chns) | |
| self.lin0 = NetLinLayer(self.chns[0], use_dropout=True) | |
| self.lin1 = NetLinLayer(self.chns[1], use_dropout=True) | |
| self.lin2 = NetLinLayer(self.chns[2], use_dropout=True) | |
| self.lin3 = NetLinLayer(self.chns[3], use_dropout=True) | |
| self.lin4 = NetLinLayer(self.chns[4], use_dropout=True) | |
| self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] | |
| with redirect_stdout(sys.stderr): | |
| fp = tops.download_file(URL_VGG16_LPIPS) | |
| state_dict = torch.load(fp, map_location="cpu") | |
| self.load_state_dict(state_dict) | |
| self.net = VGG16features() | |
| self.eval() | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| mean_rescaled = (1 + torch.tensor([-.030, -.088, -.188]).view(1, 3, 1, 1)) * 255 / 2 | |
| inv_std_rescaled = 2 / (torch.tensor([.458, .448, .450]).view(1, 3, 1, 1) * 255) | |
| self.register_buffer("mean", mean_rescaled) | |
| self.register_buffer("std", inv_std_rescaled) | |
| def normalize(self, x): | |
| # torchvision values in range [0,1] mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225] | |
| x = (x.float() - self.mean) * self.std | |
| return x | |
| def resize(x, size): | |
| if x.shape[-1] > size and x.shape[-2] > size: | |
| x = torch.nn.functional.interpolate(x, (size, size), mode='area') | |
| else: | |
| x = torch.nn.functional.interpolate(x, (size, size), mode='bilinear', align_corners=False) | |
| return x | |
| def lpips_from_feats(self, feats0, feats1): | |
| diffs = {} | |
| for kk in range(self.L): | |
| diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 | |
| res = [spatial_average(self.lins[kk].model(diffs[kk])) for kk in range(self.L)] | |
| val = sum(res) | |
| return val | |
| def get_feats(self, x): | |
| assert x.dim() == 4 and x.shape[1] == 3, 'Input 0 is not Bx3xHxW' | |
| if x.shape[-2] < 16 or x.shape[-1] < 16: # Resize images < 16x16 | |
| f = 2 | |
| size = tuple([int(f*_) for _ in x.shape[-2:]]) | |
| x = torch.nn.functional.interpolate(x, size=size, mode="bilinear", align_corners=False) | |
| in0_input = self.normalize(x) | |
| outs0 = self.net.forward(in0_input) | |
| feats = {} | |
| for kk in range(self.L): | |
| feats[kk] = normalize_tensor(outs0[kk]) | |
| return feats | |
| def forward(self, in0, in1): | |
| feats0 = self.get_feats(in0) | |
| feats1 = self.get_feats(in1) | |
| return self.lpips_from_feats(feats0, feats1), feats0, feats1 | |