-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfgm_attack.py
27 lines (23 loc) · 939 Bytes
/
fgm_attack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
class FGM:
"""Fast Gradient Method (FGM) for adversarial training on embedding layer"""
def __init__(self, model, epsilon=1.0):
self.model = model
self.epsilon = epsilon
self.backup = {}
def attack(self):
# Save original embeddings
for name, param in self.model.named_parameters():
if param.requires_grad and 'embeddings' in name:
self.backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm != 0:
r_at = self.epsilon * param.grad / norm
param.data.add_(r_at)
def restore(self):
# Restore original embeddings
for name, param in self.model.named_parameters():
if param.requires_grad and 'embeddings' in name:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}