add support for unlabel training (#2103)
parent
5cc6c50ce6
commit
9cdbdca4ee
ppcls/loss
|
@ -236,8 +236,13 @@ class DistillationDKDLoss(DKDLoss):
|
|||
temperature=1.0,
|
||||
alpha=1.0,
|
||||
beta=1.0,
|
||||
use_target_as_gt=False,
|
||||
name="loss_dkd"):
|
||||
super().__init__(temperature=temperature, alpha=alpha, beta=beta)
|
||||
super().__init__(
|
||||
temperature=temperature,
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
use_target_as_gt=use_target_as_gt)
|
||||
self.key = key
|
||||
self.model_name_pairs = model_name_pairs
|
||||
self.name = name
|
||||
|
|
|
@ -10,13 +10,20 @@ class DKDLoss(nn.Layer):
|
|||
Code was heavily based on https://github.com/megvii-research/mdistiller
|
||||
"""
|
||||
|
||||
def __init__(self, temperature=1.0, alpha=1.0, beta=1.0):
|
||||
def __init__(self,
|
||||
temperature=1.0,
|
||||
alpha=1.0,
|
||||
beta=1.0,
|
||||
use_target_as_gt=False):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.use_target_as_gt = use_target_as_gt
|
||||
|
||||
def forward(self, logits_student, logits_teacher, target):
|
||||
def forward(self, logits_student, logits_teacher, target=None):
|
||||
if target is None or self.use_target_as_gt:
|
||||
target = logits_teacher.argmax(axis=-1)
|
||||
gt_mask = _get_gt_mask(logits_student, target)
|
||||
other_mask = 1 - gt_mask
|
||||
pred_student = F.softmax(logits_student / self.temperature, axis=1)
|
||||
|
|
Loading…
Reference in New Issue