Skip to content
Snippets Groups Projects
Commit 97204d5d authored by Mina Moshfegh's avatar Mina Moshfegh
Browse files

Delete feat_extraction.py

parent 37aaf62b
No related branches found
No related tags found
No related merge requests found
import torch
import torch.nn.functional as F
from .base_defense import BaseDefense
# FeatureExtractionDefense class implements feature extraction for regularizing the student model
class FeatureExtractionDefense(BaseDefense):
def __init__(self, student_model, teacher_model=None,
feat_lambda=1.0,
noise_std=0.01):
super().__init__(student_model, teacher_model) # Initialize base class
self.feat_lambda = feat_lambda # Regularization weight for feature extraction loss
self.noise_std = noise_std # Standard deviation for noise added to inputs
# Check if the student model has the required method for feature extraction
if not hasattr(student_model, "forward_features"):
raise ValueError("Student model must define forward_features(x) for feature extraction defense.")
# Loss function combining cross-entropy loss and feature extraction loss
def loss_function(self, x, y, **kwargs):
logits_clean = self.student_model(x) # Get logits from student model for clean inputs
loss_ce = F.cross_entropy(logits_clean, y) # Compute cross-entropy loss
noise = self.noise_std * torch.randn_like(x) # Generate random noise
x_noisy = torch.clamp(x + noise, 0.0, 1.0) # Add noise to input (clamp values to [0,1])
# Extract features for both clean and noisy inputs
feats_clean = self.student_model.forward_features(x)
feats_noisy = self.student_model.forward_features(x_noisy)
# Compute feature extraction loss (MSE between clean and noisy features)
loss_feat = F.mse_loss(feats_clean, feats_noisy)
total_loss = loss_ce + self.feat_lambda * loss_feat # Combine the losses
return total_loss
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment