mirror of https://github.com/JDAI-CV/fast-reid.git
add efficientnet support
parent
296fbea989
commit
648198e6e5
|
@ -10,4 +10,4 @@ from .resnet import build_resnet_backbone
|
|||
from .osnet import build_osnet_backbone
|
||||
from .resnest import build_resnest_backbone
|
||||
from .resnext import build_resnext_backbone
|
||||
from .regnet import build_regnet_backbone
|
||||
from .regnet import build_regnet_backbone, build_effnet_backbone
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
|
||||
|
||||
from .regnet import build_regnet_backbone
|
||||
from .effnet import build_effnet_backbone
|
||||
|
|
|
@ -1,19 +1,31 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Configuration file (powered by YACS)."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from yacs.config import CfgNode as CN
|
||||
import sys
|
||||
|
||||
from yacs.config import CfgNode as CfgNode
|
||||
|
||||
|
||||
# Global config object
|
||||
_C = CN()
|
||||
_C = CfgNode()
|
||||
|
||||
# Example usage:
|
||||
# from core.config import cfg
|
||||
regnet_cfg = _C
|
||||
cfg = _C
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Model options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.MODEL = CN()
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.MODEL = CfgNode()
|
||||
|
||||
# Model type
|
||||
_C.MODEL.TYPE = ""
|
||||
|
@ -28,10 +40,10 @@ _C.MODEL.NUM_CLASSES = 10
|
|||
_C.MODEL.LOSS_FUN = "cross_entropy"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# ResNet options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.RESNET = CN()
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.RESNET = CfgNode()
|
||||
|
||||
# Transformation function (see pycls/models/resnet.py for options)
|
||||
_C.RESNET.TRANS_FUN = "basic_transform"
|
||||
|
@ -46,19 +58,19 @@ _C.RESNET.WIDTH_PER_GROUP = 64
|
|||
_C.RESNET.STRIDE_1X1 = True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# AnyNet options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.ANYNET = CN()
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.ANYNET = CfgNode()
|
||||
|
||||
# Stem type
|
||||
_C.ANYNET.STEM_TYPE = "plain_block"
|
||||
_C.ANYNET.STEM_TYPE = "simple_stem_in"
|
||||
|
||||
# Stem width
|
||||
_C.ANYNET.STEM_W = 32
|
||||
|
||||
# Block type
|
||||
_C.ANYNET.BLOCK_TYPE = "plain_block"
|
||||
_C.ANYNET.BLOCK_TYPE = "res_bottleneck_block"
|
||||
|
||||
# Depth for each stage (number of blocks in the stage)
|
||||
_C.ANYNET.DEPTHS = []
|
||||
|
@ -81,41 +93,51 @@ _C.ANYNET.SE_ON = False
|
|||
# SE ratio
|
||||
_C.ANYNET.SE_R = 0.25
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# RegNet options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.REGNET = CN()
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.REGNET = CfgNode()
|
||||
|
||||
# Stem type
|
||||
_C.REGNET.STEM_TYPE = "simple_stem_in"
|
||||
|
||||
# Stem width
|
||||
_C.REGNET.STEM_W = 32
|
||||
|
||||
# Block type
|
||||
_C.REGNET.BLOCK_TYPE = "res_bottleneck_block"
|
||||
|
||||
# Stride of each stage
|
||||
_C.REGNET.STRIDE = 2
|
||||
|
||||
# Squeeze-and-Excitation (RegNetY)
|
||||
_C.REGNET.SE_ON = False
|
||||
_C.REGNET.SE_R = 0.25
|
||||
|
||||
# Depth
|
||||
_C.REGNET.DEPTH = 10
|
||||
|
||||
# Initial width
|
||||
_C.REGNET.W0 = 32
|
||||
|
||||
# Slope
|
||||
_C.REGNET.WA = 5.0
|
||||
|
||||
# Quantization
|
||||
_C.REGNET.WM = 2.5
|
||||
|
||||
# Group width
|
||||
_C.REGNET.GROUP_W = 16
|
||||
|
||||
# Bottleneck multiplier (bm = 1 / b from the paper)
|
||||
_C.REGNET.BOT_MUL = 1.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# EfficientNet options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.EN = CN()
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.EN = CfgNode()
|
||||
|
||||
# Stem width
|
||||
_C.EN.STEM_W = 32
|
||||
|
@ -148,10 +170,10 @@ _C.EN.DC_RATIO = 0.0
|
|||
_C.EN.DROPOUT_RATIO = 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Batch norm options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.BN = CN()
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.BN = CfgNode()
|
||||
|
||||
# BN epsilon
|
||||
_C.BN.EPS = 1e-5
|
||||
|
@ -160,8 +182,8 @@ _C.BN.EPS = 1e-5
|
|||
_C.BN.MOM = 0.1
|
||||
|
||||
# Precise BN stats
|
||||
_C.BN.USE_PRECISE_STATS = False
|
||||
_C.BN.NUM_SAMPLES_PRECISE = 1024
|
||||
_C.BN.USE_PRECISE_STATS = True
|
||||
_C.BN.NUM_SAMPLES_PRECISE = 8192
|
||||
|
||||
# Initialize the gamma of the final BN of each block to zero
|
||||
_C.BN.ZERO_INIT_FINAL_GAMMA = False
|
||||
|
@ -170,10 +192,11 @@ _C.BN.ZERO_INIT_FINAL_GAMMA = False
|
|||
_C.BN.USE_CUSTOM_WEIGHT_DECAY = False
|
||||
_C.BN.CUSTOM_WEIGHT_DECAY = 0.0
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Optimizer options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.OPTIM = CN()
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.OPTIM = CfgNode()
|
||||
|
||||
# Base learning rate
|
||||
_C.OPTIM.BASE_LR = 0.1
|
||||
|
@ -212,10 +235,10 @@ _C.OPTIM.WARMUP_FACTOR = 0.1
|
|||
_C.OPTIM.WARMUP_EPOCHS = 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Training options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.TRAIN = CN()
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.TRAIN = CfgNode()
|
||||
|
||||
# Dataset and split
|
||||
_C.TRAIN.DATASET = ""
|
||||
|
@ -240,10 +263,10 @@ _C.TRAIN.AUTO_RESUME = True
|
|||
_C.TRAIN.WEIGHTS = ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Testing options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.TEST = CN()
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.TEST = CfgNode()
|
||||
|
||||
# Dataset and split
|
||||
_C.TEST.DATASET = ""
|
||||
|
@ -259,31 +282,31 @@ _C.TEST.IM_SIZE = 256
|
|||
_C.TEST.WEIGHTS = ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Common train/test data loader options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.DATA_LOADER = CN()
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.DATA_LOADER = CfgNode()
|
||||
|
||||
# Number of data loader workers per training process
|
||||
_C.DATA_LOADER.NUM_WORKERS = 4
|
||||
# Number of data loader workers per process
|
||||
_C.DATA_LOADER.NUM_WORKERS = 8
|
||||
|
||||
# Load data to pinned host memory
|
||||
_C.DATA_LOADER.PIN_MEMORY = True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Memory options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.MEM = CN()
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.MEM = CfgNode()
|
||||
|
||||
# Perform ReLU inplace
|
||||
_C.MEM.RELU_INPLACE = True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# CUDNN options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.CUDNN = CN()
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.CUDNN = CfgNode()
|
||||
|
||||
# Perform benchmarking to select the fastest CUDNN algorithms to use
|
||||
# Note that this may increase the memory usage and will likely not result
|
||||
|
@ -291,16 +314,10 @@ _C.CUDNN = CN()
|
|||
_C.CUDNN.BENCHMARK = True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Precise timing options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.PREC_TIME = CN()
|
||||
|
||||
# Perform precise timing at the start of training
|
||||
_C.PREC_TIME.ENABLED = False
|
||||
|
||||
# Total mini-batch size
|
||||
_C.PREC_TIME.BATCH_SIZE = 128
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
_C.PREC_TIME = CfgNode()
|
||||
|
||||
# Number of iterations to warm up the caches
|
||||
_C.PREC_TIME.WARMUP_ITER = 3
|
||||
|
@ -309,9 +326,9 @@ _C.PREC_TIME.WARMUP_ITER = 3
|
|||
_C.PREC_TIME.NUM_ITER = 30
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Misc options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
|
||||
# Number of GPUs to use (applies to both training and testing)
|
||||
_C.NUM_GPUS = 1
|
||||
|
@ -335,45 +352,44 @@ _C.LOG_PERIOD = 10
|
|||
# Distributed backend
|
||||
_C.DIST_BACKEND = "nccl"
|
||||
|
||||
# Hostname and port for initializing multi-process groups
|
||||
# Hostname and port range for multi-process groups (actual port selected randomly)
|
||||
_C.HOST = "localhost"
|
||||
_C.PORT = 10001
|
||||
_C.PORT_RANGE = [10000, 65000]
|
||||
|
||||
# Models weights referred to by URL are downloaded to this local cache
|
||||
_C.DOWNLOAD_CACHE = "/tmp/pycls-download-cache"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
# Deprecated keys
|
||||
# ------------------------------------------------------------------------------------ #
|
||||
|
||||
_C.register_deprecated_key("PREC_TIME.BATCH_SIZE")
|
||||
_C.register_deprecated_key("PREC_TIME.ENABLED")
|
||||
_C.register_deprecated_key("PORT")
|
||||
|
||||
|
||||
def assert_and_infer_cfg(cache_urls=True):
|
||||
"""Checks config values invariants."""
|
||||
assert (
|
||||
not _C.OPTIM.STEPS or _C.OPTIM.STEPS[0] == 0
|
||||
), "The first lr step must start at 0"
|
||||
assert _C.TRAIN.SPLIT in [
|
||||
"train",
|
||||
"val",
|
||||
"test",
|
||||
], "Train split '{}' not supported".format(_C.TRAIN.SPLIT)
|
||||
assert (
|
||||
_C.TRAIN.BATCH_SIZE % _C.NUM_GPUS == 0
|
||||
), "Train mini-batch size should be a multiple of NUM_GPUS."
|
||||
assert _C.TEST.SPLIT in [
|
||||
"train",
|
||||
"val",
|
||||
"test",
|
||||
], "Test split '{}' not supported".format(_C.TEST.SPLIT)
|
||||
assert (
|
||||
_C.TEST.BATCH_SIZE % _C.NUM_GPUS == 0
|
||||
), "Test mini-batch size should be a multiple of NUM_GPUS."
|
||||
assert (
|
||||
not _C.BN.USE_PRECISE_STATS or _C.NUM_GPUS == 1
|
||||
), "Precise BN stats computation not verified for > 1 GPU"
|
||||
assert _C.LOG_DEST in [
|
||||
"stdout",
|
||||
"file",
|
||||
], "Log destination '{}' not supported".format(_C.LOG_DEST)
|
||||
assert (
|
||||
not _C.PREC_TIME.ENABLED or _C.NUM_GPUS == 1
|
||||
), "Precise iter time computation not verified for > 1 GPU"
|
||||
err_str = "The first lr step must start at 0"
|
||||
assert not _C.OPTIM.STEPS or _C.OPTIM.STEPS[0] == 0, err_str
|
||||
data_splits = ["train", "val", "test"]
|
||||
err_str = "Data split '{}' not supported"
|
||||
assert _C.TRAIN.SPLIT in data_splits, err_str.format(_C.TRAIN.SPLIT)
|
||||
assert _C.TEST.SPLIT in data_splits, err_str.format(_C.TEST.SPLIT)
|
||||
err_str = "Mini-batch size should be a multiple of NUM_GPUS."
|
||||
assert _C.TRAIN.BATCH_SIZE % _C.NUM_GPUS == 0, err_str
|
||||
assert _C.TEST.BATCH_SIZE % _C.NUM_GPUS == 0, err_str
|
||||
err_str = "Log destination '{}' not supported"
|
||||
assert _C.LOG_DEST in ["stdout", "file"], err_str.format(_C.LOG_DEST)
|
||||
if cache_urls:
|
||||
cache_cfg_urls()
|
||||
|
||||
|
||||
def cache_cfg_urls():
|
||||
"""Download URLs in config, cache them, and rewrite cfg to use cached file."""
|
||||
_C.TRAIN.WEIGHTS = cache_url(_C.TRAIN.WEIGHTS, _C.DOWNLOAD_CACHE)
|
||||
_C.TEST.WEIGHTS = cache_url(_C.TEST.WEIGHTS, _C.DOWNLOAD_CACHE)
|
||||
|
||||
|
||||
def dump_cfg():
|
||||
|
@ -386,4 +402,19 @@ def dump_cfg():
|
|||
def load_cfg(out_dir, cfg_dest="config.yaml"):
|
||||
"""Loads config from specified output directory."""
|
||||
cfg_file = os.path.join(out_dir, cfg_dest)
|
||||
_C.merge_from_file(cfg_file)
|
||||
_C.merge_from_file(cfg_file)
|
||||
|
||||
|
||||
def load_cfg_fom_args(description="Config file options."):
|
||||
"""Load config from command line arguments and set any specified options."""
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
help_s = "Config file location"
|
||||
parser.add_argument("--cfg", dest="cfg_file", help=help_s, required=True, type=str)
|
||||
help_s = "See pycls/core/config.py for all options"
|
||||
parser.add_argument("opts", help=help_s, default=None, nargs=argparse.REMAINDER)
|
||||
if len(sys.argv) == 1:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
args = parser.parse_args()
|
||||
_C.merge_from_file(args.cfg_file)
|
||||
_C.merge_from_list(args.opts)
|
|
@ -0,0 +1,281 @@
|
|||
# !/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""EfficientNet models."""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from fastreid.layers import *
|
||||
from fastreid.modeling.backbones.build import BACKBONE_REGISTRY
|
||||
from fastreid.utils import comm
|
||||
from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
|
||||
from .config import cfg as effnet_cfg
|
||||
from .regnet import drop_connect, init_weights
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
model_urls = {
|
||||
'b0': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161305613/EN-B0_dds_8gpu.pyth',
|
||||
'b1': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B1_dds_8gpu.pyth',
|
||||
'b2': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B2_dds_8gpu.pyth',
|
||||
'b3': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B3_dds_8gpu.pyth',
|
||||
'b4': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161305098/EN-B4_dds_8gpu.pyth',
|
||||
'b5': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B5_dds_8gpu.pyth',
|
||||
'b6': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B6_dds_8gpu.pyth',
|
||||
'b7': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B7_dds_8gpu.pyth',
|
||||
}
|
||||
|
||||
|
||||
class EffHead(nn.Module):
|
||||
"""EfficientNet head: 1x1, BN, Swish, AvgPool, Dropout, FC."""
|
||||
|
||||
def __init__(self, w_in, w_out, bn_norm):
|
||||
super(EffHead, self).__init__()
|
||||
self.conv = nn.Conv2d(w_in, w_out, 1, stride=1, padding=0, bias=False)
|
||||
self.conv_bn = get_norm(bn_norm, w_out)
|
||||
self.conv_swish = Swish()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_swish(self.conv_bn(self.conv(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
"""Swish activation function: x * sigmoid(x)."""
|
||||
|
||||
def __init__(self):
|
||||
super(Swish, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class SE(nn.Module):
|
||||
"""Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid."""
|
||||
|
||||
def __init__(self, w_in, w_se):
|
||||
super(SE, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.f_ex = nn.Sequential(
|
||||
nn.Conv2d(w_in, w_se, 1, bias=True),
|
||||
Swish(),
|
||||
nn.Conv2d(w_se, w_in, 1, bias=True),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.f_ex(self.avg_pool(x))
|
||||
|
||||
|
||||
class MBConv(nn.Module):
|
||||
"""Mobile inverted bottleneck block w/ SE (MBConv)."""
|
||||
|
||||
def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, bn_norm):
|
||||
# expansion, 3x3 dwise, BN, Swish, SE, 1x1, BN, skip_connection
|
||||
super(MBConv, self).__init__()
|
||||
self.exp = None
|
||||
w_exp = int(w_in * exp_r)
|
||||
if w_exp != w_in:
|
||||
self.exp = nn.Conv2d(w_in, w_exp, 1, stride=1, padding=0, bias=False)
|
||||
self.exp_bn = get_norm(bn_norm, w_exp)
|
||||
self.exp_swish = Swish()
|
||||
dwise_args = {"groups": w_exp, "padding": (kernel - 1) // 2, "bias": False}
|
||||
self.dwise = nn.Conv2d(w_exp, w_exp, kernel, stride=stride, **dwise_args)
|
||||
self.dwise_bn = get_norm(bn_norm, w_exp)
|
||||
self.dwise_swish = Swish()
|
||||
self.se = SE(w_exp, int(w_in * se_r))
|
||||
self.lin_proj = nn.Conv2d(w_exp, w_out, 1, stride=1, padding=0, bias=False)
|
||||
self.lin_proj_bn = get_norm(bn_norm, w_out)
|
||||
# Skip connection if in and out shapes are the same (MN-V2 style)
|
||||
self.has_skip = stride == 1 and w_in == w_out
|
||||
|
||||
def forward(self, x):
|
||||
f_x = x
|
||||
if self.exp:
|
||||
f_x = self.exp_swish(self.exp_bn(self.exp(f_x)))
|
||||
f_x = self.dwise_swish(self.dwise_bn(self.dwise(f_x)))
|
||||
f_x = self.se(f_x)
|
||||
f_x = self.lin_proj_bn(self.lin_proj(f_x))
|
||||
if self.has_skip:
|
||||
if self.training and effnet_cfg.EN.DC_RATIO > 0.0:
|
||||
f_x = drop_connect(f_x, effnet_cfg.EN.DC_RATIO)
|
||||
f_x = x + f_x
|
||||
return f_x
|
||||
|
||||
|
||||
class EffStage(nn.Module):
|
||||
"""EfficientNet stage."""
|
||||
|
||||
def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, d, bn_norm):
|
||||
super(EffStage, self).__init__()
|
||||
for i in range(d):
|
||||
b_stride = stride if i == 0 else 1
|
||||
b_w_in = w_in if i == 0 else w_out
|
||||
name = "b{}".format(i + 1)
|
||||
self.add_module(name, MBConv(b_w_in, exp_r, kernel, b_stride, se_r, w_out, bn_norm))
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.children():
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
|
||||
class StemIN(nn.Module):
|
||||
"""EfficientNet stem for ImageNet: 3x3, BN, Swish."""
|
||||
|
||||
def __init__(self, w_in, w_out, bn_norm):
|
||||
super(StemIN, self).__init__()
|
||||
self.conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=False)
|
||||
self.bn = get_norm(bn_norm, w_out)
|
||||
self.swish = Swish()
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.children():
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class EffNet(nn.Module):
|
||||
"""EfficientNet model."""
|
||||
|
||||
@staticmethod
|
||||
def get_args():
|
||||
return {
|
||||
"stem_w": effnet_cfg.EN.STEM_W,
|
||||
"ds": effnet_cfg.EN.DEPTHS,
|
||||
"ws": effnet_cfg.EN.WIDTHS,
|
||||
"exp_rs": effnet_cfg.EN.EXP_RATIOS,
|
||||
"se_r": effnet_cfg.EN.SE_R,
|
||||
"ss": effnet_cfg.EN.STRIDES,
|
||||
"ks": effnet_cfg.EN.KERNELS,
|
||||
"head_w": effnet_cfg.EN.HEAD_W,
|
||||
}
|
||||
|
||||
def __init__(self, last_stride, bn_norm, **kwargs):
|
||||
super(EffNet, self).__init__()
|
||||
kwargs = self.get_args() if not kwargs else kwargs
|
||||
self._construct(**kwargs, last_stride=last_stride, bn_norm=bn_norm)
|
||||
self.apply(init_weights)
|
||||
|
||||
def _construct(self, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, last_stride, bn_norm):
|
||||
stage_params = list(zip(ds, ws, exp_rs, ss, ks))
|
||||
self.stem = StemIN(3, stem_w, bn_norm)
|
||||
prev_w = stem_w
|
||||
for i, (d, w, exp_r, stride, kernel) in enumerate(stage_params):
|
||||
name = "s{}".format(i + 1)
|
||||
if i == 5: stride = last_stride
|
||||
self.add_module(name, EffStage(prev_w, exp_r, kernel, stride, se_r, w, d, bn_norm))
|
||||
prev_w = w
|
||||
self.head = EffHead(prev_w, head_w, bn_norm)
|
||||
|
||||
def forward(self, x):
|
||||
for module in self.children():
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
|
||||
def init_pretrained_weights(key):
|
||||
"""Initializes model with pretrained weights.
|
||||
|
||||
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
||||
"""
|
||||
import os
|
||||
import errno
|
||||
import gdown
|
||||
|
||||
def _get_torch_home():
|
||||
ENV_TORCH_HOME = 'TORCH_HOME'
|
||||
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
||||
DEFAULT_CACHE_DIR = '~/.cache'
|
||||
torch_home = os.path.expanduser(
|
||||
os.getenv(
|
||||
ENV_TORCH_HOME,
|
||||
os.path.join(
|
||||
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
|
||||
)
|
||||
)
|
||||
)
|
||||
return torch_home
|
||||
|
||||
torch_home = _get_torch_home()
|
||||
model_dir = os.path.join(torch_home, 'checkpoints')
|
||||
try:
|
||||
os.makedirs(model_dir)
|
||||
except OSError as e:
|
||||
if e.errno == errno.EEXIST:
|
||||
# Directory already exists, ignore.
|
||||
pass
|
||||
else:
|
||||
# Unexpected OSError, re-raise.
|
||||
raise
|
||||
|
||||
filename = model_urls[key].split('/')[-1]
|
||||
|
||||
cached_file = os.path.join(model_dir, filename)
|
||||
|
||||
if not os.path.exists(cached_file):
|
||||
if comm.is_main_process():
|
||||
gdown.download(model_urls[key], cached_file, quiet=False)
|
||||
|
||||
comm.synchronize()
|
||||
|
||||
logger.info(f"Loading pretrained model from {cached_file}")
|
||||
state_dict = torch.load(cached_file, map_location=torch.device('cpu'))['model_state']
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def build_effnet_backbone(cfg):
|
||||
# fmt: off
|
||||
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
|
||||
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
|
||||
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
|
||||
bn_norm = cfg.MODEL.BACKBONE.NORM
|
||||
depth = cfg.MODEL.BACKBONE.DEPTH
|
||||
# fmt: on
|
||||
|
||||
cfg_files = {
|
||||
'b0': 'fastreid/modeling/backbones/regnet/effnet/EN-B0_dds_8gpu.yaml',
|
||||
'b1': 'fastreid/modeling/backbones/regnet/effnet/EN-B1_dds_8gpu.yaml',
|
||||
'b2': 'fastreid/modeling/backbones/regnet/effnet/EN-B2_dds_8gpu.yaml',
|
||||
'b3': 'fastreid/modeling/backbones/regnet/effnet/EN-B3_dds_8gpu.yaml',
|
||||
'b4': 'fastreid/modeling/backbones/regnet/effnet/EN-B4_dds_8gpu.yaml',
|
||||
'b5': 'fastreid/modeling/backbones/regnet/effnet/EN-B5_dds_8gpu.yaml',
|
||||
}[depth]
|
||||
|
||||
effnet_cfg.merge_from_file(cfg_files)
|
||||
model = EffNet(last_stride, bn_norm)
|
||||
|
||||
if pretrain:
|
||||
# Load pretrain path if specifically
|
||||
if pretrain_path:
|
||||
try:
|
||||
state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))
|
||||
logger.info(f"Loading pretrained model from {pretrain_path}")
|
||||
except FileNotFoundError as e:
|
||||
logger.info(f'{pretrain_path} is not found! Please check this path.')
|
||||
raise e
|
||||
except KeyError as e:
|
||||
logger.info("State dict keys error! Please check the state dict.")
|
||||
raise e
|
||||
else:
|
||||
key = depth
|
||||
state_dict = init_pretrained_weights(key)
|
||||
|
||||
incompatible = model.load_state_dict(state_dict, strict=False)
|
||||
if incompatible.missing_keys:
|
||||
logger.info(
|
||||
get_missing_parameters_message(incompatible.missing_keys)
|
||||
)
|
||||
if incompatible.unexpected_keys:
|
||||
logger.info(
|
||||
get_unexpected_parameters_message(incompatible.unexpected_keys)
|
||||
)
|
||||
return model
|
|
@ -0,0 +1,27 @@
|
|||
MODEL:
|
||||
TYPE: effnet
|
||||
NUM_CLASSES: 1000
|
||||
EN:
|
||||
STEM_W: 32
|
||||
STRIDES: [1, 2, 2, 2, 1, 2, 1]
|
||||
DEPTHS: [1, 2, 2, 3, 3, 4, 1]
|
||||
WIDTHS: [16, 24, 40, 80, 112, 192, 320]
|
||||
EXP_RATIOS: [1, 6, 6, 6, 6, 6, 6]
|
||||
KERNELS: [3, 3, 5, 3, 5, 5, 3]
|
||||
HEAD_W: 1280
|
||||
OPTIM:
|
||||
LR_POLICY: cos
|
||||
BASE_LR: 0.4
|
||||
MAX_EPOCH: 100
|
||||
MOMENTUM: 0.9
|
||||
WEIGHT_DECAY: 1e-5
|
||||
TRAIN:
|
||||
DATASET: imagenet
|
||||
IM_SIZE: 224
|
||||
BATCH_SIZE: 256
|
||||
TEST:
|
||||
DATASET: imagenet
|
||||
IM_SIZE: 256
|
||||
BATCH_SIZE: 200
|
||||
NUM_GPUS: 8
|
||||
OUT_DIR: .
|
|
@ -0,0 +1,27 @@
|
|||
MODEL:
|
||||
TYPE: effnet
|
||||
NUM_CLASSES: 1000
|
||||
EN:
|
||||
STEM_W: 32
|
||||
STRIDES: [1, 2, 2, 2, 1, 2, 1]
|
||||
DEPTHS: [2, 3, 3, 4, 4, 5, 2]
|
||||
WIDTHS: [16, 24, 40, 80, 112, 192, 320]
|
||||
EXP_RATIOS: [1, 6, 6, 6, 6, 6, 6]
|
||||
KERNELS: [3, 3, 5, 3, 5, 5, 3]
|
||||
HEAD_W: 1280
|
||||
OPTIM:
|
||||
LR_POLICY: cos
|
||||
BASE_LR: 0.4
|
||||
MAX_EPOCH: 100
|
||||
MOMENTUM: 0.9
|
||||
WEIGHT_DECAY: 1e-5
|
||||
TRAIN:
|
||||
DATASET: imagenet
|
||||
IM_SIZE: 240
|
||||
BATCH_SIZE: 256
|
||||
TEST:
|
||||
DATASET: imagenet
|
||||
IM_SIZE: 274
|
||||
BATCH_SIZE: 200
|
||||
NUM_GPUS: 8
|
||||
OUT_DIR: .
|
|
@ -0,0 +1,27 @@
|
|||
MODEL:
|
||||
TYPE: effnet
|
||||
NUM_CLASSES: 1000
|
||||
EN:
|
||||
STEM_W: 32
|
||||
STRIDES: [1, 2, 2, 2, 1, 2, 1]
|
||||
DEPTHS: [2, 3, 3, 4, 4, 5, 2]
|
||||
WIDTHS: [16, 24, 48, 88, 120, 208, 352]
|
||||
EXP_RATIOS: [1, 6, 6, 6, 6, 6, 6]
|
||||
KERNELS: [3, 3, 5, 3, 5, 5, 3]
|
||||
HEAD_W: 1408
|
||||
OPTIM:
|
||||
LR_POLICY: cos
|
||||
BASE_LR: 0.4
|
||||
MAX_EPOCH: 100
|
||||
MOMENTUM: 0.9
|
||||
WEIGHT_DECAY: 1e-5
|
||||
TRAIN:
|
||||
DATASET: imagenet
|
||||
IM_SIZE: 260
|
||||
BATCH_SIZE: 256
|
||||
TEST:
|
||||
DATASET: imagenet
|
||||
IM_SIZE: 298
|
||||
BATCH_SIZE: 200
|
||||
NUM_GPUS: 8
|
||||
OUT_DIR: .
|
|
@ -0,0 +1,27 @@
|
|||
MODEL:
|
||||
TYPE: effnet
|
||||
NUM_CLASSES: 1000
|
||||
EN:
|
||||
STEM_W: 40
|
||||
STRIDES: [1, 2, 2, 2, 1, 2, 1]
|
||||
DEPTHS: [2, 3, 3, 5, 5, 6, 2]
|
||||
WIDTHS: [24, 32, 48, 96, 136, 232, 384]
|
||||
EXP_RATIOS: [1, 6, 6, 6, 6, 6, 6]
|
||||
KERNELS: [3, 3, 5, 3, 5, 5, 3]
|
||||
HEAD_W: 1536
|
||||
OPTIM:
|
||||
LR_POLICY: cos
|
||||
BASE_LR: 0.4
|
||||
MAX_EPOCH: 100
|
||||
MOMENTUM: 0.9
|
||||
WEIGHT_DECAY: 1e-5
|
||||
TRAIN:
|
||||
DATASET: imagenet
|
||||
IM_SIZE: 300
|
||||
BATCH_SIZE: 256
|
||||
TEST:
|
||||
DATASET: imagenet
|
||||
IM_SIZE: 342
|
||||
BATCH_SIZE: 200
|
||||
NUM_GPUS: 8
|
||||
OUT_DIR: .
|
|
@ -0,0 +1,27 @@
|
|||
MODEL:
|
||||
TYPE: effnet
|
||||
NUM_CLASSES: 1000
|
||||
EN:
|
||||
STEM_W: 48
|
||||
STRIDES: [1, 2, 2, 2, 1, 2, 1]
|
||||
DEPTHS: [2, 4, 4, 6, 6, 8, 2]
|
||||
WIDTHS: [24, 32, 56, 112, 160, 272, 448]
|
||||
EXP_RATIOS: [1, 6, 6, 6, 6, 6, 6]
|
||||
KERNELS: [3, 3, 5, 3, 5, 5, 3]
|
||||
HEAD_W: 1792
|
||||
OPTIM:
|
||||
LR_POLICY: cos
|
||||
BASE_LR: 0.2
|
||||
MAX_EPOCH: 100
|
||||
MOMENTUM: 0.9
|
||||
WEIGHT_DECAY: 1e-5
|
||||
TRAIN:
|
||||
DATASET: imagenet
|
||||
IM_SIZE: 380
|
||||
BATCH_SIZE: 128
|
||||
TEST:
|
||||
DATASET: imagenet
|
||||
IM_SIZE: 434
|
||||
BATCH_SIZE: 104
|
||||
NUM_GPUS: 8
|
||||
OUT_DIR: .
|
|
@ -0,0 +1,27 @@
|
|||
MODEL:
|
||||
TYPE: effnet
|
||||
NUM_CLASSES: 1000
|
||||
EN:
|
||||
STEM_W: 48
|
||||
STRIDES: [1, 2, 2, 2, 1, 2, 1]
|
||||
DEPTHS: [3, 5, 5, 7, 7, 9, 3]
|
||||
WIDTHS: [24, 40, 64, 128, 176, 304, 512]
|
||||
EXP_RATIOS: [1, 6, 6, 6, 6, 6, 6]
|
||||
KERNELS: [3, 3, 5, 3, 5, 5, 3]
|
||||
HEAD_W: 2048
|
||||
OPTIM:
|
||||
LR_POLICY: cos
|
||||
BASE_LR: 0.1
|
||||
MAX_EPOCH: 100
|
||||
MOMENTUM: 0.9
|
||||
WEIGHT_DECAY: 1e-5
|
||||
TRAIN:
|
||||
DATASET: imagenet
|
||||
IM_SIZE: 456
|
||||
BATCH_SIZE: 64
|
||||
TEST:
|
||||
DATASET: imagenet
|
||||
IM_SIZE: 522
|
||||
BATCH_SIZE: 48
|
||||
NUM_GPUS: 8
|
||||
OUT_DIR: .
|
|
@ -1,14 +1,15 @@
|
|||
import torch
|
||||
import os
|
||||
import logging
|
||||
import math
|
||||
import torch.nn as nn
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from fastreid.layers import get_norm
|
||||
from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
|
||||
from fastreid.utils import comm
|
||||
from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
|
||||
from .config import cfg as regnet_cfg
|
||||
from ..build import BACKBONE_REGISTRY
|
||||
from .config import regnet_cfg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
model_urls = {
|
||||
|
@ -24,6 +25,7 @@ model_urls = {
|
|||
'6400y': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160907112/RegNetY-6.4GF_dds_8gpu.pyth',
|
||||
}
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
"""Performs ResNet-style weight initialization."""
|
||||
if isinstance(m, nn.Conv2d):
|
||||
|
@ -67,6 +69,15 @@ def get_block_fun(block_type):
|
|||
return block_funs[block_type]
|
||||
|
||||
|
||||
def drop_connect(x, drop_ratio):
|
||||
"""Drop connect (adapted from DARTS)."""
|
||||
keep_ratio = 1.0 - drop_ratio
|
||||
mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
|
||||
mask.bernoulli_(keep_ratio)
|
||||
x.div_(keep_ratio)
|
||||
x.mul_(mask)
|
||||
return x
|
||||
|
||||
class AnyHead(nn.Module):
|
||||
"""AnyNet head."""
|
||||
|
||||
|
@ -117,7 +128,7 @@ class BasicTransform(nn.Module):
|
|||
super(BasicTransform, self).__init__()
|
||||
self.construct(w_in, w_out, stride, bn_norm)
|
||||
|
||||
def construct(self, w_in, w_out, stride, bn_norm, num_split):
|
||||
def construct(self, w_in, w_out, stride, bn_norm):
|
||||
# 3x3, BN, ReLU
|
||||
self.a = nn.Conv2d(
|
||||
w_in, w_out, kernel_size=3, stride=stride, padding=1, bias=False
|
||||
|
@ -269,7 +280,7 @@ class ResStemCifar(nn.Module):
|
|||
self.conv = nn.Conv2d(
|
||||
w_in, w_out, kernel_size=3, stride=1, padding=1, bias=False
|
||||
)
|
||||
self.bn = get_norm(bn_norm, w_out, 1)
|
||||
self.bn = get_norm(bn_norm, w_out)
|
||||
self.relu = nn.ReLU(regnet_cfg.MEM.RELU_INPLACE)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -530,11 +541,12 @@ def init_pretrained_weights(key):
|
|||
@BACKBONE_REGISTRY.register()
|
||||
def build_regnet_backbone(cfg):
|
||||
# fmt: off
|
||||
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
|
||||
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
|
||||
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
|
||||
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
|
||||
bn_norm = cfg.MODEL.BACKBONE.NORM
|
||||
depth = cfg.MODEL.BACKBONE.DEPTH
|
||||
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
|
||||
bn_norm = cfg.MODEL.BACKBONE.NORM
|
||||
depth = cfg.MODEL.BACKBONE.DEPTH
|
||||
# fmt: on
|
||||
|
||||
cfg_files = {
|
||||
'800x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml',
|
||||
|
@ -549,7 +561,7 @@ def build_regnet_backbone(cfg):
|
|||
'6400y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml',
|
||||
}[depth]
|
||||
|
||||
regnet_cfg.merge_from_file(cfg_files)
|
||||
cfg.merge_from_file(cfg_files)
|
||||
model = RegNet(last_stride, bn_norm)
|
||||
|
||||
if pretrain:
|
||||
|
|
Loading…
Reference in New Issue