mirror of https://github.com/JDAI-CV/fast-reid.git
Support vision transformer backbone
parent
2b65882447
commit
2cabc3428a
|
@ -0,0 +1,88 @@
|
|||
|
||||
MODEL:
|
||||
META_ARCHITECTURE: Baseline
|
||||
PIXEL_MEAN: [127.5, 127.5, 127.5]
|
||||
PIXEL_STD: [127.5, 127.5, 127.5]
|
||||
|
||||
BACKBONE:
|
||||
NAME: build_vit_backbone
|
||||
DEPTH: base
|
||||
FEAT_DIM: 768
|
||||
PRETRAIN: True
|
||||
PRETRAIN_PATH: /export/home/lxy/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth
|
||||
STRIDE_SIZE: (16, 16)
|
||||
DROP_PATH_RATIO: 0.1
|
||||
DROP_RATIO: 0.0
|
||||
ATT_DROP_RATE: 0.0
|
||||
|
||||
HEADS:
|
||||
NAME: EmbeddingHead
|
||||
NORM: BN
|
||||
WITH_BNNECK: True
|
||||
POOL_LAYER: Identity
|
||||
NECK_FEAT: before
|
||||
CLS_LAYER: Linear
|
||||
|
||||
LOSSES:
|
||||
NAME: ("CrossEntropyLoss", "TripletLoss",)
|
||||
|
||||
CE:
|
||||
EPSILON: 0. # no smooth
|
||||
SCALE: 1.
|
||||
|
||||
TRI:
|
||||
MARGIN: 0.0
|
||||
HARD_MINING: True
|
||||
NORM_FEAT: False
|
||||
SCALE: 1.
|
||||
|
||||
INPUT:
|
||||
SIZE_TRAIN: [ 256, 128 ]
|
||||
SIZE_TEST: [ 256, 128 ]
|
||||
|
||||
REA:
|
||||
ENABLED: True
|
||||
PROB: 0.5
|
||||
|
||||
FLIP:
|
||||
ENABLED: True
|
||||
|
||||
PADDING:
|
||||
ENABLED: True
|
||||
|
||||
DATALOADER:
|
||||
SAMPLER_TRAIN: NaiveIdentitySampler
|
||||
NUM_INSTANCE: 4
|
||||
NUM_WORKERS: 8
|
||||
|
||||
SOLVER:
|
||||
AMP:
|
||||
ENABLED: False
|
||||
OPT: SGD
|
||||
MAX_EPOCH: 120
|
||||
BASE_LR: 0.008
|
||||
WEIGHT_DECAY: 0.0001
|
||||
IMS_PER_BATCH: 64
|
||||
|
||||
SCHED: CosineAnnealingLR
|
||||
ETA_MIN_LR: 0.000016
|
||||
|
||||
WARMUP_FACTOR: 0.01
|
||||
WARMUP_ITERS: 1000
|
||||
|
||||
CLIP_GRADIENTS:
|
||||
ENABLED: True
|
||||
|
||||
CHECKPOINT_PERIOD: 30
|
||||
|
||||
TEST:
|
||||
EVAL_PERIOD: 5
|
||||
IMS_PER_BATCH: 128
|
||||
|
||||
CUDNN_BENCHMARK: True
|
||||
|
||||
DATASETS:
|
||||
NAMES: ("Market1501",)
|
||||
TESTS: ("Market1501",)
|
||||
|
||||
OUTPUT_DIR: logs/market1501/sbs_vit_base
|
|
@ -23,7 +23,7 @@ _C.MODEL = CN()
|
|||
_C.MODEL.DEVICE = "cuda"
|
||||
_C.MODEL.META_ARCHITECTURE = "Baseline"
|
||||
|
||||
_C.MODEL.FREEZE_LAYERS = ['']
|
||||
_C.MODEL.FREEZE_LAYERS = []
|
||||
|
||||
# MoCo memory size
|
||||
_C.MODEL.QUEUE_SIZE = 8192
|
||||
|
@ -46,6 +46,12 @@ _C.MODEL.BACKBONE.WITH_IBN = False
|
|||
_C.MODEL.BACKBONE.WITH_SE = False
|
||||
# If use Non-local block in backbone
|
||||
_C.MODEL.BACKBONE.WITH_NL = False
|
||||
# Vision Transformer options
|
||||
_C.MODEL.BACKBONE.SIE_COE = 3.0
|
||||
_C.MODEL.BACKBONE.STRIDE_SIZE = (16, 16)
|
||||
_C.MODEL.BACKBONE.DROP_PATH_RATIO = 0.1
|
||||
_C.MODEL.BACKBONE.DROP_RATIO = 0.0
|
||||
_C.MODEL.BACKBONE.ATT_DROP_RATE = 0.0
|
||||
# If use ImageNet pretrain model
|
||||
_C.MODEL.BACKBONE.PRETRAIN = False
|
||||
# Pretrain model path
|
||||
|
@ -128,8 +134,10 @@ _C.MODEL.PIXEL_STD = [0.229*255, 0.224*255, 0.225*255]
|
|||
# -----------------------------------------------------------------------------
|
||||
|
||||
_C.KD = CN()
|
||||
_C.KD.MODEL_CONFIG = ['',]
|
||||
_C.KD.MODEL_WEIGHTS = ['',]
|
||||
_C.KD.MODEL_CONFIG = []
|
||||
_C.KD.MODEL_WEIGHTS = []
|
||||
_C.KD.EMA = CN({"ENABLED": False})
|
||||
_C.KD.EMA.MOMENTUM = 0.999
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# INPUT
|
||||
|
@ -223,14 +231,25 @@ _C.SOLVER.OPT = "Adam"
|
|||
_C.SOLVER.MAX_EPOCH = 120
|
||||
|
||||
_C.SOLVER.BASE_LR = 3e-4
|
||||
_C.SOLVER.BIAS_LR_FACTOR = 1.
|
||||
|
||||
# This LR is applied to the last classification layer if
|
||||
# you want to 10x higher than BASE_LR.
|
||||
_C.SOLVER.HEADS_LR_FACTOR = 1.
|
||||
|
||||
_C.SOLVER.MOMENTUM = 0.9
|
||||
_C.SOLVER.NESTEROV = False
|
||||
|
||||
_C.SOLVER.WEIGHT_DECAY = 0.0005
|
||||
_C.SOLVER.WEIGHT_DECAY_BIAS = 0.
|
||||
# The weight decay that's applied to parameters of normalization layers
|
||||
# (typically the affine transformation)
|
||||
_C.SOLVER.WEIGHT_DECAY_NORM = 0.0
|
||||
|
||||
# The previous detection code used a 2x higher LR and 0 WD for bias.
|
||||
# This is not useful (at least for recent models). You should avoid
|
||||
# changing these and they exists only to reproduce previous model
|
||||
# training if desired.
|
||||
_C.SOLVER.BIAS_LR_FACTOR = 1.0
|
||||
_C.SOLVER.WEIGHT_DECAY_BIAS = _C.SOLVER.WEIGHT_DECAY
|
||||
|
||||
# Multi-step learning rate options
|
||||
_C.SOLVER.SCHED = "MultiStepLR"
|
||||
|
@ -251,33 +270,31 @@ _C.SOLVER.WARMUP_METHOD = "linear"
|
|||
# Backbone freeze iters
|
||||
_C.SOLVER.FREEZE_ITERS = 0
|
||||
|
||||
# FC freeze iters
|
||||
_C.SOLVER.FREEZE_FC_ITERS = 0
|
||||
|
||||
|
||||
# SWA options
|
||||
# _C.SOLVER.SWA = CN()
|
||||
# _C.SOLVER.SWA.ENABLED = False
|
||||
# _C.SOLVER.SWA.ITER = 10
|
||||
# _C.SOLVER.SWA.PERIOD = 2
|
||||
# _C.SOLVER.SWA.LR_FACTOR = 10.
|
||||
# _C.SOLVER.SWA.ETA_MIN_LR = 3.5e-6
|
||||
# _C.SOLVER.SWA.LR_SCHED = False
|
||||
|
||||
_C.SOLVER.CHECKPOINT_PERIOD = 20
|
||||
|
||||
# Number of images per batch across all machines.
|
||||
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
|
||||
# see 2 images per batch
|
||||
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 256, each GPU will
|
||||
# see 32 images per batch
|
||||
_C.SOLVER.IMS_PER_BATCH = 64
|
||||
|
||||
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
|
||||
# see 2 images per batch
|
||||
# Gradient clipping
|
||||
_C.SOLVER.CLIP_GRADIENTS = CN({"ENABLED": False})
|
||||
# Type of gradient clipping, currently 2 values are supported:
|
||||
# - "value": the absolute values of elements of each gradients are clipped
|
||||
# - "norm": the norm of the gradient for each parameter is clipped thus
|
||||
# affecting all elements in the parameter
|
||||
_C.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "norm"
|
||||
# Maximum absolute value used for clipping gradients
|
||||
_C.SOLVER.CLIP_GRADIENTS.CLIP_VALUE = 5.0
|
||||
# Floating point number p for L-p norm to be used with the "norm"
|
||||
# gradient clipping type; for L-inf, please specify .inf
|
||||
_C.SOLVER.CLIP_GRADIENTS.NORM_TYPE = 2.0
|
||||
|
||||
_C.TEST = CN()
|
||||
|
||||
_C.TEST.EVAL_PERIOD = 20
|
||||
|
||||
# Number of images per batch in one process.
|
||||
# Number of images per batch across all machines.
|
||||
_C.TEST.IMS_PER_BATCH = 64
|
||||
_C.TEST.METRIC = "cosine"
|
||||
_C.TEST.ROC = CN({"ENABLED": False})
|
||||
|
|
|
@ -0,0 +1,161 @@
|
|||
""" DropBlock, DropPath
|
||||
PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
|
||||
Papers:
|
||||
DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
|
||||
Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
|
||||
Code:
|
||||
DropBlock impl inspired by two Tensorflow impl that I liked:
|
||||
- https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
|
||||
- https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def drop_block_2d(
|
||||
x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
|
||||
with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
|
||||
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
||||
DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
|
||||
runs with success, but needs further validation and possibly optimization for lower runtime impact.
|
||||
"""
|
||||
B, C, H, W = x.shape
|
||||
total_size = W * H
|
||||
clipped_block_size = min(block_size, min(W, H))
|
||||
# seed_drop_rate, the gamma parameter
|
||||
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
|
||||
(W - block_size + 1) * (H - block_size + 1))
|
||||
|
||||
# Forces the block to be inside the feature map.
|
||||
w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
|
||||
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
|
||||
((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
|
||||
valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
|
||||
|
||||
if batchwise:
|
||||
# one mask for whole batch, quite a bit faster
|
||||
uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
|
||||
else:
|
||||
uniform_noise = torch.rand_like(x)
|
||||
block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
|
||||
block_mask = -F.max_pool2d(
|
||||
-block_mask,
|
||||
kernel_size=clipped_block_size, # block_size,
|
||||
stride=1,
|
||||
padding=clipped_block_size // 2)
|
||||
|
||||
if with_noise:
|
||||
normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
|
||||
if inplace:
|
||||
x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
|
||||
else:
|
||||
x = x * block_mask + normal_noise * (1 - block_mask)
|
||||
else:
|
||||
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
|
||||
if inplace:
|
||||
x.mul_(block_mask * normalize_scale)
|
||||
else:
|
||||
x = x * block_mask * normalize_scale
|
||||
return x
|
||||
|
||||
|
||||
def drop_block_fast_2d(
|
||||
x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7,
|
||||
gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
|
||||
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
||||
DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
|
||||
block mask at edges.
|
||||
"""
|
||||
B, C, H, W = x.shape
|
||||
total_size = W * H
|
||||
clipped_block_size = min(block_size, min(W, H))
|
||||
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
|
||||
(W - block_size + 1) * (H - block_size + 1))
|
||||
|
||||
if batchwise:
|
||||
# one mask for whole batch, quite a bit faster
|
||||
block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma
|
||||
else:
|
||||
# mask per batch element
|
||||
block_mask = torch.rand_like(x) < gamma
|
||||
block_mask = F.max_pool2d(
|
||||
block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2)
|
||||
|
||||
if with_noise:
|
||||
normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
|
||||
if inplace:
|
||||
x.mul_(1. - block_mask).add_(normal_noise * block_mask)
|
||||
else:
|
||||
x = x * (1. - block_mask) + normal_noise * block_mask
|
||||
else:
|
||||
block_mask = 1 - block_mask
|
||||
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype)
|
||||
if inplace:
|
||||
x.mul_(block_mask * normalize_scale)
|
||||
else:
|
||||
x = x * block_mask * normalize_scale
|
||||
return x
|
||||
|
||||
|
||||
class DropBlock2d(nn.Module):
|
||||
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
drop_prob=0.1,
|
||||
block_size=7,
|
||||
gamma_scale=1.0,
|
||||
with_noise=False,
|
||||
inplace=False,
|
||||
batchwise=False,
|
||||
fast=True):
|
||||
super(DropBlock2d, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
self.gamma_scale = gamma_scale
|
||||
self.block_size = block_size
|
||||
self.with_noise = with_noise
|
||||
self.inplace = inplace
|
||||
self.batchwise = batchwise
|
||||
self.fast = fast # FIXME finish comparisons of fast vs not
|
||||
|
||||
def forward(self, x):
|
||||
if not self.training or not self.drop_prob:
|
||||
return x
|
||||
if self.fast:
|
||||
return drop_block_fast_2d(
|
||||
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
|
||||
else:
|
||||
return drop_block_2d(
|
||||
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
"""
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_() # binarize
|
||||
output = x.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
|
@ -0,0 +1,31 @@
|
|||
""" Layer/Module Helpers
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import collections.abc
|
||||
from itertools import repeat
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
to_1tuple = _ntuple(1)
|
||||
to_2tuple = _ntuple(2)
|
||||
to_3tuple = _ntuple(3)
|
||||
to_4tuple = _ntuple(4)
|
||||
to_ntuple = _ntuple
|
||||
|
||||
|
||||
def make_divisible(v, divisor=8, min_value=None):
|
||||
min_value = min_value or divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
|
@ -0,0 +1,122 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def weights_init_kaiming(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Linear') != -1:
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif classname.find('Conv') != -1:
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
if m.affine:
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
|
||||
|
||||
def weights_init_classifier(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Linear') != -1:
|
||||
nn.init.normal_(m.weight, std=0.001)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
|
||||
|
||||
from torch.nn.init import _calculate_fan_in_and_fan_out
|
||||
|
||||
|
||||
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
||||
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
||||
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||
def norm_cdf(x):
|
||||
# Computes standard normal cumulative distribution function
|
||||
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
||||
|
||||
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
||||
"The distribution of values may be incorrect.",
|
||||
stacklevel=2)
|
||||
|
||||
with torch.no_grad():
|
||||
# Values are generated by using a truncated uniform distribution and
|
||||
# then using the inverse CDF for the normal distribution.
|
||||
# Get upper and lower cdf values
|
||||
l = norm_cdf((a - mean) / std)
|
||||
u = norm_cdf((b - mean) / std)
|
||||
|
||||
# Uniformly fill tensor with values from [l, u], then translate to
|
||||
# [2l-1, 2u-1].
|
||||
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
||||
|
||||
# Use inverse cdf transform for normal distribution to get truncated
|
||||
# standard normal
|
||||
tensor.erfinv_()
|
||||
|
||||
# Transform to proper mean, std
|
||||
tensor.mul_(std * math.sqrt(2.))
|
||||
tensor.add_(mean)
|
||||
|
||||
# Clamp to ensure it's in the proper range
|
||||
tensor.clamp_(min=a, max=b)
|
||||
return tensor
|
||||
|
||||
|
||||
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
||||
# type: (Tensor, float, float, float, float) -> Tensor
|
||||
r"""Fills the input Tensor with values drawn from a truncated
|
||||
normal distribution. The values are effectively drawn from the
|
||||
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
||||
with values outside :math:`[a, b]` redrawn until they are within
|
||||
the bounds. The method used for generating the random values works
|
||||
best when :math:`a \leq \text{mean} \leq b`.
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
mean: the mean of the normal distribution
|
||||
std: the standard deviation of the normal distribution
|
||||
a: the minimum cutoff value
|
||||
b: the maximum cutoff value
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.trunc_normal_(w)
|
||||
"""
|
||||
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
||||
|
||||
|
||||
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||
if mode == 'fan_in':
|
||||
denom = fan_in
|
||||
elif mode == 'fan_out':
|
||||
denom = fan_out
|
||||
elif mode == 'fan_avg':
|
||||
denom = (fan_in + fan_out) / 2
|
||||
|
||||
variance = scale / denom
|
||||
|
||||
if distribution == "truncated_normal":
|
||||
# constant is stddev of standard normal truncated to (-2, 2)
|
||||
trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
|
||||
elif distribution == "normal":
|
||||
tensor.normal_(std=math.sqrt(variance))
|
||||
elif distribution == "uniform":
|
||||
bound = math.sqrt(3 * variance)
|
||||
tensor.uniform_(-bound, bound)
|
||||
else:
|
||||
raise ValueError(f"invalid distribution {distribution}")
|
||||
|
||||
|
||||
def lecun_normal_(tensor):
|
||||
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
|
|
@ -14,3 +14,4 @@ from .regnet import build_regnet_backbone, build_effnet_backbone
|
|||
from .shufflenet import build_shufflenetv2_backbone
|
||||
from .mobilenet import build_mobilenetv2_backbone
|
||||
from .repvgg import build_repvgg_backbone
|
||||
from .vision_transformer import build_vit_backbone
|
||||
|
|
|
@ -0,0 +1,399 @@
|
|||
""" Vision Transformer (ViT) in PyTorch
|
||||
A PyTorch implement of Vision Transformers as described in
|
||||
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
|
||||
The official jax code is released and available at https://github.com/google-research/vision_transformer
|
||||
Status/TODO:
|
||||
* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
|
||||
* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
|
||||
* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
|
||||
* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
|
||||
Acknowledgments:
|
||||
* The paper authors for releasing code and weights, thanks!
|
||||
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
|
||||
for some einops/einsum fun
|
||||
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
|
||||
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fastreid.layers import DropPath, trunc_normal_, to_2tuple
|
||||
from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
|
||||
from .build import BACKBONE_REGISTRY
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" Image to Patch Embedding
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
# FIXME look at relaxing size constraints
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class HybridEmbed(nn.Module):
|
||||
""" CNN Feature Map Embedding
|
||||
Extract feature map from CNN, flatten, project to embedding dim.
|
||||
"""
|
||||
|
||||
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
assert isinstance(backbone, nn.Module)
|
||||
img_size = to_2tuple(img_size)
|
||||
self.img_size = img_size
|
||||
self.backbone = backbone
|
||||
if feature_size is None:
|
||||
with torch.no_grad():
|
||||
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
|
||||
# map for all networks, the feature metadata has reliable channel and stride info, but using
|
||||
# stride to calc feature dim requires info about padding of each stage that isn't captured.
|
||||
training = backbone.training
|
||||
if training:
|
||||
backbone.eval()
|
||||
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
|
||||
if isinstance(o, (list, tuple)):
|
||||
o = o[-1] # last feature if backbone outputs list/tuple of features
|
||||
feature_size = o.shape[-2:]
|
||||
feature_dim = o.shape[1]
|
||||
backbone.train(training)
|
||||
else:
|
||||
feature_size = to_2tuple(feature_size)
|
||||
if hasattr(self.backbone, 'feature_info'):
|
||||
feature_dim = self.backbone.feature_info.channels()[-1]
|
||||
else:
|
||||
feature_dim = self.backbone.num_features
|
||||
self.num_patches = feature_size[0] * feature_size[1]
|
||||
self.proj = nn.Conv2d(feature_dim, embed_dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)
|
||||
if isinstance(x, (list, tuple)):
|
||||
x = x[-1] # last feature if backbone outputs list/tuple of features
|
||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed_overlap(nn.Module):
|
||||
""" Image to Patch Embedding with overlapping patches
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=16, stride_size=20, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
stride_size_tuple = to_2tuple(stride_size)
|
||||
self.num_x = (img_size[1] - patch_size[1]) // stride_size_tuple[1] + 1
|
||||
self.num_y = (img_size[0] - patch_size[0]) // stride_size_tuple[0] + 1
|
||||
num_patches = self.num_x * self.num_y
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride_size)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.InstanceNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
|
||||
# FIXME look at relaxing size constraints
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x)
|
||||
|
||||
x = x.flatten(2).transpose(1, 2) # [64, 8, 768]
|
||||
return x
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
""" Vision Transformer
|
||||
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
||||
- https://arxiv.org/abs/2010.11929
|
||||
Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
|
||||
- https://arxiv.org/abs/2012.12877
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=16, stride_size=16, in_chans=3, embed_dim=768,
|
||||
depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None,
|
||||
drop_rate=0., attn_drop_rate=0., camera=0, drop_path_rate=0., hybrid_backbone=None,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), sie_xishu=1.0):
|
||||
super().__init__()
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
if hybrid_backbone is not None:
|
||||
self.patch_embed = HybridEmbed(
|
||||
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
else:
|
||||
self.patch_embed = PatchEmbed_overlap(
|
||||
img_size=img_size, patch_size=patch_size, stride_size=stride_size, in_chans=in_chans,
|
||||
embed_dim=embed_dim)
|
||||
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
||||
self.cam_num = camera
|
||||
self.sie_xishu = sie_xishu
|
||||
# Initialize SIE Embedding
|
||||
if camera > 1:
|
||||
self.sie_embed = nn.Parameter(torch.zeros(camera, 1, embed_dim))
|
||||
trunc_normal_(self.sie_embed, std=.02)
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
Block(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
||||
for i in range(depth)])
|
||||
|
||||
self.norm = norm_layer(embed_dim)
|
||||
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'pos_embed', 'cls_token'}
|
||||
|
||||
def forward(self, x, camera_id=None):
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
if self.cam_num > 0:
|
||||
x = x + self.pos_embed + self.sie_xishu * self.sie_embed[camera_id]
|
||||
else:
|
||||
x = x + self.pos_embed
|
||||
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x[:, 0].reshape(x.shape[0], -1, 1, 1)
|
||||
|
||||
|
||||
def resize_pos_embed(posemb, posemb_new, hight, width):
|
||||
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
||||
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
||||
ntok_new = posemb_new.shape[1]
|
||||
|
||||
posemb_token, posemb_grid = posemb[:, :1], posemb[0, 1:]
|
||||
ntok_new -= 1
|
||||
|
||||
gs_old = int(math.sqrt(len(posemb_grid)))
|
||||
logger.info('Resized position embedding from size:{} to size: {} with height:{} width: {}'.format(posemb.shape,
|
||||
posemb_new.shape,
|
||||
hight,
|
||||
width))
|
||||
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
||||
posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear')
|
||||
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1)
|
||||
posemb = torch.cat([posemb_token, posemb_grid], dim=1)
|
||||
return posemb
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def build_vit_backbone(cfg):
|
||||
"""
|
||||
Create a Vision Transformer instance from config.
|
||||
Returns:
|
||||
SwinTransformer: a :class:`SwinTransformer` instance.
|
||||
"""
|
||||
# fmt: off
|
||||
input_size = cfg.INPUT.SIZE_TRAIN
|
||||
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
|
||||
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
|
||||
depth = cfg.MODEL.BACKBONE.DEPTH
|
||||
sie_xishu = cfg.MODEL.BACKBONE.SIE_COE
|
||||
stride_size = cfg.MODEL.BACKBONE.STRIDE_SIZE
|
||||
drop_ratio = cfg.MODEL.BACKBONE.DROP_RATIO
|
||||
drop_path_ratio = cfg.MODEL.BACKBONE.DROP_PATH_RATIO
|
||||
attn_drop_rate = cfg.MODEL.BACKBONE.ATT_DROP_RATE
|
||||
# fmt: on
|
||||
|
||||
num_depth = {
|
||||
'small': 8,
|
||||
'base': 12,
|
||||
}[depth]
|
||||
|
||||
num_heads = {
|
||||
'small': 8,
|
||||
'base': 12,
|
||||
}[depth]
|
||||
|
||||
mlp_ratio = {
|
||||
'small': 3.,
|
||||
'base': 4.
|
||||
}[depth]
|
||||
|
||||
qkv_bias = {
|
||||
'small': False,
|
||||
'base': True
|
||||
}[depth]
|
||||
|
||||
qk_scale = {
|
||||
'small': 768 ** -0.5,
|
||||
'base': None,
|
||||
}[depth]
|
||||
|
||||
model = VisionTransformer(img_size=input_size, sie_xishu=sie_xishu, stride_size=stride_size, depth=num_depth,
|
||||
num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop_path_rate=drop_path_ratio, drop_rate=drop_ratio, attn_drop_rate=attn_drop_rate)
|
||||
|
||||
if pretrain:
|
||||
try:
|
||||
state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))
|
||||
logger.info(f"Loading pretrained model from {pretrain_path}")
|
||||
|
||||
if 'model' in state_dict:
|
||||
state_dict = state_dict.pop('model')
|
||||
if 'state_dict' in state_dict:
|
||||
state_dict = state_dict.pop('state_dict')
|
||||
for k, v in state_dict.items():
|
||||
if 'head' in k or 'dist' in k:
|
||||
continue
|
||||
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
|
||||
# For old models that I trained prior to conv based patchification
|
||||
O, I, H, W = model.patch_embed.proj.weight.shape
|
||||
v = v.reshape(O, -1, H, W)
|
||||
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
||||
# To resize pos embedding when using model at different size from pretrained weights
|
||||
if 'distilled' in pretrain_path:
|
||||
logger.info("distill need to choose right cls token in the pth.")
|
||||
v = torch.cat([v[:, 0:1], v[:, 2:]], dim=1)
|
||||
v = resize_pos_embed(v, model.pos_embed.data, model.patch_embed.num_y, model.patch_embed.num_x)
|
||||
state_dict[k] = v
|
||||
except FileNotFoundError as e:
|
||||
logger.info(f'{pretrain_path} is not found! Please check this path.')
|
||||
raise e
|
||||
except KeyError as e:
|
||||
logger.info("State dict keys error! Please check the state dict.")
|
||||
raise e
|
||||
|
||||
incompatible = model.load_state_dict(state_dict, strict=False)
|
||||
if incompatible.missing_keys:
|
||||
logger.info(
|
||||
get_missing_parameters_message(incompatible.missing_keys)
|
||||
)
|
||||
if incompatible.unexpected_keys:
|
||||
logger.info(
|
||||
get_unexpected_parameters_message(incompatible.unexpected_keys)
|
||||
)
|
||||
|
||||
return model
|
|
@ -1,36 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from torch import nn
|
||||
|
||||
__all__ = [
|
||||
'weights_init_classifier',
|
||||
'weights_init_kaiming',
|
||||
]
|
||||
|
||||
|
||||
def weights_init_kaiming(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Linear') != -1:
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif classname.find('Conv') != -1:
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
if m.affine:
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
|
||||
|
||||
def weights_init_classifier(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Linear') != -1:
|
||||
nn.init.normal_(m.weight, std=0.001)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0.0)
|
Loading…
Reference in New Issue