mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove test MESA support, no signal that it's helpful so far
This commit is contained in:
parent
c7ac37693d
commit
5a58f4d3dc
21
train.py
21
train.py
@ -1005,27 +1005,6 @@ def train_one_epoch(
|
||||
with amp_autocast():
|
||||
output = model(input)
|
||||
loss = loss_fn(output, target)
|
||||
|
||||
if num_updates / num_updates_total > 0.25:
|
||||
with torch.no_grad():
|
||||
output_mesa = model_ema(input)
|
||||
|
||||
# loss_mesa = torch.nn.functional.binary_cross_entropy_with_logits(
|
||||
# output,
|
||||
# torch.sigmoid(output_mesa).detach(),
|
||||
# reduction='none',
|
||||
# ).mean()
|
||||
|
||||
# loss_mesa = loss_fn(
|
||||
# output, torch.sigmoid(output_mesa).detach())
|
||||
|
||||
loss_mesa = torch.nn.functional.kl_div(
|
||||
(output / 5).log_softmax(-1),
|
||||
(output_mesa / 5).log_softmax(-1).detach(),
|
||||
log_target=True,
|
||||
reduction='none').sum(-1).mean()
|
||||
loss += 10 * loss_mesa
|
||||
|
||||
if accum_steps > 1:
|
||||
loss /= accum_steps
|
||||
return loss
|
||||
|
Loading…
x
Reference in New Issue
Block a user