mirror of https://github.com/YifanXu74/MQ-Det.git
208 lines
8.4 KiB
Python
208 lines
8.4 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
|
import logging
|
|
import pickle
|
|
from collections import OrderedDict
|
|
|
|
import torch
|
|
|
|
from maskrcnn_benchmark.utils.model_serialization import load_state_dict
|
|
from maskrcnn_benchmark.utils.registry import Registry
|
|
|
|
|
|
def _rename_basic_resnet_weights(layer_keys):
|
|
layer_keys = [k.replace("_", ".") for k in layer_keys]
|
|
layer_keys = [k.replace(".w", ".weight") for k in layer_keys]
|
|
layer_keys = [k.replace(".bn", "_bn") for k in layer_keys]
|
|
layer_keys = [k.replace(".b", ".bias") for k in layer_keys]
|
|
layer_keys = [k.replace("_bn.s", "_bn.scale") for k in layer_keys]
|
|
layer_keys = [k.replace(".biasranch", ".branch") for k in layer_keys]
|
|
layer_keys = [k.replace("bbox.pred", "bbox_pred") for k in layer_keys]
|
|
layer_keys = [k.replace("cls.score", "cls_score") for k in layer_keys]
|
|
layer_keys = [k.replace("res.conv1_", "conv1_") for k in layer_keys]
|
|
|
|
# RPN / Faster RCNN
|
|
layer_keys = [k.replace(".biasbox", ".bbox") for k in layer_keys]
|
|
layer_keys = [k.replace("conv.rpn", "rpn.conv") for k in layer_keys]
|
|
layer_keys = [k.replace("rpn.bbox.pred", "rpn.bbox_pred") for k in layer_keys]
|
|
layer_keys = [k.replace("rpn.cls.logits", "rpn.cls_logits") for k in layer_keys]
|
|
|
|
# Affine-Channel -> BatchNorm enaming
|
|
layer_keys = [k.replace("_bn.scale", "_bn.weight") for k in layer_keys]
|
|
|
|
# Make torchvision-compatible
|
|
layer_keys = [k.replace("conv1_bn.", "bn1.") for k in layer_keys]
|
|
|
|
layer_keys = [k.replace("res2.", "layer1.") for k in layer_keys]
|
|
layer_keys = [k.replace("res3.", "layer2.") for k in layer_keys]
|
|
layer_keys = [k.replace("res4.", "layer3.") for k in layer_keys]
|
|
layer_keys = [k.replace("res5.", "layer4.") for k in layer_keys]
|
|
|
|
layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys]
|
|
layer_keys = [k.replace(".branch2a_bn.", ".bn1.") for k in layer_keys]
|
|
layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys]
|
|
layer_keys = [k.replace(".branch2b_bn.", ".bn2.") for k in layer_keys]
|
|
layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys]
|
|
layer_keys = [k.replace(".branch2c_bn.", ".bn3.") for k in layer_keys]
|
|
|
|
layer_keys = [k.replace(".branch1.", ".downsample.0.") for k in layer_keys]
|
|
layer_keys = [k.replace(".branch1_bn.", ".downsample.1.") for k in layer_keys]
|
|
|
|
# GroupNorm
|
|
layer_keys = [k.replace("conv1.gn.s", "bn1.weight") for k in layer_keys]
|
|
layer_keys = [k.replace("conv1.gn.bias", "bn1.bias") for k in layer_keys]
|
|
layer_keys = [k.replace("conv2.gn.s", "bn2.weight") for k in layer_keys]
|
|
layer_keys = [k.replace("conv2.gn.bias", "bn2.bias") for k in layer_keys]
|
|
layer_keys = [k.replace("conv3.gn.s", "bn3.weight") for k in layer_keys]
|
|
layer_keys = [k.replace("conv3.gn.bias", "bn3.bias") for k in layer_keys]
|
|
layer_keys = [k.replace("downsample.0.gn.s", "downsample.1.weight") \
|
|
for k in layer_keys]
|
|
layer_keys = [k.replace("downsample.0.gn.bias", "downsample.1.bias") \
|
|
for k in layer_keys]
|
|
|
|
return layer_keys
|
|
|
|
def _rename_fpn_weights(layer_keys, stage_names):
|
|
for mapped_idx, stage_name in enumerate(stage_names, 1):
|
|
suffix = ""
|
|
if mapped_idx < 4:
|
|
suffix = ".lateral"
|
|
layer_keys = [
|
|
k.replace("fpn.inner.layer{}.sum{}".format(stage_name, suffix), "fpn_inner{}".format(mapped_idx)) for k in layer_keys
|
|
]
|
|
layer_keys = [k.replace("fpn.layer{}.sum".format(stage_name), "fpn_layer{}".format(mapped_idx)) for k in layer_keys]
|
|
|
|
|
|
layer_keys = [k.replace("rpn.conv.fpn2", "rpn.conv") for k in layer_keys]
|
|
layer_keys = [k.replace("rpn.bbox_pred.fpn2", "rpn.bbox_pred") for k in layer_keys]
|
|
layer_keys = [
|
|
k.replace("rpn.cls_logits.fpn2", "rpn.cls_logits") for k in layer_keys
|
|
]
|
|
|
|
return layer_keys
|
|
|
|
|
|
def _rename_weights_for_resnet(weights, stage_names):
|
|
original_keys = sorted(weights.keys())
|
|
layer_keys = sorted(weights.keys())
|
|
|
|
# for X-101, rename output to fc1000 to avoid conflicts afterwards
|
|
layer_keys = [k if k != "pred_b" else "fc1000_b" for k in layer_keys]
|
|
layer_keys = [k if k != "pred_w" else "fc1000_w" for k in layer_keys]
|
|
|
|
# performs basic renaming: _ -> . , etc
|
|
layer_keys = _rename_basic_resnet_weights(layer_keys)
|
|
|
|
# FPN
|
|
layer_keys = _rename_fpn_weights(layer_keys, stage_names)
|
|
|
|
# Mask R-CNN
|
|
layer_keys = [k.replace("mask.fcn.logits", "mask_fcn_logits") for k in layer_keys]
|
|
layer_keys = [k.replace(".[mask].fcn", "mask_fcn") for k in layer_keys]
|
|
layer_keys = [k.replace("conv5.mask", "conv5_mask") for k in layer_keys]
|
|
|
|
# Keypoint R-CNN
|
|
layer_keys = [k.replace("kps.score.lowres", "kps_score_lowres") for k in layer_keys]
|
|
layer_keys = [k.replace("kps.score", "kps_score") for k in layer_keys]
|
|
layer_keys = [k.replace("conv.fcn", "conv_fcn") for k in layer_keys]
|
|
|
|
# Rename for our RPN structure
|
|
layer_keys = [k.replace("rpn.", "rpn.head.") for k in layer_keys]
|
|
|
|
key_map = {k: v for k, v in zip(original_keys, layer_keys)}
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.info("Remapping C2 weights")
|
|
max_c2_key_size = max([len(k) for k in original_keys if "_momentum" not in k])
|
|
|
|
new_weights = OrderedDict()
|
|
for k in original_keys:
|
|
v = weights[k]
|
|
if "_momentum" in k:
|
|
continue
|
|
if 'weight_order' in k:
|
|
continue
|
|
# if 'fc1000' in k:
|
|
# continue
|
|
w = torch.from_numpy(v)
|
|
# if "bn" in k:
|
|
# w = w.view(1, -1, 1, 1)
|
|
logger.info("C2 name: {: <{}} mapped name: {}".format(k, max_c2_key_size, key_map[k]))
|
|
new_weights[key_map[k]] = w
|
|
|
|
return new_weights
|
|
|
|
|
|
def _load_c2_pickled_weights(file_path):
|
|
with open(file_path, "rb") as f:
|
|
if torch._six.PY3:
|
|
data = pickle.load(f, encoding="latin1")
|
|
else:
|
|
data = pickle.load(f)
|
|
if "blobs" in data:
|
|
weights = data["blobs"]
|
|
else:
|
|
weights = data
|
|
return weights
|
|
|
|
|
|
def _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg):
|
|
import re
|
|
logger = logging.getLogger(__name__)
|
|
logger.info("Remapping conv weights for deformable conv weights")
|
|
layer_keys = sorted(state_dict.keys())
|
|
for ix, stage_with_dcn in enumerate(cfg.MODEL.RESNETS.STAGE_WITH_DCN, 1):
|
|
if not stage_with_dcn:
|
|
continue
|
|
for old_key in layer_keys:
|
|
pattern = ".*layer{}.*conv2.*".format(ix)
|
|
r = re.match(pattern, old_key)
|
|
if r is None:
|
|
continue
|
|
for param in ["weight", "bias"]:
|
|
if old_key.find(param) is -1:
|
|
continue
|
|
new_key = old_key.replace(
|
|
"conv2.{}".format(param), "conv2.conv.{}".format(param)
|
|
)
|
|
logger.info("pattern: {}, old_key: {}, new_key: {}".format(
|
|
pattern, old_key, new_key
|
|
))
|
|
state_dict[new_key] = state_dict[old_key]
|
|
del state_dict[old_key]
|
|
return state_dict
|
|
|
|
|
|
_C2_STAGE_NAMES = {
|
|
"R-50": ["1.2", "2.3", "3.5", "4.2"],
|
|
"R-101": ["1.2", "2.3", "3.22", "4.2"],
|
|
}
|
|
|
|
C2_FORMAT_LOADER = Registry()
|
|
|
|
|
|
@C2_FORMAT_LOADER.register("R-50-C4")
|
|
@C2_FORMAT_LOADER.register("R-50-C5")
|
|
@C2_FORMAT_LOADER.register("R-101-C4")
|
|
@C2_FORMAT_LOADER.register("R-101-C5")
|
|
@C2_FORMAT_LOADER.register("R-50-FPN")
|
|
@C2_FORMAT_LOADER.register("R-50-FPN-RETINANET")
|
|
@C2_FORMAT_LOADER.register("R-50-FPN-FCOS")
|
|
@C2_FORMAT_LOADER.register("R-101-FPN")
|
|
@C2_FORMAT_LOADER.register("R-101-FPN-RETINANET")
|
|
@C2_FORMAT_LOADER.register("R-101-FPN-FCOS")
|
|
def load_resnet_c2_format(cfg, f):
|
|
state_dict = _load_c2_pickled_weights(f)
|
|
conv_body = cfg.MODEL.BACKBONE.CONV_BODY
|
|
arch = conv_body.replace("-C4", "").replace("-C5", "").replace("-FPN", "").replace("-RETINANET", "").replace("-FCOS", "")
|
|
stages = _C2_STAGE_NAMES[arch]
|
|
state_dict = _rename_weights_for_resnet(state_dict, stages)
|
|
# ***********************************
|
|
# for deformable convolutional layer
|
|
state_dict = _rename_conv_weights_for_deformable_conv_layers(state_dict, cfg)
|
|
# ***********************************
|
|
return dict(model=state_dict)
|
|
|
|
|
|
def load_c2_format(cfg, f):
|
|
return C2_FORMAT_LOADER[cfg.MODEL.BACKBONE.CONV_BODY](cfg, f)
|