# Code taken and adapted from https://github.com/wagnermoritz/GSE import torch import torchvision import math import torch.nn.functional as F from vlm_eval.attacks.attack import Attack # required input size : batch_size x num_media x num_frames x channels x height x width class GSEAttack(Attack): def __init__(self, model, *args, mask_out='none',ver=False, img_range=(-1, 1), search_steps=4, targeted=False, sequential=False, search_factor=2, gb_size=5, sgm=1.5, mu=1, sigma=0.0025, iters=200, k_hat=10, q=0.25, **kwargs): ''' Implementation of the GSE attack. args: model: Callable, PyTorch classifier. mask_out: Masks out context images if set to context, query images if set to query and none if set to none. ver: Bool, print progress if True. img_range: Tuple of ints/floats, lower and upper bound of image entries. search_steps: Int, number of steps for line search on the trade-off parameter. targeted: Bool, given label is used as a target label if True. sequential: Bool, perturbations are computed sequentially for all images in the batch if True. For fair comparison to Homotopy attack. search_factor: Float, factor to increase/decrease the trade-off parameter until an upper/lower bound for the line search is found. gb_size: Odd int, size of the Gaussian blur kernel. sgm: Float, sigma of the gaussian blur kernel mu: Float, trade-off parameter for 2-norm regularization. sigma: Float, step size iters: Int, number of iterations. k_hat: Int, number of iterations before transitioning to NAG. q: Float, inverse of increase factor for adjust_lambda. ''' super().__init__(model, img_range=img_range, targeted=targeted) self.ver = ver self.search_steps = search_steps self.sequential = sequential self.search_factor = search_factor self.gb_size = gb_size self.sgm = sgm self.mu = mu self.sigma = sigma self.iters = iters self.k_hat = k_hat self.q = q if mask_out != 'none': self.mask_out = mask_out else: self.mask_out = None def adjust_lambda(self, lam, noise): ''' Adjust trade-off parameters (lambda) to update search space. ''' x = noise.detach().clone().abs().mean(dim=1, keepdim=True).sign() gb = torchvision.transforms.GaussianBlur((self.gb_size, self.gb_size), sigma=self.sgm) x = gb(x) + 1 x = torch.where(x == 1, self.q, x) lam /= x[:, 0, :, :] return lam def section_search(self, x, steps=50): ''' Section search for finding the maximal lambda such that the perturbation is non-zero after the first iteration. ''' noise = torch.zeros_like(x, requires_grad=True) # the shape of 'x' is batch_size x num_media x num_frames x Color x height x width loss = (-self.model(x + noise).sum() + self.mu * torch.norm(noise.view(x.size(1), x.size(3), x.size(4), x.size(5)), p=2, dim=(1,2,3)).sum()) grad = torch.autograd.grad(loss, [noise])[0].detach() noise.detach_() ones = torch.ones_like(x.view(x.size(1), x.size(3), x.size(4), x.size(5)))[:, 0, :, :] # define upper and lower bound for line search lb = torch.zeros((x.size(1),), dtype=torch.float, device=self.device).view(-1, 1, 1) ub = lb.clone() + 0.001 mask = torch.norm(self.prox(grad.clone().view(x.size(1),x.size(3),x.size(4),x.size(5)) * self.sigma, ones * ub * self.sigma), p=0, dim=(1,2,3)) != 0 while mask.any(): ub[mask] *= 2 mask = torch.norm(self.prox(grad.clone().view(x.size(1),x.size(3),x.size(4),x.size(5)) * self.sigma, ones * ub * self.sigma), p=0, dim=(1,2,3)) != 0 # perform search for _ in range(steps): cur = (ub + lb) / 2 mask = torch.norm(self.prox(grad.clone().view(x.size(1),x.size(3),x.size(4),x.size(5)) * self.sigma, ones * cur * self.sigma), p=0, dim=(1,2,3)) == 0 ub[mask] = cur[mask] mask = torch.logical_not(mask) lb[mask] = cur[mask] cur = (lb + ub).view(-1) / 2 return 0.01 * cur def __call__(self, x, y, *args, **kwargs): ''' Call the attack for a batch of images x or sequentially for all images in x depending on self.sequential. args: x: Tensor of shape [B, C, H, W], batch of images. y: Tensor of shape [B], batch of labels. Returns a tensor of the same shape as x containing adversarial examples ''' if self.sequential: result = x.clone() for i, (x_, y_) in enumerate(zip(x, y)): result[i] = self.perform_att(x_.unsqueeze(0), y_.unsqueeze(0), mu=self.mu, sigma=self.sigma, k_hat=self.k_hat).detach() return result else: return self.perform_att(x, y, mu=self.mu, sigma=self.sigma, k_hat=self.k_hat) def _set_mask(self, data): mask = torch.ones_like(data) if self.mask_out == 'context': mask[:, :-1, ...] = 0 elif self.mask_out == 'query': mask[:, -1, ...] = 0 elif isinstance(self.mask_out, int): mask[:, self.mask_out, ...] = 0 elif self.mask_out is None: pass else: raise NotImplementedError(f'Unknown mask_out: {self.mask_out}') return mask def perform_att(self, x, mu, sigma, k_hat): ''' Perform GSE attack on a batch of images x with corresponding labels y. ''' x = x.to(self.device) B, C, H, W = x.shape[1], x.shape[3], x.shape[4], x.shape[5] # Input is of the shape Batch x Num_media x num_frames x colors x height x width lams = self.section_search(x) mask_out = self._set_mask(x).view(B,C,H,W) # save x, y, and lams for resetting them at the beginning of every # section search step save_x = x.clone() save_lams = lams.clone() # upper and lower bounds for section learch ub_lams = torch.full_like(lams, torch.inf) lb_lams = torch.full_like(lams, 0.0) # tensor for saving succesful adversarial examples in inner loop result = x.clone() # tensor for saving best adversarial example so far result2 = x.clone() best_l0 = torch.full((B,), torch.inf, device=self.device).type(x.type()) # section search for step in range(self.search_steps): x = save_x.clone() lams = save_lams.clone() lam = torch.ones_like(x.view(B, C, H, W))[:, 0, :, :] * lams.view(-1, 1, 1) # tensor for tracking for which images adv. examples have been found active = torch.ones(B, dtype=bool, device=self.device) # set initial perturbation to zero noise = torch.zeros_like(x, requires_grad = True) noise_old = noise.clone() lr = 1 # attack for j in range(self.iters): if self.ver: print(f'\rSearch step {step + 1}/{self.search_steps}, ' + f'Prox.Grad. Iteration {j + 1}/{self.iters}, ' + f'Images left: {x.shape[1]}', end='') if len(x) == 0: break self.model.model.zero_grad() loss = (-self.model(x + noise).sum() + mu * (torch.norm(noise.view(B, C, H, W), p=2, dim=(1,2,3)) ** 2).sum()) noise_grad_data = torch.autograd.grad(loss, [noise])[0].detach().view(B, C, H, W) #print(f"{loss} {(torch.norm(noise.view(B, C, H, W), p=2, dim=(1,2,3)) ** 2).sum()}") with torch.no_grad(): noise_grad_data = noise_grad_data * mask_out # Mask_out shape B x C x H x W lr_ = (1 + math.sqrt(1 + 4 * lr**2)) / 2 if j == k_hat: lammask = (lam > lams.view(-1, 1, 1))[:, None, :, :] lammask = lammask.repeat(1, C, 1, 1) noise_old = noise.clone() if j < k_hat: noise = noise - sigma * noise_grad_data.view(1, B, 1, C, H, W) noise = self.prox(noise.view(B, C, H, W), lam * sigma).view(1, B, 1, C, H, W) noise_tmp = noise.clone() noise = lr / lr_ * noise + (1 - (lr/ lr_)) * noise_old noise_old = noise_tmp.clone() lam = self.adjust_lambda(lam, noise.view(B, C, H, W)) else: noise = noise - sigma * noise_grad_data.view(1, B, 1, C, H, W) noise_tmp = noise.clone() noise = lr / lr_ * noise + (1 - (lr/ lr_)) * noise_old noise_old = noise_tmp.clone() noise[lammask.view(1, B, 1, C, H, W)] = 0 # clamp adv. example to valid range x_adv = torch.clamp(x + noise, *self.img_range) noise = x_adv - x lr = lr_ noise.requires_grad = True # section search # no adv. example found => decrease upper bound and current lambda # adv. example found => save it if the "0-norm" is better than of the # previous adv. example, increase lower bound and current lambda for i in range(B): if active[i]: ub_lams[i] = save_lams[i] save_lams[i] = 0.95 * lb_lams[i] + 0.05 * save_lams[i] else: print("here") l0 = self.l20((result[i] - save_x[i]).unsqueeze(0)).to(self.device) if l0 < best_l0[i]: best_l0[i] = l0 result2[i] = result[i].clone() if torch.isinf(ub_lams[i]): lb_lams[i] = save_lams[i] save_lams[i] *= self.search_factor else: lb_lams[i] = save_lams[i] save_lams[i] = (ub_lams[i] + save_lams[i]) / 2 if self.ver: print('') return x_adv def extract_patches(self, x): ''' Extracts and returns all overlapping size by size patches from the image batch x. ''' B, C, _, _ = x.shape size = 8 kernel = torch.zeros((size ** 2, size ** 2)) kernel[range(size**2), range(size**2)] = 1.0 kernel = kernel.view(size**2, 1, size, size) kernel = kernel.repeat(C, 1, 1, 1).to(x.device) out = F.conv2d(x, kernel, groups=C) out = out.view(B, C, size, size, -1) out = out.permute(0, 4, 1, 2, 3) return out.contiguous() def l20(self, x): ''' Computes d_{2,0}(x[i]) for all perturbations x[i] in the batch x as described in section 3.2. ''' B, N, M, C, _, _ = x.shape l20s = [] for b in range(B): for n in range(N): for m in range(M): x_ = x[b, n, m] # Select the specific perturbation x[b, n, m] patches = self.extract_patches(x_.unsqueeze(0)) # Add unsqueeze to match 6D input l2s = torch.norm(patches, p=2, dim=(2,3,4)) l20s.append((l2s != 0).float().sum().item()) return torch.tensor(l20s) def prox(self, grad_loss_noise, lam): ''' Computes the proximal operator of the 1/2-norm of the gradient of the adversarial loss wrt current noise. ''' lam = lam[:, None, :, :] sh = list(grad_loss_noise.shape) lam = lam.expand(*sh) p_lam = (54 ** (1 / 3) / 4) * lam ** (2 / 3) mask1 = (grad_loss_noise > p_lam) mask2 = (torch.abs(grad_loss_noise) <= p_lam) mask3 = (grad_loss_noise < -p_lam) mask4 = mask1 + mask3 phi_lam_x = torch.arccos((lam / 8) * (torch.abs(grad_loss_noise) / 3) ** (-1.5)) grad_loss_noise[mask4] = ((2 / 3) * torch.abs(grad_loss_noise[mask4]) * (1 + torch.cos((2 * math.pi) / 3 - (2 * phi_lam_x[mask4]) / 3))).to(torch.float32) grad_loss_noise[mask3] = -grad_loss_noise[mask3] grad_loss_noise[mask2] = 0 return grad_loss_noise