mirror of https://github.com/JDAI-CV/fast-reid.git
support loading various pretrained weights
Summary: Support loading pretrained model by custom path. With this function, we can load infoMin weights.pull/63/head
parent
5d4758125d
commit
5982f90920
|
@ -17,11 +17,10 @@ from fastreid.layers import (
|
|||
Non_local,
|
||||
get_norm,
|
||||
)
|
||||
|
||||
from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
model_urls = {
|
||||
18: 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
34: 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
|
@ -44,8 +43,10 @@ class BasicBlock(nn.Module):
|
|||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn2 = get_norm(bn_norm, planes, num_splits)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
if with_se: self.se = SELayer(planes, reduction)
|
||||
else: self.se = nn.Identity()
|
||||
if with_se:
|
||||
self.se = SELayer(planes, reduction)
|
||||
else:
|
||||
self.se = nn.Identity()
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
|
@ -75,16 +76,20 @@ class Bottleneck(nn.Module):
|
|||
stride=1, downsample=None, reduction=16):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
if with_ibn: self.bn1 = IBN(planes, bn_norm, num_splits)
|
||||
else: self.bn1 = get_norm(bn_norm, planes, num_splits)
|
||||
if with_ibn:
|
||||
self.bn1 = IBN(planes, bn_norm, num_splits)
|
||||
else:
|
||||
self.bn1 = get_norm(bn_norm, planes, num_splits)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
self.bn2 = get_norm(bn_norm, planes, num_splits)
|
||||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = get_norm(bn_norm, planes * 4, num_splits)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
if with_se: self.se = SELayer(planes * 4, reduction)
|
||||
else: self.se = nn.Identity()
|
||||
if with_se:
|
||||
self.se = SELayer(planes * 4, reduction)
|
||||
else:
|
||||
self.se = nn.Identity()
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
|
@ -250,21 +255,31 @@ def build_resnet_backbone(cfg):
|
|||
num_blocks_per_stage, nl_layers_per_stage)
|
||||
if pretrain:
|
||||
if not with_ibn:
|
||||
# original resnet
|
||||
state_dict = model_zoo.load_url(model_urls[depth])
|
||||
try:
|
||||
state_dict = torch.load(pretrain_path)['model']
|
||||
# Remove module.encoder in name
|
||||
new_state_dict = {}
|
||||
for k in state_dict:
|
||||
new_k = '.'.join(k.split('.')[2:])
|
||||
if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
|
||||
new_state_dict[new_k] = state_dict[k]
|
||||
state_dict = new_state_dict
|
||||
logger.info(f"Loading pretrained model from {pretrain_path}")
|
||||
except FileNotFoundError or KeyError:
|
||||
# original resnet
|
||||
state_dict = model_zoo.load_url(model_urls[depth])
|
||||
logger.info("Loading pretrained model from torchvision")
|
||||
else:
|
||||
# ibn resnet
|
||||
state_dict = torch.load(pretrain_path)['state_dict']
|
||||
# remove module in name
|
||||
state_dict = torch.load(pretrain_path)['state_dict'] # ibn-net
|
||||
# Remove module in name
|
||||
new_state_dict = {}
|
||||
for k in state_dict:
|
||||
new_k = '.'.join(k.split('.')[1:])
|
||||
if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
|
||||
new_state_dict[new_k] = state_dict[k]
|
||||
state_dict = new_state_dict
|
||||
|
||||
logger.info(f"Loading pretrained model from {pretrain_path}")
|
||||
incompatible = model.load_state_dict(state_dict, strict=False)
|
||||
logger = logging.getLogger(__name__)
|
||||
if incompatible.missing_keys:
|
||||
logger.info(
|
||||
get_missing_parameters_message(incompatible.missing_keys)
|
||||
|
|
Loading…
Reference in New Issue