Skip to content
Snippets Groups Projects
Commit ddf4bb2b authored by Mina Moshfegh's avatar Mina Moshfegh
Browse files

Upload New File

parent 32275c39
No related branches found
No related tags found
No related merge requests found
import torch
import torch.nn.functional as F
from .base_defense import BaseDefense
# JointTrainingDefense class implements joint training with both clean and adversarial examples
class JointTrainingDefense(BaseDefense):
def __init__(self, student_model, teacher_model=None, joint_lambda=0.5):
super().__init__(student_model, teacher_model) # Initialize base class
self.joint_lambda = joint_lambda # Weight for clean vs adversarial loss
# Loss function combining clean and adversarial cross-entropy loss
def loss_function(self, x, y, x_adv=None, **kwargs):
logits_clean = self.student_model(x) # Forward pass for clean input
loss_clean = F.cross_entropy(logits_clean, y) # Cross-entropy loss for clean input
loss_adv = 0.0
if x_adv is not None:
logits_adv = self.student_model(x_adv) # Forward pass for adversarial input
loss_adv = F.cross_entropy(logits_adv, y) # Cross-entropy loss for adversarial input
# Total loss is a weighted combination of clean and adversarial losses
loss_total = self.joint_lambda * loss_clean + (1 - self.joint_lambda) * loss_adv
return loss_total
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment