Skip to content
Snippets Groups Projects
Commit b4557744 authored by Mina Moshfegh's avatar Mina Moshfegh
Browse files

Delete fgsm.py

parent d3ef8a13
No related branches found
No related tags found
No related merge requests found
import torch
import torch.nn as nn
from .base_attack import BaseAttack
# This class implements the basic FGSM (Fast Gradient Sign Method) attack.
# It's a straightforward single-step approach.
class FGSMAttack(BaseAttack):
def __init__(self, model, epsilon=0.3, clamp_min=0, clamp_max=1):
# We call the parent constructor to store the model
super().__init__(model)
# Epsilon controls how big the adversarial step is
self.epsilon = epsilon
self.clamp_min = clamp_min
self.clamp_max = clamp_max
def generate(self, x, y):
# We'll find which device the model is currently on,
# so we can move input data there
device = next(self.model.parameters()).device
x = x.to(device)
y = y.to(device)
# Make a copy of x that allows gradient calculation
x = x.clone().detach().requires_grad_(True)
# Forward pass: get predictions and compute the loss
logits = self.model(x)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, y)
# Backprop to calculate gradient
loss.backward()
# FGSM step: move in the direction of the sign of the gradient
x_adv = x + self.epsilon * x.grad.sign()
# Clamp values back into valid range
x_adv = torch.clamp(x_adv, self.clamp_min, self.clamp_max)
# Zero out the gradient so it doesn't affect future operations
x.grad.zero_()
return x_adv
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment