mirror of https://github.com/JDAI-CV/fast-reid.git
Fix minor bug in build criterion, it will replace by multiple call
Refactor resnet pretrainpull/43/head
parent
b020c7f0ae
commit
bab602dfd2
|
@ -167,16 +167,11 @@ def build_resnet_backbone(cfg):
|
|||
if pretrain:
|
||||
if not with_ibn:
|
||||
# original resnet
|
||||
# state_dict = torch.load(pretrain_path)['model_ema']
|
||||
state_dict = model_zoo.load_url(model_urls[depth])
|
||||
# remove fully-connected-layers
|
||||
state_dict.pop('fc.weight')
|
||||
state_dict.pop('fc.bias')
|
||||
else:
|
||||
# ibn resnet
|
||||
state_dict = torch.load(pretrain_path)['state_dict']
|
||||
# remove fully-connected-layers
|
||||
state_dict.pop('module.fc.weight')
|
||||
state_dict.pop('module.fc.bias')
|
||||
# remove module in name
|
||||
new_state_dict = {}
|
||||
for k in state_dict:
|
||||
|
|
|
@ -23,9 +23,9 @@ def build_criterion(cfg):
|
|||
|
||||
loss_names = cfg.MODEL.LOSSES.NAME
|
||||
loss_funcs = [LOSS_REGISTRY.get(loss_name)(cfg) for loss_name in loss_names]
|
||||
loss_dict = {}
|
||||
|
||||
def criterion(*args):
|
||||
loss_dict = {}
|
||||
for loss_func in loss_funcs:
|
||||
loss = loss_func(*args)
|
||||
loss_dict.update(loss)
|
||||
|
|
Loading…
Reference in New Issue