Skip to content
Snippets Groups Projects
Commit 192b6bc1 authored by Mina Moshfegh's avatar Mina Moshfegh
Browse files

Upload New File

parent ddf4bb2b
No related branches found
No related tags found
No related merge requests found
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base_defense import BaseDefense
# LFLDefense class implements Learning with Forgetting Labels
class LFLDefense(BaseDefense):
def __init__(self, student_model, teacher_model,
lambda_lfl=1.0,
freeze_classifier=True,
feature_lambda=1.0):
super().__init__(student_model, teacher_model) # Initialize base class
self.lambda_lfl = lambda_lfl # Regularization weight for LFL
self.freeze_classifier = freeze_classifier # Flag to freeze classifier layers
self.feature_lambda = feature_lambda # Weight for feature consistency loss
# Optionally freeze classifier layers
if freeze_classifier:
for name, param in self.student_model.named_parameters():
if "fc" in name or "linear" in name:
param.requires_grad = False
# Check if models have feature extraction methods
if not hasattr(student_model, "forward_features"):
raise ValueError("Student model must define forward_features(x) to extract features for LFL.")
if teacher_model and not hasattr(teacher_model, "forward_features"):
raise ValueError("Teacher model must define forward_features(x) to extract features for LFL.")
# Loss function combining cross-entropy loss and feature consistency loss
def loss_function(self, x, y, **kwargs):
student_logits = self.student_model(x) # Forward pass for student model
ce_loss = F.cross_entropy(student_logits, y) # Cross-entropy loss for student model
loss_total = ce_loss
if self.teacher_model is not None:
with torch.no_grad():
teacher_feats = self.teacher_model.forward_features(x) # Extract features from teacher model
student_feats = self.student_model.forward_features(x) # Extract features from student model
feat_dist = F.mse_loss(student_feats, teacher_feats) # Compute feature consistency loss
loss_total = ce_loss + self.feature_lambda * feat_dist # Combine the losses
return loss_total
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment