Spaces:
Runtime error
Runtime error
| import pickle | |
| import torch | |
| import torchvision | |
| from pathlib import Path | |
| from dp2 import utils | |
| import tops | |
| try: | |
| import clip | |
| except ImportError: | |
| print("Could not import clip.") | |
| from torch_fidelity.metric_fid import fid_features_to_statistics, fid_statistics_to_metric | |
| clip_model = None | |
| clip_preprocess = None | |
| def compute_fid_clip( | |
| dataloader, generator, | |
| cache_directory, | |
| data_len=None, | |
| **kwargs | |
| ) -> dict: | |
| """ | |
| FID CLIP following the description in The Role of ImageNet Classes in Frechet Inception Distance, Thomas Kynkaamniemi et al. | |
| Args: | |
| n_samples (int): Creates N samples from same image to calculate stats | |
| """ | |
| global clip_model, clip_preprocess | |
| if clip_model is None: | |
| clip_model, preprocess = clip.load("ViT-B/32", device="cpu") | |
| normalize_fn = preprocess.transforms[-1] | |
| img_mean = normalize_fn.mean | |
| img_std = normalize_fn.std | |
| clip_model = tops.to_cuda(clip_model.visual) | |
| clip_preprocess = tops.to_cuda(torch.nn.Sequential( | |
| torchvision.transforms.Resize((224, 224), interpolation=torchvision.transforms.InterpolationMode.BICUBIC), | |
| torchvision.transforms.Normalize(img_mean, img_std) | |
| )) | |
| cache_directory = Path(cache_directory) | |
| if data_len is None: | |
| data_len = len(dataloader)*dataloader.batch_size | |
| fid_cache_path = cache_directory.joinpath("fid_stats_clip.pkl") | |
| has_fid_cache = fid_cache_path.is_file() | |
| if not has_fid_cache: | |
| fid_features_real = torch.zeros(data_len, 512, dtype=torch.float32, device=tops.get_device()) | |
| fid_features_fake = torch.zeros(data_len, 512, dtype=torch.float32, device=tops.get_device()) | |
| eidx = 0 | |
| n_samples_seen = 0 | |
| for batch in utils.tqdm_(iter(dataloader), desc="Computing FID CLIP."): | |
| sidx = eidx | |
| eidx = sidx + batch["img"].shape[0] | |
| n_samples_seen += batch["img"].shape[0] | |
| with torch.cuda.amp.autocast(tops.AMP()): | |
| fakes = generator(**batch)["img"] | |
| real_data = batch["img"] | |
| fakes = utils.denormalize_img(fakes) | |
| real_data = utils.denormalize_img(real_data) | |
| if not has_fid_cache: | |
| real_data = clip_preprocess(real_data) | |
| fid_features_real[sidx:eidx] = clip_model(real_data) | |
| fakes = clip_preprocess(fakes) | |
| fid_features_fake[sidx:eidx] = clip_model(fakes) | |
| fid_features_fake = fid_features_fake[:n_samples_seen] | |
| fid_features_fake = tops.all_gather_uneven(fid_features_fake).cpu() | |
| if has_fid_cache: | |
| if tops.rank() == 0: | |
| with open(fid_cache_path, "rb") as fp: | |
| fid_stat_real = pickle.load(fp) | |
| else: | |
| fid_features_real = fid_features_real[:n_samples_seen] | |
| fid_features_real = tops.all_gather_uneven(fid_features_real).cpu() | |
| assert fid_features_real.shape == fid_features_fake.shape | |
| if tops.rank() == 0: | |
| fid_stat_real = fid_features_to_statistics(fid_features_real) | |
| cache_directory.mkdir(exist_ok=True, parents=True) | |
| with open(fid_cache_path, "wb") as fp: | |
| pickle.dump(fid_stat_real, fp) | |
| if tops.rank() == 0: | |
| print("Starting calculation of fid from features of shape:", fid_features_fake.shape) | |
| fid_stat_fake = fid_features_to_statistics(fid_features_fake) | |
| fid_ = fid_statistics_to_metric(fid_stat_real, fid_stat_fake, verbose=False)["frechet_inception_distance"] | |
| return dict(fid_clip=fid_) | |
| return dict(fid_clip=-1) | |