Newer
Older
import torch
import torch.nn.functional as F
from .base_defense import BaseDefense
# VanillaAdversarialTrainingDefense class implements vanilla adversarial training
class VanillaAdversarialTrainingDefense(BaseDefense):
def __init__(self, student_model, teacher_model=None, adv_lambda=1.0):
super().__init__(student_model, teacher_model) # Initialize base class
self.adv_lambda = adv_lambda # Weight for adversarial loss
# Loss function combining clean and adversarial cross-entropy loss
def loss_function(self, x, y, x_adv=None, **kwargs):
if x_adv is None:
x_adv = x # If no adversarial examples, use clean examples
logits_adv = self.student_model(x_adv) # Forward pass for adversarial input
loss_adv = F.cross_entropy(logits_adv, y) # Cross-entropy loss for adversarial input
if self.adv_lambda < 1.0:
logits_clean = self.student_model(x) # Forward pass for clean input
loss_clean = F.cross_entropy(logits_clean, y) # Cross-entropy loss for clean input
loss = self.adv_lambda * loss_adv + (1.0 - self.adv_lambda) * loss_clean # Weighted loss
else:
loss = loss_adv
return loss