import torch import torch.nn.functional as F from .base_defense import BaseDefense # FeatureExtractionDefense class implements feature extraction for regularizing the student model class FeatureExtractionDefense(BaseDefense): def __init__(self, student_model, teacher_model=None, feat_lambda=1.0, attack=None): super().__init__(student_model, teacher_model) # Initialize base class self.feat_lambda = feat_lambda # Regularization weight for feature extraction loss self.attack = attack # Standard deviation for noise added to inputs # Check if the student model has the required method for feature extraction if not hasattr(student_model, "forward_features"): raise ValueError("Student model must define forward_features(x) for feature extraction defense.") # Loss function combining cross-entropy loss and feature extraction loss def loss_function(self, x, y, **kwargs): logits_clean = self.student_model(x) # Get logits from student model for clean inputs loss_ce = F.cross_entropy(logits_clean, y) # Compute cross-entropy loss x_noisy = self.attack.generate(x, y) # Extract features for both clean and noisy inputs feats_clean = self.student_model.forward_features(x) feats_noisy = self.student_model.forward_features(x_noisy) # Compute feature extraction loss (MSE between clean and noisy features) loss_feat = F.mse_loss(feats_clean, feats_noisy) total_loss = loss_ce + self.feat_lambda * loss_feat # Combine the losses return total_loss