diff --git a/src/defenses/base_defense.py b/src/defenses/base_defense.py new file mode 100644 index 0000000000000000000000000000000000000000..574018f2ac0edee551eed88f758e231c396fc3c3 --- /dev/null +++ b/src/defenses/base_defense.py @@ -0,0 +1,16 @@ +import abc + + +# Base class for defense strategies in this framework. + +class BaseDefense(metaclass=abc.ABCMeta): + def __init__(self, student_model, teacher_model=None): + # student_model is the model being trained or adapted, + # teacher_model might be the model from the previous step/task. + self.student_model = student_model + self.teacher_model = teacher_model + + @abc.abstractmethod + def loss_function(self, x, y, **kwargs): + # Must return a scalar loss that can be backpropagated. + pass