Skip to content
Snippets Groups Projects
Commit 37aaf62b authored by Mina Moshfegh's avatar Mina Moshfegh
Browse files

Delete ewc_defense.py

parent 8717bd14
No related branches found
No related tags found
No related merge requests found
import torch
import torch.nn.functional as F
from .base_defense import BaseDefense
# Function to compute Fisher Information
def compute_fisher_information(model, data_loader, device="cuda", sample_size=1024):
fisher_dict = {}
model.eval() # Set model to evaluation mode (no gradients will be calculated)
count = 0
# Loop through data in the loader to calculate Fisher Information
for x, y in data_loader:
x, y = x.to(device), y.to(device)
model.zero_grad() # Zero out gradients
output = model(x) # Forward pass
loss = F.log_softmax(output, dim=1) # Softmax loss
loss = loss[range(len(y)), y] # Select correct class
loss = loss.mean() # Mean loss over the batch
loss.backward() # Backpropagate to compute gradients
# Calculate Fisher information for each parameter in the model
for name, param in model.named_parameters():
if param.grad is not None:
grad = param.grad.detach().clone() # Detach and clone gradient
grad_sq = grad.pow(2) # Square the gradients to compute Fisher Information
if name not in fisher_dict:
fisher_dict[name] = grad_sq # Initialize Fisher information if not present
else:
fisher_dict[name] += grad_sq # Accumulate Fisher information
count += 1
if count * x.size(0) >= sample_size:
break # Stop once we've processed enough samples
# Normalize the Fisher information
for name in fisher_dict:
fisher_dict[name] /= float(count)
return fisher_dict # Return Fisher information for each parameter
# Function to copy model parameters
def copy_params(model):
old_params = {}
for name, param in model.named_parameters():
old_params[name] = param.detach().clone() # Detach and clone model parameters
return old_params
# EWCDefense class extends BaseDefense to apply Elastic Weight Consolidation (EWC) defense
class EWCDefense(BaseDefense):
def __init__(self, student_model, teacher_model=None, lambda_ewc=100.0, fisher_dict=None, old_params=None):
super().__init__(student_model, teacher_model) # Initialize the base defense class
self.lambda_ewc = lambda_ewc # Set the weight for EWC loss
self.fisher_dict = fisher_dict if fisher_dict is not None else {} # Fisher information for parameters
self.old_params = old_params if old_params is not None else {} # Old parameters for comparison
# Loss function that combines cross-entropy and EWC loss
def loss_function(self, x, y, **kwargs):
logits = self.student_model(x) # Get predictions
ce_loss = F.cross_entropy(logits, y) # Compute cross-entropy loss
ewc_loss = 0.0
# Compute EWC loss for each parameter
for name, param in self.student_model.named_parameters():
if name in self.fisher_dict and name in self.old_params:
diff = param - self.old_params[name] # Difference from old parameters
fisher_val = self.fisher_dict[name] # Fisher information for the parameter
ewc_loss += (fisher_val * diff.pow(2)).sum() # EWC loss
total_loss = ce_loss + self.lambda_ewc * ewc_loss # Combine the losses
return total_loss
# Update EWC parameters (Fisher information and old parameters)
def update_ewc(self, fisher_dict, old_params):
self.fisher_dict = fisher_dict
self.old_params = old_params
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment