-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
adversarial_trainer_trades_pytorch.py
278 lines (223 loc) · 11 KB
/
adversarial_trainer_trades_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
# MIT License
#
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2023
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
# Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
This is a PyTorch implementation of the TRADES protocol.
| Paper link: https://proceedings.mlr.press/v97/zhang19p.html
"""
from __future__ import absolute_import, division, print_function, unicode_literals, annotations
import logging
import time
from typing import TYPE_CHECKING
import numpy as np
from tqdm.auto import trange
from art.defences.trainer.adversarial_trainer_trades import AdversarialTrainerTRADES
from art.estimators.classification.pytorch import PyTorchClassifier
from art.data_generators import DataGenerator
from art.attacks.attack import EvasionAttack
from art.utils import check_and_transform_label_format
if TYPE_CHECKING:
import torch
logger = logging.getLogger(__name__)
EPS = 1e-8
class AdversarialTrainerTRADESPyTorch(AdversarialTrainerTRADES):
"""
Class performing adversarial training following TRADES protocol.
| Paper link: https://proceedings.mlr.press/v97/zhang19p.html
"""
def __init__(self, classifier: PyTorchClassifier, attack: EvasionAttack, beta: float):
"""
Create an :class:`.AdversarialTrainerTRADESPyTorch` instance.
:param classifier: Model to train adversarially.
:param attack: attack to use for data augmentation in adversarial training
:param beta: The scaling factor controlling tradeoff between clean loss and adversarial loss
"""
super().__init__(classifier, attack, beta)
self._classifier: PyTorchClassifier
self._attack: EvasionAttack
self._beta: float
def fit(
self,
x: np.ndarray,
y: np.ndarray,
validation_data: tuple[np.ndarray, np.ndarray] | None = None,
batch_size: int = 128,
nb_epochs: int = 20,
scheduler: "torch.optim.lr_scheduler._LRScheduler" | None = None,
**kwargs,
):
"""
Train a model adversarially with TRADES protocol.
See class documentation for more information on the exact procedure.
:param x: Training set.
:param y: Labels for the training set.
:param validation_data: Tuple consisting of validation data, (x_val, y_val)
:param batch_size: Size of batches.
:param nb_epochs: Number of epochs to use for trainings.
:param scheduler: Learning rate scheduler to run at the end of every epoch.
:param kwargs: Dictionary of framework-specific arguments. These will be passed as such to the `fit` function of
the target classifier.
"""
import torch
logger.info("Performing adversarial training with TRADES protocol")
if (scheduler is not None) and (
not isinstance(scheduler, torch.optim.lr_scheduler._LRScheduler)
): # pylint: enable=protected-access
raise ValueError("Invalid Pytorch scheduler is provided for adversarial training.")
nb_batches = int(np.ceil(len(x) / batch_size))
ind = np.arange(len(x))
logger.info("Adversarial Training TRADES")
y = check_and_transform_label_format(y, nb_classes=self.classifier.nb_classes)
if validation_data is not None:
(x_test, y_test) = validation_data
y_test = check_and_transform_label_format(y_test, nb_classes=self.classifier.nb_classes)
x_preprocessed_test, y_preprocessed_test = self._classifier._apply_preprocessing(x_test, y_test, fit=True)
for i_epoch in trange(nb_epochs, desc="Adversarial Training TRADES - Epochs"):
# Shuffle the examples
np.random.shuffle(ind)
start_time = time.time()
train_loss = 0.0
train_acc = 0.0
train_n = 0.0
for batch_id in range(nb_batches):
# Create batch data
x_batch = x[ind[batch_id * batch_size : min((batch_id + 1) * batch_size, x.shape[0])]].copy()
y_batch = y[ind[batch_id * batch_size : min((batch_id + 1) * batch_size, x.shape[0])]]
_train_loss, _train_acc, _train_n = self._batch_process(x_batch, y_batch)
train_loss += _train_loss
train_acc += _train_acc
train_n += _train_n
if scheduler:
scheduler.step()
train_time = time.time()
# compute accuracy
if validation_data is not None:
output = np.argmax(self.predict(x_preprocessed_test), axis=1)
nb_correct_pred = np.sum(output == np.argmax(y_preprocessed_test, axis=1))
logger.info(
"epoch: %s time(s): %.1f loss: %.4f acc(tr): %.4f acc(val): %.4f",
i_epoch,
train_time - start_time,
train_loss / train_n,
train_acc / train_n,
nb_correct_pred / x_test.shape[0],
)
else:
logger.info(
"epoch: %s time(s): %.1f loss: %.4f acc: %.4f",
i_epoch,
train_time - start_time,
train_loss / train_n,
train_acc / train_n,
)
def fit_generator(
self,
generator: DataGenerator,
nb_epochs: int = 20,
scheduler: "torch.optim.lr_scheduler._LRScheduler" | None = None,
**kwargs,
):
"""
Train a model adversarially with TRADES protocol using a data generator.
See class documentation for more information on the exact procedure.
:param generator: Data generator.
:param nb_epochs: Number of epochs to use for trainings.
:param scheduler: Learning rate scheduler to run at the end of every epoch.
:param kwargs: Dictionary of framework-specific arguments. These will be passed as such to the `fit` function of
the target classifier.
"""
import torch
logger.info("Performing adversarial training with TRADES protocol")
if (scheduler is not None) and (
not isinstance(scheduler, torch.optim.lr_scheduler._LRScheduler)
): # pylint: enable=protected-access
raise ValueError("Invalid Pytorch scheduler is provided for adversarial training.")
size = generator.size
batch_size = generator.batch_size
if size is not None:
nb_batches = int(np.ceil(size / batch_size))
else:
raise ValueError("Size is None.")
logger.info("Adversarial Training TRADES")
for i_epoch in trange(nb_epochs, desc="Adversarial Training TRADES - Epochs"):
start_time = time.time()
train_loss = 0.0
train_acc = 0.0
train_n = 0.0
for batch_id in range(nb_batches): # pylint: disable=unused-variable
# Create batch data
x_batch, y_batch = generator.get_batch()
x_batch = x_batch.copy()
_train_loss, _train_acc, _train_n = self._batch_process(x_batch, y_batch)
train_loss += _train_loss
train_acc += _train_acc
train_n += _train_n
if scheduler:
scheduler.step()
train_time = time.time()
logger.info(
"epoch: %s time(s): %.1f loss: %.4f acc: %.4f",
i_epoch,
train_time - start_time,
train_loss / train_n,
train_acc / train_n,
)
def _batch_process(self, x_batch: np.ndarray, y_batch: np.ndarray) -> tuple[float, float, float]:
"""
Perform the operations of TRADES for a batch of data.
See class documentation for more information on the exact procedure.
:param x_batch: batch of x.
:param y_batch: batch of y.
:return: Tuple containing batch data loss, batch data accuracy and number of samples in the batch
"""
import torch
from torch import nn
import torch.nn.functional as F
if self._classifier._optimizer is None:
raise ValueError("Optimizer of classifier is currently None, but is required for adversarial training.")
n = x_batch.shape[0]
self._classifier._model.train(mode=False)
x_batch_pert = self._attack.generate(x_batch, y=y_batch)
# Apply preprocessing
y_batch = check_and_transform_label_format(y_batch, nb_classes=self.classifier.nb_classes)
x_preprocessed, y_preprocessed = self._classifier._apply_preprocessing(x_batch, y_batch, fit=True)
x_preprocessed_pert, _ = self._classifier._apply_preprocessing(x_batch_pert, y_batch, fit=True)
# Check label shape
if self._classifier._reduce_labels:
y_preprocessed = np.argmax(y_preprocessed, axis=1)
i_batch = torch.from_numpy(x_preprocessed).to(self._classifier._device)
i_batch_pert = torch.from_numpy(x_preprocessed_pert).to(self._classifier._device)
o_batch = torch.from_numpy(y_preprocessed).to(self._classifier._device)
self._classifier._model.train(mode=True)
# Zero the parameter gradients
self._classifier._optimizer.zero_grad()
# Perform prediction
model_outputs = self._classifier._model(i_batch)
model_outputs_pert = self._classifier._model(i_batch_pert)
# Form the loss function
loss_clean = self._classifier._loss(model_outputs[-1], o_batch)
loss_kl = (1.0 / n) * nn.KLDivLoss(reduction="sum")(
F.log_softmax(model_outputs_pert[-1], dim=1), torch.clamp(F.softmax(model_outputs[-1], dim=1), min=EPS)
)
loss = loss_clean + self._beta * loss_kl
loss.backward()
self._classifier._optimizer.step()
train_loss = loss.item() * o_batch.size(0)
train_acc = (model_outputs_pert[0].max(1)[1] == o_batch).sum().item()
train_n = o_batch.size(0)
self._classifier._model.train(mode=False)
return train_loss, train_acc, train_n