Skip to content
Snippets Groups Projects
Commit ae009052 authored by Mina Moshfegh's avatar Mina Moshfegh
Browse files

Upload New File

parent affe77ae
No related branches found
No related tags found
No related merge requests found
import torch
import torch.nn as nn
from .base_attack import BaseAttack
# PGD is basically an iterative version of FGSM. It does multiple smaller steps.
class PGDAttack(BaseAttack):
def __init__(self, model, epsilon=8 / 255, alpha=2 / 255, num_steps=10, random_init=True, clamp_min=0, clamp_max=1):
super().__init__(model)
# epsilon: maximum total perturbation
# alpha: step size per iteration
# num_steps: how many iterations we do
self.epsilon = epsilon
self.alpha = alpha
self.num_steps = num_steps
self.random_init = random_init
self.clamp_min = clamp_min
self.clamp_max = clamp_max
def generate(self, x, y):
# Put data on same device as model
torch.set_grad_enabled(True)
device = next(self.model.parameters()).device
x = x.to(device)
y = y.to(device)
# Make a copy of the original input
x_adv = x.clone().detach()
# Optional random initialization within the epsilon-ball
if self.random_init:
x_adv = x_adv + torch.empty_like(x_adv).uniform_(-self.epsilon, self.epsilon)
# Here we clamp a bit larger range at first, might be intentionally more lenient
x_adv = torch.clamp(x_adv, -1, 3)
# Iterative loop for PGD
for _ in range(self.num_steps):
x_adv.requires_grad_(True)
# Compute the prediction and loss
logits = self.model(x_adv)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, y)
# Get gradients
loss.backward()
with torch.no_grad():
# sign of the gradient for each pixel
grad_sign = x_adv.grad.sign()
# move step by alpha in the direction of the sign
x_adv = x_adv + self.alpha * grad_sign
# then clamp the total perturbation within [-epsilon, epsilon]
perturbation = torch.clamp(x_adv - x, min=-self.epsilon, max=self.epsilon)
# and clamp the adversarial example to [0,1] if that's the data range
x_adv = torch.clamp(x + perturbation, self.clamp_min, self.clamp_max)
# reset gradient
if x_adv.grad is not None:
x_adv.grad.zero_()
# detach so we don't keep any gradients or graph overhead
return x_adv.detach()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment