diff --git a/train.py b/train.py index b20b7dbb2..510377e11 100644 --- a/train.py +++ b/train.py @@ -352,6 +352,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % ( f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])) callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots, opt.sync_bn) + if callbacks.stop_training: + return # end batch ------------------------------------------------------------------------------------------------ # Scheduler diff --git a/utils/callbacks.py b/utils/callbacks.py index 13d82ebc2..c51c268f2 100644 --- a/utils/callbacks.py +++ b/utils/callbacks.py @@ -35,6 +35,7 @@ class Callbacks: 'on_params_update': [], 'teardown': [], } + self.stop_training = False # set True to interrupt training def register_action(self, hook, name='', callback=None): """