automaticaly update iter and epoch in message_hub (#168)

* automatic update iter and epoch in message_hub

* add docstring

* Update comment and docstring

* Fix as comment

* Fix docstring and comment

* refine comments
This commit is contained in:
Mashiro 2022-04-21 11:45:03 +08:00 committed by GitHub
parent 53101a1ab1
commit 45567b1d1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 5 deletions

View File

@ -60,7 +60,7 @@ class EpochBasedTrainLoop(BaseLoop):
self.run_iter(idx, data_batch) self.run_iter(idx, data_batch)
self.runner.call_hook('after_train_epoch') self.runner.call_hook('after_train_epoch')
self.runner._epoch += 1 self.runner.epoch += 1
def run_iter(self, idx, def run_iter(self, idx,
data_batch: Sequence[Tuple[Any, BaseDataElement]]) -> None: data_batch: Sequence[Tuple[Any, BaseDataElement]]) -> None:
@ -85,7 +85,7 @@ class EpochBasedTrainLoop(BaseLoop):
data_batch=data_batch, data_batch=data_batch,
outputs=self.runner.outputs) outputs=self.runner.outputs)
self.runner._iter += 1 self.runner.iter += 1
@LOOPS.register_module() @LOOPS.register_module()
@ -154,7 +154,7 @@ class IterBasedTrainLoop(BaseLoop):
batch_idx=self.runner._iter, batch_idx=self.runner._iter,
data_batch=data_batch, data_batch=data_batch,
outputs=self.runner.outputs) outputs=self.runner.outputs)
self.runner._iter += 1 self.runner.iter += 1
@LOOPS.register_module() @LOOPS.register_module()

View File

@ -311,7 +311,15 @@ class Runner:
self._experiment_name = self.timestamp self._experiment_name = self.timestamp
self.logger = self.build_logger(log_level=log_level) self.logger = self.build_logger(log_level=log_level)
# message hub used for component interaction # Build `message_hub` for communication among components.
# `message_hub` can store log scalars (loss, learning rate) and
# runtime information (iter and epoch). Those components that do not
# have access to the runner can get iteration or epoch information
# from `message_hub`. For example, models can get the latest created
# `message_hub` by
# `self.message_hub=MessageHub.get_current_instance()` and then get
# current epoch by `cur_epoch = self.message_hub.get_info('epoch')`.
# See `MessageHub` and `ManagerMixin` for more details.
self.message_hub = self.build_message_hub() self.message_hub = self.build_message_hub()
# writer used for writing log or visualizing all kinds of data # writer used for writing log or visualizing all kinds of data
self.writer = self.build_writer(writer) self.writer = self.build_writer(writer)
@ -407,11 +415,26 @@ class Runner:
"""int: Current epoch.""" """int: Current epoch."""
return self._epoch return self._epoch
@epoch.setter
def epoch(self, epoch: int):
"""Update epoch and synchronize epoch in :attr:`message_hub`."""
self._epoch = epoch
# To allow components that cannot access runner to get current epoch.
self.message_hub.update_info('epoch', epoch)
@property @property
def iter(self): def iter(self):
"""int: Current epoch.""" """int: Current iteration."""
return self._iter return self._iter
@iter.setter
def iter(self, iter: int):
"""Update iter and synchronize iter in :attr:`message_hub`."""
self._iter = iter
# To allow components that cannot access runner to get current
# iteration.
self.message_hub.update_info('iter', iter)
@property @property
def launcher(self): def launcher(self):
"""str: Way to launcher multi processes.""" """str: Way to launcher multi processes."""

View File

@ -754,6 +754,9 @@ class TestRunner(TestCase):
assert isinstance(runner.train_loop, EpochBasedTrainLoop) assert isinstance(runner.train_loop, EpochBasedTrainLoop)
assert runner.iter == runner.message_hub.get_info('iter')
assert runner.epoch == runner.message_hub.get_info('epoch')
for result, target, in zip(epoch_results, epoch_targets): for result, target, in zip(epoch_results, epoch_targets):
self.assertEqual(result, target) self.assertEqual(result, target)
for result, target, in zip(iter_results, iter_targets): for result, target, in zip(iter_results, iter_targets):
@ -786,6 +789,7 @@ class TestRunner(TestCase):
runner.train() runner.train()
assert isinstance(runner.train_loop, IterBasedTrainLoop) assert isinstance(runner.train_loop, IterBasedTrainLoop)
assert runner.iter == runner.message_hub.get_info('iter')
self.assertEqual(len(epoch_results), 1) self.assertEqual(len(epoch_results), 1)
self.assertEqual(epoch_results[0], 0) self.assertEqual(epoch_results[0], 0)