mirror of https://github.com/JDAI-CV/fast-reid.git
282 lines
9.6 KiB
Python
282 lines
9.6 KiB
Python
|
# !/usr/bin/env python3
|
||
|
|
||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||
|
#
|
||
|
# This source code is licensed under the MIT license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
||
|
"""EfficientNet models."""
|
||
|
|
||
|
import logging
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
from fastreid.layers import *
|
||
|
from fastreid.modeling.backbones.build import BACKBONE_REGISTRY
|
||
|
from fastreid.utils import comm
|
||
|
from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
|
||
|
from .config import cfg as effnet_cfg
|
||
|
from .regnet import drop_connect, init_weights
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
model_urls = {
|
||
|
'b0': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161305613/EN-B0_dds_8gpu.pyth',
|
||
|
'b1': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B1_dds_8gpu.pyth',
|
||
|
'b2': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B2_dds_8gpu.pyth',
|
||
|
'b3': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B3_dds_8gpu.pyth',
|
||
|
'b4': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161305098/EN-B4_dds_8gpu.pyth',
|
||
|
'b5': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B5_dds_8gpu.pyth',
|
||
|
'b6': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B6_dds_8gpu.pyth',
|
||
|
'b7': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B7_dds_8gpu.pyth',
|
||
|
}
|
||
|
|
||
|
|
||
|
class EffHead(nn.Module):
|
||
|
"""EfficientNet head: 1x1, BN, Swish, AvgPool, Dropout, FC."""
|
||
|
|
||
|
def __init__(self, w_in, w_out, bn_norm):
|
||
|
super(EffHead, self).__init__()
|
||
|
self.conv = nn.Conv2d(w_in, w_out, 1, stride=1, padding=0, bias=False)
|
||
|
self.conv_bn = get_norm(bn_norm, w_out)
|
||
|
self.conv_swish = Swish()
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.conv_swish(self.conv_bn(self.conv(x)))
|
||
|
return x
|
||
|
|
||
|
|
||
|
class Swish(nn.Module):
|
||
|
"""Swish activation function: x * sigmoid(x)."""
|
||
|
|
||
|
def __init__(self):
|
||
|
super(Swish, self).__init__()
|
||
|
|
||
|
def forward(self, x):
|
||
|
return x * torch.sigmoid(x)
|
||
|
|
||
|
|
||
|
class SE(nn.Module):
|
||
|
"""Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid."""
|
||
|
|
||
|
def __init__(self, w_in, w_se):
|
||
|
super(SE, self).__init__()
|
||
|
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||
|
self.f_ex = nn.Sequential(
|
||
|
nn.Conv2d(w_in, w_se, 1, bias=True),
|
||
|
Swish(),
|
||
|
nn.Conv2d(w_se, w_in, 1, bias=True),
|
||
|
nn.Sigmoid(),
|
||
|
)
|
||
|
|
||
|
def forward(self, x):
|
||
|
return x * self.f_ex(self.avg_pool(x))
|
||
|
|
||
|
|
||
|
class MBConv(nn.Module):
|
||
|
"""Mobile inverted bottleneck block w/ SE (MBConv)."""
|
||
|
|
||
|
def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, bn_norm):
|
||
|
# expansion, 3x3 dwise, BN, Swish, SE, 1x1, BN, skip_connection
|
||
|
super(MBConv, self).__init__()
|
||
|
self.exp = None
|
||
|
w_exp = int(w_in * exp_r)
|
||
|
if w_exp != w_in:
|
||
|
self.exp = nn.Conv2d(w_in, w_exp, 1, stride=1, padding=0, bias=False)
|
||
|
self.exp_bn = get_norm(bn_norm, w_exp)
|
||
|
self.exp_swish = Swish()
|
||
|
dwise_args = {"groups": w_exp, "padding": (kernel - 1) // 2, "bias": False}
|
||
|
self.dwise = nn.Conv2d(w_exp, w_exp, kernel, stride=stride, **dwise_args)
|
||
|
self.dwise_bn = get_norm(bn_norm, w_exp)
|
||
|
self.dwise_swish = Swish()
|
||
|
self.se = SE(w_exp, int(w_in * se_r))
|
||
|
self.lin_proj = nn.Conv2d(w_exp, w_out, 1, stride=1, padding=0, bias=False)
|
||
|
self.lin_proj_bn = get_norm(bn_norm, w_out)
|
||
|
# Skip connection if in and out shapes are the same (MN-V2 style)
|
||
|
self.has_skip = stride == 1 and w_in == w_out
|
||
|
|
||
|
def forward(self, x):
|
||
|
f_x = x
|
||
|
if self.exp:
|
||
|
f_x = self.exp_swish(self.exp_bn(self.exp(f_x)))
|
||
|
f_x = self.dwise_swish(self.dwise_bn(self.dwise(f_x)))
|
||
|
f_x = self.se(f_x)
|
||
|
f_x = self.lin_proj_bn(self.lin_proj(f_x))
|
||
|
if self.has_skip:
|
||
|
if self.training and effnet_cfg.EN.DC_RATIO > 0.0:
|
||
|
f_x = drop_connect(f_x, effnet_cfg.EN.DC_RATIO)
|
||
|
f_x = x + f_x
|
||
|
return f_x
|
||
|
|
||
|
|
||
|
class EffStage(nn.Module):
|
||
|
"""EfficientNet stage."""
|
||
|
|
||
|
def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, d, bn_norm):
|
||
|
super(EffStage, self).__init__()
|
||
|
for i in range(d):
|
||
|
b_stride = stride if i == 0 else 1
|
||
|
b_w_in = w_in if i == 0 else w_out
|
||
|
name = "b{}".format(i + 1)
|
||
|
self.add_module(name, MBConv(b_w_in, exp_r, kernel, b_stride, se_r, w_out, bn_norm))
|
||
|
|
||
|
def forward(self, x):
|
||
|
for block in self.children():
|
||
|
x = block(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class StemIN(nn.Module):
|
||
|
"""EfficientNet stem for ImageNet: 3x3, BN, Swish."""
|
||
|
|
||
|
def __init__(self, w_in, w_out, bn_norm):
|
||
|
super(StemIN, self).__init__()
|
||
|
self.conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=False)
|
||
|
self.bn = get_norm(bn_norm, w_out)
|
||
|
self.swish = Swish()
|
||
|
|
||
|
def forward(self, x):
|
||
|
for layer in self.children():
|
||
|
x = layer(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class EffNet(nn.Module):
|
||
|
"""EfficientNet model."""
|
||
|
|
||
|
@staticmethod
|
||
|
def get_args():
|
||
|
return {
|
||
|
"stem_w": effnet_cfg.EN.STEM_W,
|
||
|
"ds": effnet_cfg.EN.DEPTHS,
|
||
|
"ws": effnet_cfg.EN.WIDTHS,
|
||
|
"exp_rs": effnet_cfg.EN.EXP_RATIOS,
|
||
|
"se_r": effnet_cfg.EN.SE_R,
|
||
|
"ss": effnet_cfg.EN.STRIDES,
|
||
|
"ks": effnet_cfg.EN.KERNELS,
|
||
|
"head_w": effnet_cfg.EN.HEAD_W,
|
||
|
}
|
||
|
|
||
|
def __init__(self, last_stride, bn_norm, **kwargs):
|
||
|
super(EffNet, self).__init__()
|
||
|
kwargs = self.get_args() if not kwargs else kwargs
|
||
|
self._construct(**kwargs, last_stride=last_stride, bn_norm=bn_norm)
|
||
|
self.apply(init_weights)
|
||
|
|
||
|
def _construct(self, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, last_stride, bn_norm):
|
||
|
stage_params = list(zip(ds, ws, exp_rs, ss, ks))
|
||
|
self.stem = StemIN(3, stem_w, bn_norm)
|
||
|
prev_w = stem_w
|
||
|
for i, (d, w, exp_r, stride, kernel) in enumerate(stage_params):
|
||
|
name = "s{}".format(i + 1)
|
||
|
if i == 5: stride = last_stride
|
||
|
self.add_module(name, EffStage(prev_w, exp_r, kernel, stride, se_r, w, d, bn_norm))
|
||
|
prev_w = w
|
||
|
self.head = EffHead(prev_w, head_w, bn_norm)
|
||
|
|
||
|
def forward(self, x):
|
||
|
for module in self.children():
|
||
|
x = module(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
def init_pretrained_weights(key):
|
||
|
"""Initializes model with pretrained weights.
|
||
|
|
||
|
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
||
|
"""
|
||
|
import os
|
||
|
import errno
|
||
|
import gdown
|
||
|
|
||
|
def _get_torch_home():
|
||
|
ENV_TORCH_HOME = 'TORCH_HOME'
|
||
|
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
||
|
DEFAULT_CACHE_DIR = '~/.cache'
|
||
|
torch_home = os.path.expanduser(
|
||
|
os.getenv(
|
||
|
ENV_TORCH_HOME,
|
||
|
os.path.join(
|
||
|
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
|
||
|
)
|
||
|
)
|
||
|
)
|
||
|
return torch_home
|
||
|
|
||
|
torch_home = _get_torch_home()
|
||
|
model_dir = os.path.join(torch_home, 'checkpoints')
|
||
|
try:
|
||
|
os.makedirs(model_dir)
|
||
|
except OSError as e:
|
||
|
if e.errno == errno.EEXIST:
|
||
|
# Directory already exists, ignore.
|
||
|
pass
|
||
|
else:
|
||
|
# Unexpected OSError, re-raise.
|
||
|
raise
|
||
|
|
||
|
filename = model_urls[key].split('/')[-1]
|
||
|
|
||
|
cached_file = os.path.join(model_dir, filename)
|
||
|
|
||
|
if not os.path.exists(cached_file):
|
||
|
if comm.is_main_process():
|
||
|
gdown.download(model_urls[key], cached_file, quiet=False)
|
||
|
|
||
|
comm.synchronize()
|
||
|
|
||
|
logger.info(f"Loading pretrained model from {cached_file}")
|
||
|
state_dict = torch.load(cached_file, map_location=torch.device('cpu'))['model_state']
|
||
|
|
||
|
return state_dict
|
||
|
|
||
|
|
||
|
@BACKBONE_REGISTRY.register()
|
||
|
def build_effnet_backbone(cfg):
|
||
|
# fmt: off
|
||
|
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
|
||
|
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
|
||
|
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
|
||
|
bn_norm = cfg.MODEL.BACKBONE.NORM
|
||
|
depth = cfg.MODEL.BACKBONE.DEPTH
|
||
|
# fmt: on
|
||
|
|
||
|
cfg_files = {
|
||
|
'b0': 'fastreid/modeling/backbones/regnet/effnet/EN-B0_dds_8gpu.yaml',
|
||
|
'b1': 'fastreid/modeling/backbones/regnet/effnet/EN-B1_dds_8gpu.yaml',
|
||
|
'b2': 'fastreid/modeling/backbones/regnet/effnet/EN-B2_dds_8gpu.yaml',
|
||
|
'b3': 'fastreid/modeling/backbones/regnet/effnet/EN-B3_dds_8gpu.yaml',
|
||
|
'b4': 'fastreid/modeling/backbones/regnet/effnet/EN-B4_dds_8gpu.yaml',
|
||
|
'b5': 'fastreid/modeling/backbones/regnet/effnet/EN-B5_dds_8gpu.yaml',
|
||
|
}[depth]
|
||
|
|
||
|
effnet_cfg.merge_from_file(cfg_files)
|
||
|
model = EffNet(last_stride, bn_norm)
|
||
|
|
||
|
if pretrain:
|
||
|
# Load pretrain path if specifically
|
||
|
if pretrain_path:
|
||
|
try:
|
||
|
state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))
|
||
|
logger.info(f"Loading pretrained model from {pretrain_path}")
|
||
|
except FileNotFoundError as e:
|
||
|
logger.info(f'{pretrain_path} is not found! Please check this path.')
|
||
|
raise e
|
||
|
except KeyError as e:
|
||
|
logger.info("State dict keys error! Please check the state dict.")
|
||
|
raise e
|
||
|
else:
|
||
|
key = depth
|
||
|
state_dict = init_pretrained_weights(key)
|
||
|
|
||
|
incompatible = model.load_state_dict(state_dict, strict=False)
|
||
|
if incompatible.missing_keys:
|
||
|
logger.info(
|
||
|
get_missing_parameters_message(incompatible.missing_keys)
|
||
|
)
|
||
|
if incompatible.unexpected_keys:
|
||
|
logger.info(
|
||
|
get_unexpected_parameters_message(incompatible.unexpected_keys)
|
||
|
)
|
||
|
return model
|