Lowkey_V_0.1 / util /prepare_utils.py
SHREYSH's picture
Upload 29 files
2c62202 verified
raw
history blame
9.67 kB
# Helper function for extracting features from pre-trained models
import os
import sys
import math
import numbers
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn as nn
from PIL import Image
import numpy as np
import torchvision.datasets as datasets
import copy
import time
from util.feature_extraction_utils import feature_extractor, face_extractor, warp_image, de_preprocess, normalize_batch
from lpips_pytorch import LPIPS, lpips
from backbone.model_irse import IR_50, IR_101, IR_152, IR_SE_50, IR_SE_101, IR_SE_152
from backbone.model_resnet import ResNet_50, ResNet_101, ResNet_152
from tqdm import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tensor_transform = transforms.ToTensor()
pil_transform = transforms.ToPILImage()
class ImageFolderWithPaths(datasets.ImageFolder):
"""Custom dataset that includes image file paths. Extends
torchvision.datasets.ImageFolder
"""
# override the __getitem__ method. this is the method that dataloader calls
def __getitem__(self, index):
# this is what ImageFolder normally returns
original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
# the image file path
path = self.imgs[index][0]
# make a new tuple that includes original and the path
tuple_with_path = (original_tuple + (path,))
return tuple_with_path
class GaussianSmoothing(nn.Module):
"""
Apply gaussian smoothing on a
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
in the input using a depthwise convolution.
Arguments:
channels (int, sequence): Number of channels of the input tensors. Output will
have this number of channels as well.
kernel_size (int, sequence): Size of the gaussian kernel.
sigma (float, sequence): Standard deviation of the gaussian kernel.
dim (int, optional): The number of dimensions of the data.
Default value is 2 (spatial).
"""
def __init__(self, channels, kernel_size, sigma, dim=2):
super(GaussianSmoothing, self).__init__()
if isinstance(kernel_size, numbers.Number):
kernel_size = [kernel_size] * dim
if isinstance(sigma, numbers.Number):
sigma = [sigma] * dim
# The gaussian kernel is the product of the
# gaussian function of each dimension.
kernel = 1
meshgrids = torch.meshgrid(
[
torch.arange(size, dtype=torch.float32)
for size in kernel_size
]
)
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
torch.exp(-((mgrid - mean) / std) ** 2 / 2)
# Make sure sum of values in gaussian kernel equals 1.
kernel = kernel / torch.sum(kernel)
# Reshape to depthwise convolutional weight
kernel = kernel.view(1, 1, *kernel.size())
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
self.register_buffer('weight', kernel)
self.groups = channels
if dim == 1:
self.conv = F.conv1d
elif dim == 2:
self.conv = F.conv2d
elif dim == 3:
self.conv = F.conv3d
else:
raise RuntimeError(
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
)
self.pad_size = int(kernel_size[0] / 2)
def forward(self, input):
"""
Apply gaussian filter to input.
Arguments:
input (torch.Tensor): Input to apply gaussian filter on.
Returns:
filtered (torch.Tensor): Filtered output.
"""
input = F.pad(input, (self.pad_size, self.pad_size, self.pad_size, self.pad_size), mode='reflect')
return self.conv(input, weight=self.weight, groups=self.groups)
class dim_reduction(nn.Module):
def __init__(self, V):
super(dim_reduction, self).__init__()
self.V = V
def forward(self, input):
return torch.matmul(input, self.V.to(input.device))
def get_ensemble(models, sigma_gf, kernel_size_gf, combination, V_reduction, warp=False,
theta_warp=None):
# function prepares ensemble of feature extractors
# outputs list of pytorch nn models
feature_extractor_ensemble = []
if sigma_gf != None:
# if apply gaussian filterng during attack
gaussian_filtering = GaussianSmoothing(3, kernel_size_gf, sigma_gf)
if V_reduction == None:
for model in models:
feature_extractor_model = nn.DataParallel(nn.Sequential(gaussian_filtering,
feature_extractor(model=model, warp=warp,
theta_warp=theta_warp))).to(device)
feature_extractor_ensemble.append(feature_extractor_model)
if combination:
feature_extractor_model = nn.DataParallel(
feature_extractor(model=model, warp=warp, theta_warp=theta_warp)).to(device)
feature_extractor_ensemble.append(feature_extractor_model)
else:
for i, model in enumerate(models):
feature_extractor_model = nn.DataParallel(
nn.Sequential(gaussian_filtering, feature_extractor(model=model, warp=warp, theta_warp=theta_warp),
dim_reduction(V_reduction[i]))).to(device)
feature_extractor_ensemble.append(feature_extractor_model)
if combination:
feature_extractor_model = nn.DataParallel(
nn.Sequential(feature_extractor(model=model, warp=warp, theta_warp=theta_warp),
dim_reduction(V_reduction[i]))).to(device)
feature_extractor_ensemble.append(feature_extractor_model)
else:
if V_reduction == None:
for model in models:
feature_extractor_model = nn.DataParallel(
feature_extractor(model=model, warp=warp, theta_warp=theta_warp)).to(device)
feature_extractor_ensemble.append(feature_extractor_model)
else:
for i, model in enumerate(models):
feature_extractor_model = nn.DataParallel(
nn.Sequential(feature_extractor(model=model, warp=warp, theta_warp=theta_warp),
dim_reduction(V_reduction[i]))).to(device)
feature_extractor_ensemble.append(feature_extractor_model)
return feature_extractor_ensemble
def extract_features(imgs, feature_extractor_ensemble, dim):
# function computes mean feature vector of images with ensemble of feature extractors
features = torch.zeros(imgs.shape[0], len(feature_extractor_ensemble), dim)
for i, feature_extractor_model in enumerate(feature_extractor_ensemble):
# batch size, model in ensemble, dim
features_model = feature_extractor_model(imgs)
features[:, i, :] = features_model
return features
def prepare_models(model_backbones,
input_size,
model_roots,
kernel_size_attack,
sigma_attack,
combination,
using_subspace,
V_reduction_root):
backbone_dict = {'IR_50': IR_50(input_size), 'IR_152': IR_152(input_size), 'ResNet_50': ResNet_50(input_size),
'ResNet_152': ResNet_152(input_size)}
print("Loading Attack Backbone Checkpoint '{}'".format(model_roots))
print('=' * 20)
models_attack = []
for i in range(len(model_backbones)):
model = backbone_dict[model_backbones[i]]
model.load_state_dict(torch.load(model_roots[i], map_location=device))
models_attack.append(model)
if using_subspace:
V_reduction = []
for i in range(len(model_backbones)):
V_reduction.append(torch.tensor(np.load(V_reduction_root[i])))
dim = V_reduction[0].shape[1]
else:
V_reduction = None
dim = 512
return models_attack, V_reduction, dim
def prepare_data(query_data_root, target_data_root, freq, batch_size, warp = False, theta_warp = None):
data = datasets.ImageFolder(query_data_root, tensor_transform)
subset_query = list(range(0, len(data), freq))
subset_gallery = [x for x in list(range(0, len(data))) if x not in subset_query]
query_set = torch.utils.data.Subset(data, subset_query)
gallery_set = torch.utils.data.Subset(data, subset_gallery)
if target_data_root != None:
target_data = datasets.ImageFolder(target_data_root, tensor_transform)
target_loader = torch.utils.data.DataLoader(
target_data, batch_size = batch_size)
else:
target_loader = None
query_loader = torch.utils.data.DataLoader(
query_set, batch_size = batch_size)
gallery_loader = torch.utils.data.DataLoader(
gallery_set, batch_size = batch_size)
return query_loader, gallery_loader, target_loader
def prepare_dir_vec(dir_vec_extractor, imgs, dim, combination):
dir_vec = extract_features(imgs, dir_vec_extractor, dim).detach().cpu()
if combination:
dir_vec = torch.repeat_interleave(dir_vec,2,1)
return dir_vec