support loading various pretrained weights

Summary: Support loading pretrained model by custom path. With this function, we can load infoMin weights.
pull/63/head
liaoxingyu 2020-05-26 14:33:18 +08:00
parent 5d4758125d
commit 5982f90920
1 changed files with 30 additions and 15 deletions

View File

@ -17,11 +17,10 @@ from fastreid.layers import (
Non_local,
get_norm,
)
from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
from .build import BACKBONE_REGISTRY
logger = logging.getLogger(__name__)
model_urls = {
18: 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
34: 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
@ -44,8 +43,10 @@ class BasicBlock(nn.Module):
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = get_norm(bn_norm, planes, num_splits)
self.relu = nn.ReLU(inplace=True)
if with_se: self.se = SELayer(planes, reduction)
else: self.se = nn.Identity()
if with_se:
self.se = SELayer(planes, reduction)
else:
self.se = nn.Identity()
self.downsample = downsample
self.stride = stride
@ -75,16 +76,20 @@ class Bottleneck(nn.Module):
stride=1, downsample=None, reduction=16):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
if with_ibn: self.bn1 = IBN(planes, bn_norm, num_splits)
else: self.bn1 = get_norm(bn_norm, planes, num_splits)
if with_ibn:
self.bn1 = IBN(planes, bn_norm, num_splits)
else:
self.bn1 = get_norm(bn_norm, planes, num_splits)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = get_norm(bn_norm, planes, num_splits)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = get_norm(bn_norm, planes * 4, num_splits)
self.relu = nn.ReLU(inplace=True)
if with_se: self.se = SELayer(planes * 4, reduction)
else: self.se = nn.Identity()
if with_se:
self.se = SELayer(planes * 4, reduction)
else:
self.se = nn.Identity()
self.downsample = downsample
self.stride = stride
@ -250,21 +255,31 @@ def build_resnet_backbone(cfg):
num_blocks_per_stage, nl_layers_per_stage)
if pretrain:
if not with_ibn:
# original resnet
state_dict = model_zoo.load_url(model_urls[depth])
try:
state_dict = torch.load(pretrain_path)['model']
# Remove module.encoder in name
new_state_dict = {}
for k in state_dict:
new_k = '.'.join(k.split('.')[2:])
if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
new_state_dict[new_k] = state_dict[k]
state_dict = new_state_dict
logger.info(f"Loading pretrained model from {pretrain_path}")
except FileNotFoundError or KeyError:
# original resnet
state_dict = model_zoo.load_url(model_urls[depth])
logger.info("Loading pretrained model from torchvision")
else:
# ibn resnet
state_dict = torch.load(pretrain_path)['state_dict']
# remove module in name
state_dict = torch.load(pretrain_path)['state_dict'] # ibn-net
# Remove module in name
new_state_dict = {}
for k in state_dict:
new_k = '.'.join(k.split('.')[1:])
if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
new_state_dict[new_k] = state_dict[k]
state_dict = new_state_dict
logger.info(f"Loading pretrained model from {pretrain_path}")
incompatible = model.load_state_dict(state_dict, strict=False)
logger = logging.getLogger(__name__)
if incompatible.missing_keys:
logger.info(
get_missing_parameters_message(incompatible.missing_keys)