Skip to content
Snippets Groups Projects
Commit effe03c5 authored by Mina Moshfegh's avatar Mina Moshfegh
Browse files

Upload New File

parent 6fd5bf60
No related branches found
No related tags found
No related merge requests found
import torch
import torch.nn.functional as F
from .base_defense import BaseDefense
# VanillaAdversarialTrainingDefense class implements vanilla adversarial training
class VanillaAdversarialTrainingDefense(BaseDefense):
def __init__(self, student_model, teacher_model=None, adv_lambda=1.0):
super().__init__(student_model, teacher_model) # Initialize base class
self.adv_lambda = adv_lambda # Weight for adversarial loss
# Loss function combining clean and adversarial cross-entropy loss
def loss_function(self, x, y, x_adv=None, **kwargs):
if x_adv is None:
x_adv = x # If no adversarial examples, use clean examples
logits_adv = self.student_model(x_adv) # Forward pass for adversarial input
loss_adv = F.cross_entropy(logits_adv, y) # Cross-entropy loss for adversarial input
if self.adv_lambda < 1.0:
logits_clean = self.student_model(x) # Forward pass for clean input
loss_clean = F.cross_entropy(logits_clean, y) # Cross-entropy loss for clean input
loss = self.adv_lambda * loss_adv + (1.0 - self.adv_lambda) * loss_clean # Weighted loss
else:
loss = loss_adv
return loss
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment