fix mem cal bug (#398)

pull/401/head
Cao Yuhang 2020-07-08 19:02:36 +08:00 committed by GitHub
parent d5cbf7eed1
commit d678986289
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 2 deletions

View File

@ -50,10 +50,11 @@ class TextLoggerHook(LoggerHook):
self._dump_log(runner.meta, runner)
def _get_max_memory(self, runner):
mem = torch.cuda.max_memory_allocated()
device = runner.model.output_device
mem = torch.cuda.max_memory_allocated(device=device)
mem_mb = torch.tensor([mem / (1024 * 1024)],
dtype=torch.int,
device=torch.device('cuda'))
device=device)
if runner.world_size > 1:
dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
return mem_mb.item()