mirror of https://github.com/open-mmlab/mmcv.git
fix bug that stuck at evaluation (#53)
* fix bug that stuck at evaluation * remove mode variablepull/54/head
parent
bef7c13a95
commit
b85136772d
|
@ -16,6 +16,15 @@ class TextLoggerHook(LoggerHook):
|
|||
super(TextLoggerHook, self).before_run(runner)
|
||||
self.start_iter = runner.iter
|
||||
|
||||
def _get_max_memory(self, runner):
|
||||
mem = torch.cuda.max_memory_allocated()
|
||||
mem_mb = torch.tensor([mem / (1024 * 1024)],
|
||||
dtype=torch.int,
|
||||
device=torch.device('cuda'))
|
||||
if runner.world_size > 1:
|
||||
dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
|
||||
return mem_mb
|
||||
|
||||
def log(self, runner):
|
||||
if runner.mode == 'train':
|
||||
lr_str = ', '.join(
|
||||
|
@ -38,13 +47,9 @@ class TextLoggerHook(LoggerHook):
|
|||
'time: {log[time]:.3f}, data_time: {log[data_time]:.3f}, '.
|
||||
format(log=runner.log_buffer.output))
|
||||
# statistic memory
|
||||
if runner.mode == 'train' and torch.cuda.is_available():
|
||||
mem = torch.cuda.max_memory_allocated()
|
||||
mem_mb = torch.tensor([mem / (1024 * 1024)],
|
||||
dtype=torch.int,
|
||||
device=torch.device('cuda'))
|
||||
if runner.world_size > 1:
|
||||
dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
|
||||
# training mode if the output contains the key "time"
|
||||
if 'time' in runner.log_buffer.output and torch.cuda.is_available():
|
||||
mem_mb = self._get_max_memory(runner)
|
||||
log_str += 'memory: {}, '.format(mem_mb.item())
|
||||
log_items = []
|
||||
for name, val in runner.log_buffer.output.items():
|
||||
|
|
Loading…
Reference in New Issue