Skip to content
Snippets Groups Projects
Commit 4dfd5946 authored by Mina Moshfegh's avatar Mina Moshfegh
Browse files

Delete vanilla_at.py

parent c9aea40c
No related branches found
No related tags found
No related merge requests found
import torch
import torch.nn.functional as F
from .base_defense import BaseDefense
# VanillaAdversarialTrainingDefense class implements vanilla adversarial training
class VanillaAdversarialTrainingDefense(BaseDefense):
def __init__(self, student_model, teacher_model=None, adv_lambda=1.0):
super().__init__(student_model, teacher_model) # Initialize base class
self.adv_lambda = adv_lambda # Weight for adversarial loss
# Loss function combining clean and adversarial cross-entropy loss
def loss_function(self, x, y, x_adv=None, **kwargs):
if x_adv is None:
x_adv = x # If no adversarial examples, use clean examples
logits_adv = self.student_model(x_adv) # Forward pass for adversarial input
loss_adv = F.cross_entropy(logits_adv, y) # Cross-entropy loss for adversarial input
if self.adv_lambda < 1.0:
logits_clean = self.student_model(x) # Forward pass for clean input
loss_clean = F.cross_entropy(logits_clean, y) # Cross-entropy loss for clean input
loss = self.adv_lambda * loss_adv + (1.0 - self.adv_lambda) * loss_clean # Weighted loss
else:
loss = loss_adv
return loss
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment