import torch import torch.nn.functional as F from .base_defense import BaseDefense # EWCDefense class extends BaseDefense to apply Elastic Weight Consolidation (EWC) defense class EWCDefense(BaseDefense): def __init__(self, student_model, teacher_model=None, lambda_ewc=100.0, fisher_dict=None, old_params=None): super().__init__(student_model, teacher_model) # Initialize the base defense class self.lambda_ewc = lambda_ewc # Set the weight for EWC loss self.fisher_dict = fisher_dict if fisher_dict is not None else {} # Fisher information for parameters self.old_params = old_params if old_params is not None else {} # Old parameters for comparison # Loss function that combines cross-entropy and EWC loss def loss_function(self, x, y, **kwargs): logits = self.student_model(x) # Get predictions ce_loss = F.cross_entropy(logits, y) # Compute cross-entropy loss ewc_loss = 0.0 # Compute EWC loss for each parameter for name, param in self.student_model.named_parameters(): if name in self.fisher_dict and name in self.old_params: diff = param - self.old_params[name] # Difference from old parameters fisher_val = self.fisher_dict[name] # Fisher information for the parameter ewc_loss += (fisher_val * diff.pow(2)).sum() # EWC loss total_loss = ce_loss + self.lambda_ewc * ewc_loss # Combine the losses return total_loss # Update EWC parameters (Fisher information and old parameters) def update_ewc(self, fisher_dict, old_params): self.fisher_dict = fisher_dict self.old_params = old_params