diff --git a/mmcv/runner/hooks/logger/text.py b/mmcv/runner/hooks/logger/text.py index 866016470..d78227299 100644 --- a/mmcv/runner/hooks/logger/text.py +++ b/mmcv/runner/hooks/logger/text.py @@ -50,10 +50,11 @@ class TextLoggerHook(LoggerHook): self._dump_log(runner.meta, runner) def _get_max_memory(self, runner): - mem = torch.cuda.max_memory_allocated() + device = runner.model.output_device + mem = torch.cuda.max_memory_allocated(device=device) mem_mb = torch.tensor([mem / (1024 * 1024)], dtype=torch.int, - device=torch.device('cuda')) + device=device) if runner.world_size > 1: dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX) return mem_mb.item()