Merge pull request #47 from yhcao6/memory_statistic

reduce max memory in dist training
pull/53/head
Kai Chen 2019-04-10 22:38:45 -07:00 committed by GitHub
commit bef7c13a95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 3 deletions

View File

@ -1,6 +1,7 @@
import datetime
import torch
import torch.distributed as dist
from .base import LoggerHook
@ -39,9 +40,12 @@ class TextLoggerHook(LoggerHook):
# statistic memory
if runner.mode == 'train' and torch.cuda.is_available():
mem = torch.cuda.max_memory_allocated()
mem_mb = int(mem / (1024 * 1024))
mem_str = 'memory: {}, '.format(mem_mb)
log_str += mem_str
mem_mb = torch.tensor([mem / (1024 * 1024)],
dtype=torch.int,
device=torch.device('cuda'))
if runner.world_size > 1:
dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
log_str += 'memory: {}, '.format(mem_mb.item())
log_items = []
for name, val in runner.log_buffer.output.items():
if name in ['time', 'data_time']: