File size: 9,807 Bytes
fc0ff8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# Code taken and adapted from https://github.com/wagnermoritz/GSE

from vlm_eval.attacks.attack import Attack
import torch
import math
import torch.nn.functional as F

class StrAttack(Attack):
    def __init__(self, model, *args, targeted=False, img_range=(0, 1), kappa=0,
                 max_iter=100, ver=False, search_steps=2, max_c=1e10, rho=1, mask_out='none',
                 c=2.5, retrain=False, **kwargs):
        '''
        Implementation of StrAttack: https://arxiv.org/abs/1808.01664
        Adapted from https://github.com/KaidiXu/StrAttack

        args:
        model:         Callable, PyTorch classifier.
        targeted:      Bool, given label is used as a target label if True.
        img_range:     Tuple of ints/floats, lower and upper bound of image
                       entries.
        max_iter:      Int, number of iterations.
        ver:           Bool, print progress if True.
        search_steps:  Int, number of binary search steps.
        max_c:         Float, upper bound for regularizaion parameter.
        rho:           Float, ADMM parameter.
        c:             Float, initial regularization parameter.
        '''
        super().__init__(model, targeted=targeted, img_range=img_range)
        self.max_iter = max_iter
        self.ver = ver
        self.search_steps = search_steps
        self.max_c = max_c
        self.rho = rho
        self.c = c
        self.retrain = retrain
        if mask_out != 'none':
            self.mask_out = mask_out
        else:
            self.mask_out = None

    def _set_mask(self, data):
        mask = torch.ones_like(data)
        if self.mask_out == 'context':
            mask[:, :-1, ...] = 0
        elif self.mask_out == 'query':
            mask[:, -1, ...] = 0
        elif isinstance(self.mask_out, int):
            mask[:, self.mask_out, ...] = 0
        elif self.mask_out is None:
            pass
        else:
            raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
        return mask

    def __call__(self, imgs, *args, **kwargs):
        '''
        Perform StrAttack on a batch of images x with corresponding labels y.

        args:
        x:   Tensor of shape [B, C, H, W], batch of images.

        Returns a tensor of the same shape as x containing adversarial examples
        '''

        for param in self.model.model.parameters():
            param.requires_grad = False

        c_ = self.c
        imgs = imgs.to(self.device)
        sh = imgs.shape
        batch_size = sh[1]
        mask_out = self._set_mask(imgs)

        alpha, tau, gamma = 5, 2, 1
        eps = torch.full_like(imgs, 1.0) * mask_out
        # 16 for imagenet, 2 for CIFAR and MNIST
        filterSize = 8 if sh[-1] > 32 else 2
        stride = filterSize
        # convolution kernel used to compute norm of each group
        slidingM = torch.ones((1, sh[3], filterSize, filterSize), device=self.device)

        cs = torch.ones(batch_size, device=self.device) * c_
        lower_bound = torch.zeros(batch_size)
        upper_bound = torch.ones(batch_size) * self.max_c

        o_bestl2 = torch.full_like(torch.randn(batch_size), 1e10, dtype=torch.float)
        o_bestscore = torch.full_like(o_bestl2, -1, dtype=torch.float)
        o_bestattack = imgs.clone()
        o_besty = torch.ones_like(imgs)

        for step in range(self.search_steps):

            bestl2 = torch.full_like(o_bestl2, 1e10, dtype=torch.float)
            bestscore = torch.full_like(o_bestl2, -1, dtype=torch.float)

            z, v, u, s = (torch.zeros_like(imgs) for _ in range(4))

            for iter_ in range(self.max_iter):
                if (not iter_%10 or iter_ == self.max_iter - 1) and self.ver:
                    print(f'\rIteration: {iter_+1}/{self.max_iter}, ' +
                          f'Search Step: {step+1}/{self.search_steps}', end='')

                # first update step (7) / Proposition 1
                delta = self.rho / (self.rho + 2 * gamma) * (z - u / self.rho)

                b = (z - s / self.rho) * mask_out
                tmp = torch.minimum(self.img_range[1] - imgs, eps)
                w = torch.where(b.view(*sh) > tmp.view(*sh), tmp, b) # creating issue (1x5x'5'x3x224x224 instead of 1x5x1x3x224x224)
                tmp = torch.maximum(self.img_range[0] - imgs, -eps)
                w = torch.where(b.view(*sh) < tmp.view(*sh), tmp, w)
                
                c = z - v / self.rho
                cNorm = torch.sqrt(F.conv2d(c.view(sh[1], sh[3], sh[4], sh[5]) ** 2, slidingM, stride=stride))
                cNorm = torch.where(cNorm == 0, torch.full_like(cNorm, 1e-12), cNorm)
                cNorm = F.interpolate(cNorm, scale_factor=filterSize)
                y = torch.clamp((1 - tau / (self.rho * cNorm.unsqueeze(0).unsqueeze(3))), 0) * c
                
                # second update step (8) / equation (15)
                z_grads = self.__get_z_grad(imgs, z.clone(), cs)
                eta = alpha * math.sqrt(iter_ + 1)
                coeff = (1 / (eta + 3 * self.rho))
                z = coeff * (eta * z + self.rho * (delta + w + y) + u + s + v - z_grads)

                # third update step (9)
                u = u + self.rho * (delta - z) * mask_out
                v = v + self.rho * (y - z) * mask_out
                s = s + self.rho * (w - z) * mask_out
                # get info for binary search
                x = imgs + y * mask_out
                l2s = torch.sum((z ** 2).reshape(z.size(1), -1), dim=-1)
                
                for i, (l2, x_) in enumerate(zip(l2s, x.squeeze(0))):
                    if l2 < bestl2[i]:
                        bestl2[i] = l2
                    if l2 < o_bestl2[i]:
                        o_bestl2[i] = l2
                        o_bestattack[:,i] = x_.detach().unsqueeze(0).clone()
                        o_besty[:,i] = y[:,i]
            for i in range(batch_size):
                
                lower_bound[i] = max(lower_bound[i], cs[i])
                if upper_bound[i] < 1e9:
                    cs[i] = (lower_bound[i] + upper_bound[i]) / 2
                else:
                    cs[i] *= 5

        del v, u, s, z_grads, w, tmp
        
        if self.retrain:
            cs = torch.full_like(o_bestl2, 5.0, dtype=torch.float)
            zeros = torch.zeros_like(imgs)

            for step in range(8):
                bestl2 = torch.full_like(cs, 1e10, dtype=torch.float, device=self.device)
                bestscore = torch.full_like(cs, -1, dtype=torch.float, device=self.device)

                Nz = o_besty[o_besty != 0]
                e0 = torch.quantile(Nz.abs(), 0.03)
                A2 = torch.where(o_besty.abs() <= e0, 0, 1)
                z1 = o_besty
                u1 = torch.zeros_like(imgs)
                tmpc = self.rho / (self.rho + gamma / 100)

                for j in range(100):
                    if self.ver and not j % 10:
                        print(f'\rRetrain iteration: {step+1}/8, ' +
                              f'Search Step: {j+1}/200', end='')

                    tmpA = (z1 - u1) * tmpc
                    tmpA1 = torch.where(o_besty.abs() <= e0, zeros, tmpA)
                    cond = torch.logical_and(tmpA >
                                             torch.minimum(self.img_range[1] - imgs, eps),
                                             o_besty.abs() > e0)
                    tmpA2 = torch.where(cond, torch.minimum(self.img_range[1] - imgs, eps),
                                        tmpA1)
                    cond = torch.logical_and(tmpA <
                                             torch.maximum(self.img_range[0] - imgs, -eps),
                                             o_besty.abs() > e0)
                    deltA = torch.where(cond, torch.maximum(self.img_range[0] - imgs, -eps),
                                        tmpA2)
                    
                    x = imgs + deltA * mask_out
                    grad = self.__get_z_grad(imgs, deltA, cs)

                    stepsize = 1 / (alpha + 2 * self.rho)
                    z1 = stepsize * (alpha * z1 * self.rho
                                     * (deltA + u1) - grad * A2)
                    u1 = u1 + deltA - z1

                    for i, (l2, x_) in enumerate(zip(l2s, x.squeeze(0))):
                        if l2 < bestl2[i]:
                            bestl2[i] = l2
                            #bestscore[i] = asc
                        if l2 < o_bestl2[i]:
                            o_bestl2[i] = l2
                            #o_bestscore[i] = asc
                            o_bestattack[:,i] = x_.detach().unsqueeze(0).clone()
                            o_besty[i] = deltA[i]


                for i in range(batch_size):
                    if (bestscore[i] != -1 and bestl2[i] == o_bestl2[i]):
                        upper_bound[i] = min(upper_bound[i], cs[i])
                        if upper_bound[i] < 1e9:
                            cs[i] = (lower_bound[i] + upper_bound[i]) / 2

                    else:
                        lower_bound[i] = max(lower_bound[i], cs[i])
                        if upper_bound[i] < 1e9:
                            cs[i] = (lower_bound[i] + upper_bound[i]) / 2
                        else:
                            cs[i] *= 5

        if self.ver:
            print('')
        
        return (o_bestattack * mask_out).detach()


    def __get_z_grad(self, imgs, z, cs):
        '''
        Compute and return gradient of loss wrt. z.
        '''
        z.requires_grad = True
        tmp = self.model(z + imgs).sum() if self.targeted else -self.model(z + imgs).sum()
        loss = torch.mean(cs.to(self.device) * tmp)
        z_grad_data = torch.autograd.grad(loss, [z])[0].detach()
        z.detach_()
        return z_grad_data