Remove test MESA support, no signal that it's helpful so far

This commit is contained in:
Ross Wightman 2024-02-10 14:38:01 -08:00
parent c7ac37693d
commit 5a58f4d3dc

View File

@ -1005,27 +1005,6 @@ def train_one_epoch(
with amp_autocast():
output = model(input)
loss = loss_fn(output, target)
if num_updates / num_updates_total > 0.25:
with torch.no_grad():
output_mesa = model_ema(input)
# loss_mesa = torch.nn.functional.binary_cross_entropy_with_logits(
# output,
# torch.sigmoid(output_mesa).detach(),
# reduction='none',
# ).mean()
# loss_mesa = loss_fn(
# output, torch.sigmoid(output_mesa).detach())
loss_mesa = torch.nn.functional.kl_div(
(output / 5).log_softmax(-1),
(output_mesa / 5).log_softmax(-1).detach(),
log_target=True,
reduction='none').sum(-1).mean()
loss += 10 * loss_mesa
if accum_steps > 1:
loss /= accum_steps
return loss