Skip to content
Snippets Groups Projects
ewc_defense.py 3.63 KiB
Newer Older
Mina Moshfegh's avatar
Mina Moshfegh committed
import torch
import torch.nn.functional as F
from .base_defense import BaseDefense


# Function to compute Fisher Information
def compute_fisher_information(model, data_loader, device="cuda", sample_size=1024):
    fisher_dict = {}
    model.eval()  # Set model to evaluation mode (no gradients will be calculated)
    count = 0

    # Loop through data in the loader to calculate Fisher Information
    for x, y in data_loader:
        x, y = x.to(device), y.to(device)
        model.zero_grad()  # Zero out gradients

        output = model(x)                    # Forward pass
        loss = F.log_softmax(output, dim=1)  # Softmax loss
        loss = loss[range(len(y)), y]        # Select correct class
        loss = loss.mean()                   # Mean loss over the batch
        loss.backward()                      # Backpropagate to compute gradients

        # Calculate Fisher information for each parameter in the model
        for name, param in model.named_parameters():
            if param.grad is not None:
                grad = param.grad.detach().clone()  # Detach and clone gradient
                grad_sq = grad.pow(2)               # Square the gradients to compute Fisher Information
                if name not in fisher_dict:
                    fisher_dict[name] = grad_sq     # Initialize Fisher information if not present
                else:
                    fisher_dict[name] += grad_sq    # Accumulate Fisher information
        count += 1
        if count * x.size(0) >= sample_size:
            break  # Stop once we've processed enough samples

    # Normalize the Fisher information
    for name in fisher_dict:
        fisher_dict[name] /= float(count)
    return fisher_dict  # Return Fisher information for each parameter


# Function to copy model parameters
def copy_params(model):
    old_params = {}
    for name, param in model.named_parameters():
        old_params[name] = param.detach().clone()  # Detach and clone model parameters
    return old_params


# EWCDefense class extends BaseDefense to apply Elastic Weight Consolidation (EWC) defense
class EWCDefense(BaseDefense):
    def __init__(self, student_model, teacher_model=None, lambda_ewc=100.0, fisher_dict=None, old_params=None):
        super().__init__(student_model, teacher_model)  # Initialize the base defense class
        self.lambda_ewc = lambda_ewc  # Set the weight for EWC loss
        self.fisher_dict = fisher_dict if fisher_dict is not None else {}  # Fisher information for parameters
        self.old_params = old_params if old_params is not None else {}  # Old parameters for comparison

    # Loss function that combines cross-entropy and EWC loss
    def loss_function(self, x, y, **kwargs):
        logits = self.student_model(x)        # Get predictions
        ce_loss = F.cross_entropy(logits, y)  # Compute cross-entropy loss

        ewc_loss = 0.0
        # Compute EWC loss for each parameter
        for name, param in self.student_model.named_parameters():
            if name in self.fisher_dict and name in self.old_params:
                diff = param - self.old_params[name]          # Difference from old parameters
                fisher_val = self.fisher_dict[name]           # Fisher information for the parameter
                ewc_loss += (fisher_val * diff.pow(2)).sum()  # EWC loss

        total_loss = ce_loss + self.lambda_ewc * ewc_loss  # Combine the losses
        return total_loss

    # Update EWC parameters (Fisher information and old parameters)
    def update_ewc(self, fisher_dict, old_params):
        self.fisher_dict = fisher_dict
        self.old_params = old_params