diff --git a/train.py b/train.py index 39e889ed..539dff3d 100755 --- a/train.py +++ b/train.py @@ -1005,27 +1005,6 @@ def train_one_epoch( with amp_autocast(): output = model(input) loss = loss_fn(output, target) - - if num_updates / num_updates_total > 0.25: - with torch.no_grad(): - output_mesa = model_ema(input) - - # loss_mesa = torch.nn.functional.binary_cross_entropy_with_logits( - # output, - # torch.sigmoid(output_mesa).detach(), - # reduction='none', - # ).mean() - - # loss_mesa = loss_fn( - # output, torch.sigmoid(output_mesa).detach()) - - loss_mesa = torch.nn.functional.kl_div( - (output / 5).log_softmax(-1), - (output_mesa / 5).log_softmax(-1).detach(), - log_target=True, - reduction='none').sum(-1).mean() - loss += 10 * loss_mesa - if accum_steps > 1: loss /= accum_steps return loss