Newer
Older
import torch
import torch.nn.functional as F
from .base_defense import BaseDefense
# JointTrainingDefense class implements joint training with both clean and adversarial examples
class JointTrainingDefense(BaseDefense):
def __init__(self, student_model, teacher_model=None, joint_lambda=0.5):
super().__init__(student_model, teacher_model) # Initialize base class
self.joint_lambda = joint_lambda # Weight for clean vs adversarial loss
# Loss function combining clean and adversarial cross-entropy loss
def loss_function(self, x, y, x_adv=None, **kwargs):
logits_clean = self.student_model(x) # Forward pass for clean input
loss_clean = F.cross_entropy(logits_clean, y) # Cross-entropy loss for clean input
loss_adv = 0.0
if x_adv is not None:
logits_adv = self.student_model(x_adv) # Forward pass for adversarial input
loss_adv = F.cross_entropy(logits_adv, y) # Cross-entropy loss for adversarial input
# Total loss is a weighted combination of clean and adversarial losses
loss_total = self.joint_lambda * loss_clean + (1 - self.joint_lambda) * loss_adv
return loss_total