Skip to content
Snippets Groups Projects
joint_training.py 1.21 KiB
Newer Older
Mina Moshfegh's avatar
Mina Moshfegh committed
import torch
import torch.nn.functional as F
from .base_defense import BaseDefense

# JointTrainingDefense class implements joint training with both clean and adversarial examples
class JointTrainingDefense(BaseDefense):
    def __init__(self, student_model, teacher_model=None, joint_lambda=0.5):
        super().__init__(student_model, teacher_model)  # Initialize base class
        self.joint_lambda = joint_lambda                # Weight for clean vs adversarial loss

    # Loss function combining clean and adversarial cross-entropy loss
    def loss_function(self, x, y, x_adv=None, **kwargs):
        logits_clean = self.student_model(x)           # Forward pass for clean input
        loss_clean = F.cross_entropy(logits_clean, y)  # Cross-entropy loss for clean input

        loss_adv = 0.0
        if x_adv is not None:
            logits_adv = self.student_model(x_adv)     # Forward pass for adversarial input
            loss_adv = F.cross_entropy(logits_adv, y)  # Cross-entropy loss for adversarial input

        # Total loss is a weighted combination of clean and adversarial losses
        loss_total = self.joint_lambda * loss_clean + (1 - self.joint_lambda) * loss_adv
        return loss_total