KC123hello's picture
Upload Files
fc0ff8f verified
# Code taken and adapted from https://github.com/wagnermoritz/GSE
import torch
import math
from vlm_eval.attacks.attack import Attack
class FWnucl(Attack):
def __init__(self, model, *args, iters=200, img_range=(-1, 1), ver=False,
targeted=False, eps=5, mask_out='none',**kwargs):
'''
Implementation of the nuclear group norm attack.
args:
model: Callable, PyTorch classifier.
ver: Bool, print progress if True.
img_range: Tuple of ints/floats, lower and upper bound of image
entries.
targeted: Bool, given label is used as a target label if True.
eps: Float, radius of the nuclear group norm ball.
'''
super().__init__(model, img_range=img_range, targeted=targeted)
self.iters = iters
self.ver = ver
self.eps = eps
self.gr = (math.sqrt(5) + 1) / 2
if mask_out != 'none':
self.mask_out = mask_out
else:
self.mask_out = None
def _set_mask(self, data):
mask = torch.ones_like(data)
if self.mask_out == 'context':
mask[:, :-1, ...] = 0
elif self.mask_out == 'query':
mask[:, -1, ...] = 0
elif isinstance(self.mask_out, int):
mask[:, self.mask_out, ...] = 0
elif self.mask_out is None:
pass
else:
raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
return mask
def __loss_fn(self, x):
'''
Compute loss depending on self.targeted.
'''
if self.targeted:
return -self.model(x).sum()
else:
return self.model(x).sum()
def __call__(self, x, *args, **kwargs):
'''
Perform the nuclear group norm attack on a batch of images x.
args:
x: Tensor of shape [B, C, H, W], batch of images.
y: Tensor of shape [B], batch of labels.
Returns a tensor of the same shape as x containing adversarial examples
'''
for param in self.model.model.parameters():
param.requires_grad = False
mask_out = self._set_mask(x)
x = x.to(self.device)
noise = torch.zeros_like(x)
noise.requires_grad = True
for t in range(self.iters):
if self.ver:
print(f'\rIteration {t+1}/{self.iters}', end='')
loss = self.__loss_fn(x + noise * mask_out)
loss.backward()
noise.grad.data = noise.grad.data * mask_out
s = self.__groupNuclearLMO(noise.grad.data, eps=self.eps)
with torch.no_grad():
gamma = self.__lineSearch(x=x, s=s, noise=noise)
noise = (1 - gamma) * noise + gamma * s
noise.requires_grad = True
if self.ver and t % 20 == 0:
print(f"Iteration: {t}, Loss: {loss.item()}")
x = torch.clamp(x + noise, 0, 1)
if self.ver:
print("")
return x.detach()
def __lineSearch(self, x, s, noise, steps=25):
'''
Perform line search for the step size.
'''
a = torch.zeros(x.shape[1], device=self.device).view(-1, 1, 1, 1)
b = torch.ones(x.shape[1], device=self.device).view(-1, 1, 1, 1)
c = b - (b - a) / self.gr
d = a + (b - a) / self.gr
sx = s - noise
for i in range(steps):
loss1 = self.__loss_fn(x + noise + (c * sx).view(*x.shape))
loss2 = self.__loss_fn(x + noise + (d * sx).view(*x.shape))
mask = loss1 > loss2
b[mask] = d[mask]
mask = torch.logical_not(mask)
a[mask] = c[mask]
c = b - (b - a) / self.gr
d = a + (b - a) / self.gr
return (b + a) / 2
def __groupNuclearLMO(self, x, eps=5):
'''
LMO for the nuclear group norm ball.
'''
B, C, H, W = x.shape[1], x.shape[3], x.shape[4], x.shape[5]
size = 32 if H > 64 else 4
# turn batch of images into batch of size by size pixel groups per
# color channel
xrgb = [x.view(B, C, H, W)[:, c, :, :] for c in range(C)]
xrgb = [xc.unfold(1, size, size).unfold(2, size, size) for xc in xrgb]
xrgb = [xc.reshape(-1, size, size) for xc in xrgb]
# compute nuclear norm of each patch (sum norms over color channels)
norms = torch.linalg.svdvals(xrgb[0])
for xc in xrgb[1:]:
norms += torch.linalg.svdvals(xc)
norms = norms.sum(-1).reshape(B, -1)
# only keep the patch g* with the largest nuclear norm for each image
idxs = norms.argmax(dim=1).view(-1, 1)
xrgb = [xc.reshape(B, -1, size, size) for xc in xrgb]
xrgb = [xc[torch.arange(B).view(-1, 1), idxs].view(B, size, size)
for xc in xrgb]
# build index tensor corr. to the position of the kept patches in x
off = (idxs % (W / size)).long() * size
off += torch.floor(idxs / (W / size)).long() * W * size
idxs = torch.arange(0, size**2,
device=self.device).view(1, -1).repeat(B, 1) + off
off = torch.arange(0, size,
device=self.device).view(-1, 1).repeat(1, size)
off = off * W - off * size
idxs += off.view(1, -1)
# compute singular vector pairs corresponding to largest singular value
# and final perturbation (LMO solution)
pert = torch.zeros_like(x).view(B, C, H, W)
for i, xc in enumerate(xrgb):
U, _, V = torch.linalg.svd(xc)
U = U[:, :, 0].view(B, size, 1)
V = V.transpose(-2, -1)[:, :, 0].view(B, size, 1)
pert_gr = torch.bmm(U, V.transpose(-2, -1)).reshape(B, size * size)
idx = torch.arange(B).view(-1, 1)
pert_tmp = pert[:, i, :, :].view(B, -1)
pert_tmp[idx, idxs] = pert_gr * eps
pert_clone = pert.clone()
pert_clone[:, i, :, :] = pert_tmp.view(B, H, W)
return pert_clone.view(*x.shape)