Skip to content
Snippets Groups Projects
feat_extraction.py 1.81 KiB
Newer Older
Mina Moshfegh's avatar
Mina Moshfegh committed
import torch
import torch.nn.functional as F
from .base_defense import BaseDefense

# FeatureExtractionDefense class implements feature extraction for regularizing the student model
class FeatureExtractionDefense(BaseDefense):
    def __init__(self, student_model, teacher_model=None,
                 feat_lambda=1.0,
                 noise_std=0.01):
        super().__init__(student_model, teacher_model)  # Initialize base class
        self.feat_lambda = feat_lambda                  # Regularization weight for feature extraction loss
        self.noise_std = noise_std                      # Standard deviation for noise added to inputs

        # Check if the student model has the required method for feature extraction
        if not hasattr(student_model, "forward_features"):
            raise ValueError("Student model must define forward_features(x) for feature extraction defense.")

    # Loss function combining cross-entropy loss and feature extraction loss
    def loss_function(self, x, y, **kwargs):
        logits_clean = self.student_model(x)        # Get logits from student model for clean inputs
        loss_ce = F.cross_entropy(logits_clean, y)  # Compute cross-entropy loss

        noise = self.noise_std * torch.randn_like(x)          # Generate random noise
        x_noisy = torch.clamp(x + noise, 0.0, 1.0)  # Add noise to input (clamp values to [0,1])

        # Extract features for both clean and noisy inputs
        feats_clean = self.student_model.forward_features(x)
        feats_noisy = self.student_model.forward_features(x_noisy)

        # Compute feature extraction loss (MSE between clean and noisy features)
        loss_feat = F.mse_loss(feats_clean, feats_noisy)

        total_loss = loss_ce + self.feat_lambda * loss_feat  # Combine the losses
        return total_loss