yolov5/utils/callbacks.py

77 lines
2.6 KiB
Python
Raw Normal View History

# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
"""
Callback utils
"""
import threading
class Callbacks:
""""
Handles all registered callbacks for YOLOv5 Hooks
"""
def __init__(self):
# Define the available callbacks
self._callbacks = {
'on_pretrain_routine_start': [],
'on_pretrain_routine_end': [],
'on_train_start': [],
'on_train_epoch_start': [],
'on_train_batch_start': [],
'optimizer_step': [],
'on_before_zero_grad': [],
'on_train_batch_end': [],
'on_train_epoch_end': [],
'on_val_start': [],
'on_val_batch_start': [],
'on_val_image_end': [],
'on_val_batch_end': [],
'on_val_end': [],
'on_fit_epoch_end': [], # fit = train + val
'on_model_save': [],
'on_train_end': [],
'on_params_update': [],
precommit: yapf (#5494) * precommit: yapf * align isort * fix # Conflicts: # utils/plots.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.cfg * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.cfg * Update setup.cfg * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update wandb_utils.py * Update augmentations.py * Update setup.cfg * Update yolo.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update val.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * simplify colorstr * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * val run fix * export.py last comma * Update export.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update hubconf.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * PyTorch Hub tuple fix * PyTorch Hub tuple fix2 * PyTorch Hub tuple fix3 * Update setup Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
2022-03-31 22:52:34 +08:00
'teardown': [],}
self.stop_training = False # set True to interrupt training
def register_action(self, hook, name='', callback=None):
"""
Register a new action to a callback hook
Args:
hook: The callback hook name to register the action to
name: The name of the action for later reference
callback: The callback to fire
"""
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
assert callable(callback), f"callback '{callback}' is not callable"
self._callbacks[hook].append({'name': name, 'callback': callback})
def get_registered_actions(self, hook=None):
""""
Returns all the registered actions by callback hook
Args:
hook: The name of the hook to check, defaults to all
"""
return self._callbacks[hook] if hook else self._callbacks
def run(self, hook, *args, thread=False, **kwargs):
"""
Loop through the registered actions and fire all callbacks on main thread
Args:
hook: The name of the hook to check, defaults to all
args: Arguments to receive from YOLOv5
thread: (boolean) Run callbacks in daemon thread
kwargs: Keyword Arguments to receive from YOLOv5
"""
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
for logger in self._callbacks[hook]:
if thread:
threading.Thread(target=logger['callback'], args=args, kwargs=kwargs, daemon=True).start()
else:
logger['callback'](*args, **kwargs)