mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #2092 from huggingface/mesa_ema
ModelEMAV3 + MESA experiments
This commit is contained in:
commit
1b50b15145
@ -32,16 +32,12 @@ class ToNumpy:
|
||||
|
||||
|
||||
class ToTensor:
|
||||
|
||||
""" ToTensor with no rescaling of values"""
|
||||
def __init__(self, dtype=torch.float32):
|
||||
self.dtype = dtype
|
||||
|
||||
def __call__(self, pil_img):
|
||||
np_img = np.array(pil_img, dtype=np.uint8)
|
||||
if np_img.ndim < 3:
|
||||
np_img = np.expand_dims(np_img, axis=-1)
|
||||
np_img = np.rollaxis(np_img, 2) # HWC to CHW
|
||||
return torch.from_numpy(np_img).to(dtype=self.dtype)
|
||||
return F.pil_to_tensor(pil_img).to(dtype=self.dtype)
|
||||
|
||||
|
||||
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
|
||||
|
@ -180,10 +180,10 @@ class NormMlpClassifierHead(nn.Module):
|
||||
self.drop = nn.Dropout(drop_rate)
|
||||
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def reset(self, num_classes, global_pool=None):
|
||||
if global_pool is not None:
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
||||
def reset(self, num_classes, pool_type=None):
|
||||
if pool_type is not None:
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
|
||||
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
|
||||
self.use_conv = self.global_pool.is_identity()
|
||||
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
|
||||
if self.hidden_size:
|
||||
|
@ -148,7 +148,7 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
|
||||
return _ACT_LAYER_DEFAULT[name]
|
||||
|
||||
|
||||
def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs):
|
||||
def create_act_layer(name: Union[Type[nn.Module], str], inplace=None, **kwargs):
|
||||
act_layer = get_act_layer(name)
|
||||
if act_layer is None:
|
||||
return None
|
||||
|
@ -39,6 +39,7 @@ from .mobilevit import *
|
||||
from .mvitv2 import *
|
||||
from .nasnet import *
|
||||
from .nest import *
|
||||
from .nextvit import *
|
||||
from .nfnet import *
|
||||
from .pit import *
|
||||
from .pnasnet import *
|
||||
|
@ -547,6 +547,17 @@ class DaVit(nn.Module):
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
return dict(
|
||||
stem=r'^stem', # stem and embed
|
||||
blocks=r'^stages\.(\d+)' if coarse else [
|
||||
(r'^stages\.(\d+).downsample', (0,)),
|
||||
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
|
||||
(r'^norm_pre', (99999,)),
|
||||
]
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.grad_checkpointing = enable
|
||||
@ -558,7 +569,7 @@ class DaVit(nn.Module):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
self.head.reset(num_classes, global_pool=global_pool)
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
|
685
timm/models/nextvit.py
Normal file
685
timm/models/nextvit.py
Normal file
@ -0,0 +1,685 @@
|
||||
""" Next-ViT
|
||||
|
||||
As described in https://arxiv.org/abs/2207.05501
|
||||
|
||||
Next-ViT model defs and weights adapted from https://github.com/bytedance/Next-ViT, original copyright below
|
||||
"""
|
||||
# Copyright (c) ByteDance Inc. All rights reserved.
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, trunc_normal_, ConvMlp, get_norm_layer, get_act_layer, use_fused_attn
|
||||
from timm.layers import ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
|
||||
def merge_pre_bn(module, pre_bn_1, pre_bn_2=None):
|
||||
""" Merge pre BN to reduce inference runtime.
|
||||
"""
|
||||
weight = module.weight.data
|
||||
if module.bias is None:
|
||||
zeros = torch.zeros(module.out_chs, device=weight.device).type(weight.type())
|
||||
module.bias = nn.Parameter(zeros)
|
||||
bias = module.bias.data
|
||||
if pre_bn_2 is None:
|
||||
assert pre_bn_1.track_running_stats is True, "Unsupported bn_module.track_running_stats is False"
|
||||
assert pre_bn_1.affine is True, "Unsupported bn_module.affine is False"
|
||||
|
||||
scale_invstd = pre_bn_1.running_var.add(pre_bn_1.eps).pow(-0.5)
|
||||
extra_weight = scale_invstd * pre_bn_1.weight
|
||||
extra_bias = pre_bn_1.bias - pre_bn_1.weight * pre_bn_1.running_mean * scale_invstd
|
||||
else:
|
||||
assert pre_bn_1.track_running_stats is True, "Unsupported bn_module.track_running_stats is False"
|
||||
assert pre_bn_1.affine is True, "Unsupported bn_module.affine is False"
|
||||
|
||||
assert pre_bn_2.track_running_stats is True, "Unsupported bn_module.track_running_stats is False"
|
||||
assert pre_bn_2.affine is True, "Unsupported bn_module.affine is False"
|
||||
|
||||
scale_invstd_1 = pre_bn_1.running_var.add(pre_bn_1.eps).pow(-0.5)
|
||||
scale_invstd_2 = pre_bn_2.running_var.add(pre_bn_2.eps).pow(-0.5)
|
||||
|
||||
extra_weight = scale_invstd_1 * pre_bn_1.weight * scale_invstd_2 * pre_bn_2.weight
|
||||
extra_bias = (
|
||||
scale_invstd_2 * pre_bn_2.weight
|
||||
* (pre_bn_1.bias - pre_bn_1.weight * pre_bn_1.running_mean * scale_invstd_1 - pre_bn_2.running_mean)
|
||||
+ pre_bn_2.bias
|
||||
)
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
extra_bias = weight @ extra_bias
|
||||
weight.mul_(extra_weight.view(1, weight.size(1)).expand_as(weight))
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
assert weight.shape[2] == 1 and weight.shape[3] == 1
|
||||
weight = weight.reshape(weight.shape[0], weight.shape[1])
|
||||
extra_bias = weight @ extra_bias
|
||||
weight.mul_(extra_weight.view(1, weight.size(1)).expand_as(weight))
|
||||
weight = weight.reshape(weight.shape[0], weight.shape[1], 1, 1)
|
||||
bias.add_(extra_bias)
|
||||
|
||||
module.weight.data = weight
|
||||
module.bias.data = bias
|
||||
|
||||
|
||||
class ConvNormAct(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
groups=1,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
act_layer=nn.ReLU,
|
||||
):
|
||||
super(ConvNormAct, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_chs, out_chs, kernel_size=kernel_size, stride=stride,
|
||||
padding=1, groups=groups, bias=False)
|
||||
self.norm = norm_layer(out_chs)
|
||||
self.act = act_layer()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.norm(x)
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
def __init__(self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
stride=1,
|
||||
norm_layer = nn.BatchNorm2d,
|
||||
):
|
||||
super(PatchEmbed, self).__init__()
|
||||
|
||||
if stride == 2:
|
||||
self.pool = nn.AvgPool2d((2, 2), stride=2, ceil_mode=True, count_include_pad=False)
|
||||
self.conv = nn.Conv2d(in_chs, out_chs, kernel_size=1, stride=1, bias=False)
|
||||
self.norm = norm_layer(out_chs)
|
||||
elif in_chs != out_chs:
|
||||
self.pool = nn.Identity()
|
||||
self.conv = nn.Conv2d(in_chs, out_chs, kernel_size=1, stride=1, bias=False)
|
||||
self.norm = norm_layer(out_chs)
|
||||
else:
|
||||
self.pool = nn.Identity()
|
||||
self.conv = nn.Identity()
|
||||
self.norm = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(self.conv(self.pool(x)))
|
||||
|
||||
|
||||
class ConvAttention(nn.Module):
|
||||
"""
|
||||
Multi-Head Convolutional Attention
|
||||
"""
|
||||
|
||||
def __init__(self, out_chs, head_dim, norm_layer = nn.BatchNorm2d, act_layer = nn.ReLU):
|
||||
super(ConvAttention, self).__init__()
|
||||
self.group_conv3x3 = nn.Conv2d(
|
||||
out_chs, out_chs,
|
||||
kernel_size=3, stride=1, padding=1, groups=out_chs // head_dim, bias=False
|
||||
)
|
||||
self.norm = norm_layer(out_chs)
|
||||
self.act = act_layer()
|
||||
self.projection = nn.Conv2d(out_chs, out_chs, kernel_size=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.group_conv3x3(x)
|
||||
out = self.norm(out)
|
||||
out = self.act(out)
|
||||
out = self.projection(out)
|
||||
return out
|
||||
|
||||
class NextConvBlock(nn.Module):
|
||||
"""
|
||||
Next Convolution Block
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
stride=1,
|
||||
drop_path=0.,
|
||||
drop=0.,
|
||||
head_dim=32,
|
||||
mlp_ratio=3.,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
act_layer=nn.ReLU
|
||||
):
|
||||
super(NextConvBlock, self).__init__()
|
||||
self.in_chs = in_chs
|
||||
self.out_chs = out_chs
|
||||
assert out_chs % head_dim == 0
|
||||
|
||||
self.patch_embed = PatchEmbed(in_chs, out_chs, stride, norm_layer=norm_layer)
|
||||
self.mhca = ConvAttention(
|
||||
out_chs,
|
||||
head_dim,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
)
|
||||
self.attn_drop_path = DropPath(drop_path)
|
||||
|
||||
self.norm = norm_layer(out_chs)
|
||||
self.mlp = ConvMlp(
|
||||
out_chs,
|
||||
hidden_features=int(out_chs * mlp_ratio),
|
||||
drop=drop,
|
||||
bias=True,
|
||||
act_layer=act_layer,
|
||||
)
|
||||
self.mlp_drop_path = DropPath(drop_path)
|
||||
self.is_fused = False
|
||||
|
||||
@torch.no_grad()
|
||||
def reparameterize(self):
|
||||
if not self.is_fused:
|
||||
merge_pre_bn(self.mlp.fc1, self.norm)
|
||||
self.norm = None
|
||||
self.is_fused = True
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = x + self.attn_drop_path(self.mhca(x))
|
||||
|
||||
out = self.norm(x)
|
||||
x = x + self.mlp_drop_path(self.mlp(out))
|
||||
return x
|
||||
|
||||
|
||||
class EfficientAttention(nn.Module):
|
||||
"""
|
||||
Efficient Multi-Head Self Attention
|
||||
"""
|
||||
fused_attn: torch.jit.Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
out_dim=None,
|
||||
head_dim=32,
|
||||
qkv_bias=True,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
sr_ratio=1,
|
||||
norm_layer=nn.BatchNorm1d,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.out_dim = out_dim if out_dim is not None else dim
|
||||
self.num_heads = self.dim // head_dim
|
||||
self.head_dim = head_dim
|
||||
self.scale = head_dim ** -0.5
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
self.q = nn.Linear(dim, self.dim, bias=qkv_bias)
|
||||
self.k = nn.Linear(dim, self.dim, bias=qkv_bias)
|
||||
self.v = nn.Linear(dim, self.dim, bias=qkv_bias)
|
||||
self.proj = nn.Linear(self.dim, self.out_dim)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
self.sr_ratio = sr_ratio
|
||||
self.N_ratio = sr_ratio ** 2
|
||||
if sr_ratio > 1:
|
||||
self.sr = nn.AvgPool1d(kernel_size=self.N_ratio, stride=self.N_ratio)
|
||||
self.norm = norm_layer(dim)
|
||||
else:
|
||||
self.sr = None
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
||||
|
||||
if self.sr is not None:
|
||||
x = self.sr(x.transpose(1, 2))
|
||||
x = self.norm(x).transpose(1, 2)
|
||||
|
||||
k = self.k(x).reshape(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = self.v(x).reshape(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-1, -2)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class NextTransformerBlock(nn.Module):
|
||||
"""
|
||||
Next Transformer Block
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
drop_path,
|
||||
stride=1,
|
||||
sr_ratio=1,
|
||||
mlp_ratio=2,
|
||||
head_dim=32,
|
||||
mix_block_ratio=0.75,
|
||||
attn_drop=0.,
|
||||
drop=0.,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
act_layer=nn.ReLU,
|
||||
):
|
||||
super(NextTransformerBlock, self).__init__()
|
||||
self.in_chs = in_chs
|
||||
self.out_chs = out_chs
|
||||
self.mix_block_ratio = mix_block_ratio
|
||||
|
||||
self.mhsa_out_chs = _make_divisible(int(out_chs * mix_block_ratio), 32)
|
||||
self.mhca_out_chs = out_chs - self.mhsa_out_chs
|
||||
|
||||
self.patch_embed = PatchEmbed(in_chs, self.mhsa_out_chs, stride)
|
||||
self.norm1 = norm_layer(self.mhsa_out_chs)
|
||||
self.e_mhsa = EfficientAttention(
|
||||
self.mhsa_out_chs,
|
||||
head_dim=head_dim,
|
||||
sr_ratio=sr_ratio,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
self.mhsa_drop_path = DropPath(drop_path * mix_block_ratio)
|
||||
|
||||
self.projection = PatchEmbed(self.mhsa_out_chs, self.mhca_out_chs, stride=1, norm_layer=norm_layer)
|
||||
self.mhca = ConvAttention(
|
||||
self.mhca_out_chs,
|
||||
head_dim=head_dim,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
)
|
||||
self.mhca_drop_path = DropPath(drop_path * (1 - mix_block_ratio))
|
||||
|
||||
self.norm2 = norm_layer(out_chs)
|
||||
self.mlp = ConvMlp(
|
||||
out_chs,
|
||||
hidden_features=int(out_chs * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
self.mlp_drop_path = DropPath(drop_path)
|
||||
self.is_fused = False
|
||||
|
||||
@torch.no_grad()
|
||||
def reparameterize(self):
|
||||
if not self.is_fused:
|
||||
merge_pre_bn(self.e_mhsa.q, self.norm1)
|
||||
if self.e_mhsa.norm is not None:
|
||||
merge_pre_bn(self.e_mhsa.k, self.norm1, self.e_mhsa.norm)
|
||||
merge_pre_bn(self.e_mhsa.v, self.norm1, self.e_mhsa.norm)
|
||||
self.e_mhsa.norm = nn.Identity()
|
||||
else:
|
||||
merge_pre_bn(self.e_mhsa.k, self.norm1)
|
||||
merge_pre_bn(self.e_mhsa.v, self.norm1)
|
||||
self.norm1 = nn.Identity()
|
||||
|
||||
merge_pre_bn(self.mlp.fc1, self.norm2)
|
||||
self.norm2 = nn.Identity()
|
||||
self.is_fused = True
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
B, C, H, W = x.shape
|
||||
|
||||
out = self.norm1(x)
|
||||
out = out.reshape(B, C, -1).transpose(-1, -2)
|
||||
out = self.mhsa_drop_path(self.e_mhsa(out))
|
||||
x = x + out.transpose(-1, -2).reshape(B, C, H, W)
|
||||
|
||||
out = self.projection(x)
|
||||
out = out + self.mhca_drop_path(self.mhca(out))
|
||||
x = torch.cat([x, out], dim=1)
|
||||
|
||||
out = self.norm2(x)
|
||||
x = x + self.mlp_drop_path(self.mlp(out))
|
||||
return x
|
||||
|
||||
|
||||
class NextStage(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
block_chs,
|
||||
block_types,
|
||||
stride=2,
|
||||
sr_ratio=1,
|
||||
mix_block_ratio=1.0,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
head_dim=32,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
act_layer=nn.ReLU,
|
||||
):
|
||||
super().__init__()
|
||||
self.grad_checkpointing = False
|
||||
|
||||
blocks = []
|
||||
for block_idx, block_type in enumerate(block_types):
|
||||
stride = stride if block_idx == 0 else 1
|
||||
out_chs = block_chs[block_idx]
|
||||
block_type = block_types[block_idx]
|
||||
dpr = drop_path[block_idx] if isinstance(drop_path, (list, tuple)) else drop_path
|
||||
if block_type is NextConvBlock:
|
||||
layer = NextConvBlock(
|
||||
in_chs,
|
||||
out_chs,
|
||||
stride=stride,
|
||||
drop_path=dpr,
|
||||
drop=drop,
|
||||
head_dim=head_dim,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
)
|
||||
blocks.append(layer)
|
||||
elif block_type is NextTransformerBlock:
|
||||
layer = NextTransformerBlock(
|
||||
in_chs,
|
||||
out_chs,
|
||||
drop_path=dpr,
|
||||
stride=stride,
|
||||
sr_ratio=sr_ratio,
|
||||
head_dim=head_dim,
|
||||
mix_block_ratio=mix_block_ratio,
|
||||
attn_drop=attn_drop,
|
||||
drop=drop,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
)
|
||||
blocks.append(layer)
|
||||
in_chs = out_chs
|
||||
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.grad_checkpointing = enable
|
||||
|
||||
def forward(self, x):
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.blocks, x)
|
||||
else:
|
||||
x = self.blocks(x)
|
||||
return x
|
||||
|
||||
|
||||
class NextViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chans,
|
||||
num_classes=1000,
|
||||
global_pool='avg',
|
||||
stem_chs=(64, 32, 64),
|
||||
depths=(3, 4, 10, 3),
|
||||
strides=(1, 2, 2, 2),
|
||||
sr_ratios=(8, 4, 2, 1),
|
||||
drop_path_rate=0.1,
|
||||
attn_drop_rate=0.,
|
||||
drop_rate=0.,
|
||||
head_dim=32,
|
||||
mix_block_ratio=0.75,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
act_layer=None,
|
||||
):
|
||||
super(NextViT, self).__init__()
|
||||
self.grad_checkpointing = False
|
||||
self.num_classes = num_classes
|
||||
norm_layer = get_norm_layer(norm_layer)
|
||||
if act_layer is None:
|
||||
act_layer = partial(nn.ReLU, inplace=True)
|
||||
else:
|
||||
act_layer = get_act_layer(act_layer)
|
||||
|
||||
self.stage_out_chs = [
|
||||
[96] * (depths[0]),
|
||||
[192] * (depths[1] - 1) + [256],
|
||||
[384, 384, 384, 384, 512] * (depths[2] // 5),
|
||||
[768] * (depths[3] - 1) + [1024]
|
||||
]
|
||||
self.feature_info = [dict(
|
||||
num_chs=sc[-1],
|
||||
reduction=2**(i + 2),
|
||||
module=f'stages.{i}'
|
||||
) for i, sc in enumerate(self.stage_out_chs)]
|
||||
|
||||
# Next Hybrid Strategy
|
||||
self.stage_block_types = [
|
||||
[NextConvBlock] * depths[0],
|
||||
[NextConvBlock] * (depths[1] - 1) + [NextTransformerBlock],
|
||||
[NextConvBlock, NextConvBlock, NextConvBlock, NextConvBlock, NextTransformerBlock] * (depths[2] // 5),
|
||||
[NextConvBlock] * (depths[3] - 1) + [NextTransformerBlock]]
|
||||
|
||||
self.stem = nn.Sequential(
|
||||
ConvNormAct(in_chans, stem_chs[0], kernel_size=3, stride=2, norm_layer=norm_layer, act_layer=act_layer),
|
||||
ConvNormAct(stem_chs[0], stem_chs[1], kernel_size=3, stride=1, norm_layer=norm_layer, act_layer=act_layer),
|
||||
ConvNormAct(stem_chs[1], stem_chs[2], kernel_size=3, stride=1, norm_layer=norm_layer, act_layer=act_layer),
|
||||
ConvNormAct(stem_chs[2], stem_chs[2], kernel_size=3, stride=2, norm_layer=norm_layer, act_layer=act_layer),
|
||||
)
|
||||
in_chs = out_chs = stem_chs[-1]
|
||||
stages = []
|
||||
idx = 0
|
||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||
for stage_idx in range(len(depths)):
|
||||
stage = NextStage(
|
||||
in_chs=in_chs,
|
||||
block_chs=self.stage_out_chs[stage_idx],
|
||||
block_types=self.stage_block_types[stage_idx],
|
||||
stride=strides[stage_idx],
|
||||
sr_ratio=sr_ratios[stage_idx],
|
||||
mix_block_ratio=mix_block_ratio,
|
||||
head_dim=head_dim,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[stage_idx],
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
)
|
||||
in_chs = out_chs = self.stage_out_chs[stage_idx][-1]
|
||||
stages += [stage]
|
||||
idx += depths[stage_idx]
|
||||
self.num_features = out_chs
|
||||
self.stages = nn.Sequential(*stages)
|
||||
self.norm = norm_layer(out_chs)
|
||||
self.head = ClassifierHead(pool_type=global_pool, in_features=out_chs, num_classes=num_classes)
|
||||
|
||||
self.stage_out_idx = [sum(depths[:idx + 1]) - 1 for idx in range(len(depths))]
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if hasattr(m, 'bias') and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if hasattr(m, 'bias') and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
return dict(
|
||||
stem=r'^stem', # stem and embed
|
||||
blocks=r'^stages\.(\d+)' if coarse else [
|
||||
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
|
||||
(r'^norm', (99999,)),
|
||||
]
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.grad_checkpointing = enable
|
||||
for stage in self.stages:
|
||||
stage.set_grad_checkpointing(enable=enable)
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.stages, x)
|
||||
else:
|
||||
x = self.stages(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" Remap original checkpoints -> timm """
|
||||
if 'head.fc.weight' in state_dict:
|
||||
return state_dict # non-original
|
||||
|
||||
D = model.state_dict()
|
||||
out_dict = {}
|
||||
# remap originals based on order
|
||||
for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()):
|
||||
out_dict[ka] = vb
|
||||
|
||||
return out_dict
|
||||
|
||||
|
||||
def _create_nextvit(variant, pretrained=False, **kwargs):
|
||||
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
|
||||
out_indices = kwargs.pop('out_indices', default_out_indices)
|
||||
|
||||
model = build_model_with_cfg(
|
||||
NextViT,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
**kwargs)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.95, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.0.conv', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'nextvit_small.bd_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
),
|
||||
'nextvit_base.bd_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
),
|
||||
'nextvit_large.bd_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
),
|
||||
'nextvit_small.bd_in1k_384': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
|
||||
),
|
||||
'nextvit_base.bd_in1k_384': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
|
||||
),
|
||||
'nextvit_large.bd_in1k_384': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
|
||||
),
|
||||
|
||||
'nextvit_small.bd_ssld_6m_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
),
|
||||
'nextvit_base.bd_ssld_6m_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
),
|
||||
'nextvit_large.bd_ssld_6m_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
),
|
||||
'nextvit_small.bd_ssld_6m_in1k_384': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
|
||||
),
|
||||
'nextvit_base.bd_ssld_6m_in1k_384': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
|
||||
),
|
||||
'nextvit_large.bd_ssld_6m_in1k_384': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def nextvit_small(pretrained=False, **kwargs):
|
||||
model_args = dict(depths=(3, 4, 10, 3), drop_path_rate=0.1)
|
||||
model = _create_nextvit(
|
||||
'nextvit_small', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def nextvit_base(pretrained=False, **kwargs):
|
||||
model_args = dict(depths=(3, 4, 20, 3), drop_path_rate=0.2)
|
||||
model = _create_nextvit(
|
||||
'nextvit_base', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def nextvit_large(pretrained=False, **kwargs):
|
||||
model_args = dict(depths=(3, 4, 30, 3), drop_path_rate=0.2)
|
||||
model = _create_nextvit(
|
||||
'nextvit_large', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
@ -535,7 +535,7 @@ class TinyVit(nn.Module):
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool=global_pool)
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
|
@ -421,6 +421,7 @@ class VisionTransformer(nn.Module):
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '',
|
||||
fix_init: bool = False,
|
||||
embed_layer: Callable = PatchEmbed,
|
||||
norm_layer: Optional[LayerType] = None,
|
||||
act_layer: Optional[LayerType] = None,
|
||||
@ -449,6 +450,7 @@ class VisionTransformer(nn.Module):
|
||||
attn_drop_rate: Attention dropout rate.
|
||||
drop_path_rate: Stochastic depth rate.
|
||||
weight_init: Weight initialization scheme.
|
||||
fix_init: Apply weight initialization fix (scaling w/ layer index).
|
||||
embed_layer: Patch embedding layer.
|
||||
norm_layer: Normalization layer.
|
||||
act_layer: MLP activation layer.
|
||||
@ -536,8 +538,18 @@ class VisionTransformer(nn.Module):
|
||||
|
||||
if weight_init != 'skip':
|
||||
self.init_weights(weight_init)
|
||||
if fix_init:
|
||||
self.fix_init_weight()
|
||||
|
||||
def init_weights(self, mode: Literal['jax', 'jax_nlhb', 'moco', ''] = '') -> None:
|
||||
def fix_init_weight(self):
|
||||
def rescale(param, _layer_id):
|
||||
param.div_(math.sqrt(2.0 * _layer_id))
|
||||
|
||||
for layer_id, layer in enumerate(self.blocks):
|
||||
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
||||
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
||||
|
||||
def init_weights(self, mode: str = '') -> None:
|
||||
assert mode in ('jax', 'jax_nlhb', 'moco', '')
|
||||
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
@ -737,7 +749,7 @@ def init_weights_vit_moco(module: nn.Module, name: str = '') -> None:
|
||||
module.init_weights()
|
||||
|
||||
|
||||
def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> None:
|
||||
def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable:
|
||||
if 'jax' in mode:
|
||||
return partial(init_weights_vit_jax, head_bias=head_bias)
|
||||
elif 'moco' in mode:
|
||||
@ -1723,7 +1735,12 @@ default_cfgs = {
|
||||
input_size=(3, 256, 256)),
|
||||
'vit_medium_patch16_reg4_gap_256': _cfg(
|
||||
input_size=(3, 256, 256)),
|
||||
'vit_base_patch16_reg8_gap_256': _cfg(input_size=(3, 256, 256)),
|
||||
'vit_base_patch16_reg4_gap_256': _cfg(
|
||||
input_size=(3, 256, 256)),
|
||||
'vit_so150m_patch16_reg4_gap_256': _cfg(
|
||||
input_size=(3, 256, 256)),
|
||||
'vit_so150m_patch16_reg4_map_256': _cfg(
|
||||
input_size=(3, 256, 256)),
|
||||
}
|
||||
|
||||
_quick_gelu_cfgs = [
|
||||
@ -2623,13 +2640,35 @@ def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_reg8_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
def vit_base_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False,
|
||||
no_embed_class=True, global_pool='avg', reg_tokens=8,
|
||||
no_embed_class=True, global_pool='avg', reg_tokens=4,
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_base_patch16_reg8_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
'vit_base_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_so150m_patch16_reg4_map_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
|
||||
class_token=False, reg_tokens=4, global_pool='map',
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_so150m_patch16_reg4_map_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
|
||||
class_token=False, reg_tokens=4, global_pool='avg', fc_norm=False,
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_so150m_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
|
@ -7,7 +7,12 @@ Hacked together by / Copyright 2022, Ross Wightman
|
||||
import logging
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Type, Union
|
||||
|
||||
try:
|
||||
from typing import Literal
|
||||
except ImportError:
|
||||
from typing_extensions import Literal
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -15,9 +20,11 @@ from torch.jit import Final
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn, LayerType
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
from .vision_transformer import get_init_weights_vit
|
||||
|
||||
__all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
@ -215,59 +222,61 @@ class VisionTransformerRelPos(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool='avg',
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_norm=False,
|
||||
init_values=1e-6,
|
||||
class_token=False,
|
||||
fc_norm=False,
|
||||
rel_pos_type='mlp',
|
||||
rel_pos_dim=None,
|
||||
shared_rel_pos=False,
|
||||
drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
weight_init='skip',
|
||||
embed_layer=PatchEmbed,
|
||||
norm_layer=None,
|
||||
act_layer=None,
|
||||
block_fn=RelPosBlock
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
global_pool: Literal['', 'avg', 'token', 'map'] = 'avg',
|
||||
embed_dim: int = 768,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
init_values: Optional[float] = 1e-6,
|
||||
class_token: bool = False,
|
||||
fc_norm: bool = False,
|
||||
rel_pos_type: str = 'mlp',
|
||||
rel_pos_dim: Optional[int] = None,
|
||||
shared_rel_pos: bool = False,
|
||||
drop_rate: float = 0.,
|
||||
proj_drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
weight_init: Literal['skip', 'jax', 'moco', ''] = 'skip',
|
||||
fix_init: bool = False,
|
||||
embed_layer: Type[nn.Module] = PatchEmbed,
|
||||
norm_layer: Optional[LayerType] = None,
|
||||
act_layer: Optional[LayerType] = None,
|
||||
block_fn: Type[nn.Module] = RelPosBlock
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
img_size (int, tuple): input image size
|
||||
patch_size (int, tuple): patch size
|
||||
in_chans (int): number of input channels
|
||||
num_classes (int): number of classes for classification head
|
||||
global_pool (str): type of global pooling for final sequence (default: 'avg')
|
||||
embed_dim (int): embedding dimension
|
||||
depth (int): depth of transformer
|
||||
num_heads (int): number of attention heads
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||
qkv_bias (bool): enable bias for qkv if True
|
||||
qk_norm (bool): Enable normalization of query and key in attention
|
||||
init_values: (float): layer-scale init values
|
||||
class_token (bool): use class token (default: False)
|
||||
fc_norm (bool): use pre classifier norm instead of pre-pool
|
||||
rel_pos_ty pe (str): type of relative position
|
||||
shared_rel_pos (bool): share relative pos across all blocks
|
||||
drop_rate (float): dropout rate
|
||||
proj_drop_rate (float): projection dropout rate
|
||||
attn_drop_rate (float): attention dropout rate
|
||||
drop_path_rate (float): stochastic depth rate
|
||||
weight_init (str): weight init scheme
|
||||
embed_layer (nn.Module): patch embedding layer
|
||||
norm_layer: (nn.Module): normalization layer
|
||||
act_layer: (nn.Module): MLP activation layer
|
||||
img_size: input image size
|
||||
patch_size: patch size
|
||||
in_chans: number of input channels
|
||||
num_classes: number of classes for classification head
|
||||
global_pool: type of global pooling for final sequence (default: 'avg')
|
||||
embed_dim: embedding dimension
|
||||
depth: depth of transformer
|
||||
num_heads: number of attention heads
|
||||
mlp_ratio: ratio of mlp hidden dim to embedding dim
|
||||
qkv_bias: enable bias for qkv if True
|
||||
qk_norm: Enable normalization of query and key in attention
|
||||
init_values: layer-scale init values
|
||||
class_token: use class token (default: False)
|
||||
fc_norm: use pre classifier norm instead of pre-pool
|
||||
rel_pos_type: type of relative position
|
||||
shared_rel_pos: share relative pos across all blocks
|
||||
drop_rate: dropout rate
|
||||
proj_drop_rate: projection dropout rate
|
||||
attn_drop_rate: attention dropout rate
|
||||
drop_path_rate: stochastic depth rate
|
||||
weight_init: weight init scheme
|
||||
fix_init: apply weight initialization fix (scaling w/ layer index)
|
||||
embed_layer: patch embedding layer
|
||||
norm_layer: normalization layer
|
||||
act_layer: MLP activation layer
|
||||
"""
|
||||
super().__init__()
|
||||
assert global_pool in ('', 'avg', 'token')
|
||||
@ -332,13 +341,22 @@ class VisionTransformerRelPos(nn.Module):
|
||||
|
||||
if weight_init != 'skip':
|
||||
self.init_weights(weight_init)
|
||||
if fix_init:
|
||||
self.fix_init_weight()
|
||||
|
||||
def init_weights(self, mode=''):
|
||||
assert mode in ('jax', 'moco', '')
|
||||
if self.cls_token is not None:
|
||||
nn.init.normal_(self.cls_token, std=1e-6)
|
||||
# FIXME weight init scheme using PyTorch defaults curently
|
||||
#named_apply(get_init_weights_vit(mode, head_bias), self)
|
||||
named_apply(get_init_weights_vit(mode), self)
|
||||
|
||||
def fix_init_weight(self):
|
||||
def rescale(param, _layer_id):
|
||||
param.div_(math.sqrt(2.0 * _layer_id))
|
||||
|
||||
for layer_id, layer in enumerate(self.blocks):
|
||||
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
||||
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
|
@ -10,6 +10,6 @@ from .log import setup_default_logging, FormatterNoInfo
|
||||
from .metrics import AverageMeter, accuracy
|
||||
from .misc import natural_key, add_bool_arg, ParseKwargs
|
||||
from .model import unwrap_model, get_state_dict, freeze, unfreeze, reparameterize_model
|
||||
from .model_ema import ModelEma, ModelEmaV2
|
||||
from .model_ema import ModelEma, ModelEmaV2, ModelEmaV3
|
||||
from .random import random_seed
|
||||
from .summary import update_summary, get_outdir
|
||||
|
@ -2,18 +2,17 @@
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
|
||||
try:
|
||||
import horovod.torch as hvd
|
||||
except ImportError:
|
||||
hvd = None
|
||||
|
||||
from .model import unwrap_model
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def reduce_tensor(tensor, n):
|
||||
rt = tensor.clone()
|
||||
@ -84,9 +83,39 @@ def init_distributed_device(args):
|
||||
args.world_size = 1
|
||||
args.rank = 0 # global rank
|
||||
args.local_rank = 0
|
||||
result = init_distributed_device_so(
|
||||
device=getattr(args, 'device', 'cuda'),
|
||||
dist_backend=getattr(args, 'dist_backend', None),
|
||||
dist_url=getattr(args, 'dist_url', None),
|
||||
)
|
||||
args.device = result['device']
|
||||
args.world_size = result['world_size']
|
||||
args.rank = result['global_rank']
|
||||
args.local_rank = result['local_rank']
|
||||
args.distributed = result['distributed']
|
||||
device = torch.device(args.device)
|
||||
return device
|
||||
|
||||
|
||||
def init_distributed_device_so(
|
||||
device: str = 'cuda',
|
||||
dist_backend: Optional[str] = None,
|
||||
dist_url: Optional[str] = None,
|
||||
):
|
||||
# Distributed training = training on more than one GPU.
|
||||
# Works in both single and multi-node scenarios.
|
||||
distributed = False
|
||||
world_size = 1
|
||||
global_rank = 0
|
||||
local_rank = 0
|
||||
if dist_backend is None:
|
||||
# FIXME sane defaults for other device backends?
|
||||
dist_backend = 'nccl' if 'cuda' in device else 'gloo'
|
||||
dist_url = dist_url or 'env://'
|
||||
|
||||
# TBD, support horovod?
|
||||
# if args.horovod:
|
||||
# import horovod.torch as hvd
|
||||
# assert hvd is not None, "Horovod is not installed"
|
||||
# hvd.init()
|
||||
# args.local_rank = int(hvd.local_rank())
|
||||
@ -96,42 +125,51 @@ def init_distributed_device(args):
|
||||
# os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
# os.environ['RANK'] = str(args.rank)
|
||||
# os.environ['WORLD_SIZE'] = str(args.world_size)
|
||||
dist_backend = getattr(args, 'dist_backend', 'nccl')
|
||||
dist_url = getattr(args, 'dist_url', 'env://')
|
||||
if is_distributed_env():
|
||||
if 'SLURM_PROCID' in os.environ:
|
||||
# DDP via SLURM
|
||||
args.local_rank, args.rank, args.world_size = world_info_from_env()
|
||||
local_rank, global_rank, world_size = world_info_from_env()
|
||||
# SLURM var -> torch.distributed vars in case needed
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
os.environ['RANK'] = str(args.rank)
|
||||
os.environ['WORLD_SIZE'] = str(args.world_size)
|
||||
os.environ['LOCAL_RANK'] = str(local_rank)
|
||||
os.environ['RANK'] = str(global_rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
torch.distributed.init_process_group(
|
||||
backend=dist_backend,
|
||||
init_method=dist_url,
|
||||
world_size=args.world_size,
|
||||
rank=args.rank,
|
||||
world_size=world_size,
|
||||
rank=global_rank,
|
||||
)
|
||||
else:
|
||||
# DDP via torchrun, torch.distributed.launch
|
||||
args.local_rank, _, _ = world_info_from_env()
|
||||
local_rank, _, _ = world_info_from_env()
|
||||
torch.distributed.init_process_group(
|
||||
backend=dist_backend,
|
||||
init_method=dist_url,
|
||||
)
|
||||
args.world_size = torch.distributed.get_world_size()
|
||||
args.rank = torch.distributed.get_rank()
|
||||
args.distributed = True
|
||||
world_size = torch.distributed.get_world_size()
|
||||
global_rank = torch.distributed.get_rank()
|
||||
distributed = True
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if args.distributed:
|
||||
device = 'cuda:%d' % args.local_rank
|
||||
else:
|
||||
device = 'cuda:0'
|
||||
if 'cuda' in device:
|
||||
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
|
||||
|
||||
if distributed and device != 'cpu':
|
||||
device, *device_idx = device.split(':', maxsplit=1)
|
||||
|
||||
# Ignore manually specified device index in distributed mode and
|
||||
# override with resolved local rank, fewer headaches in most setups.
|
||||
if device_idx:
|
||||
_logger.warning(f'device index {device_idx[0]} removed from specified ({device}).')
|
||||
|
||||
device = f'{device}:{local_rank}'
|
||||
|
||||
if device.startswith('cuda:'):
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
device = 'cpu'
|
||||
|
||||
args.device = device
|
||||
device = torch.device(device)
|
||||
return device
|
||||
return dict(
|
||||
device=device,
|
||||
global_rank=global_rank,
|
||||
local_rank=local_rank,
|
||||
world_size=world_size,
|
||||
distributed=distributed,
|
||||
)
|
||||
|
@ -5,6 +5,7 @@ Hacked together by / Copyright 2020 Ross Wightman
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -103,7 +104,7 @@ class ModelEmaV2(nn.Module):
|
||||
GPU assignment and distributed training wrappers.
|
||||
"""
|
||||
def __init__(self, model, decay=0.9999, device=None):
|
||||
super(ModelEmaV2, self).__init__()
|
||||
super().__init__()
|
||||
# make a copy of the model for accumulating moving average of weights
|
||||
self.module = deepcopy(model)
|
||||
self.module.eval()
|
||||
@ -124,3 +125,136 @@ class ModelEmaV2(nn.Module):
|
||||
|
||||
def set(self, model):
|
||||
self._update(model, update_fn=lambda e, m: m)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
|
||||
class ModelEmaV3(nn.Module):
|
||||
""" Model Exponential Moving Average V3
|
||||
|
||||
Keep a moving average of everything in the model state_dict (parameters and buffers).
|
||||
V3 of this module leverages for_each and in-place operations for faster performance.
|
||||
|
||||
Decay warmup based on code by @crowsonkb, her comments:
|
||||
If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
|
||||
good values for models you plan to train for a million or more steps (reaches decay
|
||||
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
|
||||
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
|
||||
215.4k steps).
|
||||
|
||||
This is intended to allow functionality like
|
||||
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||
|
||||
To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
|
||||
disable validation of the EMA weights. Validation will have to be done manually in a separate
|
||||
process, or after the training stops converging.
|
||||
|
||||
This class is sensitive where it is initialized in the sequence of model init,
|
||||
GPU assignment and distributed training wrappers.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
decay: float = 0.9999,
|
||||
min_decay: float = 0.0,
|
||||
update_after_step: int = 0,
|
||||
use_warmup: bool = False,
|
||||
warmup_gamma: float = 1.0,
|
||||
warmup_power: float = 2/3,
|
||||
device: Optional[torch.device] = None,
|
||||
foreach: bool = True,
|
||||
exclude_buffers: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
# make a copy of the model for accumulating moving average of weights
|
||||
self.module = deepcopy(model)
|
||||
self.module.eval()
|
||||
self.decay = decay
|
||||
self.min_decay = min_decay
|
||||
self.update_after_step = update_after_step
|
||||
self.use_warmup = use_warmup
|
||||
self.warmup_gamma = warmup_gamma
|
||||
self.warmup_power = warmup_power
|
||||
self.foreach = foreach
|
||||
self.device = device # perform ema on different device from model if set
|
||||
self.exclude_buffers = exclude_buffers
|
||||
if self.device is not None and device != next(model.parameters()).device:
|
||||
self.foreach = False # cannot use foreach methods with different devices
|
||||
self.module.to(device=device)
|
||||
|
||||
def get_decay(self, step: Optional[int] = None) -> float:
|
||||
"""
|
||||
Compute the decay factor for the exponential moving average.
|
||||
"""
|
||||
if step is None:
|
||||
return self.decay
|
||||
|
||||
step = max(0, step - self.update_after_step - 1)
|
||||
if step <= 0:
|
||||
return 0.0
|
||||
|
||||
if self.use_warmup:
|
||||
decay = 1 - (1 + step / self.warmup_gamma) ** -self.warmup_power
|
||||
decay = max(min(decay, self.decay), self.min_decay)
|
||||
else:
|
||||
decay = self.decay
|
||||
|
||||
return decay
|
||||
|
||||
@torch.no_grad()
|
||||
def update(self, model, step: Optional[int] = None):
|
||||
decay = self.get_decay(step)
|
||||
if self.exclude_buffers:
|
||||
self.apply_update_no_buffers_(model, decay)
|
||||
else:
|
||||
self.apply_update_(model, decay)
|
||||
|
||||
def apply_update_(self, model, decay: float):
|
||||
# interpolate parameters and buffers
|
||||
if self.foreach:
|
||||
ema_lerp_values = []
|
||||
model_lerp_values = []
|
||||
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
|
||||
if ema_v.is_floating_point():
|
||||
ema_lerp_values.append(ema_v)
|
||||
model_lerp_values.append(model_v)
|
||||
else:
|
||||
ema_v.copy_(model_v)
|
||||
|
||||
if hasattr(torch, '_foreach_lerp_'):
|
||||
torch._foreach_lerp_(ema_lerp_values, model_lerp_values, weight=1. - decay)
|
||||
else:
|
||||
torch._foreach_mul_(ema_lerp_values, scalar=decay)
|
||||
torch._foreach_add_(ema_lerp_values, model_lerp_values, alpha=1. - decay)
|
||||
else:
|
||||
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
|
||||
if ema_v.is_floating_point():
|
||||
ema_v.lerp_(model_v, weight=1. - decay)
|
||||
else:
|
||||
ema_v.copy_(model_v)
|
||||
|
||||
def apply_update_no_buffers_(self, model, decay: float):
|
||||
# interpolate parameters, copy buffers
|
||||
ema_params = tuple(self.module.parameters())
|
||||
model_params = tuple(model.parameters())
|
||||
if self.foreach:
|
||||
if hasattr(torch, '_foreach_lerp_'):
|
||||
torch._foreach_lerp_(ema_params, model_params, weight=1. - decay)
|
||||
else:
|
||||
torch._foreach_mul_(ema_params, scalar=decay)
|
||||
torch._foreach_add_(ema_params, model_params, alpha=1 - decay)
|
||||
else:
|
||||
for ema_p, model_p in zip(ema_params, model_params):
|
||||
ema_p.lerp_(model_p, weight=1. - decay)
|
||||
|
||||
for ema_b, model_b in zip(self.module.buffers(), model.buffers()):
|
||||
ema_b.copy_(model_b.to(device=self.device))
|
||||
|
||||
@torch.no_grad()
|
||||
def set(self, model):
|
||||
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
|
||||
ema_v.copy_(model_v.to(device=self.device))
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module(*args, **kwargs)
|
58
train.py
58
train.py
@ -15,6 +15,7 @@ NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
|
||||
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
|
||||
"""
|
||||
import argparse
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@ -168,6 +169,24 @@ scripting_group.add_argument('--torchscript', dest='torchscript', action='store_
|
||||
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
|
||||
help="Enable compilation w/ specified backend (default: inductor).")
|
||||
|
||||
# Device & distributed
|
||||
group = parser.add_argument_group('Device parameters')
|
||||
group.add_argument('--device', default='cuda', type=str,
|
||||
help="Device (accelerator) to use.")
|
||||
group.add_argument('--amp', action='store_true', default=False,
|
||||
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
|
||||
group.add_argument('--amp-dtype', default='float16', type=str,
|
||||
help='lower precision AMP dtype (default: float16)')
|
||||
group.add_argument('--amp-impl', default='native', type=str,
|
||||
help='AMP impl to use, "native" or "apex" (default: native)')
|
||||
group.add_argument('--no-ddp-bb', action='store_true', default=False,
|
||||
help='Force broadcast buffers for native DDP to off.')
|
||||
group.add_argument('--synchronize-step', action='store_true', default=False,
|
||||
help='torch.cuda.synchronize() end of each step')
|
||||
group.add_argument("--local_rank", default=0, type=int)
|
||||
parser.add_argument('--device-modules', default=None, type=str, nargs='+',
|
||||
help="Python imports for device backend modules.")
|
||||
|
||||
# Optimizer parameters
|
||||
group = parser.add_argument_group('Optimizer parameters')
|
||||
group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
||||
@ -330,11 +349,13 @@ group.add_argument('--split-bn', action='store_true',
|
||||
# Model Exponential Moving Average
|
||||
group = parser.add_argument_group('Model exponential moving average parameters')
|
||||
group.add_argument('--model-ema', action='store_true', default=False,
|
||||
help='Enable tracking moving average of model weights')
|
||||
help='Enable tracking moving average of model weights.')
|
||||
group.add_argument('--model-ema-force-cpu', action='store_true', default=False,
|
||||
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
|
||||
group.add_argument('--model-ema-decay', type=float, default=0.9998,
|
||||
help='decay factor for model weights moving average (default: 0.9998)')
|
||||
help='Decay factor for model weights moving average (default: 0.9998)')
|
||||
group.add_argument('--model-ema-warmup', action='store_true',
|
||||
help='Enable warmup for model EMA decay.')
|
||||
|
||||
# Misc
|
||||
group = parser.add_argument_group('Miscellaneous parameters')
|
||||
@ -352,16 +373,6 @@ group.add_argument('-j', '--workers', type=int, default=4, metavar='N',
|
||||
help='how many training processes to use (default: 4)')
|
||||
group.add_argument('--save-images', action='store_true', default=False,
|
||||
help='save images of input bathes every log interval for debugging')
|
||||
group.add_argument('--amp', action='store_true', default=False,
|
||||
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
|
||||
group.add_argument('--amp-dtype', default='float16', type=str,
|
||||
help='lower precision AMP dtype (default: float16)')
|
||||
group.add_argument('--amp-impl', default='native', type=str,
|
||||
help='AMP impl to use, "native" or "apex" (default: native)')
|
||||
group.add_argument('--no-ddp-bb', action='store_true', default=False,
|
||||
help='Force broadcast buffers for native DDP to off.')
|
||||
group.add_argument('--synchronize-step', action='store_true', default=False,
|
||||
help='torch.cuda.synchronize() end of each step')
|
||||
group.add_argument('--pin-mem', action='store_true', default=False,
|
||||
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
||||
group.add_argument('--no-prefetcher', action='store_true', default=False,
|
||||
@ -374,7 +385,6 @@ group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METR
|
||||
help='Best metric (default: "top1"')
|
||||
group.add_argument('--tta', type=int, default=0, metavar='N',
|
||||
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
|
||||
group.add_argument("--local_rank", default=0, type=int)
|
||||
group.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
|
||||
help='use the multi-epochs-loader to save time at the beginning of every epoch')
|
||||
group.add_argument('--log-wandb', action='store_true', default=False,
|
||||
@ -402,6 +412,10 @@ def main():
|
||||
utils.setup_default_logging()
|
||||
args, args_text = _parse_args()
|
||||
|
||||
if args.device_modules:
|
||||
for module in args.device_modules:
|
||||
importlib.import_module(module)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
@ -586,10 +600,16 @@ def main():
|
||||
model_ema = None
|
||||
if args.model_ema:
|
||||
# Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
|
||||
model_ema = utils.ModelEmaV2(
|
||||
model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
|
||||
model_ema = utils.ModelEmaV3(
|
||||
model,
|
||||
decay=args.model_ema_decay,
|
||||
use_warmup=args.model_ema_warmup,
|
||||
device='cpu' if args.model_ema_force_cpu else None,
|
||||
)
|
||||
if args.resume:
|
||||
load_checkpoint(model_ema.module, args.resume, use_ema=True)
|
||||
if args.torchcompile:
|
||||
model_ema = torch.compile(model_ema, backend=args.torchcompile)
|
||||
|
||||
# setup distributed training
|
||||
if args.distributed:
|
||||
@ -847,6 +867,7 @@ def main():
|
||||
loss_scaler=loss_scaler,
|
||||
model_ema=model_ema,
|
||||
mixup_fn=mixup_fn,
|
||||
num_updates_total=num_epochs * updates_per_epoch,
|
||||
)
|
||||
|
||||
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
||||
@ -860,6 +881,7 @@ def main():
|
||||
loader_eval,
|
||||
validate_loss_fn,
|
||||
args,
|
||||
device=device,
|
||||
amp_autocast=amp_autocast,
|
||||
)
|
||||
|
||||
@ -868,10 +890,11 @@ def main():
|
||||
utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
|
||||
|
||||
ema_eval_metrics = validate(
|
||||
model_ema.module,
|
||||
model_ema,
|
||||
loader_eval,
|
||||
validate_loss_fn,
|
||||
args,
|
||||
device=device,
|
||||
amp_autocast=amp_autocast,
|
||||
log_suffix=' (EMA)',
|
||||
)
|
||||
@ -935,6 +958,7 @@ def train_one_epoch(
|
||||
loss_scaler=None,
|
||||
model_ema=None,
|
||||
mixup_fn=None,
|
||||
num_updates_total=None,
|
||||
):
|
||||
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
|
||||
if args.prefetcher and loader.mixup_enabled:
|
||||
@ -1026,7 +1050,7 @@ def train_one_epoch(
|
||||
num_updates += 1
|
||||
optimizer.zero_grad()
|
||||
if model_ema is not None:
|
||||
model_ema.update(model)
|
||||
model_ema.update(model, step=num_updates)
|
||||
|
||||
if args.synchronize_step and device.type == 'cuda':
|
||||
torch.cuda.synchronize()
|
||||
|
Loading…
x
Reference in New Issue
Block a user