mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Call SyncBufferHook before validation in IterBasedTrainLoop (#982)
* [Fix] Call SyncBufferHook before validation in IterBasedTrainLoop * Add before_val_epoch in SyncBuffersHook * Fix white space format * Add comments for SyncBuffersHook * Add comments for SyncBuffersHook Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Add comments for SyncBuffersHook Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Fix white space format * Add before_test_epoch * Remove before_test_epoch --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
parent
0e5f9da68b
commit
6ebb6f838a
@ -13,6 +13,24 @@ class SyncBuffersHook(Hook):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.distributed = is_distributed()
|
||||
# A flag to mark whether synchronization has been done in
|
||||
# after_train_epoch
|
||||
self.called_in_train = False
|
||||
|
||||
def before_val_epoch(self, runner) -> None:
|
||||
"""All-reduce model buffers before each validation epoch.
|
||||
|
||||
Synchronize the buffers before each validation if they have not been
|
||||
synchronized at the end of the previous training epoch. This method
|
||||
will be called when using IterBasedTrainLoop.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
if self.distributed:
|
||||
if not self.called_in_train:
|
||||
all_reduce_params(runner.model.buffers(), op='mean')
|
||||
self.called_in_train = False
|
||||
|
||||
def after_train_epoch(self, runner) -> None:
|
||||
"""All-reduce model buffers at the end of each epoch.
|
||||
@ -22,3 +40,4 @@ class SyncBuffersHook(Hook):
|
||||
"""
|
||||
if self.distributed:
|
||||
all_reduce_params(runner.model.buffers(), op='mean')
|
||||
self.called_in_train = True
|
||||
|
Loading…
x
Reference in New Issue
Block a user