mirror of https://github.com/facebookresearch/deit
Add comments in the code
parent
806fb71d37
commit
ae4dba9b45
|
@ -53,10 +53,16 @@ class DistillationLoss(torch.nn.Module):
|
|||
# with slight modifications
|
||||
distillation_loss = F.kl_div(
|
||||
F.log_softmax(outputs_kd / T, dim=1),
|
||||
#We provide the teacher's targets in log probability because we use log_target=True
|
||||
#(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719)
|
||||
#but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both.
|
||||
F.log_softmax(teacher_outputs / T, dim=1),
|
||||
reduction='sum',
|
||||
log_target=True
|
||||
) * (T * T) / outputs_kd.numel()
|
||||
#We divide by outputs_kd.numel() to have the legacy PyTorch behavior.
|
||||
#But we also experiments output_kd.size(0)
|
||||
#see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details
|
||||
elif self.distillation_type == 'hard':
|
||||
distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
|
||||
|
||||
|
|
Loading…
Reference in New Issue