import torch
import torch.nn.functional as F
from .base_defense import BaseDefense


# EWCDefense class extends BaseDefense to apply Elastic Weight Consolidation (EWC) defense
class EWCDefense(BaseDefense):
    def __init__(self, student_model, teacher_model=None, lambda_ewc=100.0, fisher_dict=None, old_params=None):
        super().__init__(student_model, teacher_model)  # Initialize the base defense class
        self.lambda_ewc = lambda_ewc  # Set the weight for EWC loss
        self.fisher_dict = fisher_dict if fisher_dict is not None else {}  # Fisher information for parameters
        self.old_params = old_params if old_params is not None else {}  # Old parameters for comparison

    # Loss function that combines cross-entropy and EWC loss
    def loss_function(self, x, y, **kwargs):
        logits = self.student_model(x)  # Get predictions
        ce_loss = F.cross_entropy(logits, y)  # Compute cross-entropy loss

        ewc_loss = 0.0
        # Compute EWC loss for each parameter
        for name, param in self.student_model.named_parameters():
            if name in self.fisher_dict and name in self.old_params:
                diff = param - self.old_params[name]  # Difference from old parameters
                fisher_val = self.fisher_dict[name]  # Fisher information for the parameter
                ewc_loss += (fisher_val * diff.pow(2)).sum()  # EWC loss

        total_loss = ce_loss + self.lambda_ewc * ewc_loss  # Combine the losses
        return total_loss

    # Update EWC parameters (Fisher information and old parameters)
    def update_ewc(self, fisher_dict, old_params):
        self.fisher_dict = fisher_dict
        self.old_params = old_params