Skip to content
Snippets Groups Projects
vanilla_at.py 1.33 KiB
Newer Older
Mina Moshfegh's avatar
Mina Moshfegh committed
import torch
import torch.nn.functional as F
from .base_defense import BaseDefense

# VanillaAdversarialTrainingDefense class implements vanilla adversarial training
class VanillaAdversarialTrainingDefense(BaseDefense):
    def __init__(self, student_model, teacher_model=None, adv_lambda=1.0):
        super().__init__(student_model, teacher_model)  # Initialize base class
        self.adv_lambda = adv_lambda                    # Weight for adversarial loss

    # Loss function combining clean and adversarial cross-entropy loss
    def loss_function(self, x, y, x_adv=None, **kwargs):
        if x_adv is None:
            x_adv = x                              # If no adversarial examples, use clean examples

        logits_adv = self.student_model(x_adv)     # Forward pass for adversarial input
        loss_adv = F.cross_entropy(logits_adv, y)  # Cross-entropy loss for adversarial input

        if self.adv_lambda < 1.0:
            logits_clean = self.student_model(x)                                      # Forward pass for clean input
            loss_clean = F.cross_entropy(logits_clean, y)                             # Cross-entropy loss for clean input
            loss = self.adv_lambda * loss_adv + (1.0 - self.adv_lambda) * loss_clean  # Weighted loss
        else:
            loss = loss_adv

        return loss