Newer
Older
import torch
import torch.nn.functional as F
from .base_defense import BaseDefense
# EWCDefense class extends BaseDefense to apply Elastic Weight Consolidation (EWC) defense
class EWCDefense(BaseDefense):
def __init__(self, student_model, teacher_model=None, lambda_ewc=100.0, fisher_dict=None, old_params=None):
super().__init__(student_model, teacher_model) # Initialize the base defense class
self.lambda_ewc = lambda_ewc # Set the weight for EWC loss
self.fisher_dict = fisher_dict if fisher_dict is not None else {} # Fisher information for parameters
self.old_params = old_params if old_params is not None else {} # Old parameters for comparison
# Loss function that combines cross-entropy and EWC loss
def loss_function(self, x, y, **kwargs):
logits = self.student_model(x) # Get predictions
ce_loss = F.cross_entropy(logits, y) # Compute cross-entropy loss
ewc_loss = 0.0
# Compute EWC loss for each parameter
for name, param in self.student_model.named_parameters():
if name in self.fisher_dict and name in self.old_params:
diff = param - self.old_params[name] # Difference from old parameters
fisher_val = self.fisher_dict[name] # Fisher information for the parameter
ewc_loss += (fisher_val * diff.pow(2)).sum() # EWC loss
total_loss = ce_loss + self.lambda_ewc * ewc_loss # Combine the losses
return total_loss
# Update EWC parameters (Fisher information and old parameters)
def update_ewc(self, fisher_dict, old_params):
self.fisher_dict = fisher_dict
self.old_params = old_params