# Code taken and adapted from https://github.com/wagnermoritz/GSE import torch from vlm_eval.attacks.attack import Attack class EAD(Attack): def __init__(self,model, targeted=False, img_range=(0,1), steps=100, beta=5e-5, mask_out='none', ver=False, binary_steps=2, step_size=1e-2, decision_rule='L1'): super().__init__(model=model, targeted=targeted, img_range=img_range) self.steps = steps self.ver = ver self.binary_steps = binary_steps self.beta = beta if mask_out != 'none': self.mask_out = mask_out else: self.mask_out = None self.decision_rule = decision_rule self.ver = ver self.step_size = step_size def _set_mask(self, data): mask = torch.ones_like(data) if self.mask_out == 'context': mask[:, :-1, ...] = 0 elif self.mask_out == 'query': mask[:, -1, ...] = 0 elif isinstance(self.mask_out, int): mask[:, self.mask_out, ...] = 0 elif self.mask_out is None: pass else: raise NotImplementedError(f'Unknown mask_out: {self.mask_out}') return mask def __call__(self, x_orig): for param in self.model.model.parameters(): param.requires_grad = False mask_out = self._set_mask(x_orig) c = 1e-1 c_upper = 10e+10 c_lower = 0 overall_best_attack = x_orig.clone() overall_best_dist = torch.inf overall_best_loss = 1e10 for binary_step in range(self.binary_steps): global_step = 0 x = x_orig.clone().detach() y = x_orig.clone().detach() best_attack = x_orig.clone().detach() best_dist = torch.inf best_loss = 1e10 step_size = 1e-2 for step in range(self.steps): y.requires_grad = True _, loss = self.loss_fn(x=y, c=c, x_orig=x_orig) loss.backward() y_grad = y.grad.data * mask_out with torch.no_grad(): x_new = self.project(x=y-step_size*y_grad, x_orig=x_orig) step_size = (self.step_size - 0) * (1 - global_step / self.steps) ** 0.5 + 0 global_step += 1 y = x_new + (step / (step + 3)) * (x_new - x) x = x_new loss_model, loss = self.loss_fn(x=x, c=c, x_orig=x_orig) if self.ver and step % 20 == 0: print(f"Binary Step: {binary_step}, Iter: {step}, Loss: {loss.item()}, L0: {(x - x_orig).norm(p=0)}, Linf: {(x - x_orig).norm(p=torch.inf)}") if self.decision_rule == 'L1': if (x - x_orig).norm(p=1).item() < best_dist and loss_model < best_loss: best_loss = loss_model best_attack = x.clone() best_dist = (x - x_orig).norm(p=1).item() else: raise NotImplementedError # Updating c if overall_best_dist > best_dist and best_loss < overall_best_loss: overall_best_loss = best_loss overall_best_dist = best_dist overall_best_attack = best_attack.clone() c_upper = min(c_upper, c) if c_upper < 1e9: c = (c_upper + c_lower) / 2 else: c_lower = max(c_lower, c) if c_upper < 1e9: c = (c_lower + c_upper) / 2.0 else: c *= 10 print(f"Final L0: {(overall_best_attack - x_orig).norm(p=0)}, Linf: {(overall_best_attack - x_orig).norm(p=torch.inf)}") return overall_best_attack.detach() def project(self, x, x_orig): mask_1 = (x - x_orig > self.beta).float() mask_2 = ((x - x_orig).abs() <= self.beta).float() mask_3 = (x - x_orig < -self.beta).float() upper = torch.minimum(x - self.beta, torch.tensor(1.0)) lower = torch.maximum(x + self.beta, torch.tensor(0.0)) proj_x = mask_1 * upper + mask_2 * x_orig + mask_3 * lower return proj_x def loss_fn(self, x, c, x_orig): out = -self.model(x).sum() if not self.targeted else self.model(x).sum() l2_dist = ((x - x_orig) ** 2).view(x.shape[0], -1).sum(dim=1) l1_dist = ((x - x_orig).abs()).view(x.shape[0], -1).sum(dim=1) return out, c * out + l2_dist.sum() + \ self.beta * l1_dist.sum()