Simplify callbacks (#4289)
parent
771ac6c53d
commit
4103ce9ad0
|
@ -58,12 +58,11 @@ class Callbacks:
|
|||
else:
|
||||
return self._callbacks
|
||||
|
||||
@staticmethod
|
||||
def run_callbacks(register, *args, **kwargs):
|
||||
def run_callbacks(self, hook, *args, **kwargs):
|
||||
"""
|
||||
Loop through the registered actions and fire all callbacks
|
||||
"""
|
||||
for logger in register:
|
||||
for logger in self._callbacks[hook]:
|
||||
# print(f"Running callbacks.{logger['callback'].__name__}()")
|
||||
logger['callback'](*args, **kwargs)
|
||||
|
||||
|
@ -71,106 +70,106 @@ class Callbacks:
|
|||
"""
|
||||
Fires all registered callbacks at the start of each pretraining routine
|
||||
"""
|
||||
self.run_callbacks(self._callbacks['on_pretrain_routine_start'], *args, **kwargs)
|
||||
self.run_callbacks('on_pretrain_routine_start', *args, **kwargs)
|
||||
|
||||
def on_pretrain_routine_end(self, *args, **kwargs):
|
||||
"""
|
||||
Fires all registered callbacks at the end of each pretraining routine
|
||||
"""
|
||||
self.run_callbacks(self._callbacks['on_pretrain_routine_end'], *args, **kwargs)
|
||||
self.run_callbacks('on_pretrain_routine_end', *args, **kwargs)
|
||||
|
||||
def on_train_start(self, *args, **kwargs):
|
||||
"""
|
||||
Fires all registered callbacks at the start of each training
|
||||
"""
|
||||
self.run_callbacks(self._callbacks['on_train_start'], *args, **kwargs)
|
||||
self.run_callbacks('on_train_start', *args, **kwargs)
|
||||
|
||||
def on_train_epoch_start(self, *args, **kwargs):
|
||||
"""
|
||||
Fires all registered callbacks at the start of each training epoch
|
||||
"""
|
||||
self.run_callbacks(self._callbacks['on_train_epoch_start'], *args, **kwargs)
|
||||
self.run_callbacks('on_train_epoch_start', *args, **kwargs)
|
||||
|
||||
def on_train_batch_start(self, *args, **kwargs):
|
||||
"""
|
||||
Fires all registered callbacks at the start of each training batch
|
||||
"""
|
||||
self.run_callbacks(self._callbacks['on_train_batch_start'], *args, **kwargs)
|
||||
self.run_callbacks('on_train_batch_start', *args, **kwargs)
|
||||
|
||||
def optimizer_step(self, *args, **kwargs):
|
||||
"""
|
||||
Fires all registered callbacks on each optimizer step
|
||||
"""
|
||||
self.run_callbacks(self._callbacks['optimizer_step'], *args, **kwargs)
|
||||
self.run_callbacks('optimizer_step', *args, **kwargs)
|
||||
|
||||
def on_before_zero_grad(self, *args, **kwargs):
|
||||
"""
|
||||
Fires all registered callbacks before zero grad
|
||||
"""
|
||||
self.run_callbacks(self._callbacks['on_before_zero_grad'], *args, **kwargs)
|
||||
self.run_callbacks('on_before_zero_grad', *args, **kwargs)
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
"""
|
||||
Fires all registered callbacks at the end of each training batch
|
||||
"""
|
||||
self.run_callbacks(self._callbacks['on_train_batch_end'], *args, **kwargs)
|
||||
self.run_callbacks('on_train_batch_end', *args, **kwargs)
|
||||
|
||||
def on_train_epoch_end(self, *args, **kwargs):
|
||||
"""
|
||||
Fires all registered callbacks at the end of each training epoch
|
||||
"""
|
||||
self.run_callbacks(self._callbacks['on_train_epoch_end'], *args, **kwargs)
|
||||
self.run_callbacks('on_train_epoch_end', *args, **kwargs)
|
||||
|
||||
def on_val_start(self, *args, **kwargs):
|
||||
"""
|
||||
Fires all registered callbacks at the start of the validation
|
||||
"""
|
||||
self.run_callbacks(self._callbacks['on_val_start'], *args, **kwargs)
|
||||
self.run_callbacks('on_val_start', *args, **kwargs)
|
||||
|
||||
def on_val_batch_start(self, *args, **kwargs):
|
||||
"""
|
||||
Fires all registered callbacks at the start of each validation batch
|
||||
"""
|
||||
self.run_callbacks(self._callbacks['on_val_batch_start'], *args, **kwargs)
|
||||
self.run_callbacks('on_val_batch_start', *args, **kwargs)
|
||||
|
||||
def on_val_image_end(self, *args, **kwargs):
|
||||
"""
|
||||
Fires all registered callbacks at the end of each val image
|
||||
"""
|
||||
self.run_callbacks(self._callbacks['on_val_image_end'], *args, **kwargs)
|
||||
self.run_callbacks('on_val_image_end', *args, **kwargs)
|
||||
|
||||
def on_val_batch_end(self, *args, **kwargs):
|
||||
"""
|
||||
Fires all registered callbacks at the end of each validation batch
|
||||
"""
|
||||
self.run_callbacks(self._callbacks['on_val_batch_end'], *args, **kwargs)
|
||||
self.run_callbacks('on_val_batch_end', *args, **kwargs)
|
||||
|
||||
def on_val_end(self, *args, **kwargs):
|
||||
"""
|
||||
Fires all registered callbacks at the end of the validation
|
||||
"""
|
||||
self.run_callbacks(self._callbacks['on_val_end'], *args, **kwargs)
|
||||
self.run_callbacks('on_val_end', *args, **kwargs)
|
||||
|
||||
def on_fit_epoch_end(self, *args, **kwargs):
|
||||
"""
|
||||
Fires all registered callbacks at the end of each fit (train+val) epoch
|
||||
"""
|
||||
self.run_callbacks(self._callbacks['on_fit_epoch_end'], *args, **kwargs)
|
||||
self.run_callbacks('on_fit_epoch_end', *args, **kwargs)
|
||||
|
||||
def on_model_save(self, *args, **kwargs):
|
||||
"""
|
||||
Fires all registered callbacks after each model save
|
||||
"""
|
||||
self.run_callbacks(self._callbacks['on_model_save'], *args, **kwargs)
|
||||
self.run_callbacks('on_model_save', *args, **kwargs)
|
||||
|
||||
def on_train_end(self, *args, **kwargs):
|
||||
"""
|
||||
Fires all registered callbacks at the end of training
|
||||
"""
|
||||
self.run_callbacks(self._callbacks['on_train_end'], *args, **kwargs)
|
||||
self.run_callbacks('on_train_end', *args, **kwargs)
|
||||
|
||||
def teardown(self, *args, **kwargs):
|
||||
"""
|
||||
Fires all registered callbacks before teardown
|
||||
"""
|
||||
self.run_callbacks(self._callbacks['teardown'], *args, **kwargs)
|
||||
self.run_callbacks('teardown', *args, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue