mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix LoggerHook save mutiple ranks scalar in the same json file. (#124)
* use master_only to decorator _log_train and _log_val * fix resoloved TODO fix resoloved TODO fix resoloved TODO * fix raise error typo * ensure log item is python scalar
This commit is contained in:
parent
a7961407e4
commit
4d49de7d81
@ -10,6 +10,7 @@ from typing import Any, Optional, Sequence, Tuple, Union
|
||||
import torch
|
||||
|
||||
from mmengine.data import BaseDataSample
|
||||
from mmengine.dist import master_only
|
||||
from mmengine.fileio import FileClient
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.registry import HOOKS
|
||||
@ -232,6 +233,7 @@ class LoggerHook(Hook):
|
||||
runner.logger.info((f'{local_filepath} was removed due to the '
|
||||
'`self.keep_local=False`'))
|
||||
|
||||
@master_only
|
||||
def _log_train(self, runner) -> None:
|
||||
"""Collect and record training logs which start named with "train/*".
|
||||
|
||||
@ -295,6 +297,7 @@ class LoggerHook(Hook):
|
||||
runner.writer.add_scalars(
|
||||
tag, step=runner.iter + 1, file_path=self.json_log_path)
|
||||
|
||||
@master_only
|
||||
def _log_val(self, runner) -> None:
|
||||
"""Collect and record training logs which start named with "val/*".
|
||||
|
||||
@ -402,7 +405,7 @@ class LoggerHook(Hook):
|
||||
for cfg in log_cfg:
|
||||
log_name = cfg.get('log_name', None)
|
||||
if log_name in log_names:
|
||||
raise KeyError(f'{cfg["log_name"]} cannot be Redefined in '
|
||||
raise KeyError(f'{cfg["log_name"]} cannot be redefined in '
|
||||
'log_key')
|
||||
if log_name is not None:
|
||||
log_names.add(log_name)
|
||||
@ -418,7 +421,7 @@ class LoggerHook(Hook):
|
||||
name = log_cfg.pop('log_name')
|
||||
else:
|
||||
name = log_key
|
||||
tag[name] = log_buffers[log_key].statistics(**log_cfg)
|
||||
tag[name] = log_buffers[log_key].statistics(**log_cfg).item()
|
||||
else:
|
||||
raise ValueError('The structure of `LoggerHook.custom key` is '
|
||||
'wrong, please make sure the type of each key is '
|
||||
@ -435,7 +438,6 @@ class LoggerHook(Hook):
|
||||
The maximum GPU memory occupied by tensors in megabytes for a given
|
||||
device.
|
||||
"""
|
||||
# TODO use `mmengine.dist.max_memory_allocated` to count mem_mb
|
||||
device = getattr(runner.model, 'output_device', None)
|
||||
mem = torch.cuda.max_memory_allocated(device=device)
|
||||
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
|
||||
|
Loading…
x
Reference in New Issue
Block a user