mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
automaticaly update iter and epoch in message_hub (#168)
* automatic update iter and epoch in message_hub * add docstring * Update comment and docstring * Fix as comment * Fix docstring and comment * refine comments
This commit is contained in:
parent
53101a1ab1
commit
45567b1d1c
@ -60,7 +60,7 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||||||
self.run_iter(idx, data_batch)
|
self.run_iter(idx, data_batch)
|
||||||
|
|
||||||
self.runner.call_hook('after_train_epoch')
|
self.runner.call_hook('after_train_epoch')
|
||||||
self.runner._epoch += 1
|
self.runner.epoch += 1
|
||||||
|
|
||||||
def run_iter(self, idx,
|
def run_iter(self, idx,
|
||||||
data_batch: Sequence[Tuple[Any, BaseDataElement]]) -> None:
|
data_batch: Sequence[Tuple[Any, BaseDataElement]]) -> None:
|
||||||
@ -85,7 +85,7 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||||||
data_batch=data_batch,
|
data_batch=data_batch,
|
||||||
outputs=self.runner.outputs)
|
outputs=self.runner.outputs)
|
||||||
|
|
||||||
self.runner._iter += 1
|
self.runner.iter += 1
|
||||||
|
|
||||||
|
|
||||||
@LOOPS.register_module()
|
@LOOPS.register_module()
|
||||||
@ -154,7 +154,7 @@ class IterBasedTrainLoop(BaseLoop):
|
|||||||
batch_idx=self.runner._iter,
|
batch_idx=self.runner._iter,
|
||||||
data_batch=data_batch,
|
data_batch=data_batch,
|
||||||
outputs=self.runner.outputs)
|
outputs=self.runner.outputs)
|
||||||
self.runner._iter += 1
|
self.runner.iter += 1
|
||||||
|
|
||||||
|
|
||||||
@LOOPS.register_module()
|
@LOOPS.register_module()
|
||||||
|
@ -311,7 +311,15 @@ class Runner:
|
|||||||
self._experiment_name = self.timestamp
|
self._experiment_name = self.timestamp
|
||||||
|
|
||||||
self.logger = self.build_logger(log_level=log_level)
|
self.logger = self.build_logger(log_level=log_level)
|
||||||
# message hub used for component interaction
|
# Build `message_hub` for communication among components.
|
||||||
|
# `message_hub` can store log scalars (loss, learning rate) and
|
||||||
|
# runtime information (iter and epoch). Those components that do not
|
||||||
|
# have access to the runner can get iteration or epoch information
|
||||||
|
# from `message_hub`. For example, models can get the latest created
|
||||||
|
# `message_hub` by
|
||||||
|
# `self.message_hub=MessageHub.get_current_instance()` and then get
|
||||||
|
# current epoch by `cur_epoch = self.message_hub.get_info('epoch')`.
|
||||||
|
# See `MessageHub` and `ManagerMixin` for more details.
|
||||||
self.message_hub = self.build_message_hub()
|
self.message_hub = self.build_message_hub()
|
||||||
# writer used for writing log or visualizing all kinds of data
|
# writer used for writing log or visualizing all kinds of data
|
||||||
self.writer = self.build_writer(writer)
|
self.writer = self.build_writer(writer)
|
||||||
@ -407,11 +415,26 @@ class Runner:
|
|||||||
"""int: Current epoch."""
|
"""int: Current epoch."""
|
||||||
return self._epoch
|
return self._epoch
|
||||||
|
|
||||||
|
@epoch.setter
|
||||||
|
def epoch(self, epoch: int):
|
||||||
|
"""Update epoch and synchronize epoch in :attr:`message_hub`."""
|
||||||
|
self._epoch = epoch
|
||||||
|
# To allow components that cannot access runner to get current epoch.
|
||||||
|
self.message_hub.update_info('epoch', epoch)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def iter(self):
|
def iter(self):
|
||||||
"""int: Current epoch."""
|
"""int: Current iteration."""
|
||||||
return self._iter
|
return self._iter
|
||||||
|
|
||||||
|
@iter.setter
|
||||||
|
def iter(self, iter: int):
|
||||||
|
"""Update iter and synchronize iter in :attr:`message_hub`."""
|
||||||
|
self._iter = iter
|
||||||
|
# To allow components that cannot access runner to get current
|
||||||
|
# iteration.
|
||||||
|
self.message_hub.update_info('iter', iter)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def launcher(self):
|
def launcher(self):
|
||||||
"""str: Way to launcher multi processes."""
|
"""str: Way to launcher multi processes."""
|
||||||
|
@ -754,6 +754,9 @@ class TestRunner(TestCase):
|
|||||||
|
|
||||||
assert isinstance(runner.train_loop, EpochBasedTrainLoop)
|
assert isinstance(runner.train_loop, EpochBasedTrainLoop)
|
||||||
|
|
||||||
|
assert runner.iter == runner.message_hub.get_info('iter')
|
||||||
|
assert runner.epoch == runner.message_hub.get_info('epoch')
|
||||||
|
|
||||||
for result, target, in zip(epoch_results, epoch_targets):
|
for result, target, in zip(epoch_results, epoch_targets):
|
||||||
self.assertEqual(result, target)
|
self.assertEqual(result, target)
|
||||||
for result, target, in zip(iter_results, iter_targets):
|
for result, target, in zip(iter_results, iter_targets):
|
||||||
@ -786,6 +789,7 @@ class TestRunner(TestCase):
|
|||||||
runner.train()
|
runner.train()
|
||||||
|
|
||||||
assert isinstance(runner.train_loop, IterBasedTrainLoop)
|
assert isinstance(runner.train_loop, IterBasedTrainLoop)
|
||||||
|
assert runner.iter == runner.message_hub.get_info('iter')
|
||||||
|
|
||||||
self.assertEqual(len(epoch_results), 1)
|
self.assertEqual(len(epoch_results), 1)
|
||||||
self.assertEqual(epoch_results[0], 0)
|
self.assertEqual(epoch_results[0], 0)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user