Skip to content
Snippets Groups Projects
Commit 953b4e4c authored by Mina Moshfegh's avatar Mina Moshfegh
Browse files

Upload New File

parent 212f5df2
No related branches found
No related tags found
No related merge requests found
import torch
import torch.nn.functional as F
from .base_defense import BaseDefense
# FeatureExtractionDefense class implements feature extraction for regularizing the student model
class FeatureExtractionDefense(BaseDefense):
def __init__(self, student_model, teacher_model=None, feat_lambda=1.0, attack=None):
super().__init__(student_model, teacher_model) # Initialize base class
self.feat_lambda = feat_lambda # Regularization weight for feature extraction loss
self.attack = attack # Standard deviation for noise added to inputs
# Check if the student model has the required method for feature extraction
if not hasattr(student_model, "forward_features"):
raise ValueError("Student model must define forward_features(x) for feature extraction defense.")
# Loss function combining cross-entropy loss and feature extraction loss
def loss_function(self, x, y, **kwargs):
logits_clean = self.student_model(x) # Get logits from student model for clean inputs
loss_ce = F.cross_entropy(logits_clean, y) # Compute cross-entropy loss
x_noisy = self.attack.generate(x, y)
# Extract features for both clean and noisy inputs
feats_clean = self.student_model.forward_features(x)
feats_noisy = self.student_model.forward_features(x_noisy)
# Compute feature extraction loss (MSE between clean and noisy features)
loss_feat = F.mse_loss(feats_clean, feats_noisy)
total_loss = loss_ce + self.feat_lambda * loss_feat # Combine the losses
return total_loss
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment