# Code taken and adapted from https://github.com/wagnermoritz/GSE from vlm_eval.attacks.attack import Attack import torch import math import torch.nn.functional as F class StrAttack(Attack): def __init__(self, model, *args, targeted=False, img_range=(0, 1), kappa=0, max_iter=100, ver=False, search_steps=2, max_c=1e10, rho=1, mask_out='none', c=2.5, retrain=False, **kwargs): ''' Implementation of StrAttack: https://arxiv.org/abs/1808.01664 Adapted from https://github.com/KaidiXu/StrAttack args: model: Callable, PyTorch classifier. targeted: Bool, given label is used as a target label if True. img_range: Tuple of ints/floats, lower and upper bound of image entries. max_iter: Int, number of iterations. ver: Bool, print progress if True. search_steps: Int, number of binary search steps. max_c: Float, upper bound for regularizaion parameter. rho: Float, ADMM parameter. c: Float, initial regularization parameter. ''' super().__init__(model, targeted=targeted, img_range=img_range) self.max_iter = max_iter self.ver = ver self.search_steps = search_steps self.max_c = max_c self.rho = rho self.c = c self.retrain = retrain if mask_out != 'none': self.mask_out = mask_out else: self.mask_out = None def _set_mask(self, data): mask = torch.ones_like(data) if self.mask_out == 'context': mask[:, :-1, ...] = 0 elif self.mask_out == 'query': mask[:, -1, ...] = 0 elif isinstance(self.mask_out, int): mask[:, self.mask_out, ...] = 0 elif self.mask_out is None: pass else: raise NotImplementedError(f'Unknown mask_out: {self.mask_out}') return mask def __call__(self, imgs, *args, **kwargs): ''' Perform StrAttack on a batch of images x with corresponding labels y. args: x: Tensor of shape [B, C, H, W], batch of images. Returns a tensor of the same shape as x containing adversarial examples ''' for param in self.model.model.parameters(): param.requires_grad = False c_ = self.c imgs = imgs.to(self.device) sh = imgs.shape batch_size = sh[1] mask_out = self._set_mask(imgs) alpha, tau, gamma = 5, 2, 1 eps = torch.full_like(imgs, 1.0) * mask_out # 16 for imagenet, 2 for CIFAR and MNIST filterSize = 8 if sh[-1] > 32 else 2 stride = filterSize # convolution kernel used to compute norm of each group slidingM = torch.ones((1, sh[3], filterSize, filterSize), device=self.device) cs = torch.ones(batch_size, device=self.device) * c_ lower_bound = torch.zeros(batch_size) upper_bound = torch.ones(batch_size) * self.max_c o_bestl2 = torch.full_like(torch.randn(batch_size), 1e10, dtype=torch.float) o_bestscore = torch.full_like(o_bestl2, -1, dtype=torch.float) o_bestattack = imgs.clone() o_besty = torch.ones_like(imgs) for step in range(self.search_steps): bestl2 = torch.full_like(o_bestl2, 1e10, dtype=torch.float) bestscore = torch.full_like(o_bestl2, -1, dtype=torch.float) z, v, u, s = (torch.zeros_like(imgs) for _ in range(4)) for iter_ in range(self.max_iter): if (not iter_%10 or iter_ == self.max_iter - 1) and self.ver: print(f'\rIteration: {iter_+1}/{self.max_iter}, ' + f'Search Step: {step+1}/{self.search_steps}', end='') # first update step (7) / Proposition 1 delta = self.rho / (self.rho + 2 * gamma) * (z - u / self.rho) b = (z - s / self.rho) * mask_out tmp = torch.minimum(self.img_range[1] - imgs, eps) w = torch.where(b.view(*sh) > tmp.view(*sh), tmp, b) # creating issue (1x5x'5'x3x224x224 instead of 1x5x1x3x224x224) tmp = torch.maximum(self.img_range[0] - imgs, -eps) w = torch.where(b.view(*sh) < tmp.view(*sh), tmp, w) c = z - v / self.rho cNorm = torch.sqrt(F.conv2d(c.view(sh[1], sh[3], sh[4], sh[5]) ** 2, slidingM, stride=stride)) cNorm = torch.where(cNorm == 0, torch.full_like(cNorm, 1e-12), cNorm) cNorm = F.interpolate(cNorm, scale_factor=filterSize) y = torch.clamp((1 - tau / (self.rho * cNorm.unsqueeze(0).unsqueeze(3))), 0) * c # second update step (8) / equation (15) z_grads = self.__get_z_grad(imgs, z.clone(), cs) eta = alpha * math.sqrt(iter_ + 1) coeff = (1 / (eta + 3 * self.rho)) z = coeff * (eta * z + self.rho * (delta + w + y) + u + s + v - z_grads) # third update step (9) u = u + self.rho * (delta - z) * mask_out v = v + self.rho * (y - z) * mask_out s = s + self.rho * (w - z) * mask_out # get info for binary search x = imgs + y * mask_out l2s = torch.sum((z ** 2).reshape(z.size(1), -1), dim=-1) for i, (l2, x_) in enumerate(zip(l2s, x.squeeze(0))): if l2 < bestl2[i]: bestl2[i] = l2 if l2 < o_bestl2[i]: o_bestl2[i] = l2 o_bestattack[:,i] = x_.detach().unsqueeze(0).clone() o_besty[:,i] = y[:,i] for i in range(batch_size): lower_bound[i] = max(lower_bound[i], cs[i]) if upper_bound[i] < 1e9: cs[i] = (lower_bound[i] + upper_bound[i]) / 2 else: cs[i] *= 5 del v, u, s, z_grads, w, tmp if self.retrain: cs = torch.full_like(o_bestl2, 5.0, dtype=torch.float) zeros = torch.zeros_like(imgs) for step in range(8): bestl2 = torch.full_like(cs, 1e10, dtype=torch.float, device=self.device) bestscore = torch.full_like(cs, -1, dtype=torch.float, device=self.device) Nz = o_besty[o_besty != 0] e0 = torch.quantile(Nz.abs(), 0.03) A2 = torch.where(o_besty.abs() <= e0, 0, 1) z1 = o_besty u1 = torch.zeros_like(imgs) tmpc = self.rho / (self.rho + gamma / 100) for j in range(100): if self.ver and not j % 10: print(f'\rRetrain iteration: {step+1}/8, ' + f'Search Step: {j+1}/200', end='') tmpA = (z1 - u1) * tmpc tmpA1 = torch.where(o_besty.abs() <= e0, zeros, tmpA) cond = torch.logical_and(tmpA > torch.minimum(self.img_range[1] - imgs, eps), o_besty.abs() > e0) tmpA2 = torch.where(cond, torch.minimum(self.img_range[1] - imgs, eps), tmpA1) cond = torch.logical_and(tmpA < torch.maximum(self.img_range[0] - imgs, -eps), o_besty.abs() > e0) deltA = torch.where(cond, torch.maximum(self.img_range[0] - imgs, -eps), tmpA2) x = imgs + deltA * mask_out grad = self.__get_z_grad(imgs, deltA, cs) stepsize = 1 / (alpha + 2 * self.rho) z1 = stepsize * (alpha * z1 * self.rho * (deltA + u1) - grad * A2) u1 = u1 + deltA - z1 for i, (l2, x_) in enumerate(zip(l2s, x.squeeze(0))): if l2 < bestl2[i]: bestl2[i] = l2 #bestscore[i] = asc if l2 < o_bestl2[i]: o_bestl2[i] = l2 #o_bestscore[i] = asc o_bestattack[:,i] = x_.detach().unsqueeze(0).clone() o_besty[i] = deltA[i] for i in range(batch_size): if (bestscore[i] != -1 and bestl2[i] == o_bestl2[i]): upper_bound[i] = min(upper_bound[i], cs[i]) if upper_bound[i] < 1e9: cs[i] = (lower_bound[i] + upper_bound[i]) / 2 else: lower_bound[i] = max(lower_bound[i], cs[i]) if upper_bound[i] < 1e9: cs[i] = (lower_bound[i] + upper_bound[i]) / 2 else: cs[i] *= 5 if self.ver: print('') return (o_bestattack * mask_out).detach() def __get_z_grad(self, imgs, z, cs): ''' Compute and return gradient of loss wrt. z. ''' z.requires_grad = True tmp = self.model(z + imgs).sum() if self.targeted else -self.model(z + imgs).sum() loss = torch.mean(cs.to(self.device) * tmp) z_grad_data = torch.autograd.grad(loss, [z])[0].detach() z.detach_() return z_grad_data