Spaces:
Sleeping
Sleeping
| # Helper function for extracting features from pre-trained models | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| from torch.autograd import Variable | |
| from util.feature_extraction_utils import warp_image, normalize_batch | |
| from util.prepare_utils import get_ensemble, extract_features | |
| from lpips_pytorch import LPIPS | |
| from tqdm import trange | |
| tensor_transform = transforms.ToTensor() | |
| pil_transform = transforms.ToPILImage() | |
| class Attack(nn.Module): | |
| def __init__( | |
| self, | |
| models, | |
| dim, | |
| attack_type, | |
| eps, | |
| c_sim=0.5, | |
| net_type="alex", | |
| lr=0.05, | |
| n_iters=100, | |
| noise_size=0.001, | |
| n_starts=10, | |
| c_tv=None, | |
| sigma_gf=None, | |
| kernel_size_gf=None, | |
| combination=False, | |
| warp=False, | |
| theta_warp=None, | |
| V_reduction=None, | |
| ): | |
| super(Attack, self).__init__() | |
| self.extractor_ens = get_ensemble( | |
| models, sigma_gf, kernel_size_gf, combination, V_reduction, warp, theta_warp | |
| ) | |
| # print("There are '{}'' models in the attack ensemble".format(len(self.extractor_ens))) | |
| self.dim = dim | |
| self.eps = eps | |
| self.c_sim = c_sim | |
| self.net_type = net_type | |
| self.lr = lr | |
| self.n_iters = n_iters | |
| self.noise_size = noise_size | |
| self.n_starts = n_starts | |
| self.c_tv = None | |
| self.attack_type = attack_type | |
| self.warp = warp | |
| self.theta_warp = theta_warp | |
| if self.attack_type == "lpips": | |
| self.lpips_loss = LPIPS(self.net_type) | |
| def execute(self, images, dir_vec, direction): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.lpips_loss.to(device) | |
| images = Variable(images).to(device) | |
| dir_vec = dir_vec.to(device) | |
| # take norm wrt dim | |
| dir_vec_norm = dir_vec.norm(dim=2).unsqueeze(2).to(device) | |
| dist = torch.zeros(images.shape[0]).to(device) | |
| adv_images = images.detach().clone() | |
| if self.warp: | |
| self.face_img = warp_image(images, self.theta_warp) | |
| for start in range(self.n_starts): | |
| # update adversarial images old and distance old | |
| adv_images_old = adv_images.detach().clone() | |
| dist_old = dist.clone() | |
| # add noise to initialize ( - noise_size, noise_size) | |
| noise_uniform = Variable( | |
| 2 * self.noise_size * torch.rand(images.size()) - self.noise_size | |
| ).to(device) | |
| adv_images = Variable( | |
| images.detach().clone() + noise_uniform, requires_grad=True | |
| ).to(device) | |
| for i in trange(self.n_iters): | |
| adv_features = extract_features( | |
| adv_images, self.extractor_ens, self.dim | |
| ).to(device) | |
| # normalize feature vectors in ensembles | |
| loss = direction * torch.mean( | |
| (adv_features - dir_vec) ** 2 / dir_vec_norm | |
| ) | |
| if self.c_tv is not None: | |
| tv_out = self.total_var_reg(images, adv_images) | |
| loss -= self.c_tv * tv_out | |
| if self.attack_type == "lpips": | |
| lpips_out = self.lpips_reg(images, adv_images) | |
| loss -= self.c_sim * lpips_out | |
| grad = torch.autograd.grad(loss, [adv_images]) | |
| adv_images = adv_images + self.lr * grad[0].sign() | |
| perturbation = adv_images - images | |
| if self.attack_type == "sgd": | |
| perturbation = torch.clamp( | |
| perturbation, min=-self.eps, max=self.eps | |
| ) | |
| adv_images = images + perturbation | |
| adv_images = torch.clamp(adv_images, min=0, max=1) | |
| adv_features = extract_features( | |
| adv_images, self.extractor_ens, self.dim | |
| ).to(device) | |
| dist = torch.mean((adv_features - dir_vec) ** 2 / dir_vec_norm, dim=[1, 2]) | |
| if direction == 1: | |
| adv_images[dist < dist_old] = adv_images_old[dist < dist_old] | |
| dist[dist < dist_old] = dist_old[dist < dist_old] | |
| else: | |
| adv_images[dist > dist_old] = adv_images_old[dist > dist_old] | |
| dist[dist > dist_old] = dist_old[dist > dist_old] | |
| return adv_images.detach().cpu() | |
| def lpips_reg(self, images, adv_images): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if self.warp: | |
| face_adv = warp_image(adv_images, self.theta_warp) | |
| lpips_out = self.lpips_loss( | |
| normalize_batch(self.face_img).to(device), | |
| normalize_batch(face_adv).to(device), | |
| )[0][0][0][0] / (2 * adv_images.shape[0]) | |
| lpips_out += self.lpips_loss( | |
| normalize_batch(images).to(device), | |
| normalize_batch(adv_images).to(device), | |
| )[0][0][0][0] / (2 * adv_images.shape[0]) | |
| else: | |
| lpips_out = ( | |
| self.lpips_loss( | |
| normalize_batch(images).to(device), | |
| normalize_batch(adv_images).to(device), | |
| )[0][0][0][0] | |
| / adv_images.shape[0] | |
| ) | |
| return lpips_out | |
| def total_var_reg(images, adv_images): | |
| perturbation = adv_images - images | |
| tv = torch.mean( | |
| torch.abs(perturbation[:, :, :, :-1] - perturbation[:, :, :, 1:]) | |
| ) + torch.mean( | |
| torch.abs(perturbation[:, :, :-1, :] - perturbation[:, :, 1:, :]) | |
| ) | |
| return tv | |