2025-01-15 21:13:57 +08:00
|
|
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
|
|
|
2024-01-08 08:29:14 +08:00
|
|
|
"""Callback utils."""
|
2021-08-15 03:17:51 +08:00
|
|
|
|
2022-08-21 09:47:37 +08:00
|
|
|
import threading
|
|
|
|
|
2021-08-01 06:18:07 +08:00
|
|
|
|
|
|
|
class Callbacks:
|
2024-02-25 20:57:41 +08:00
|
|
|
"""Handles all registered callbacks for YOLOv5 Hooks."""
|
2022-04-05 04:47:00 +08:00
|
|
|
|
2021-12-14 22:47:49 +08:00
|
|
|
def __init__(self):
|
2024-02-25 21:04:01 +08:00
|
|
|
"""Initializes a Callbacks object to manage registered YOLOv5 training event hooks."""
|
2021-12-14 22:47:49 +08:00
|
|
|
self._callbacks = {
|
2024-01-08 08:29:14 +08:00
|
|
|
"on_pretrain_routine_start": [],
|
|
|
|
"on_pretrain_routine_end": [],
|
|
|
|
"on_train_start": [],
|
|
|
|
"on_train_epoch_start": [],
|
|
|
|
"on_train_batch_start": [],
|
|
|
|
"optimizer_step": [],
|
|
|
|
"on_before_zero_grad": [],
|
|
|
|
"on_train_batch_end": [],
|
|
|
|
"on_train_epoch_end": [],
|
|
|
|
"on_val_start": [],
|
|
|
|
"on_val_batch_start": [],
|
|
|
|
"on_val_image_end": [],
|
|
|
|
"on_val_batch_end": [],
|
|
|
|
"on_val_end": [],
|
|
|
|
"on_fit_epoch_end": [], # fit = train + val
|
|
|
|
"on_model_save": [],
|
|
|
|
"on_train_end": [],
|
|
|
|
"on_params_update": [],
|
|
|
|
"teardown": [],
|
|
|
|
}
|
2022-01-23 10:37:21 +08:00
|
|
|
self.stop_training = False # set True to interrupt training
|
2021-08-01 06:18:07 +08:00
|
|
|
|
2024-01-08 08:29:14 +08:00
|
|
|
def register_action(self, hook, name="", callback=None):
|
2021-08-01 06:18:07 +08:00
|
|
|
"""
|
2024-01-08 08:29:14 +08:00
|
|
|
Register a new action to a callback hook.
|
2021-08-01 06:18:07 +08:00
|
|
|
|
|
|
|
Args:
|
2022-04-07 22:15:01 +08:00
|
|
|
hook: The callback hook name to register the action to
|
|
|
|
name: The name of the action for later reference
|
|
|
|
callback: The callback to fire
|
2021-08-01 06:18:07 +08:00
|
|
|
"""
|
|
|
|
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
|
|
|
|
assert callable(callback), f"callback '{callback}' is not callable"
|
2024-01-08 08:29:14 +08:00
|
|
|
self._callbacks[hook].append({"name": name, "callback": callback})
|
2021-08-01 06:18:07 +08:00
|
|
|
|
|
|
|
def get_registered_actions(self, hook=None):
|
2024-01-08 08:29:14 +08:00
|
|
|
"""
|
2024-02-25 20:57:41 +08:00
|
|
|
Returns all the registered actions by callback hook.
|
2021-08-01 06:18:07 +08:00
|
|
|
|
|
|
|
Args:
|
2022-04-07 22:15:01 +08:00
|
|
|
hook: The name of the hook to check, defaults to all
|
2021-08-01 06:18:07 +08:00
|
|
|
"""
|
2022-04-07 22:15:01 +08:00
|
|
|
return self._callbacks[hook] if hook else self._callbacks
|
2021-08-01 06:18:07 +08:00
|
|
|
|
2022-08-21 09:47:37 +08:00
|
|
|
def run(self, hook, *args, thread=False, **kwargs):
|
2021-08-01 06:18:07 +08:00
|
|
|
"""
|
2024-01-08 08:29:14 +08:00
|
|
|
Loop through the registered actions and fire all callbacks on main thread.
|
2021-08-01 06:18:07 +08:00
|
|
|
|
2021-09-08 00:32:15 +08:00
|
|
|
Args:
|
2022-04-07 22:15:01 +08:00
|
|
|
hook: The name of the hook to check, defaults to all
|
|
|
|
args: Arguments to receive from YOLOv5
|
2022-08-21 09:47:37 +08:00
|
|
|
thread: (boolean) Run callbacks in daemon thread
|
2022-04-07 22:15:01 +08:00
|
|
|
kwargs: Keyword Arguments to receive from YOLOv5
|
2021-08-01 06:18:07 +08:00
|
|
|
"""
|
2021-09-08 00:32:15 +08:00
|
|
|
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
|
|
|
|
for logger in self._callbacks[hook]:
|
2022-08-21 09:47:37 +08:00
|
|
|
if thread:
|
2024-01-08 08:29:14 +08:00
|
|
|
threading.Thread(target=logger["callback"], args=args, kwargs=kwargs, daemon=True).start()
|
2022-08-21 09:47:37 +08:00
|
|
|
else:
|
2024-01-08 08:29:14 +08:00
|
|
|
logger["callback"](*args, **kwargs)
|