[Fix] Fix dist training infinite waiting issue (#1035)
* [#1034] fix dist training infinite waiting issue * print log_vars keys in assertion msg * linting issuepull/1103/head^2
parent
a3574192c8
commit
f8ed148fb4
|
@ -188,6 +188,17 @@ class BaseSegmentor(BaseModule, metaclass=ABCMeta):
|
|||
loss = sum(_value for _key, _value in log_vars.items()
|
||||
if 'loss' in _key)
|
||||
|
||||
# If the loss_vars has different length, raise assertion error
|
||||
# to prevent GPUs from infinite waiting.
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
log_var_length = torch.tensor(len(log_vars), device=loss.device)
|
||||
dist.all_reduce(log_var_length)
|
||||
message = (f'rank {dist.get_rank()}' +
|
||||
f' len(log_vars): {len(log_vars)}' + ' keys: ' +
|
||||
','.join(log_vars.keys()) + '\n')
|
||||
assert log_var_length == len(log_vars) * dist.get_world_size(), \
|
||||
'loss log variables are different across GPUs!\n' + message
|
||||
|
||||
log_vars['loss'] = loss
|
||||
for loss_name, loss_value in log_vars.items():
|
||||
# reduce loss when distributed training
|
||||
|
|
Loading…
Reference in New Issue