mirror of https://github.com/JDAI-CV/fast-reid.git
34 lines
829 B
Python
34 lines
829 B
Python
# encoding: utf-8
|
|
"""
|
|
@author: l1aoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
from ...utils.registry import Registry
|
|
|
|
LOSS_REGISTRY = Registry("LOSS")
|
|
LOSS_REGISTRY.__doc__ = """
|
|
Registry for loss, which extract feature maps from images
|
|
The registered object must be a callable that accepts two arguments:
|
|
It must returns an instance of :class:`Loss`.
|
|
"""
|
|
|
|
|
|
def build_criterion(cfg):
|
|
"""
|
|
Build a loss from `cfg.MODEL.BACKBONE.NAME`.
|
|
Returns:
|
|
an instance of :class:`Loss`
|
|
"""
|
|
|
|
loss_names = cfg.MODEL.LOSSES.NAME
|
|
loss_funcs = [LOSS_REGISTRY.get(loss_name)(cfg) for loss_name in loss_names]
|
|
|
|
def criterion(*args):
|
|
loss_dict = {}
|
|
for loss_func in loss_funcs:
|
|
loss = loss_func(*args)
|
|
loss_dict.update(loss)
|
|
return loss_dict
|
|
return criterion
|