diff --git a/train.py b/train.py index ff11622b..95b4a985 100755 --- a/train.py +++ b/train.py @@ -1042,8 +1042,7 @@ def train_one_epoch( loss = _forward() _backward(loss) - if not args.distributed: - losses_m.update(loss.item() * accum_steps, input.size(0)) + losses_m.update(loss.item() * accum_steps, input.size(0)) update_sample_count += input.size(0) if not need_update: @@ -1068,16 +1067,18 @@ def train_one_epoch( lrl = [param_group['lr'] for param_group in optimizer.param_groups] lr = sum(lrl) / len(lrl) + loss_avg, loss_now = losses_m.avg, losses_m.val if args.distributed: - reduced_loss = utils.reduce_tensor(loss.data, args.world_size) - losses_m.update(reduced_loss.item() * accum_steps, input.size(0)) + # synchronize current step and avg loss, each process keeps its own running avg + loss_avg = utils.reduce_tensor(loss.new([loss_avg]), args.world_size).item() + loss_now = utils.reduce_tensor(loss.new([loss_now]), args.world_size).item() update_sample_count *= args.world_size if utils.is_primary(args): _logger.info( f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} ' f'({100. * (update_idx + 1) / updates_per_epoch:>3.0f}%)] ' - f'Loss: {losses_m.val:#.3g} ({losses_m.avg:#.3g}) ' + f'Loss: {loss_now:#.3g} ({loss_avg:#.3g}) ' f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s ' f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) ' f'LR: {lr:.3e} ' @@ -1106,7 +1107,12 @@ def train_one_epoch( if hasattr(optimizer, 'sync_lookahead'): optimizer.sync_lookahead() - return OrderedDict([('loss', losses_m.avg)]) + loss_avg = losses_m.avg + if args.distributed: + # synchronize avg loss, each process keeps its own running avg + loss_avg = torch.tensor([loss_avg], device=device, dtype=torch.float32) + loss_avg = utils.reduce_tensor(loss_avg, args.world_size).item() + return OrderedDict([('loss', loss_avg)]) def validate(