style(backbone): make parameters loading logging more elegant

pull/49/head
liaoxingyu 2020-05-08 12:22:06 +08:00
parent 0b15ac4e03
commit 8ab0bc2455
2 changed files with 23 additions and 6 deletions

View File

@ -16,6 +16,8 @@ from fastreid.layers import (
get_norm,
)
from fastreid.utils.checkpoint import get_unexpected_parameters_message, get_missing_parameters_message
from .build import BACKBONE_REGISTRY
_url_format = 'https://hangzh.s3.amazonaws.com/encoding/models/{}-{}.pth'
@ -396,8 +398,14 @@ def build_resnest_backbone(cfg):
# if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
# new_state_dict[new_k] = state_dict[k]
# state_dict = new_state_dict
res = model.load_state_dict(state_dict, strict=False)
incompatible = model.load_state_dict(state_dict, strict=False)
logger = logging.getLogger(__name__)
logger.info('missing keys is {}'.format(res.missing_keys))
logger.info('unexpected keys is {}'.format(res.unexpected_keys))
if incompatible.missing_keys:
logger.info(
get_missing_parameters_message(incompatible.missing_keys)
)
if incompatible.unexpected_keys:
logger.info(
get_unexpected_parameters_message(incompatible.unexpected_keys)
)
return model

View File

@ -18,6 +18,8 @@ from fastreid.layers import (
get_norm,
)
from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
from .build import BACKBONE_REGISTRY
model_urls = {
@ -229,8 +231,15 @@ def build_resnet_backbone(cfg):
if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
new_state_dict[new_k] = state_dict[k]
state_dict = new_state_dict
res = model.load_state_dict(state_dict, strict=False)
incompatible = model.load_state_dict(state_dict, strict=False)
logger = logging.getLogger(__name__)
logger.info('missing keys is {}'.format(res.missing_keys))
logger.info('unexpected keys is {}'.format(res.unexpected_keys))
if incompatible.missing_keys:
logger.info(
get_missing_parameters_message(incompatible.missing_keys)
)
if incompatible.unexpected_keys:
logger.info(
get_unexpected_parameters_message(incompatible.unexpected_keys)
)
return model