diff --git a/ppcls/loss/emlloss.py b/ppcls/loss/emlloss.py index 973570389..38b707fe1 100644 --- a/ppcls/loss/emlloss.py +++ b/ppcls/loss/emlloss.py @@ -23,6 +23,11 @@ from .comfunc import rerange_index class EmlLoss(paddle.nn.Layer): + """Ensemble Metric Learning Loss + paper: [Large Scale Strongly Supervised Ensemble Metric Learning, with Applications to Face Verification and Retrieval](https://arxiv.org/pdf/1212.6094.pdf) + code reference: https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/metric_learning/losses/emlloss.py + """ + def __init__(self, batch_size=40, samples_each_class=2): super(EmlLoss, self).__init__() assert (batch_size % samples_each_class == 0)