From ddf4bb2b85babde784b48ffbd15c0c9489f99b2f Mon Sep 17 00:00:00 2001 From: Mina Moshfegh <mina.moshfegh@fau.de> Date: Wed, 19 Feb 2025 15:54:50 +0000 Subject: [PATCH] Upload New File --- src/defenses/joint_training.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 src/defenses/joint_training.py diff --git a/src/defenses/joint_training.py b/src/defenses/joint_training.py new file mode 100644 index 0000000..d9d3bcd --- /dev/null +++ b/src/defenses/joint_training.py @@ -0,0 +1,23 @@ +import torch +import torch.nn.functional as F +from .base_defense import BaseDefense + +# JointTrainingDefense class implements joint training with both clean and adversarial examples +class JointTrainingDefense(BaseDefense): + def __init__(self, student_model, teacher_model=None, joint_lambda=0.5): + super().__init__(student_model, teacher_model) # Initialize base class + self.joint_lambda = joint_lambda # Weight for clean vs adversarial loss + + # Loss function combining clean and adversarial cross-entropy loss + def loss_function(self, x, y, x_adv=None, **kwargs): + logits_clean = self.student_model(x) # Forward pass for clean input + loss_clean = F.cross_entropy(logits_clean, y) # Cross-entropy loss for clean input + + loss_adv = 0.0 + if x_adv is not None: + logits_adv = self.student_model(x_adv) # Forward pass for adversarial input + loss_adv = F.cross_entropy(logits_adv, y) # Cross-entropy loss for adversarial input + + # Total loss is a weighted combination of clean and adversarial losses + loss_total = self.joint_lambda * loss_clean + (1 - self.joint_lambda) * loss_adv + return loss_total -- GitLab