From b85136772d4a78a3cdc7779301f984ad0dd21235 Mon Sep 17 00:00:00 2001 From: Cao Yuhang Date: Mon, 15 Apr 2019 13:28:05 +0800 Subject: [PATCH] fix bug that stuck at evaluation (#53) * fix bug that stuck at evaluation * remove mode variable --- mmcv/runner/hooks/logger/text.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/mmcv/runner/hooks/logger/text.py b/mmcv/runner/hooks/logger/text.py index 832ffee95..eba0d2a53 100644 --- a/mmcv/runner/hooks/logger/text.py +++ b/mmcv/runner/hooks/logger/text.py @@ -16,6 +16,15 @@ class TextLoggerHook(LoggerHook): super(TextLoggerHook, self).before_run(runner) self.start_iter = runner.iter + def _get_max_memory(self, runner): + mem = torch.cuda.max_memory_allocated() + mem_mb = torch.tensor([mem / (1024 * 1024)], + dtype=torch.int, + device=torch.device('cuda')) + if runner.world_size > 1: + dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX) + return mem_mb + def log(self, runner): if runner.mode == 'train': lr_str = ', '.join( @@ -38,13 +47,9 @@ class TextLoggerHook(LoggerHook): 'time: {log[time]:.3f}, data_time: {log[data_time]:.3f}, '. format(log=runner.log_buffer.output)) # statistic memory - if runner.mode == 'train' and torch.cuda.is_available(): - mem = torch.cuda.max_memory_allocated() - mem_mb = torch.tensor([mem / (1024 * 1024)], - dtype=torch.int, - device=torch.device('cuda')) - if runner.world_size > 1: - dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX) + # training mode if the output contains the key "time" + if 'time' in runner.log_buffer.output and torch.cuda.is_available(): + mem_mb = self._get_max_memory(runner) log_str += 'memory: {}, '.format(mem_mb.item()) log_items = [] for name, val in runner.log_buffer.output.items():