mirror of https://github.com/open-mmlab/mmcv.git
commit
bef7c13a95
|
@ -1,6 +1,7 @@
|
|||
import datetime
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from .base import LoggerHook
|
||||
|
||||
|
@ -39,9 +40,12 @@ class TextLoggerHook(LoggerHook):
|
|||
# statistic memory
|
||||
if runner.mode == 'train' and torch.cuda.is_available():
|
||||
mem = torch.cuda.max_memory_allocated()
|
||||
mem_mb = int(mem / (1024 * 1024))
|
||||
mem_str = 'memory: {}, '.format(mem_mb)
|
||||
log_str += mem_str
|
||||
mem_mb = torch.tensor([mem / (1024 * 1024)],
|
||||
dtype=torch.int,
|
||||
device=torch.device('cuda'))
|
||||
if runner.world_size > 1:
|
||||
dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
|
||||
log_str += 'memory: {}, '.format(mem_mb.item())
|
||||
log_items = []
|
||||
for name, val in runner.log_buffer.output.items():
|
||||
if name in ['time', 'data_time']:
|
||||
|
|
Loading…
Reference in New Issue