Skip to content
Snippets Groups Projects
Commit 212f5df2 authored by Mina Moshfegh's avatar Mina Moshfegh
Browse files

Upload New File

parent 4dfd5946
No related branches found
No related tags found
No related merge requests found
import torch
import torch.nn.functional as F
from .base_defense import BaseDefense
# EWCDefense class extends BaseDefense to apply Elastic Weight Consolidation (EWC) defense
class EWCDefense(BaseDefense):
def __init__(self, student_model, teacher_model=None, lambda_ewc=100.0, fisher_dict=None, old_params=None):
super().__init__(student_model, teacher_model) # Initialize the base defense class
self.lambda_ewc = lambda_ewc # Set the weight for EWC loss
self.fisher_dict = fisher_dict if fisher_dict is not None else {} # Fisher information for parameters
self.old_params = old_params if old_params is not None else {} # Old parameters for comparison
# Loss function that combines cross-entropy and EWC loss
def loss_function(self, x, y, **kwargs):
logits = self.student_model(x) # Get predictions
ce_loss = F.cross_entropy(logits, y) # Compute cross-entropy loss
ewc_loss = 0.0
# Compute EWC loss for each parameter
for name, param in self.student_model.named_parameters():
if name in self.fisher_dict and name in self.old_params:
diff = param - self.old_params[name] # Difference from old parameters
fisher_val = self.fisher_dict[name] # Fisher information for the parameter
ewc_loss += (fisher_val * diff.pow(2)).sum() # EWC loss
total_loss = ce_loss + self.lambda_ewc * ewc_loss # Combine the losses
return total_loss
# Update EWC parameters (Fisher information and old parameters)
def update_ewc(self, fisher_dict, old_params):
self.fisher_dict = fisher_dict
self.old_params = old_params
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment