mirror of https://github.com/open-mmlab/mmcv.git
fix mem cal bug (#398)
parent
d5cbf7eed1
commit
d678986289
|
@ -50,10 +50,11 @@ class TextLoggerHook(LoggerHook):
|
|||
self._dump_log(runner.meta, runner)
|
||||
|
||||
def _get_max_memory(self, runner):
|
||||
mem = torch.cuda.max_memory_allocated()
|
||||
device = runner.model.output_device
|
||||
mem = torch.cuda.max_memory_allocated(device=device)
|
||||
mem_mb = torch.tensor([mem / (1024 * 1024)],
|
||||
dtype=torch.int,
|
||||
device=torch.device('cuda'))
|
||||
device=device)
|
||||
if runner.world_size > 1:
|
||||
dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
|
||||
return mem_mb.item()
|
||||
|
|
Loading…
Reference in New Issue