mirror of https://github.com/JDAI-CV/fast-reid.git
add efficientnet support
parent
296fbea989
commit
648198e6e5
|
@ -10,4 +10,4 @@ from .resnet import build_resnet_backbone
|
||||||
from .osnet import build_osnet_backbone
|
from .osnet import build_osnet_backbone
|
||||||
from .resnest import build_resnest_backbone
|
from .resnest import build_resnest_backbone
|
||||||
from .resnext import build_resnext_backbone
|
from .resnext import build_resnext_backbone
|
||||||
from .regnet import build_regnet_backbone
|
from .regnet import build_regnet_backbone, build_effnet_backbone
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
|
||||||
|
|
||||||
from .regnet import build_regnet_backbone
|
from .regnet import build_regnet_backbone
|
||||||
|
from .effnet import build_effnet_backbone
|
||||||
|
|
|
@ -1,19 +1,31 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
"""Configuration file (powered by YACS)."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
import os
|
import os
|
||||||
from yacs.config import CfgNode as CN
|
import sys
|
||||||
|
|
||||||
|
from yacs.config import CfgNode as CfgNode
|
||||||
|
|
||||||
|
|
||||||
# Global config object
|
# Global config object
|
||||||
_C = CN()
|
_C = CfgNode()
|
||||||
|
|
||||||
# Example usage:
|
# Example usage:
|
||||||
# from core.config import cfg
|
# from core.config import cfg
|
||||||
regnet_cfg = _C
|
cfg = _C
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
# Model options
|
# Model options
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
_C.MODEL = CN()
|
_C.MODEL = CfgNode()
|
||||||
|
|
||||||
# Model type
|
# Model type
|
||||||
_C.MODEL.TYPE = ""
|
_C.MODEL.TYPE = ""
|
||||||
|
@ -28,10 +40,10 @@ _C.MODEL.NUM_CLASSES = 10
|
||||||
_C.MODEL.LOSS_FUN = "cross_entropy"
|
_C.MODEL.LOSS_FUN = "cross_entropy"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
# ResNet options
|
# ResNet options
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
_C.RESNET = CN()
|
_C.RESNET = CfgNode()
|
||||||
|
|
||||||
# Transformation function (see pycls/models/resnet.py for options)
|
# Transformation function (see pycls/models/resnet.py for options)
|
||||||
_C.RESNET.TRANS_FUN = "basic_transform"
|
_C.RESNET.TRANS_FUN = "basic_transform"
|
||||||
|
@ -46,19 +58,19 @@ _C.RESNET.WIDTH_PER_GROUP = 64
|
||||||
_C.RESNET.STRIDE_1X1 = True
|
_C.RESNET.STRIDE_1X1 = True
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
# AnyNet options
|
# AnyNet options
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
_C.ANYNET = CN()
|
_C.ANYNET = CfgNode()
|
||||||
|
|
||||||
# Stem type
|
# Stem type
|
||||||
_C.ANYNET.STEM_TYPE = "plain_block"
|
_C.ANYNET.STEM_TYPE = "simple_stem_in"
|
||||||
|
|
||||||
# Stem width
|
# Stem width
|
||||||
_C.ANYNET.STEM_W = 32
|
_C.ANYNET.STEM_W = 32
|
||||||
|
|
||||||
# Block type
|
# Block type
|
||||||
_C.ANYNET.BLOCK_TYPE = "plain_block"
|
_C.ANYNET.BLOCK_TYPE = "res_bottleneck_block"
|
||||||
|
|
||||||
# Depth for each stage (number of blocks in the stage)
|
# Depth for each stage (number of blocks in the stage)
|
||||||
_C.ANYNET.DEPTHS = []
|
_C.ANYNET.DEPTHS = []
|
||||||
|
@ -81,41 +93,51 @@ _C.ANYNET.SE_ON = False
|
||||||
# SE ratio
|
# SE ratio
|
||||||
_C.ANYNET.SE_R = 0.25
|
_C.ANYNET.SE_R = 0.25
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------- #
|
|
||||||
|
# ------------------------------------------------------------------------------------ #
|
||||||
# RegNet options
|
# RegNet options
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
_C.REGNET = CN()
|
_C.REGNET = CfgNode()
|
||||||
|
|
||||||
# Stem type
|
# Stem type
|
||||||
_C.REGNET.STEM_TYPE = "simple_stem_in"
|
_C.REGNET.STEM_TYPE = "simple_stem_in"
|
||||||
|
|
||||||
# Stem width
|
# Stem width
|
||||||
_C.REGNET.STEM_W = 32
|
_C.REGNET.STEM_W = 32
|
||||||
|
|
||||||
# Block type
|
# Block type
|
||||||
_C.REGNET.BLOCK_TYPE = "res_bottleneck_block"
|
_C.REGNET.BLOCK_TYPE = "res_bottleneck_block"
|
||||||
|
|
||||||
# Stride of each stage
|
# Stride of each stage
|
||||||
_C.REGNET.STRIDE = 2
|
_C.REGNET.STRIDE = 2
|
||||||
|
|
||||||
# Squeeze-and-Excitation (RegNetY)
|
# Squeeze-and-Excitation (RegNetY)
|
||||||
_C.REGNET.SE_ON = False
|
_C.REGNET.SE_ON = False
|
||||||
_C.REGNET.SE_R = 0.25
|
_C.REGNET.SE_R = 0.25
|
||||||
|
|
||||||
# Depth
|
# Depth
|
||||||
_C.REGNET.DEPTH = 10
|
_C.REGNET.DEPTH = 10
|
||||||
|
|
||||||
# Initial width
|
# Initial width
|
||||||
_C.REGNET.W0 = 32
|
_C.REGNET.W0 = 32
|
||||||
|
|
||||||
# Slope
|
# Slope
|
||||||
_C.REGNET.WA = 5.0
|
_C.REGNET.WA = 5.0
|
||||||
|
|
||||||
# Quantization
|
# Quantization
|
||||||
_C.REGNET.WM = 2.5
|
_C.REGNET.WM = 2.5
|
||||||
|
|
||||||
# Group width
|
# Group width
|
||||||
_C.REGNET.GROUP_W = 16
|
_C.REGNET.GROUP_W = 16
|
||||||
|
|
||||||
# Bottleneck multiplier (bm = 1 / b from the paper)
|
# Bottleneck multiplier (bm = 1 / b from the paper)
|
||||||
_C.REGNET.BOT_MUL = 1.0
|
_C.REGNET.BOT_MUL = 1.0
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
# EfficientNet options
|
# EfficientNet options
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
_C.EN = CN()
|
_C.EN = CfgNode()
|
||||||
|
|
||||||
# Stem width
|
# Stem width
|
||||||
_C.EN.STEM_W = 32
|
_C.EN.STEM_W = 32
|
||||||
|
@ -148,10 +170,10 @@ _C.EN.DC_RATIO = 0.0
|
||||||
_C.EN.DROPOUT_RATIO = 0.0
|
_C.EN.DROPOUT_RATIO = 0.0
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
# Batch norm options
|
# Batch norm options
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
_C.BN = CN()
|
_C.BN = CfgNode()
|
||||||
|
|
||||||
# BN epsilon
|
# BN epsilon
|
||||||
_C.BN.EPS = 1e-5
|
_C.BN.EPS = 1e-5
|
||||||
|
@ -160,8 +182,8 @@ _C.BN.EPS = 1e-5
|
||||||
_C.BN.MOM = 0.1
|
_C.BN.MOM = 0.1
|
||||||
|
|
||||||
# Precise BN stats
|
# Precise BN stats
|
||||||
_C.BN.USE_PRECISE_STATS = False
|
_C.BN.USE_PRECISE_STATS = True
|
||||||
_C.BN.NUM_SAMPLES_PRECISE = 1024
|
_C.BN.NUM_SAMPLES_PRECISE = 8192
|
||||||
|
|
||||||
# Initialize the gamma of the final BN of each block to zero
|
# Initialize the gamma of the final BN of each block to zero
|
||||||
_C.BN.ZERO_INIT_FINAL_GAMMA = False
|
_C.BN.ZERO_INIT_FINAL_GAMMA = False
|
||||||
|
@ -170,10 +192,11 @@ _C.BN.ZERO_INIT_FINAL_GAMMA = False
|
||||||
_C.BN.USE_CUSTOM_WEIGHT_DECAY = False
|
_C.BN.USE_CUSTOM_WEIGHT_DECAY = False
|
||||||
_C.BN.CUSTOM_WEIGHT_DECAY = 0.0
|
_C.BN.CUSTOM_WEIGHT_DECAY = 0.0
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------- #
|
|
||||||
|
# ------------------------------------------------------------------------------------ #
|
||||||
# Optimizer options
|
# Optimizer options
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
_C.OPTIM = CN()
|
_C.OPTIM = CfgNode()
|
||||||
|
|
||||||
# Base learning rate
|
# Base learning rate
|
||||||
_C.OPTIM.BASE_LR = 0.1
|
_C.OPTIM.BASE_LR = 0.1
|
||||||
|
@ -212,10 +235,10 @@ _C.OPTIM.WARMUP_FACTOR = 0.1
|
||||||
_C.OPTIM.WARMUP_EPOCHS = 0
|
_C.OPTIM.WARMUP_EPOCHS = 0
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
# Training options
|
# Training options
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
_C.TRAIN = CN()
|
_C.TRAIN = CfgNode()
|
||||||
|
|
||||||
# Dataset and split
|
# Dataset and split
|
||||||
_C.TRAIN.DATASET = ""
|
_C.TRAIN.DATASET = ""
|
||||||
|
@ -240,10 +263,10 @@ _C.TRAIN.AUTO_RESUME = True
|
||||||
_C.TRAIN.WEIGHTS = ""
|
_C.TRAIN.WEIGHTS = ""
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
# Testing options
|
# Testing options
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
_C.TEST = CN()
|
_C.TEST = CfgNode()
|
||||||
|
|
||||||
# Dataset and split
|
# Dataset and split
|
||||||
_C.TEST.DATASET = ""
|
_C.TEST.DATASET = ""
|
||||||
|
@ -259,31 +282,31 @@ _C.TEST.IM_SIZE = 256
|
||||||
_C.TEST.WEIGHTS = ""
|
_C.TEST.WEIGHTS = ""
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
# Common train/test data loader options
|
# Common train/test data loader options
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
_C.DATA_LOADER = CN()
|
_C.DATA_LOADER = CfgNode()
|
||||||
|
|
||||||
# Number of data loader workers per training process
|
# Number of data loader workers per process
|
||||||
_C.DATA_LOADER.NUM_WORKERS = 4
|
_C.DATA_LOADER.NUM_WORKERS = 8
|
||||||
|
|
||||||
# Load data to pinned host memory
|
# Load data to pinned host memory
|
||||||
_C.DATA_LOADER.PIN_MEMORY = True
|
_C.DATA_LOADER.PIN_MEMORY = True
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
# Memory options
|
# Memory options
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
_C.MEM = CN()
|
_C.MEM = CfgNode()
|
||||||
|
|
||||||
# Perform ReLU inplace
|
# Perform ReLU inplace
|
||||||
_C.MEM.RELU_INPLACE = True
|
_C.MEM.RELU_INPLACE = True
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
# CUDNN options
|
# CUDNN options
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
_C.CUDNN = CN()
|
_C.CUDNN = CfgNode()
|
||||||
|
|
||||||
# Perform benchmarking to select the fastest CUDNN algorithms to use
|
# Perform benchmarking to select the fastest CUDNN algorithms to use
|
||||||
# Note that this may increase the memory usage and will likely not result
|
# Note that this may increase the memory usage and will likely not result
|
||||||
|
@ -291,16 +314,10 @@ _C.CUDNN = CN()
|
||||||
_C.CUDNN.BENCHMARK = True
|
_C.CUDNN.BENCHMARK = True
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
# Precise timing options
|
# Precise timing options
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
_C.PREC_TIME = CN()
|
_C.PREC_TIME = CfgNode()
|
||||||
|
|
||||||
# Perform precise timing at the start of training
|
|
||||||
_C.PREC_TIME.ENABLED = False
|
|
||||||
|
|
||||||
# Total mini-batch size
|
|
||||||
_C.PREC_TIME.BATCH_SIZE = 128
|
|
||||||
|
|
||||||
# Number of iterations to warm up the caches
|
# Number of iterations to warm up the caches
|
||||||
_C.PREC_TIME.WARMUP_ITER = 3
|
_C.PREC_TIME.WARMUP_ITER = 3
|
||||||
|
@ -309,9 +326,9 @@ _C.PREC_TIME.WARMUP_ITER = 3
|
||||||
_C.PREC_TIME.NUM_ITER = 30
|
_C.PREC_TIME.NUM_ITER = 30
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
# Misc options
|
# Misc options
|
||||||
# ---------------------------------------------------------------------------- #
|
# ------------------------------------------------------------------------------------ #
|
||||||
|
|
||||||
# Number of GPUs to use (applies to both training and testing)
|
# Number of GPUs to use (applies to both training and testing)
|
||||||
_C.NUM_GPUS = 1
|
_C.NUM_GPUS = 1
|
||||||
|
@ -335,45 +352,44 @@ _C.LOG_PERIOD = 10
|
||||||
# Distributed backend
|
# Distributed backend
|
||||||
_C.DIST_BACKEND = "nccl"
|
_C.DIST_BACKEND = "nccl"
|
||||||
|
|
||||||
# Hostname and port for initializing multi-process groups
|
# Hostname and port range for multi-process groups (actual port selected randomly)
|
||||||
_C.HOST = "localhost"
|
_C.HOST = "localhost"
|
||||||
_C.PORT = 10001
|
_C.PORT_RANGE = [10000, 65000]
|
||||||
|
|
||||||
# Models weights referred to by URL are downloaded to this local cache
|
# Models weights referred to by URL are downloaded to this local cache
|
||||||
_C.DOWNLOAD_CACHE = "/tmp/pycls-download-cache"
|
_C.DOWNLOAD_CACHE = "/tmp/pycls-download-cache"
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------------ #
|
||||||
|
# Deprecated keys
|
||||||
|
# ------------------------------------------------------------------------------------ #
|
||||||
|
|
||||||
|
_C.register_deprecated_key("PREC_TIME.BATCH_SIZE")
|
||||||
|
_C.register_deprecated_key("PREC_TIME.ENABLED")
|
||||||
|
_C.register_deprecated_key("PORT")
|
||||||
|
|
||||||
|
|
||||||
def assert_and_infer_cfg(cache_urls=True):
|
def assert_and_infer_cfg(cache_urls=True):
|
||||||
"""Checks config values invariants."""
|
"""Checks config values invariants."""
|
||||||
assert (
|
err_str = "The first lr step must start at 0"
|
||||||
not _C.OPTIM.STEPS or _C.OPTIM.STEPS[0] == 0
|
assert not _C.OPTIM.STEPS or _C.OPTIM.STEPS[0] == 0, err_str
|
||||||
), "The first lr step must start at 0"
|
data_splits = ["train", "val", "test"]
|
||||||
assert _C.TRAIN.SPLIT in [
|
err_str = "Data split '{}' not supported"
|
||||||
"train",
|
assert _C.TRAIN.SPLIT in data_splits, err_str.format(_C.TRAIN.SPLIT)
|
||||||
"val",
|
assert _C.TEST.SPLIT in data_splits, err_str.format(_C.TEST.SPLIT)
|
||||||
"test",
|
err_str = "Mini-batch size should be a multiple of NUM_GPUS."
|
||||||
], "Train split '{}' not supported".format(_C.TRAIN.SPLIT)
|
assert _C.TRAIN.BATCH_SIZE % _C.NUM_GPUS == 0, err_str
|
||||||
assert (
|
assert _C.TEST.BATCH_SIZE % _C.NUM_GPUS == 0, err_str
|
||||||
_C.TRAIN.BATCH_SIZE % _C.NUM_GPUS == 0
|
err_str = "Log destination '{}' not supported"
|
||||||
), "Train mini-batch size should be a multiple of NUM_GPUS."
|
assert _C.LOG_DEST in ["stdout", "file"], err_str.format(_C.LOG_DEST)
|
||||||
assert _C.TEST.SPLIT in [
|
if cache_urls:
|
||||||
"train",
|
cache_cfg_urls()
|
||||||
"val",
|
|
||||||
"test",
|
|
||||||
], "Test split '{}' not supported".format(_C.TEST.SPLIT)
|
def cache_cfg_urls():
|
||||||
assert (
|
"""Download URLs in config, cache them, and rewrite cfg to use cached file."""
|
||||||
_C.TEST.BATCH_SIZE % _C.NUM_GPUS == 0
|
_C.TRAIN.WEIGHTS = cache_url(_C.TRAIN.WEIGHTS, _C.DOWNLOAD_CACHE)
|
||||||
), "Test mini-batch size should be a multiple of NUM_GPUS."
|
_C.TEST.WEIGHTS = cache_url(_C.TEST.WEIGHTS, _C.DOWNLOAD_CACHE)
|
||||||
assert (
|
|
||||||
not _C.BN.USE_PRECISE_STATS or _C.NUM_GPUS == 1
|
|
||||||
), "Precise BN stats computation not verified for > 1 GPU"
|
|
||||||
assert _C.LOG_DEST in [
|
|
||||||
"stdout",
|
|
||||||
"file",
|
|
||||||
], "Log destination '{}' not supported".format(_C.LOG_DEST)
|
|
||||||
assert (
|
|
||||||
not _C.PREC_TIME.ENABLED or _C.NUM_GPUS == 1
|
|
||||||
), "Precise iter time computation not verified for > 1 GPU"
|
|
||||||
|
|
||||||
|
|
||||||
def dump_cfg():
|
def dump_cfg():
|
||||||
|
@ -387,3 +403,18 @@ def load_cfg(out_dir, cfg_dest="config.yaml"):
|
||||||
"""Loads config from specified output directory."""
|
"""Loads config from specified output directory."""
|
||||||
cfg_file = os.path.join(out_dir, cfg_dest)
|
cfg_file = os.path.join(out_dir, cfg_dest)
|
||||||
_C.merge_from_file(cfg_file)
|
_C.merge_from_file(cfg_file)
|
||||||
|
|
||||||
|
|
||||||
|
def load_cfg_fom_args(description="Config file options."):
|
||||||
|
"""Load config from command line arguments and set any specified options."""
|
||||||
|
parser = argparse.ArgumentParser(description=description)
|
||||||
|
help_s = "Config file location"
|
||||||
|
parser.add_argument("--cfg", dest="cfg_file", help=help_s, required=True, type=str)
|
||||||
|
help_s = "See pycls/core/config.py for all options"
|
||||||
|
parser.add_argument("opts", help=help_s, default=None, nargs=argparse.REMAINDER)
|
||||||
|
if len(sys.argv) == 1:
|
||||||
|
parser.print_help()
|
||||||
|
sys.exit(1)
|
||||||
|
args = parser.parse_args()
|
||||||
|
_C.merge_from_file(args.cfg_file)
|
||||||
|
_C.merge_from_list(args.opts)
|
|
@ -0,0 +1,281 @@
|
||||||
|
# !/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
"""EfficientNet models."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from fastreid.layers import *
|
||||||
|
from fastreid.modeling.backbones.build import BACKBONE_REGISTRY
|
||||||
|
from fastreid.utils import comm
|
||||||
|
from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
|
||||||
|
from .config import cfg as effnet_cfg
|
||||||
|
from .regnet import drop_connect, init_weights
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
model_urls = {
|
||||||
|
'b0': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161305613/EN-B0_dds_8gpu.pyth',
|
||||||
|
'b1': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B1_dds_8gpu.pyth',
|
||||||
|
'b2': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B2_dds_8gpu.pyth',
|
||||||
|
'b3': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B3_dds_8gpu.pyth',
|
||||||
|
'b4': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161305098/EN-B4_dds_8gpu.pyth',
|
||||||
|
'b5': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B5_dds_8gpu.pyth',
|
||||||
|
'b6': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B6_dds_8gpu.pyth',
|
||||||
|
'b7': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161304979/EN-B7_dds_8gpu.pyth',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class EffHead(nn.Module):
|
||||||
|
"""EfficientNet head: 1x1, BN, Swish, AvgPool, Dropout, FC."""
|
||||||
|
|
||||||
|
def __init__(self, w_in, w_out, bn_norm):
|
||||||
|
super(EffHead, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(w_in, w_out, 1, stride=1, padding=0, bias=False)
|
||||||
|
self.conv_bn = get_norm(bn_norm, w_out)
|
||||||
|
self.conv_swish = Swish()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv_swish(self.conv_bn(self.conv(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Swish(nn.Module):
|
||||||
|
"""Swish activation function: x * sigmoid(x)."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(Swish, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
class SE(nn.Module):
|
||||||
|
"""Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid."""
|
||||||
|
|
||||||
|
def __init__(self, w_in, w_se):
|
||||||
|
super(SE, self).__init__()
|
||||||
|
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
|
self.f_ex = nn.Sequential(
|
||||||
|
nn.Conv2d(w_in, w_se, 1, bias=True),
|
||||||
|
Swish(),
|
||||||
|
nn.Conv2d(w_se, w_in, 1, bias=True),
|
||||||
|
nn.Sigmoid(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * self.f_ex(self.avg_pool(x))
|
||||||
|
|
||||||
|
|
||||||
|
class MBConv(nn.Module):
|
||||||
|
"""Mobile inverted bottleneck block w/ SE (MBConv)."""
|
||||||
|
|
||||||
|
def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, bn_norm):
|
||||||
|
# expansion, 3x3 dwise, BN, Swish, SE, 1x1, BN, skip_connection
|
||||||
|
super(MBConv, self).__init__()
|
||||||
|
self.exp = None
|
||||||
|
w_exp = int(w_in * exp_r)
|
||||||
|
if w_exp != w_in:
|
||||||
|
self.exp = nn.Conv2d(w_in, w_exp, 1, stride=1, padding=0, bias=False)
|
||||||
|
self.exp_bn = get_norm(bn_norm, w_exp)
|
||||||
|
self.exp_swish = Swish()
|
||||||
|
dwise_args = {"groups": w_exp, "padding": (kernel - 1) // 2, "bias": False}
|
||||||
|
self.dwise = nn.Conv2d(w_exp, w_exp, kernel, stride=stride, **dwise_args)
|
||||||
|
self.dwise_bn = get_norm(bn_norm, w_exp)
|
||||||
|
self.dwise_swish = Swish()
|
||||||
|
self.se = SE(w_exp, int(w_in * se_r))
|
||||||
|
self.lin_proj = nn.Conv2d(w_exp, w_out, 1, stride=1, padding=0, bias=False)
|
||||||
|
self.lin_proj_bn = get_norm(bn_norm, w_out)
|
||||||
|
# Skip connection if in and out shapes are the same (MN-V2 style)
|
||||||
|
self.has_skip = stride == 1 and w_in == w_out
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
f_x = x
|
||||||
|
if self.exp:
|
||||||
|
f_x = self.exp_swish(self.exp_bn(self.exp(f_x)))
|
||||||
|
f_x = self.dwise_swish(self.dwise_bn(self.dwise(f_x)))
|
||||||
|
f_x = self.se(f_x)
|
||||||
|
f_x = self.lin_proj_bn(self.lin_proj(f_x))
|
||||||
|
if self.has_skip:
|
||||||
|
if self.training and effnet_cfg.EN.DC_RATIO > 0.0:
|
||||||
|
f_x = drop_connect(f_x, effnet_cfg.EN.DC_RATIO)
|
||||||
|
f_x = x + f_x
|
||||||
|
return f_x
|
||||||
|
|
||||||
|
|
||||||
|
class EffStage(nn.Module):
|
||||||
|
"""EfficientNet stage."""
|
||||||
|
|
||||||
|
def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, d, bn_norm):
|
||||||
|
super(EffStage, self).__init__()
|
||||||
|
for i in range(d):
|
||||||
|
b_stride = stride if i == 0 else 1
|
||||||
|
b_w_in = w_in if i == 0 else w_out
|
||||||
|
name = "b{}".format(i + 1)
|
||||||
|
self.add_module(name, MBConv(b_w_in, exp_r, kernel, b_stride, se_r, w_out, bn_norm))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for block in self.children():
|
||||||
|
x = block(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class StemIN(nn.Module):
|
||||||
|
"""EfficientNet stem for ImageNet: 3x3, BN, Swish."""
|
||||||
|
|
||||||
|
def __init__(self, w_in, w_out, bn_norm):
|
||||||
|
super(StemIN, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=False)
|
||||||
|
self.bn = get_norm(bn_norm, w_out)
|
||||||
|
self.swish = Swish()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for layer in self.children():
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class EffNet(nn.Module):
|
||||||
|
"""EfficientNet model."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_args():
|
||||||
|
return {
|
||||||
|
"stem_w": effnet_cfg.EN.STEM_W,
|
||||||
|
"ds": effnet_cfg.EN.DEPTHS,
|
||||||
|
"ws": effnet_cfg.EN.WIDTHS,
|
||||||
|
"exp_rs": effnet_cfg.EN.EXP_RATIOS,
|
||||||
|
"se_r": effnet_cfg.EN.SE_R,
|
||||||
|
"ss": effnet_cfg.EN.STRIDES,
|
||||||
|
"ks": effnet_cfg.EN.KERNELS,
|
||||||
|
"head_w": effnet_cfg.EN.HEAD_W,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, last_stride, bn_norm, **kwargs):
|
||||||
|
super(EffNet, self).__init__()
|
||||||
|
kwargs = self.get_args() if not kwargs else kwargs
|
||||||
|
self._construct(**kwargs, last_stride=last_stride, bn_norm=bn_norm)
|
||||||
|
self.apply(init_weights)
|
||||||
|
|
||||||
|
def _construct(self, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, last_stride, bn_norm):
|
||||||
|
stage_params = list(zip(ds, ws, exp_rs, ss, ks))
|
||||||
|
self.stem = StemIN(3, stem_w, bn_norm)
|
||||||
|
prev_w = stem_w
|
||||||
|
for i, (d, w, exp_r, stride, kernel) in enumerate(stage_params):
|
||||||
|
name = "s{}".format(i + 1)
|
||||||
|
if i == 5: stride = last_stride
|
||||||
|
self.add_module(name, EffStage(prev_w, exp_r, kernel, stride, se_r, w, d, bn_norm))
|
||||||
|
prev_w = w
|
||||||
|
self.head = EffHead(prev_w, head_w, bn_norm)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for module in self.children():
|
||||||
|
x = module(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def init_pretrained_weights(key):
|
||||||
|
"""Initializes model with pretrained weights.
|
||||||
|
|
||||||
|
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import errno
|
||||||
|
import gdown
|
||||||
|
|
||||||
|
def _get_torch_home():
|
||||||
|
ENV_TORCH_HOME = 'TORCH_HOME'
|
||||||
|
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
||||||
|
DEFAULT_CACHE_DIR = '~/.cache'
|
||||||
|
torch_home = os.path.expanduser(
|
||||||
|
os.getenv(
|
||||||
|
ENV_TORCH_HOME,
|
||||||
|
os.path.join(
|
||||||
|
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return torch_home
|
||||||
|
|
||||||
|
torch_home = _get_torch_home()
|
||||||
|
model_dir = os.path.join(torch_home, 'checkpoints')
|
||||||
|
try:
|
||||||
|
os.makedirs(model_dir)
|
||||||
|
except OSError as e:
|
||||||
|
if e.errno == errno.EEXIST:
|
||||||
|
# Directory already exists, ignore.
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Unexpected OSError, re-raise.
|
||||||
|
raise
|
||||||
|
|
||||||
|
filename = model_urls[key].split('/')[-1]
|
||||||
|
|
||||||
|
cached_file = os.path.join(model_dir, filename)
|
||||||
|
|
||||||
|
if not os.path.exists(cached_file):
|
||||||
|
if comm.is_main_process():
|
||||||
|
gdown.download(model_urls[key], cached_file, quiet=False)
|
||||||
|
|
||||||
|
comm.synchronize()
|
||||||
|
|
||||||
|
logger.info(f"Loading pretrained model from {cached_file}")
|
||||||
|
state_dict = torch.load(cached_file, map_location=torch.device('cpu'))['model_state']
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
@BACKBONE_REGISTRY.register()
|
||||||
|
def build_effnet_backbone(cfg):
|
||||||
|
# fmt: off
|
||||||
|
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
|
||||||
|
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
|
||||||
|
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
|
||||||
|
bn_norm = cfg.MODEL.BACKBONE.NORM
|
||||||
|
depth = cfg.MODEL.BACKBONE.DEPTH
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
cfg_files = {
|
||||||
|
'b0': 'fastreid/modeling/backbones/regnet/effnet/EN-B0_dds_8gpu.yaml',
|
||||||
|
'b1': 'fastreid/modeling/backbones/regnet/effnet/EN-B1_dds_8gpu.yaml',
|
||||||
|
'b2': 'fastreid/modeling/backbones/regnet/effnet/EN-B2_dds_8gpu.yaml',
|
||||||
|
'b3': 'fastreid/modeling/backbones/regnet/effnet/EN-B3_dds_8gpu.yaml',
|
||||||
|
'b4': 'fastreid/modeling/backbones/regnet/effnet/EN-B4_dds_8gpu.yaml',
|
||||||
|
'b5': 'fastreid/modeling/backbones/regnet/effnet/EN-B5_dds_8gpu.yaml',
|
||||||
|
}[depth]
|
||||||
|
|
||||||
|
effnet_cfg.merge_from_file(cfg_files)
|
||||||
|
model = EffNet(last_stride, bn_norm)
|
||||||
|
|
||||||
|
if pretrain:
|
||||||
|
# Load pretrain path if specifically
|
||||||
|
if pretrain_path:
|
||||||
|
try:
|
||||||
|
state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))
|
||||||
|
logger.info(f"Loading pretrained model from {pretrain_path}")
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logger.info(f'{pretrain_path} is not found! Please check this path.')
|
||||||
|
raise e
|
||||||
|
except KeyError as e:
|
||||||
|
logger.info("State dict keys error! Please check the state dict.")
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
key = depth
|
||||||
|
state_dict = init_pretrained_weights(key)
|
||||||
|
|
||||||
|
incompatible = model.load_state_dict(state_dict, strict=False)
|
||||||
|
if incompatible.missing_keys:
|
||||||
|
logger.info(
|
||||||
|
get_missing_parameters_message(incompatible.missing_keys)
|
||||||
|
)
|
||||||
|
if incompatible.unexpected_keys:
|
||||||
|
logger.info(
|
||||||
|
get_unexpected_parameters_message(incompatible.unexpected_keys)
|
||||||
|
)
|
||||||
|
return model
|
|
@ -0,0 +1,27 @@
|
||||||
|
MODEL:
|
||||||
|
TYPE: effnet
|
||||||
|
NUM_CLASSES: 1000
|
||||||
|
EN:
|
||||||
|
STEM_W: 32
|
||||||
|
STRIDES: [1, 2, 2, 2, 1, 2, 1]
|
||||||
|
DEPTHS: [1, 2, 2, 3, 3, 4, 1]
|
||||||
|
WIDTHS: [16, 24, 40, 80, 112, 192, 320]
|
||||||
|
EXP_RATIOS: [1, 6, 6, 6, 6, 6, 6]
|
||||||
|
KERNELS: [3, 3, 5, 3, 5, 5, 3]
|
||||||
|
HEAD_W: 1280
|
||||||
|
OPTIM:
|
||||||
|
LR_POLICY: cos
|
||||||
|
BASE_LR: 0.4
|
||||||
|
MAX_EPOCH: 100
|
||||||
|
MOMENTUM: 0.9
|
||||||
|
WEIGHT_DECAY: 1e-5
|
||||||
|
TRAIN:
|
||||||
|
DATASET: imagenet
|
||||||
|
IM_SIZE: 224
|
||||||
|
BATCH_SIZE: 256
|
||||||
|
TEST:
|
||||||
|
DATASET: imagenet
|
||||||
|
IM_SIZE: 256
|
||||||
|
BATCH_SIZE: 200
|
||||||
|
NUM_GPUS: 8
|
||||||
|
OUT_DIR: .
|
|
@ -0,0 +1,27 @@
|
||||||
|
MODEL:
|
||||||
|
TYPE: effnet
|
||||||
|
NUM_CLASSES: 1000
|
||||||
|
EN:
|
||||||
|
STEM_W: 32
|
||||||
|
STRIDES: [1, 2, 2, 2, 1, 2, 1]
|
||||||
|
DEPTHS: [2, 3, 3, 4, 4, 5, 2]
|
||||||
|
WIDTHS: [16, 24, 40, 80, 112, 192, 320]
|
||||||
|
EXP_RATIOS: [1, 6, 6, 6, 6, 6, 6]
|
||||||
|
KERNELS: [3, 3, 5, 3, 5, 5, 3]
|
||||||
|
HEAD_W: 1280
|
||||||
|
OPTIM:
|
||||||
|
LR_POLICY: cos
|
||||||
|
BASE_LR: 0.4
|
||||||
|
MAX_EPOCH: 100
|
||||||
|
MOMENTUM: 0.9
|
||||||
|
WEIGHT_DECAY: 1e-5
|
||||||
|
TRAIN:
|
||||||
|
DATASET: imagenet
|
||||||
|
IM_SIZE: 240
|
||||||
|
BATCH_SIZE: 256
|
||||||
|
TEST:
|
||||||
|
DATASET: imagenet
|
||||||
|
IM_SIZE: 274
|
||||||
|
BATCH_SIZE: 200
|
||||||
|
NUM_GPUS: 8
|
||||||
|
OUT_DIR: .
|
|
@ -0,0 +1,27 @@
|
||||||
|
MODEL:
|
||||||
|
TYPE: effnet
|
||||||
|
NUM_CLASSES: 1000
|
||||||
|
EN:
|
||||||
|
STEM_W: 32
|
||||||
|
STRIDES: [1, 2, 2, 2, 1, 2, 1]
|
||||||
|
DEPTHS: [2, 3, 3, 4, 4, 5, 2]
|
||||||
|
WIDTHS: [16, 24, 48, 88, 120, 208, 352]
|
||||||
|
EXP_RATIOS: [1, 6, 6, 6, 6, 6, 6]
|
||||||
|
KERNELS: [3, 3, 5, 3, 5, 5, 3]
|
||||||
|
HEAD_W: 1408
|
||||||
|
OPTIM:
|
||||||
|
LR_POLICY: cos
|
||||||
|
BASE_LR: 0.4
|
||||||
|
MAX_EPOCH: 100
|
||||||
|
MOMENTUM: 0.9
|
||||||
|
WEIGHT_DECAY: 1e-5
|
||||||
|
TRAIN:
|
||||||
|
DATASET: imagenet
|
||||||
|
IM_SIZE: 260
|
||||||
|
BATCH_SIZE: 256
|
||||||
|
TEST:
|
||||||
|
DATASET: imagenet
|
||||||
|
IM_SIZE: 298
|
||||||
|
BATCH_SIZE: 200
|
||||||
|
NUM_GPUS: 8
|
||||||
|
OUT_DIR: .
|
|
@ -0,0 +1,27 @@
|
||||||
|
MODEL:
|
||||||
|
TYPE: effnet
|
||||||
|
NUM_CLASSES: 1000
|
||||||
|
EN:
|
||||||
|
STEM_W: 40
|
||||||
|
STRIDES: [1, 2, 2, 2, 1, 2, 1]
|
||||||
|
DEPTHS: [2, 3, 3, 5, 5, 6, 2]
|
||||||
|
WIDTHS: [24, 32, 48, 96, 136, 232, 384]
|
||||||
|
EXP_RATIOS: [1, 6, 6, 6, 6, 6, 6]
|
||||||
|
KERNELS: [3, 3, 5, 3, 5, 5, 3]
|
||||||
|
HEAD_W: 1536
|
||||||
|
OPTIM:
|
||||||
|
LR_POLICY: cos
|
||||||
|
BASE_LR: 0.4
|
||||||
|
MAX_EPOCH: 100
|
||||||
|
MOMENTUM: 0.9
|
||||||
|
WEIGHT_DECAY: 1e-5
|
||||||
|
TRAIN:
|
||||||
|
DATASET: imagenet
|
||||||
|
IM_SIZE: 300
|
||||||
|
BATCH_SIZE: 256
|
||||||
|
TEST:
|
||||||
|
DATASET: imagenet
|
||||||
|
IM_SIZE: 342
|
||||||
|
BATCH_SIZE: 200
|
||||||
|
NUM_GPUS: 8
|
||||||
|
OUT_DIR: .
|
|
@ -0,0 +1,27 @@
|
||||||
|
MODEL:
|
||||||
|
TYPE: effnet
|
||||||
|
NUM_CLASSES: 1000
|
||||||
|
EN:
|
||||||
|
STEM_W: 48
|
||||||
|
STRIDES: [1, 2, 2, 2, 1, 2, 1]
|
||||||
|
DEPTHS: [2, 4, 4, 6, 6, 8, 2]
|
||||||
|
WIDTHS: [24, 32, 56, 112, 160, 272, 448]
|
||||||
|
EXP_RATIOS: [1, 6, 6, 6, 6, 6, 6]
|
||||||
|
KERNELS: [3, 3, 5, 3, 5, 5, 3]
|
||||||
|
HEAD_W: 1792
|
||||||
|
OPTIM:
|
||||||
|
LR_POLICY: cos
|
||||||
|
BASE_LR: 0.2
|
||||||
|
MAX_EPOCH: 100
|
||||||
|
MOMENTUM: 0.9
|
||||||
|
WEIGHT_DECAY: 1e-5
|
||||||
|
TRAIN:
|
||||||
|
DATASET: imagenet
|
||||||
|
IM_SIZE: 380
|
||||||
|
BATCH_SIZE: 128
|
||||||
|
TEST:
|
||||||
|
DATASET: imagenet
|
||||||
|
IM_SIZE: 434
|
||||||
|
BATCH_SIZE: 104
|
||||||
|
NUM_GPUS: 8
|
||||||
|
OUT_DIR: .
|
|
@ -0,0 +1,27 @@
|
||||||
|
MODEL:
|
||||||
|
TYPE: effnet
|
||||||
|
NUM_CLASSES: 1000
|
||||||
|
EN:
|
||||||
|
STEM_W: 48
|
||||||
|
STRIDES: [1, 2, 2, 2, 1, 2, 1]
|
||||||
|
DEPTHS: [3, 5, 5, 7, 7, 9, 3]
|
||||||
|
WIDTHS: [24, 40, 64, 128, 176, 304, 512]
|
||||||
|
EXP_RATIOS: [1, 6, 6, 6, 6, 6, 6]
|
||||||
|
KERNELS: [3, 3, 5, 3, 5, 5, 3]
|
||||||
|
HEAD_W: 2048
|
||||||
|
OPTIM:
|
||||||
|
LR_POLICY: cos
|
||||||
|
BASE_LR: 0.1
|
||||||
|
MAX_EPOCH: 100
|
||||||
|
MOMENTUM: 0.9
|
||||||
|
WEIGHT_DECAY: 1e-5
|
||||||
|
TRAIN:
|
||||||
|
DATASET: imagenet
|
||||||
|
IM_SIZE: 456
|
||||||
|
BATCH_SIZE: 64
|
||||||
|
TEST:
|
||||||
|
DATASET: imagenet
|
||||||
|
IM_SIZE: 522
|
||||||
|
BATCH_SIZE: 48
|
||||||
|
NUM_GPUS: 8
|
||||||
|
OUT_DIR: .
|
|
@ -1,14 +1,15 @@
|
||||||
import torch
|
|
||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import torch.nn as nn
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
from fastreid.layers import get_norm
|
from fastreid.layers import get_norm
|
||||||
from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
|
|
||||||
from fastreid.utils import comm
|
from fastreid.utils import comm
|
||||||
|
from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
|
||||||
|
from .config import cfg as regnet_cfg
|
||||||
from ..build import BACKBONE_REGISTRY
|
from ..build import BACKBONE_REGISTRY
|
||||||
from .config import regnet_cfg
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
model_urls = {
|
model_urls = {
|
||||||
|
@ -24,6 +25,7 @@ model_urls = {
|
||||||
'6400y': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160907112/RegNetY-6.4GF_dds_8gpu.pyth',
|
'6400y': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160907112/RegNetY-6.4GF_dds_8gpu.pyth',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def init_weights(m):
|
def init_weights(m):
|
||||||
"""Performs ResNet-style weight initialization."""
|
"""Performs ResNet-style weight initialization."""
|
||||||
if isinstance(m, nn.Conv2d):
|
if isinstance(m, nn.Conv2d):
|
||||||
|
@ -67,6 +69,15 @@ def get_block_fun(block_type):
|
||||||
return block_funs[block_type]
|
return block_funs[block_type]
|
||||||
|
|
||||||
|
|
||||||
|
def drop_connect(x, drop_ratio):
|
||||||
|
"""Drop connect (adapted from DARTS)."""
|
||||||
|
keep_ratio = 1.0 - drop_ratio
|
||||||
|
mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
|
||||||
|
mask.bernoulli_(keep_ratio)
|
||||||
|
x.div_(keep_ratio)
|
||||||
|
x.mul_(mask)
|
||||||
|
return x
|
||||||
|
|
||||||
class AnyHead(nn.Module):
|
class AnyHead(nn.Module):
|
||||||
"""AnyNet head."""
|
"""AnyNet head."""
|
||||||
|
|
||||||
|
@ -117,7 +128,7 @@ class BasicTransform(nn.Module):
|
||||||
super(BasicTransform, self).__init__()
|
super(BasicTransform, self).__init__()
|
||||||
self.construct(w_in, w_out, stride, bn_norm)
|
self.construct(w_in, w_out, stride, bn_norm)
|
||||||
|
|
||||||
def construct(self, w_in, w_out, stride, bn_norm, num_split):
|
def construct(self, w_in, w_out, stride, bn_norm):
|
||||||
# 3x3, BN, ReLU
|
# 3x3, BN, ReLU
|
||||||
self.a = nn.Conv2d(
|
self.a = nn.Conv2d(
|
||||||
w_in, w_out, kernel_size=3, stride=stride, padding=1, bias=False
|
w_in, w_out, kernel_size=3, stride=stride, padding=1, bias=False
|
||||||
|
@ -269,7 +280,7 @@ class ResStemCifar(nn.Module):
|
||||||
self.conv = nn.Conv2d(
|
self.conv = nn.Conv2d(
|
||||||
w_in, w_out, kernel_size=3, stride=1, padding=1, bias=False
|
w_in, w_out, kernel_size=3, stride=1, padding=1, bias=False
|
||||||
)
|
)
|
||||||
self.bn = get_norm(bn_norm, w_out, 1)
|
self.bn = get_norm(bn_norm, w_out)
|
||||||
self.relu = nn.ReLU(regnet_cfg.MEM.RELU_INPLACE)
|
self.relu = nn.ReLU(regnet_cfg.MEM.RELU_INPLACE)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -535,6 +546,7 @@ def build_regnet_backbone(cfg):
|
||||||
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
|
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
|
||||||
bn_norm = cfg.MODEL.BACKBONE.NORM
|
bn_norm = cfg.MODEL.BACKBONE.NORM
|
||||||
depth = cfg.MODEL.BACKBONE.DEPTH
|
depth = cfg.MODEL.BACKBONE.DEPTH
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
cfg_files = {
|
cfg_files = {
|
||||||
'800x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml',
|
'800x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml',
|
||||||
|
@ -549,7 +561,7 @@ def build_regnet_backbone(cfg):
|
||||||
'6400y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml',
|
'6400y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml',
|
||||||
}[depth]
|
}[depth]
|
||||||
|
|
||||||
regnet_cfg.merge_from_file(cfg_files)
|
cfg.merge_from_file(cfg_files)
|
||||||
model = RegNet(last_stride, bn_norm)
|
model = RegNet(last_stride, bn_norm)
|
||||||
|
|
||||||
if pretrain:
|
if pretrain:
|
||||||
|
|
Loading…
Reference in New Issue