import torch import torch.nn.functional as F from .base_defense import BaseDefense # FeatureExtractionDefense class implements feature extraction for regularizing the student model class FeatureExtractionDefense(BaseDefense): def __init__(self, student_model, teacher_model=None, feat_lambda=1.0, noise_std=0.01): super().__init__(student_model, teacher_model) # Initialize base class self.feat_lambda = feat_lambda # Regularization weight for feature extraction loss self.noise_std = noise_std # Standard deviation for noise added to inputs # Check if the student model has the required method for feature extraction if not hasattr(student_model, "forward_features"): raise ValueError("Student model must define forward_features(x) for feature extraction defense.") # Loss function combining cross-entropy loss and feature extraction loss def loss_function(self, x, y, **kwargs): logits_clean = self.student_model(x) # Get logits from student model for clean inputs loss_ce = F.cross_entropy(logits_clean, y) # Compute cross-entropy loss noise = self.noise_std * torch.randn_like(x) # Generate random noise x_noisy = torch.clamp(x + noise, 0.0, 1.0) # Add noise to input (clamp values to [0,1]) # Extract features for both clean and noisy inputs feats_clean = self.student_model.forward_features(x) feats_noisy = self.student_model.forward_features(x_noisy) # Compute feature extraction loss (MSE between clean and noisy features) loss_feat = F.mse_loss(feats_clean, feats_noisy) total_loss = loss_ce + self.feat_lambda * loss_feat # Combine the losses return total_loss