mirror of https://github.com/JDAI-CV/fast-reid.git
update pairwise circle loss
Summary: add param of pairwise circle loss to config, and update pairwise circle loss versionpull/150/head
parent
96b9ad2d99
commit
cbdc01a1c3
|
@ -97,6 +97,11 @@ _C.MODEL.LOSSES.TRI.HARD_MINING = True
|
|||
_C.MODEL.LOSSES.TRI.USE_COSINE_DIST = False
|
||||
_C.MODEL.LOSSES.TRI.SCALE = 1.0
|
||||
|
||||
# Circle Loss options
|
||||
_C.MODEL.LOSSES.CIRCLE = CN()
|
||||
_C.MODEL.LOSSES.CIRCLE.MARGIN = 0.25
|
||||
_C.MODEL.LOSSES.CIRCLE.SCALE = 128
|
||||
|
||||
# Focal Loss options
|
||||
_C.MODEL.LOSSES.FL = CN()
|
||||
_C.MODEL.LOSSES.FL.ALPHA = 0.25
|
||||
|
|
|
@ -164,10 +164,10 @@ class TripletLoss(object):
|
|||
|
||||
class CircleLoss(object):
|
||||
def __init__(self, cfg):
|
||||
self._scale = cfg.MODEL.LOSSES.SCALE_TRI
|
||||
self._scale = cfg.MODEL.LOSSES.CIRCLE.SCALE
|
||||
|
||||
self.m = 0.25
|
||||
self.s = 128
|
||||
self.m = cfg.MODEL.LOSSES.CIRCLE.MARGIN
|
||||
self.s = cfg.MODEL.LOSSES.CIRCLE.SCALE
|
||||
|
||||
def __call__(self, _, global_features, targets):
|
||||
global_features = normalize(global_features, axis=-1)
|
||||
|
|
Loading…
Reference in New Issue