# Code adapted from https://github.com/wagnermoritz/GSE from vlm_eval.attacks.attack import Attack import torch import math import time class SAIF(Attack): def __init__(self, model, *args, targeted=False, img_range=(-1, 1), steps=200, r0=1, ver=False, k=10000, eps=16./255., mask_out='none', **kwargs): ''' Adapted from: https://github.com/wagnermoritz/GSE/tree/main Implementation of the sparse Frank-Wolfe attack SAIF https://arxiv.org/pdf/2212.07495.pdf args: model: Callable, PyTorch classifier. img_range: Tuple of ints/floats, lower and upper bound of image entries. targeted: Bool, given label is used as a target label if True. steps: Int, number of FW iterations. r0: Int, parameter for step size computation. ver: Bool, print progress if True. ''' super().__init__(model, targeted=targeted, img_range=img_range) self.steps = steps self.r0 = r0 self.loss_fn = torch.nn.CrossEntropyLoss() self.ver = ver self.k = k self.eps = eps if mask_out != 'none': self.mask_out = mask_out else: self.mask_out = None def _set_mask(self, data): mask = torch.ones_like(data) if self.mask_out == 'context': mask[:, :-1, ...] = 0 elif self.mask_out == 'query': mask[:, -1, ...] = 0 elif isinstance(self.mask_out, int): mask[:, self.mask_out, ...] = 0 elif self.mask_out is None: pass else: raise NotImplementedError(f'Unknown mask_out: {self.mask_out}') return mask def __call__(self, x): ''' Perform the attack on a batch of images x. args: x: Tensor of shape [B, C, H, W], batch of images. k: Int, sparsity parameter, eps: Float, perturbation magnitude parameter. Returns a tensor of the same shape as x containing adversarial examples. ''' assert x.shape[0] == 1, "Only support batch size 1 for now" for param in self.model.model.parameters(): param.requires_grad = False B, C, H, W = x.shape[1], x.shape[3], x.shape[4], x.shape[5] x = x.to(self.device) batchidx = torch.arange(B).view(-1, 1) mask_out = self._set_mask(x) # compute p_0 and s_0 x_ = x.clone() x_.requires_grad = True out = self.model(x_) loss = -out.sum() if not self.targeted else out.sum() x__grad = torch.autograd.grad(loss, [x_])[0].detach() * mask_out p = -self.eps * x__grad.sign() p = p.detach().half() ksmallest = torch.topk(-x__grad.view(B, -1), self.k, dim=1)[1] ksmask = torch.zeros((B, C * H * W), device=self.device) ksmask[batchidx, ksmallest] = 1 s = torch.logical_and(ksmask.view(*x.shape), x__grad < 0).float() s = s.detach().half() r = self.r0 for t in range(self.steps): if self.ver: print(f'\r Iteration {t+1}/{self.steps}', end='') p.requires_grad = True s.requires_grad = True D = self.Loss_fn(x, s, p, mask_out) D.backward() mp = p.grad * mask_out ms = s.grad * mask_out with torch.no_grad(): # inf-norm LMO v = (-self.eps * mp.sign()).half() # 1-norm LMO ksmallest = torch.topk(-ms.view(B, -1), self.k, dim=1)[1] ksmask = torch.zeros((B, C * H * W), device=self.device) ksmask[batchidx, ksmallest] = 1 ksmask = ksmask.view(*x.shape) * mask_out z = torch.logical_and(ksmask, ms < 0).float().half() # update stepsize until primal progress is made mu = 1 / (2 ** r * math.sqrt(t + 1)) progress_condition = (self.Loss_fn(x, s + mu * (z - s), p + mu * (v - p), mask_out) > D) while progress_condition: r += 1 if r >= 50: break mu = 1 / (2 ** r * math.sqrt(t + 1)) progress_condition = (self.Loss_fn(x, s + mu * (z - s), p + mu * (v - p), mask_out) > D) p = p + mu * (v - p) s = s + mu * (z - s) x_adv = torch.clamp(x + p, *self.img_range) p = x_adv - x if self.ver and t % 10 == 0: print(f" Loss: {D}") if self.ver: print('') return (x + s * p * mask_out).detach(), torch.norm(s*p,p=0).item() def Loss_fn(self, x, s, p, mask_out): out = self.model(x + s * p * mask_out).sum() if self.targeted: return out else: return -out