Skip to content
Snippets Groups Projects
feat_extraction.py 1.59 KiB
Newer Older
Mina Moshfegh's avatar
Mina Moshfegh committed
import torch
import torch.nn.functional as F
from .base_defense import BaseDefense


# FeatureExtractionDefense class implements feature extraction for regularizing the student model
class FeatureExtractionDefense(BaseDefense):
    def __init__(self, student_model, teacher_model=None, feat_lambda=1.0, attack=None):
        super().__init__(student_model, teacher_model)  # Initialize base class
        self.feat_lambda = feat_lambda  # Regularization weight for feature extraction loss
        self.attack = attack  # Standard deviation for noise added to inputs

        # Check if the student model has the required method for feature extraction
        if not hasattr(student_model, "forward_features"):
            raise ValueError("Student model must define forward_features(x) for feature extraction defense.")

    # Loss function combining cross-entropy loss and feature extraction loss
    def loss_function(self, x, y, **kwargs):
        logits_clean = self.student_model(x)  # Get logits from student model for clean inputs
        loss_ce = F.cross_entropy(logits_clean, y)  # Compute cross-entropy loss

        x_noisy = self.attack.generate(x, y)

        # Extract features for both clean and noisy inputs
        feats_clean = self.student_model.forward_features(x)
        feats_noisy = self.student_model.forward_features(x_noisy)

        # Compute feature extraction loss (MSE between clean and noisy features)
        loss_feat = F.mse_loss(feats_clean, feats_noisy)

        total_loss = loss_ce + self.feat_lambda * loss_feat  # Combine the losses
        return total_loss