fix bug that stuck at evaluation (#53)

* fix bug that stuck at evaluation

* remove mode variable
pull/54/head
Cao Yuhang 2019-04-15 13:28:05 +08:00 committed by Kai Chen
parent bef7c13a95
commit b85136772d
1 changed files with 12 additions and 7 deletions

View File

@ -16,6 +16,15 @@ class TextLoggerHook(LoggerHook):
super(TextLoggerHook, self).before_run(runner)
self.start_iter = runner.iter
def _get_max_memory(self, runner):
mem = torch.cuda.max_memory_allocated()
mem_mb = torch.tensor([mem / (1024 * 1024)],
dtype=torch.int,
device=torch.device('cuda'))
if runner.world_size > 1:
dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
return mem_mb
def log(self, runner):
if runner.mode == 'train':
lr_str = ', '.join(
@ -38,13 +47,9 @@ class TextLoggerHook(LoggerHook):
'time: {log[time]:.3f}, data_time: {log[data_time]:.3f}, '.
format(log=runner.log_buffer.output))
# statistic memory
if runner.mode == 'train' and torch.cuda.is_available():
mem = torch.cuda.max_memory_allocated()
mem_mb = torch.tensor([mem / (1024 * 1024)],
dtype=torch.int,
device=torch.device('cuda'))
if runner.world_size > 1:
dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
# training mode if the output contains the key "time"
if 'time' in runner.log_buffer.output and torch.cuda.is_available():
mem_mb = self._get_max_memory(runner)
log_str += 'memory: {}, '.format(mem_mb.item())
log_items = []
for name, val in runner.log_buffer.output.items():