mirror of https://github.com/JDAI-CV/fast-reid.git
style(backbone): make parameters loading logging more elegant
parent
0b15ac4e03
commit
8ab0bc2455
|
@ -16,6 +16,8 @@ from fastreid.layers import (
|
|||
get_norm,
|
||||
)
|
||||
|
||||
from fastreid.utils.checkpoint import get_unexpected_parameters_message, get_missing_parameters_message
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
|
||||
_url_format = 'https://hangzh.s3.amazonaws.com/encoding/models/{}-{}.pth'
|
||||
|
@ -396,8 +398,14 @@ def build_resnest_backbone(cfg):
|
|||
# if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
|
||||
# new_state_dict[new_k] = state_dict[k]
|
||||
# state_dict = new_state_dict
|
||||
res = model.load_state_dict(state_dict, strict=False)
|
||||
incompatible = model.load_state_dict(state_dict, strict=False)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info('missing keys is {}'.format(res.missing_keys))
|
||||
logger.info('unexpected keys is {}'.format(res.unexpected_keys))
|
||||
if incompatible.missing_keys:
|
||||
logger.info(
|
||||
get_missing_parameters_message(incompatible.missing_keys)
|
||||
)
|
||||
if incompatible.unexpected_keys:
|
||||
logger.info(
|
||||
get_unexpected_parameters_message(incompatible.unexpected_keys)
|
||||
)
|
||||
return model
|
||||
|
|
|
@ -18,6 +18,8 @@ from fastreid.layers import (
|
|||
get_norm,
|
||||
)
|
||||
|
||||
from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
|
||||
model_urls = {
|
||||
|
@ -229,8 +231,15 @@ def build_resnet_backbone(cfg):
|
|||
if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
|
||||
new_state_dict[new_k] = state_dict[k]
|
||||
state_dict = new_state_dict
|
||||
res = model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
incompatible = model.load_state_dict(state_dict, strict=False)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info('missing keys is {}'.format(res.missing_keys))
|
||||
logger.info('unexpected keys is {}'.format(res.unexpected_keys))
|
||||
if incompatible.missing_keys:
|
||||
logger.info(
|
||||
get_missing_parameters_message(incompatible.missing_keys)
|
||||
)
|
||||
if incompatible.unexpected_keys:
|
||||
logger.info(
|
||||
get_unexpected_parameters_message(incompatible.unexpected_keys)
|
||||
)
|
||||
return model
|
||||
|
|
Loading…
Reference in New Issue