From ddf4bb2b85babde784b48ffbd15c0c9489f99b2f Mon Sep 17 00:00:00 2001
From: Mina Moshfegh <mina.moshfegh@fau.de>
Date: Wed, 19 Feb 2025 15:54:50 +0000
Subject: [PATCH] Upload New File

---
 src/defenses/joint_training.py | 23 +++++++++++++++++++++++
 1 file changed, 23 insertions(+)
 create mode 100644 src/defenses/joint_training.py

diff --git a/src/defenses/joint_training.py b/src/defenses/joint_training.py
new file mode 100644
index 0000000..d9d3bcd
--- /dev/null
+++ b/src/defenses/joint_training.py
@@ -0,0 +1,23 @@
+import torch
+import torch.nn.functional as F
+from .base_defense import BaseDefense
+
+# JointTrainingDefense class implements joint training with both clean and adversarial examples
+class JointTrainingDefense(BaseDefense):
+    def __init__(self, student_model, teacher_model=None, joint_lambda=0.5):
+        super().__init__(student_model, teacher_model)  # Initialize base class
+        self.joint_lambda = joint_lambda                # Weight for clean vs adversarial loss
+
+    # Loss function combining clean and adversarial cross-entropy loss
+    def loss_function(self, x, y, x_adv=None, **kwargs):
+        logits_clean = self.student_model(x)           # Forward pass for clean input
+        loss_clean = F.cross_entropy(logits_clean, y)  # Cross-entropy loss for clean input
+
+        loss_adv = 0.0
+        if x_adv is not None:
+            logits_adv = self.student_model(x_adv)     # Forward pass for adversarial input
+            loss_adv = F.cross_entropy(logits_adv, y)  # Cross-entropy loss for adversarial input
+
+        # Total loss is a weighted combination of clean and adversarial losses
+        loss_total = self.joint_lambda * loss_clean + (1 - self.joint_lambda) * loss_adv
+        return loss_total
-- 
GitLab