diff --git a/src/defenses/joint_training.py b/src/defenses/joint_training.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9d3bcdad2d11c52b52d6ab9cb03e0c3ef233dc1
--- /dev/null
+++ b/src/defenses/joint_training.py
@@ -0,0 +1,23 @@
+import torch
+import torch.nn.functional as F
+from .base_defense import BaseDefense
+
+# JointTrainingDefense class implements joint training with both clean and adversarial examples
+class JointTrainingDefense(BaseDefense):
+    def __init__(self, student_model, teacher_model=None, joint_lambda=0.5):
+        super().__init__(student_model, teacher_model)  # Initialize base class
+        self.joint_lambda = joint_lambda                # Weight for clean vs adversarial loss
+
+    # Loss function combining clean and adversarial cross-entropy loss
+    def loss_function(self, x, y, x_adv=None, **kwargs):
+        logits_clean = self.student_model(x)           # Forward pass for clean input
+        loss_clean = F.cross_entropy(logits_clean, y)  # Cross-entropy loss for clean input
+
+        loss_adv = 0.0
+        if x_adv is not None:
+            logits_adv = self.student_model(x_adv)     # Forward pass for adversarial input
+            loss_adv = F.cross_entropy(logits_adv, y)  # Cross-entropy loss for adversarial input
+
+        # Total loss is a weighted combination of clean and adversarial losses
+        loss_total = self.joint_lambda * loss_clean + (1 - self.joint_lambda) * loss_adv
+        return loss_total