# Code adapted from https://github.com/chs20/RobustVLM/tree/main import torch import math class APGD: def __init__(self, model, norm, eps, mask_out='context', initial_stepsize=None, decrease_every=None, decrease_every_max=None, random_init=False): # model returns loss sum over batch # thus currently only works with batch size 1 # initial_stepsize: in terms of eps. called alpha in apgd # decrease_every: potentially decrease stepsize every x fraction of total iterations. default: 0.22 self.model = model self.norm = norm self.eps = eps self.initial_stepsize = initial_stepsize self.decrease_every = decrease_every self.decrease_every_max = decrease_every_max self.random_init = random_init if mask_out != 'none': self.mask_out = mask_out else: self.mask_out = None def perturb(self, data_clean, iterations, pert_init=None, verbose=False): mask = self._set_mask(data_clean) data_adv, _, _ = apgd( self.model, data_clean, norm=self.norm, eps=self.eps, n_iter=iterations, use_rs=self.random_init, mask=mask, alpha=self.initial_stepsize, n_iter_2=self.decrease_every, n_iter_min=self.decrease_every_max, pert_init=pert_init, verbose=verbose ) return data_adv def _set_mask(self, data): mask = torch.ones_like(data) if self.mask_out == 'context': mask[:, :-1, ...] = 0 elif self.mask_out == 'query': mask[:, -1, ...] = 0 elif isinstance(self.mask_out, int): mask[:, self.mask_out, ...] = 0 elif self.mask_out is None: pass else: raise NotImplementedError(f'Unknown mask_out: {self.mask_out}') return mask def __str__(self): return 'APGD' def L1_projection(x2, y2, eps1): ''' x2: center of the L1 ball (bs x input_dim) y2: current perturbation (x2 + y2 is the point to be projected) eps1: radius of the L1 ball output: delta s.th. ||y2 + delta||_1 = eps1 and 0 <= x2 + y2 + delta <= 1 ''' x = x2.clone().float().view(x2.shape[0], -1) y = y2.clone().float().view(y2.shape[0], -1) sigma = y.clone().sign() u = torch.min(1 - x - y, x + y) # u = torch.min(u, epsinf - torch.clone(y).abs()) u = torch.min(torch.zeros_like(y), u) l = -torch.clone(y).abs() d = u.clone() bs, indbs = torch.sort(-torch.cat((u, l), 1), dim=1) bs2 = torch.cat((bs[:, 1:], torch.zeros(bs.shape[0], 1).to(bs.device)), 1) inu = 2 * (indbs < u.shape[1]).float() - 1 size1 = inu.cumsum(dim=1) s1 = -u.sum(dim=1) c = eps1 - y.clone().abs().sum(dim=1) c5 = s1 + c < 0 c2 = c5.nonzero().squeeze(1) s = s1.unsqueeze(-1) + torch.cumsum((bs2 - bs) * size1, dim=1) # print(s[0]) # print(c5.shape, c2) if c2.nelement != 0: lb = torch.zeros_like(c2).float() ub = torch.ones_like(lb) * (bs.shape[1] - 1) # print(c2.shape, lb.shape) nitermax = torch.ceil(torch.log2(torch.tensor(bs.shape[1]).float())) counter2 = torch.zeros_like(lb).long() counter = 0 while counter < nitermax: counter4 = torch.floor((lb + ub) / 2.) counter2 = counter4.type(torch.LongTensor) c8 = s[c2, counter2] + c[c2] < 0 ind3 = c8.nonzero().squeeze(1) ind32 = (~c8).nonzero().squeeze(1) # print(ind3.shape) if ind3.nelement != 0: lb[ind3] = counter4[ind3] if ind32.nelement != 0: ub[ind32] = counter4[ind32] # print(lb, ub) counter += 1 lb2 = lb.long() alpha = (-s[c2, lb2] - c[c2]) / size1[c2, lb2 + 1] + bs2[c2, lb2] d[c2] = -torch.min(torch.max(-u[c2], alpha.unsqueeze(-1)), -l[c2]) return (sigma * d).view(x2.shape) def L0_projection(x_adv, x, eps, step_size, lam=0.01): pert = x_adv - x pert_proj = torch.clamp(pert,-eps,eps) x_adv_temp = torch.clamp(x + pert_proj,0.,1.) pert_proj = x_adv_temp - x pert = torch.where(pert ** 2 - (pert_proj - pert) ** 2 > 2 * step_size * lam, pert_proj, 0) #pert = torch.where(pert > (2 * lam * step_size) ** 0.5, pert, 0) return torch.clamp(x+pert,0.0,1.0) def L1_norm(x, keepdim=False): z = x.abs().view(x.shape[0], -1).sum(-1) if keepdim: z = z.view(-1, *[1] * (len(x.shape) - 1)) return z def L2_norm(x, keepdim=False): z = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt() if keepdim: z = z.view(-1, *[1] * (len(x.shape) - 1)) return z def L0_norm(x): return (x != 0.).view(x.shape[0], -1).sum(-1) def dlr_loss(x, y, reduction='none'): x_sorted, ind_sorted = x.sort(dim=1) ind = (ind_sorted[:, -1] == y).float() return -(x[torch.arange(x.shape[0]), y] - x_sorted[:, -2] * ind - \ x_sorted[:, -1] * (1. - ind)) / (x_sorted[:, -1] - x_sorted[:, -3] + 1e-12) def dlr_loss_targeted(x, y, y_target): x_sorted, ind_sorted = x.sort(dim=1) u = torch.arange(x.shape[0]) return -(x[u, y] - x[u, y_target]) / (x_sorted[:, -1] - .5 * ( x_sorted[:, -3] + x_sorted[:, -4]) + 1e-12) def check_oscillation(x, j, k, y5, k3=0.75): t = torch.zeros(x.shape[1]).to(x.device) for counter5 in range(k): t += (x[j - counter5] > x[j - counter5 - 1]).float() return (t <= k * k3 * torch.ones_like(t)).float() def apgd(model, x, norm, eps, n_iter=10, use_rs=False, mask=None, alpha=None, n_iter_2=None, n_iter_min=None, pert_init=None, verbose=False, is_train=True): # from https://github.com/fra31/robust-finetuning assert x.shape[0] == 1 # only support batch size 1 for now norm = norm.replace('l', 'L') device = x.device ndims = len(x.shape) - 1 if not use_rs: x_adv = x.clone() else: if norm == 'Linf': t = torch.zeros_like(x).uniform_(-eps, eps).detach() x_adv = x + t elif norm == 'L2': t = torch.randn(x.shape).to(device).detach() x_adv = x + eps * torch.ones_like(x).detach() * t / (L2_norm(t, keepdim=True) + 1e-12) if pert_init is not None: assert not use_rs assert pert_init.shape == x.shape, f'pert_init.shape: {pert_init.shape}, x.shape: {x.shape}' x_adv = x + pert_init x_adv = x_adv.clamp(0., 1.) x_best = x_adv.clone() x_best_adv = x_adv.clone() loss_steps = torch.zeros([n_iter, x.shape[0]], device=device) loss_best_steps = torch.zeros([n_iter + 1, x.shape[0]], device=device) # set params n_fts = math.prod(x.shape[1:]) if norm in ['Linf', 'L2']: n_iter_2_frac = 0.22 if n_iter_2 is None else n_iter_2 n_iter_min_frac = 0.06 if n_iter_min is None else n_iter_min n_iter_2 = max(int(n_iter_2_frac * n_iter), 1) n_iter_min = max(int(n_iter_min_frac * n_iter), 1) size_decr = max(int(0.03 * n_iter), 1) k = n_iter_2 + 0 thr_decr = .75 alpha = 2. if alpha is None else alpha elif norm in ['L1','L0']: k = max(int(.04 * n_iter), 1) init_topk = .05 if is_train else .2 topk = init_topk * torch.ones([x.shape[0]], device=device) sp_old = n_fts * torch.ones_like(topk) adasp_redstep = 1.5 adasp_minstep = 10. alpha = 1. if alpha is None else alpha step_size = alpha * eps * torch.ones([x.shape[0], *[1] * ndims], device=device) counter3 = 0 x_adv.requires_grad_() # grad = torch.zeros_like(x) # for _ in range(self.eot_iter) with torch.enable_grad(): loss_indiv = model(x_adv)#.unsqueeze(0) loss = loss_indiv.sum() # grad += torch.autograd.grad(loss, [x_adv])[0].detach() grad = torch.autograd.grad(loss, [x_adv])[0].detach() if mask is not None: grad *= mask # grad /= float(self.eot_iter) grad_best = grad.clone() x_adv.detach_() loss_indiv = loss_indiv.detach() loss = loss.detach() loss_best = loss_indiv.detach().clone() loss_best_last_check = loss_best.clone() reduced_last_check = torch.ones_like(loss_best) n_reduced = 0 u = torch.arange(x.shape[0], device=device) x_adv_old = x_adv.clone().detach() for i in range(n_iter): ### gradient step if True: # with torch.no_grad() x_adv = x_adv.detach() grad2 = x_adv - x_adv_old x_adv_old = x_adv.clone() loss_curr = loss.detach().mean() a = 0.75 if i > 0 else 1.0 if norm == 'Linf': x_adv_1 = x_adv + step_size * torch.sign(grad) x_adv_1 = torch.clamp(torch.min(torch.max(x_adv_1, x - eps), x + eps), 0.0, 1.0) x_adv_1 = torch.clamp(torch.min(torch.max( x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a), x - eps), x + eps), 0.0, 1.0) elif norm == 'L2': x_adv_1 = x_adv + step_size * grad / (L2_norm(grad, keepdim=True) + 1e-12) x_adv_1 = torch.clamp(x + (x_adv_1 - x) / (L2_norm(x_adv_1 - x, keepdim=True) + 1e-12) * torch.min(eps * torch.ones_like(x), L2_norm(x_adv_1 - x, keepdim=True)), 0.0, 1.0) x_adv_1 = x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a) x_adv_1 = torch.clamp(x + (x_adv_1 - x) / (L2_norm(x_adv_1 - x, keepdim=True) + 1e-12) * torch.min(eps * torch.ones_like(x), L2_norm(x_adv_1 - x, keepdim=True)), 0.0, 1.0) elif norm == 'L1': grad_topk = grad.abs().view(x.shape[0], -1).sort(-1)[0] topk_curr = torch.clamp((1. - topk) * n_fts, min=0, max=n_fts - 1).long() grad_topk = grad_topk[u, topk_curr].view(-1, *[1] * (len(x.shape) - 1)) sparsegrad = grad * (grad.abs() >= grad_topk).float() x_adv_1 = x_adv + step_size * sparsegrad.sign() / ( sparsegrad.sign().abs().view(x.shape[0], -1).sum(dim=-1).view( -1, 1, 1, 1) + 1e-10) delta_u = x_adv_1 - x delta_p = L1_projection(x, delta_u, eps) x_adv_1 = x + delta_u + delta_p elif norm == 'L0': L1normgrad = grad / (grad.abs().view(grad.shape[0], -1).sum( dim=-1, keepdim=True) + 1e-12).view(grad.shape[0], *[1] * ( len(grad.shape) - 1)) x_adv_1 = x_adv + step_size * L1normgrad * n_fts # TODO: add momentum x_adv = x_adv_1.to(dtype=x_adv.dtype) + 0. ### get gradient x_adv.requires_grad_() # grad = torch.zeros_like(x) # for _ in range(self.eot_iter) with torch.enable_grad(): loss_indiv = model(x_adv)#.unsqueeze(0) loss = loss_indiv.sum() # grad += torch.autograd.grad(loss, [x_adv])[0].detach() if i < n_iter - 1: # save one backward pass grad = torch.autograd.grad(loss, [x_adv])[0].detach() if mask is not None: grad *= mask # grad /= float(self.eot_iter) x_adv.detach_() loss_indiv = loss_indiv.detach() loss = loss.detach() x_best_adv = x_adv + 0. if verbose and (i % max(n_iter // 10, 1) == 0 or i == n_iter - 1): str_stats = ' - step size: {:.5f} - topk: {:.2f}'.format( step_size.mean(), topk.mean() * n_fts) if norm in ['L1'] else ' - step size: {:.5f}'.format( step_size.mean()) print('iteration: {} - best loss: {:.6f} curr loss {:.6f} {}'.format( i, loss_best.sum(), loss_curr, str_stats)) # print('pert {}'.format((x - x_best_adv).abs().view(x.shape[0], -1).sum(-1).max())) ### check step size if True: # with torch.no_grad() y1 = loss_indiv.detach().clone() loss_steps[i] = y1 + 0 ind = (y1 > loss_best).nonzero().squeeze() x_best[ind] = x_adv[ind].clone() grad_best[ind] = grad[ind].clone() loss_best[ind] = y1[ind] + 0 loss_best_steps[i + 1] = loss_best + 0 counter3 += 1 if counter3 == k: if norm in ['Linf', 'L2']: fl_oscillation = check_oscillation(loss_steps, i, k, loss_best, k3=thr_decr) fl_reduce_no_impr = (1. - reduced_last_check) * ( loss_best_last_check >= loss_best).float() fl_oscillation = torch.max(fl_oscillation, fl_reduce_no_impr) reduced_last_check = fl_oscillation.clone() loss_best_last_check = loss_best.clone() if fl_oscillation.sum() > 0: ind_fl_osc = (fl_oscillation > 0).nonzero().squeeze() step_size[ind_fl_osc] /= 2.0 n_reduced = fl_oscillation.sum() x_adv[ind_fl_osc] = x_best[ind_fl_osc].clone() grad[ind_fl_osc] = grad_best[ind_fl_osc].clone() counter3 = 0 k = max(k - size_decr, n_iter_min) elif norm in ['L1']: # adjust sparsity sp_curr = L0_norm(x_best - x) fl_redtopk = (sp_curr / sp_old) < .95 topk = sp_curr / n_fts / 1.5 step_size[fl_redtopk] = alpha * eps step_size[~fl_redtopk] /= adasp_redstep step_size.clamp_(alpha * eps / adasp_minstep, alpha * eps) sp_old = sp_curr.clone() x_adv[fl_redtopk] = x_best[fl_redtopk].clone() grad[fl_redtopk] = grad_best[fl_redtopk].clone() counter3 = 0 return x_best, loss_best, x_best_adv