fast-reid/fastreid/modeling/backbones/regnet/effnet.py

282 lines
9.6 KiB
Python

# !/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""EfficientNet models."""
import logging
import torch
import torch.nn as nn
from fastreid.layers import *
from fastreid.modeling.backbones.build import BACKBONE_REGISTRY
from fastreid.utils import comm
from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
from .config import cfg as effnet_cfg
from .regnet import drop_connect, init_weights
logger = logging.getLogger(__name__)
model_urls = {
'b0': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161305613/EN-B0_dds_8gpu.pyth',
'b1': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B1_dds_8gpu.pyth',
'b2': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B2_dds_8gpu.pyth',
'b3': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B3_dds_8gpu.pyth',
'b4': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161305098/EN-B4_dds_8gpu.pyth',
'b5': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B5_dds_8gpu.pyth',
'b6': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B6_dds_8gpu.pyth',
'b7': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B7_dds_8gpu.pyth',
}
class EffHead(nn.Module):
"""EfficientNet head: 1x1, BN, Swish, AvgPool, Dropout, FC."""
def __init__(self, w_in, w_out, bn_norm):
super(EffHead, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 1, stride=1, padding=0, bias=False)
self.conv_bn = get_norm(bn_norm, w_out)
self.conv_swish = Swish()
def forward(self, x):
x = self.conv_swish(self.conv_bn(self.conv(x)))
return x
class Swish(nn.Module):
"""Swish activation function: x * sigmoid(x)."""
def __init__(self):
super(Swish, self).__init__()
def forward(self, x):
return x * torch.sigmoid(x)
class SE(nn.Module):
"""Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid."""
def __init__(self, w_in, w_se):
super(SE, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.f_ex = nn.Sequential(
nn.Conv2d(w_in, w_se, 1, bias=True),
Swish(),
nn.Conv2d(w_se, w_in, 1, bias=True),
nn.Sigmoid(),
)
def forward(self, x):
return x * self.f_ex(self.avg_pool(x))
class MBConv(nn.Module):
"""Mobile inverted bottleneck block w/ SE (MBConv)."""
def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, bn_norm):
# expansion, 3x3 dwise, BN, Swish, SE, 1x1, BN, skip_connection
super(MBConv, self).__init__()
self.exp = None
w_exp = int(w_in * exp_r)
if w_exp != w_in:
self.exp = nn.Conv2d(w_in, w_exp, 1, stride=1, padding=0, bias=False)
self.exp_bn = get_norm(bn_norm, w_exp)
self.exp_swish = Swish()
dwise_args = {"groups": w_exp, "padding": (kernel - 1) // 2, "bias": False}
self.dwise = nn.Conv2d(w_exp, w_exp, kernel, stride=stride, **dwise_args)
self.dwise_bn = get_norm(bn_norm, w_exp)
self.dwise_swish = Swish()
self.se = SE(w_exp, int(w_in * se_r))
self.lin_proj = nn.Conv2d(w_exp, w_out, 1, stride=1, padding=0, bias=False)
self.lin_proj_bn = get_norm(bn_norm, w_out)
# Skip connection if in and out shapes are the same (MN-V2 style)
self.has_skip = stride == 1 and w_in == w_out
def forward(self, x):
f_x = x
if self.exp:
f_x = self.exp_swish(self.exp_bn(self.exp(f_x)))
f_x = self.dwise_swish(self.dwise_bn(self.dwise(f_x)))
f_x = self.se(f_x)
f_x = self.lin_proj_bn(self.lin_proj(f_x))
if self.has_skip:
if self.training and effnet_cfg.EN.DC_RATIO > 0.0:
f_x = drop_connect(f_x, effnet_cfg.EN.DC_RATIO)
f_x = x + f_x
return f_x
class EffStage(nn.Module):
"""EfficientNet stage."""
def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, d, bn_norm):
super(EffStage, self).__init__()
for i in range(d):
b_stride = stride if i == 0 else 1
b_w_in = w_in if i == 0 else w_out
name = "b{}".format(i + 1)
self.add_module(name, MBConv(b_w_in, exp_r, kernel, b_stride, se_r, w_out, bn_norm))
def forward(self, x):
for block in self.children():
x = block(x)
return x
class StemIN(nn.Module):
"""EfficientNet stem for ImageNet: 3x3, BN, Swish."""
def __init__(self, w_in, w_out, bn_norm):
super(StemIN, self).__init__()
self.conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=False)
self.bn = get_norm(bn_norm, w_out)
self.swish = Swish()
def forward(self, x):
for layer in self.children():
x = layer(x)
return x
class EffNet(nn.Module):
"""EfficientNet model."""
@staticmethod
def get_args():
return {
"stem_w": effnet_cfg.EN.STEM_W,
"ds": effnet_cfg.EN.DEPTHS,
"ws": effnet_cfg.EN.WIDTHS,
"exp_rs": effnet_cfg.EN.EXP_RATIOS,
"se_r": effnet_cfg.EN.SE_R,
"ss": effnet_cfg.EN.STRIDES,
"ks": effnet_cfg.EN.KERNELS,
"head_w": effnet_cfg.EN.HEAD_W,
}
def __init__(self, last_stride, bn_norm, **kwargs):
super(EffNet, self).__init__()
kwargs = self.get_args() if not kwargs else kwargs
self._construct(**kwargs, last_stride=last_stride, bn_norm=bn_norm)
self.apply(init_weights)
def _construct(self, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, last_stride, bn_norm):
stage_params = list(zip(ds, ws, exp_rs, ss, ks))
self.stem = StemIN(3, stem_w, bn_norm)
prev_w = stem_w
for i, (d, w, exp_r, stride, kernel) in enumerate(stage_params):
name = "s{}".format(i + 1)
if i == 5: stride = last_stride
self.add_module(name, EffStage(prev_w, exp_r, kernel, stride, se_r, w, d, bn_norm))
prev_w = w
self.head = EffHead(prev_w, head_w, bn_norm)
def forward(self, x):
for module in self.children():
x = module(x)
return x
def init_pretrained_weights(key):
"""Initializes model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
import os
import errno
import gdown
def _get_torch_home():
ENV_TORCH_HOME = 'TORCH_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
torch_home = os.path.expanduser(
os.getenv(
ENV_TORCH_HOME,
os.path.join(
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
)
)
)
return torch_home
torch_home = _get_torch_home()
model_dir = os.path.join(torch_home, 'checkpoints')
try:
os.makedirs(model_dir)
except OSError as e:
if e.errno == errno.EEXIST:
# Directory already exists, ignore.
pass
else:
# Unexpected OSError, re-raise.
raise
filename = model_urls[key].split('/')[-1]
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
if comm.is_main_process():
gdown.download(model_urls[key], cached_file, quiet=False)
comm.synchronize()
logger.info(f"Loading pretrained model from {cached_file}")
state_dict = torch.load(cached_file, map_location=torch.device('cpu'))['model_state']
return state_dict
@BACKBONE_REGISTRY.register()
def build_effnet_backbone(cfg):
# fmt: off
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
bn_norm = cfg.MODEL.BACKBONE.NORM
depth = cfg.MODEL.BACKBONE.DEPTH
# fmt: on
cfg_files = {
'b0': 'fastreid/modeling/backbones/regnet/effnet/EN-B0_dds_8gpu.yaml',
'b1': 'fastreid/modeling/backbones/regnet/effnet/EN-B1_dds_8gpu.yaml',
'b2': 'fastreid/modeling/backbones/regnet/effnet/EN-B2_dds_8gpu.yaml',
'b3': 'fastreid/modeling/backbones/regnet/effnet/EN-B3_dds_8gpu.yaml',
'b4': 'fastreid/modeling/backbones/regnet/effnet/EN-B4_dds_8gpu.yaml',
'b5': 'fastreid/modeling/backbones/regnet/effnet/EN-B5_dds_8gpu.yaml',
}[depth]
effnet_cfg.merge_from_file(cfg_files)
model = EffNet(last_stride, bn_norm)
if pretrain:
# Load pretrain path if specifically
if pretrain_path:
try:
state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))
logger.info(f"Loading pretrained model from {pretrain_path}")
except FileNotFoundError as e:
logger.info(f'{pretrain_path} is not found! Please check this path.')
raise e
except KeyError as e:
logger.info("State dict keys error! Please check the state dict.")
raise e
else:
key = depth
state_dict = init_pretrained_weights(key)
incompatible = model.load_state_dict(state_dict, strict=False)
if incompatible.missing_keys:
logger.info(
get_missing_parameters_message(incompatible.missing_keys)
)
if incompatible.unexpected_keys:
logger.info(
get_unexpected_parameters_message(incompatible.unexpected_keys)
)
return model