diff --git a/utils/callbacks.py b/utils/callbacks.py index f23d57a6c04345b270702233a4569cbfd0729d02..a204ec1ceaaf50adc173d80c28861463aba22043 100644 --- a/utils/callbacks.py +++ b/utils/callbacks.py @@ -58,12 +58,11 @@ class Callbacks: else: return self._callbacks - @staticmethod - def run_callbacks(register, *args, **kwargs): + def run_callbacks(self, hook, *args, **kwargs): """ Loop through the registered actions and fire all callbacks """ - for logger in register: + for logger in self._callbacks[hook]: # print(f"Running callbacks.{logger['callback'].__name__}()") logger['callback'](*args, **kwargs) @@ -71,106 +70,106 @@ class Callbacks: """ Fires all registered callbacks at the start of each pretraining routine """ - self.run_callbacks(self._callbacks['on_pretrain_routine_start'], *args, **kwargs) + self.run_callbacks('on_pretrain_routine_start', *args, **kwargs) def on_pretrain_routine_end(self, *args, **kwargs): """ Fires all registered callbacks at the end of each pretraining routine """ - self.run_callbacks(self._callbacks['on_pretrain_routine_end'], *args, **kwargs) + self.run_callbacks('on_pretrain_routine_end', *args, **kwargs) def on_train_start(self, *args, **kwargs): """ Fires all registered callbacks at the start of each training """ - self.run_callbacks(self._callbacks['on_train_start'], *args, **kwargs) + self.run_callbacks('on_train_start', *args, **kwargs) def on_train_epoch_start(self, *args, **kwargs): """ Fires all registered callbacks at the start of each training epoch """ - self.run_callbacks(self._callbacks['on_train_epoch_start'], *args, **kwargs) + self.run_callbacks('on_train_epoch_start', *args, **kwargs) def on_train_batch_start(self, *args, **kwargs): """ Fires all registered callbacks at the start of each training batch """ - self.run_callbacks(self._callbacks['on_train_batch_start'], *args, **kwargs) + self.run_callbacks('on_train_batch_start', *args, **kwargs) def optimizer_step(self, *args, **kwargs): """ Fires all registered callbacks on each optimizer step """ - self.run_callbacks(self._callbacks['optimizer_step'], *args, **kwargs) + self.run_callbacks('optimizer_step', *args, **kwargs) def on_before_zero_grad(self, *args, **kwargs): """ Fires all registered callbacks before zero grad """ - self.run_callbacks(self._callbacks['on_before_zero_grad'], *args, **kwargs) + self.run_callbacks('on_before_zero_grad', *args, **kwargs) def on_train_batch_end(self, *args, **kwargs): """ Fires all registered callbacks at the end of each training batch """ - self.run_callbacks(self._callbacks['on_train_batch_end'], *args, **kwargs) + self.run_callbacks('on_train_batch_end', *args, **kwargs) def on_train_epoch_end(self, *args, **kwargs): """ Fires all registered callbacks at the end of each training epoch """ - self.run_callbacks(self._callbacks['on_train_epoch_end'], *args, **kwargs) + self.run_callbacks('on_train_epoch_end', *args, **kwargs) def on_val_start(self, *args, **kwargs): """ Fires all registered callbacks at the start of the validation """ - self.run_callbacks(self._callbacks['on_val_start'], *args, **kwargs) + self.run_callbacks('on_val_start', *args, **kwargs) def on_val_batch_start(self, *args, **kwargs): """ Fires all registered callbacks at the start of each validation batch """ - self.run_callbacks(self._callbacks['on_val_batch_start'], *args, **kwargs) + self.run_callbacks('on_val_batch_start', *args, **kwargs) def on_val_image_end(self, *args, **kwargs): """ Fires all registered callbacks at the end of each val image """ - self.run_callbacks(self._callbacks['on_val_image_end'], *args, **kwargs) + self.run_callbacks('on_val_image_end', *args, **kwargs) def on_val_batch_end(self, *args, **kwargs): """ Fires all registered callbacks at the end of each validation batch """ - self.run_callbacks(self._callbacks['on_val_batch_end'], *args, **kwargs) + self.run_callbacks('on_val_batch_end', *args, **kwargs) def on_val_end(self, *args, **kwargs): """ Fires all registered callbacks at the end of the validation """ - self.run_callbacks(self._callbacks['on_val_end'], *args, **kwargs) + self.run_callbacks('on_val_end', *args, **kwargs) def on_fit_epoch_end(self, *args, **kwargs): """ Fires all registered callbacks at the end of each fit (train+val) epoch """ - self.run_callbacks(self._callbacks['on_fit_epoch_end'], *args, **kwargs) + self.run_callbacks('on_fit_epoch_end', *args, **kwargs) def on_model_save(self, *args, **kwargs): """ Fires all registered callbacks after each model save """ - self.run_callbacks(self._callbacks['on_model_save'], *args, **kwargs) + self.run_callbacks('on_model_save', *args, **kwargs) def on_train_end(self, *args, **kwargs): """ Fires all registered callbacks at the end of training """ - self.run_callbacks(self._callbacks['on_train_end'], *args, **kwargs) + self.run_callbacks('on_train_end', *args, **kwargs) def teardown(self, *args, **kwargs): """ Fires all registered callbacks before teardown """ - self.run_callbacks(self._callbacks['teardown'], *args, **kwargs) + self.run_callbacks('teardown', *args, **kwargs)