[Fix] Fix publish multiple checkpoints when using multiple GPUs (#1059) (#1070)

This commit is contained in:
Junwei Zheng 2023-04-12 04:38:48 +02:00 committed by GitHub
parent 9207e84aa0
commit d41906fa15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -7,7 +7,7 @@ from math import inf
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Union from typing import Callable, Dict, List, Optional, Sequence, Union
from mmengine.dist import is_main_process from mmengine.dist import is_main_process, master_only
from mmengine.fileio import FileClient, get_file_backend from mmengine.fileio import FileClient, get_file_backend
from mmengine.logging import print_log from mmengine.logging import print_log
from mmengine.registry import HOOKS from mmengine.registry import HOOKS
@ -347,6 +347,7 @@ class CheckpointHook(Hook):
for key, best_ckpt in self.best_ckpt_path_dict.items(): for key, best_ckpt in self.best_ckpt_path_dict.items():
self._publish_model(runner, best_ckpt) self._publish_model(runner, best_ckpt)
@master_only
def _publish_model(self, runner, ckpt_path: str) -> None: def _publish_model(self, runner, ckpt_path: str) -> None:
"""Remove unnecessary keys from ckpt_path and save the new checkpoint. """Remove unnecessary keys from ckpt_path and save the new checkpoint.