mirror of https://github.com/JDAI-CV/fast-reid.git
Fix minor bug in build criterion, it will replace by multiple call
Refactor resnet pretrainpull/43/head
parent
b020c7f0ae
commit
bab602dfd2
|
@ -167,16 +167,11 @@ def build_resnet_backbone(cfg):
|
||||||
if pretrain:
|
if pretrain:
|
||||||
if not with_ibn:
|
if not with_ibn:
|
||||||
# original resnet
|
# original resnet
|
||||||
|
# state_dict = torch.load(pretrain_path)['model_ema']
|
||||||
state_dict = model_zoo.load_url(model_urls[depth])
|
state_dict = model_zoo.load_url(model_urls[depth])
|
||||||
# remove fully-connected-layers
|
|
||||||
state_dict.pop('fc.weight')
|
|
||||||
state_dict.pop('fc.bias')
|
|
||||||
else:
|
else:
|
||||||
# ibn resnet
|
# ibn resnet
|
||||||
state_dict = torch.load(pretrain_path)['state_dict']
|
state_dict = torch.load(pretrain_path)['state_dict']
|
||||||
# remove fully-connected-layers
|
|
||||||
state_dict.pop('module.fc.weight')
|
|
||||||
state_dict.pop('module.fc.bias')
|
|
||||||
# remove module in name
|
# remove module in name
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
for k in state_dict:
|
for k in state_dict:
|
||||||
|
|
|
@ -23,9 +23,9 @@ def build_criterion(cfg):
|
||||||
|
|
||||||
loss_names = cfg.MODEL.LOSSES.NAME
|
loss_names = cfg.MODEL.LOSSES.NAME
|
||||||
loss_funcs = [LOSS_REGISTRY.get(loss_name)(cfg) for loss_name in loss_names]
|
loss_funcs = [LOSS_REGISTRY.get(loss_name)(cfg) for loss_name in loss_names]
|
||||||
loss_dict = {}
|
|
||||||
|
|
||||||
def criterion(*args):
|
def criterion(*args):
|
||||||
|
loss_dict = {}
|
||||||
for loss_func in loss_funcs:
|
for loss_func in loss_funcs:
|
||||||
loss = loss_func(*args)
|
loss = loss_func(*args)
|
||||||
loss_dict.update(loss)
|
loss_dict.update(loss)
|
||||||
|
|
Loading…
Reference in New Issue