Merge pull request #2092 from huggingface/mesa_ema

ModelEMAV3 + MESA experiments
This commit is contained in:
Ross Wightman 2024-02-10 23:10:27 -08:00 committed by GitHub
commit 1b50b15145
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1065 additions and 119 deletions

View File

@ -32,16 +32,12 @@ class ToNumpy:
class ToTensor: class ToTensor:
""" ToTensor with no rescaling of values"""
def __init__(self, dtype=torch.float32): def __init__(self, dtype=torch.float32):
self.dtype = dtype self.dtype = dtype
def __call__(self, pil_img): def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8) return F.pil_to_tensor(pil_img).to(dtype=self.dtype)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return torch.from_numpy(np_img).to(dtype=self.dtype)
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in # Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in

View File

@ -180,10 +180,10 @@ class NormMlpClassifierHead(nn.Module):
self.drop = nn.Dropout(drop_rate) self.drop = nn.Dropout(drop_rate)
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def reset(self, num_classes, global_pool=None): def reset(self, num_classes, pool_type=None):
if global_pool is not None: if pool_type is not None:
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
self.use_conv = self.global_pool.is_identity() self.use_conv = self.global_pool.is_identity()
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
if self.hidden_size: if self.hidden_size:

View File

@ -148,7 +148,7 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
return _ACT_LAYER_DEFAULT[name] return _ACT_LAYER_DEFAULT[name]
def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs): def create_act_layer(name: Union[Type[nn.Module], str], inplace=None, **kwargs):
act_layer = get_act_layer(name) act_layer = get_act_layer(name)
if act_layer is None: if act_layer is None:
return None return None

View File

@ -39,6 +39,7 @@ from .mobilevit import *
from .mvitv2 import * from .mvitv2 import *
from .nasnet import * from .nasnet import *
from .nest import * from .nest import *
from .nextvit import *
from .nfnet import * from .nfnet import *
from .pit import * from .pit import *
from .pnasnet import * from .pnasnet import *

View File

@ -547,6 +547,17 @@ class DaVit(nn.Module):
if isinstance(m, nn.Linear) and m.bias is not None: if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem', # stem and embed
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+).downsample', (0,)),
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
(r'^norm_pre', (99999,)),
]
)
@torch.jit.ignore @torch.jit.ignore
def set_grad_checkpointing(self, enable=True): def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable self.grad_checkpointing = enable
@ -558,7 +569,7 @@ class DaVit(nn.Module):
return self.head.fc return self.head.fc
def reset_classifier(self, num_classes, global_pool=None): def reset_classifier(self, num_classes, global_pool=None):
self.head.reset(num_classes, global_pool=global_pool) self.head.reset(num_classes, global_pool)
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)

685
timm/models/nextvit.py Normal file
View File

@ -0,0 +1,685 @@
""" Next-ViT
As described in https://arxiv.org/abs/2207.05501
Next-ViT model defs and weights adapted from https://github.com/bytedance/Next-ViT, original copyright below
"""
# Copyright (c) ByteDance Inc. All rights reserved.
from functools import partial
import torch
import torch.nn.functional as F
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, ConvMlp, get_norm_layer, get_act_layer, use_fused_attn
from timm.layers import ClassifierHead
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model
def merge_pre_bn(module, pre_bn_1, pre_bn_2=None):
""" Merge pre BN to reduce inference runtime.
"""
weight = module.weight.data
if module.bias is None:
zeros = torch.zeros(module.out_chs, device=weight.device).type(weight.type())
module.bias = nn.Parameter(zeros)
bias = module.bias.data
if pre_bn_2 is None:
assert pre_bn_1.track_running_stats is True, "Unsupported bn_module.track_running_stats is False"
assert pre_bn_1.affine is True, "Unsupported bn_module.affine is False"
scale_invstd = pre_bn_1.running_var.add(pre_bn_1.eps).pow(-0.5)
extra_weight = scale_invstd * pre_bn_1.weight
extra_bias = pre_bn_1.bias - pre_bn_1.weight * pre_bn_1.running_mean * scale_invstd
else:
assert pre_bn_1.track_running_stats is True, "Unsupported bn_module.track_running_stats is False"
assert pre_bn_1.affine is True, "Unsupported bn_module.affine is False"
assert pre_bn_2.track_running_stats is True, "Unsupported bn_module.track_running_stats is False"
assert pre_bn_2.affine is True, "Unsupported bn_module.affine is False"
scale_invstd_1 = pre_bn_1.running_var.add(pre_bn_1.eps).pow(-0.5)
scale_invstd_2 = pre_bn_2.running_var.add(pre_bn_2.eps).pow(-0.5)
extra_weight = scale_invstd_1 * pre_bn_1.weight * scale_invstd_2 * pre_bn_2.weight
extra_bias = (
scale_invstd_2 * pre_bn_2.weight
* (pre_bn_1.bias - pre_bn_1.weight * pre_bn_1.running_mean * scale_invstd_1 - pre_bn_2.running_mean)
+ pre_bn_2.bias
)
if isinstance(module, nn.Linear):
extra_bias = weight @ extra_bias
weight.mul_(extra_weight.view(1, weight.size(1)).expand_as(weight))
elif isinstance(module, nn.Conv2d):
assert weight.shape[2] == 1 and weight.shape[3] == 1
weight = weight.reshape(weight.shape[0], weight.shape[1])
extra_bias = weight @ extra_bias
weight.mul_(extra_weight.view(1, weight.size(1)).expand_as(weight))
weight = weight.reshape(weight.shape[0], weight.shape[1], 1, 1)
bias.add_(extra_bias)
module.weight.data = weight
module.bias.data = bias
class ConvNormAct(nn.Module):
def __init__(
self,
in_chs,
out_chs,
kernel_size=3,
stride=1,
groups=1,
norm_layer=nn.BatchNorm2d,
act_layer=nn.ReLU,
):
super(ConvNormAct, self).__init__()
self.conv = nn.Conv2d(
in_chs, out_chs, kernel_size=kernel_size, stride=stride,
padding=1, groups=groups, bias=False)
self.norm = norm_layer(out_chs)
self.act = act_layer()
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
x = self.act(x)
return x
def _make_divisible(v, divisor, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class PatchEmbed(nn.Module):
def __init__(self,
in_chs,
out_chs,
stride=1,
norm_layer = nn.BatchNorm2d,
):
super(PatchEmbed, self).__init__()
if stride == 2:
self.pool = nn.AvgPool2d((2, 2), stride=2, ceil_mode=True, count_include_pad=False)
self.conv = nn.Conv2d(in_chs, out_chs, kernel_size=1, stride=1, bias=False)
self.norm = norm_layer(out_chs)
elif in_chs != out_chs:
self.pool = nn.Identity()
self.conv = nn.Conv2d(in_chs, out_chs, kernel_size=1, stride=1, bias=False)
self.norm = norm_layer(out_chs)
else:
self.pool = nn.Identity()
self.conv = nn.Identity()
self.norm = nn.Identity()
def forward(self, x):
return self.norm(self.conv(self.pool(x)))
class ConvAttention(nn.Module):
"""
Multi-Head Convolutional Attention
"""
def __init__(self, out_chs, head_dim, norm_layer = nn.BatchNorm2d, act_layer = nn.ReLU):
super(ConvAttention, self).__init__()
self.group_conv3x3 = nn.Conv2d(
out_chs, out_chs,
kernel_size=3, stride=1, padding=1, groups=out_chs // head_dim, bias=False
)
self.norm = norm_layer(out_chs)
self.act = act_layer()
self.projection = nn.Conv2d(out_chs, out_chs, kernel_size=1, bias=False)
def forward(self, x):
out = self.group_conv3x3(x)
out = self.norm(out)
out = self.act(out)
out = self.projection(out)
return out
class NextConvBlock(nn.Module):
"""
Next Convolution Block
"""
def __init__(
self,
in_chs,
out_chs,
stride=1,
drop_path=0.,
drop=0.,
head_dim=32,
mlp_ratio=3.,
norm_layer=nn.BatchNorm2d,
act_layer=nn.ReLU
):
super(NextConvBlock, self).__init__()
self.in_chs = in_chs
self.out_chs = out_chs
assert out_chs % head_dim == 0
self.patch_embed = PatchEmbed(in_chs, out_chs, stride, norm_layer=norm_layer)
self.mhca = ConvAttention(
out_chs,
head_dim,
norm_layer=norm_layer,
act_layer=act_layer,
)
self.attn_drop_path = DropPath(drop_path)
self.norm = norm_layer(out_chs)
self.mlp = ConvMlp(
out_chs,
hidden_features=int(out_chs * mlp_ratio),
drop=drop,
bias=True,
act_layer=act_layer,
)
self.mlp_drop_path = DropPath(drop_path)
self.is_fused = False
@torch.no_grad()
def reparameterize(self):
if not self.is_fused:
merge_pre_bn(self.mlp.fc1, self.norm)
self.norm = None
self.is_fused = True
def forward(self, x):
x = self.patch_embed(x)
x = x + self.attn_drop_path(self.mhca(x))
out = self.norm(x)
x = x + self.mlp_drop_path(self.mlp(out))
return x
class EfficientAttention(nn.Module):
"""
Efficient Multi-Head Self Attention
"""
fused_attn: torch.jit.Final[bool]
def __init__(
self,
dim,
out_dim=None,
head_dim=32,
qkv_bias=True,
attn_drop=0.,
proj_drop=0.,
sr_ratio=1,
norm_layer=nn.BatchNorm1d,
):
super().__init__()
self.dim = dim
self.out_dim = out_dim if out_dim is not None else dim
self.num_heads = self.dim // head_dim
self.head_dim = head_dim
self.scale = head_dim ** -0.5
self.fused_attn = use_fused_attn()
self.q = nn.Linear(dim, self.dim, bias=qkv_bias)
self.k = nn.Linear(dim, self.dim, bias=qkv_bias)
self.v = nn.Linear(dim, self.dim, bias=qkv_bias)
self.proj = nn.Linear(self.dim, self.out_dim)
self.attn_drop = nn.Dropout(attn_drop)
self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio
self.N_ratio = sr_ratio ** 2
if sr_ratio > 1:
self.sr = nn.AvgPool1d(kernel_size=self.N_ratio, stride=self.N_ratio)
self.norm = norm_layer(dim)
else:
self.sr = None
self.norm = None
def forward(self, x):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
if self.sr is not None:
x = self.sr(x.transpose(1, 2))
x = self.norm(x).transpose(1, 2)
k = self.k(x).reshape(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v(x).reshape(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
attn = q @ k.transpose(-1, -2)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class NextTransformerBlock(nn.Module):
"""
Next Transformer Block
"""
def __init__(
self,
in_chs,
out_chs,
drop_path,
stride=1,
sr_ratio=1,
mlp_ratio=2,
head_dim=32,
mix_block_ratio=0.75,
attn_drop=0.,
drop=0.,
norm_layer=nn.BatchNorm2d,
act_layer=nn.ReLU,
):
super(NextTransformerBlock, self).__init__()
self.in_chs = in_chs
self.out_chs = out_chs
self.mix_block_ratio = mix_block_ratio
self.mhsa_out_chs = _make_divisible(int(out_chs * mix_block_ratio), 32)
self.mhca_out_chs = out_chs - self.mhsa_out_chs
self.patch_embed = PatchEmbed(in_chs, self.mhsa_out_chs, stride)
self.norm1 = norm_layer(self.mhsa_out_chs)
self.e_mhsa = EfficientAttention(
self.mhsa_out_chs,
head_dim=head_dim,
sr_ratio=sr_ratio,
attn_drop=attn_drop,
proj_drop=drop,
)
self.mhsa_drop_path = DropPath(drop_path * mix_block_ratio)
self.projection = PatchEmbed(self.mhsa_out_chs, self.mhca_out_chs, stride=1, norm_layer=norm_layer)
self.mhca = ConvAttention(
self.mhca_out_chs,
head_dim=head_dim,
norm_layer=norm_layer,
act_layer=act_layer,
)
self.mhca_drop_path = DropPath(drop_path * (1 - mix_block_ratio))
self.norm2 = norm_layer(out_chs)
self.mlp = ConvMlp(
out_chs,
hidden_features=int(out_chs * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
self.mlp_drop_path = DropPath(drop_path)
self.is_fused = False
@torch.no_grad()
def reparameterize(self):
if not self.is_fused:
merge_pre_bn(self.e_mhsa.q, self.norm1)
if self.e_mhsa.norm is not None:
merge_pre_bn(self.e_mhsa.k, self.norm1, self.e_mhsa.norm)
merge_pre_bn(self.e_mhsa.v, self.norm1, self.e_mhsa.norm)
self.e_mhsa.norm = nn.Identity()
else:
merge_pre_bn(self.e_mhsa.k, self.norm1)
merge_pre_bn(self.e_mhsa.v, self.norm1)
self.norm1 = nn.Identity()
merge_pre_bn(self.mlp.fc1, self.norm2)
self.norm2 = nn.Identity()
self.is_fused = True
def forward(self, x):
x = self.patch_embed(x)
B, C, H, W = x.shape
out = self.norm1(x)
out = out.reshape(B, C, -1).transpose(-1, -2)
out = self.mhsa_drop_path(self.e_mhsa(out))
x = x + out.transpose(-1, -2).reshape(B, C, H, W)
out = self.projection(x)
out = out + self.mhca_drop_path(self.mhca(out))
x = torch.cat([x, out], dim=1)
out = self.norm2(x)
x = x + self.mlp_drop_path(self.mlp(out))
return x
class NextStage(nn.Module):
def __init__(
self,
in_chs,
block_chs,
block_types,
stride=2,
sr_ratio=1,
mix_block_ratio=1.0,
drop=0.,
attn_drop=0.,
drop_path=0.,
head_dim=32,
norm_layer=nn.BatchNorm2d,
act_layer=nn.ReLU,
):
super().__init__()
self.grad_checkpointing = False
blocks = []
for block_idx, block_type in enumerate(block_types):
stride = stride if block_idx == 0 else 1
out_chs = block_chs[block_idx]
block_type = block_types[block_idx]
dpr = drop_path[block_idx] if isinstance(drop_path, (list, tuple)) else drop_path
if block_type is NextConvBlock:
layer = NextConvBlock(
in_chs,
out_chs,
stride=stride,
drop_path=dpr,
drop=drop,
head_dim=head_dim,
norm_layer=norm_layer,
act_layer=act_layer,
)
blocks.append(layer)
elif block_type is NextTransformerBlock:
layer = NextTransformerBlock(
in_chs,
out_chs,
drop_path=dpr,
stride=stride,
sr_ratio=sr_ratio,
head_dim=head_dim,
mix_block_ratio=mix_block_ratio,
attn_drop=attn_drop,
drop=drop,
norm_layer=norm_layer,
act_layer=act_layer,
)
blocks.append(layer)
in_chs = out_chs
self.blocks = nn.Sequential(*blocks)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
def forward(self, x):
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x
class NextViT(nn.Module):
def __init__(
self,
in_chans,
num_classes=1000,
global_pool='avg',
stem_chs=(64, 32, 64),
depths=(3, 4, 10, 3),
strides=(1, 2, 2, 2),
sr_ratios=(8, 4, 2, 1),
drop_path_rate=0.1,
attn_drop_rate=0.,
drop_rate=0.,
head_dim=32,
mix_block_ratio=0.75,
norm_layer=nn.BatchNorm2d,
act_layer=None,
):
super(NextViT, self).__init__()
self.grad_checkpointing = False
self.num_classes = num_classes
norm_layer = get_norm_layer(norm_layer)
if act_layer is None:
act_layer = partial(nn.ReLU, inplace=True)
else:
act_layer = get_act_layer(act_layer)
self.stage_out_chs = [
[96] * (depths[0]),
[192] * (depths[1] - 1) + [256],
[384, 384, 384, 384, 512] * (depths[2] // 5),
[768] * (depths[3] - 1) + [1024]
]
self.feature_info = [dict(
num_chs=sc[-1],
reduction=2**(i + 2),
module=f'stages.{i}'
) for i, sc in enumerate(self.stage_out_chs)]
# Next Hybrid Strategy
self.stage_block_types = [
[NextConvBlock] * depths[0],
[NextConvBlock] * (depths[1] - 1) + [NextTransformerBlock],
[NextConvBlock, NextConvBlock, NextConvBlock, NextConvBlock, NextTransformerBlock] * (depths[2] // 5),
[NextConvBlock] * (depths[3] - 1) + [NextTransformerBlock]]
self.stem = nn.Sequential(
ConvNormAct(in_chans, stem_chs[0], kernel_size=3, stride=2, norm_layer=norm_layer, act_layer=act_layer),
ConvNormAct(stem_chs[0], stem_chs[1], kernel_size=3, stride=1, norm_layer=norm_layer, act_layer=act_layer),
ConvNormAct(stem_chs[1], stem_chs[2], kernel_size=3, stride=1, norm_layer=norm_layer, act_layer=act_layer),
ConvNormAct(stem_chs[2], stem_chs[2], kernel_size=3, stride=2, norm_layer=norm_layer, act_layer=act_layer),
)
in_chs = out_chs = stem_chs[-1]
stages = []
idx = 0
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
for stage_idx in range(len(depths)):
stage = NextStage(
in_chs=in_chs,
block_chs=self.stage_out_chs[stage_idx],
block_types=self.stage_block_types[stage_idx],
stride=strides[stage_idx],
sr_ratio=sr_ratios[stage_idx],
mix_block_ratio=mix_block_ratio,
head_dim=head_dim,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[stage_idx],
norm_layer=norm_layer,
act_layer=act_layer,
)
in_chs = out_chs = self.stage_out_chs[stage_idx][-1]
stages += [stage]
idx += depths[stage_idx]
self.num_features = out_chs
self.stages = nn.Sequential(*stages)
self.norm = norm_layer(out_chs)
self.head = ClassifierHead(pool_type=global_pool, in_features=out_chs, num_classes=num_classes)
self.stage_out_idx = [sum(depths[:idx + 1]) - 1 for idx in range(len(depths))]
self._initialize_weights()
def _initialize_weights(self):
for n, m in self.named_modules():
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=.02)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0)
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem', # stem and embed
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
(r'^norm', (99999,)),
]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
for stage in self.stages:
stage.set_grad_checkpointing(enable=enable)
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
self.head.reset(num_classes, pool_type=global_pool)
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stages, x)
else:
x = self.stages(x)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def checkpoint_filter_fn(state_dict, model):
""" Remap original checkpoints -> timm """
if 'head.fc.weight' in state_dict:
return state_dict # non-original
D = model.state_dict()
out_dict = {}
# remap originals based on order
for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()):
out_dict[ka] = vb
return out_dict
def _create_nextvit(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
model = build_model_with_cfg(
NextViT,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs)
return model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.95, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.0.conv', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = generate_default_cfgs({
'nextvit_small.bd_in1k': _cfg(
hf_hub_id='timm/',
),
'nextvit_base.bd_in1k': _cfg(
hf_hub_id='timm/',
),
'nextvit_large.bd_in1k': _cfg(
hf_hub_id='timm/',
),
'nextvit_small.bd_in1k_384': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
),
'nextvit_base.bd_in1k_384': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
),
'nextvit_large.bd_in1k_384': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
),
'nextvit_small.bd_ssld_6m_in1k': _cfg(
hf_hub_id='timm/',
),
'nextvit_base.bd_ssld_6m_in1k': _cfg(
hf_hub_id='timm/',
),
'nextvit_large.bd_ssld_6m_in1k': _cfg(
hf_hub_id='timm/',
),
'nextvit_small.bd_ssld_6m_in1k_384': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
),
'nextvit_base.bd_ssld_6m_in1k_384': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
),
'nextvit_large.bd_ssld_6m_in1k_384': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
),
})
@register_model
def nextvit_small(pretrained=False, **kwargs):
model_args = dict(depths=(3, 4, 10, 3), drop_path_rate=0.1)
model = _create_nextvit(
'nextvit_small', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def nextvit_base(pretrained=False, **kwargs):
model_args = dict(depths=(3, 4, 20, 3), drop_path_rate=0.2)
model = _create_nextvit(
'nextvit_base', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def nextvit_large(pretrained=False, **kwargs):
model_args = dict(depths=(3, 4, 30, 3), drop_path_rate=0.2)
model = _create_nextvit(
'nextvit_large', pretrained=pretrained, **dict(model_args, **kwargs))
return model

View File

@ -535,7 +535,7 @@ class TinyVit(nn.Module):
def reset_classifier(self, num_classes, global_pool=None): def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes self.num_classes = num_classes
self.head.reset(num_classes, global_pool=global_pool) self.head.reset(num_classes, pool_type=global_pool)
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)

View File

@ -421,6 +421,7 @@ class VisionTransformer(nn.Module):
attn_drop_rate: float = 0., attn_drop_rate: float = 0.,
drop_path_rate: float = 0., drop_path_rate: float = 0.,
weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '', weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '',
fix_init: bool = False,
embed_layer: Callable = PatchEmbed, embed_layer: Callable = PatchEmbed,
norm_layer: Optional[LayerType] = None, norm_layer: Optional[LayerType] = None,
act_layer: Optional[LayerType] = None, act_layer: Optional[LayerType] = None,
@ -449,6 +450,7 @@ class VisionTransformer(nn.Module):
attn_drop_rate: Attention dropout rate. attn_drop_rate: Attention dropout rate.
drop_path_rate: Stochastic depth rate. drop_path_rate: Stochastic depth rate.
weight_init: Weight initialization scheme. weight_init: Weight initialization scheme.
fix_init: Apply weight initialization fix (scaling w/ layer index).
embed_layer: Patch embedding layer. embed_layer: Patch embedding layer.
norm_layer: Normalization layer. norm_layer: Normalization layer.
act_layer: MLP activation layer. act_layer: MLP activation layer.
@ -536,8 +538,18 @@ class VisionTransformer(nn.Module):
if weight_init != 'skip': if weight_init != 'skip':
self.init_weights(weight_init) self.init_weights(weight_init)
if fix_init:
self.fix_init_weight()
def init_weights(self, mode: Literal['jax', 'jax_nlhb', 'moco', ''] = '') -> None: def fix_init_weight(self):
def rescale(param, _layer_id):
param.div_(math.sqrt(2.0 * _layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def init_weights(self, mode: str = '') -> None:
assert mode in ('jax', 'jax_nlhb', 'moco', '') assert mode in ('jax', 'jax_nlhb', 'moco', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.pos_embed, std=.02)
@ -737,7 +749,7 @@ def init_weights_vit_moco(module: nn.Module, name: str = '') -> None:
module.init_weights() module.init_weights()
def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> None: def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable:
if 'jax' in mode: if 'jax' in mode:
return partial(init_weights_vit_jax, head_bias=head_bias) return partial(init_weights_vit_jax, head_bias=head_bias)
elif 'moco' in mode: elif 'moco' in mode:
@ -1723,7 +1735,12 @@ default_cfgs = {
input_size=(3, 256, 256)), input_size=(3, 256, 256)),
'vit_medium_patch16_reg4_gap_256': _cfg( 'vit_medium_patch16_reg4_gap_256': _cfg(
input_size=(3, 256, 256)), input_size=(3, 256, 256)),
'vit_base_patch16_reg8_gap_256': _cfg(input_size=(3, 256, 256)), 'vit_base_patch16_reg4_gap_256': _cfg(
input_size=(3, 256, 256)),
'vit_so150m_patch16_reg4_gap_256': _cfg(
input_size=(3, 256, 256)),
'vit_so150m_patch16_reg4_map_256': _cfg(
input_size=(3, 256, 256)),
} }
_quick_gelu_cfgs = [ _quick_gelu_cfgs = [
@ -2623,13 +2640,35 @@ def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
@register_model @register_model
def vit_base_patch16_reg8_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: def vit_base_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict( model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False,
no_embed_class=True, global_pool='avg', reg_tokens=8, no_embed_class=True, global_pool='avg', reg_tokens=4,
) )
model = _create_vision_transformer( model = _create_vision_transformer(
'vit_base_patch16_reg8_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) 'vit_base_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_so150m_patch16_reg4_map_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
class_token=False, reg_tokens=4, global_pool='map',
)
model = _create_vision_transformer(
'vit_so150m_patch16_reg4_map_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
class_token=False, reg_tokens=4, global_pool='avg', fc_norm=False,
)
model = _create_vision_transformer(
'vit_so150m_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model

View File

@ -7,7 +7,12 @@ Hacked together by / Copyright 2022, Ross Wightman
import logging import logging
import math import math
from functools import partial from functools import partial
from typing import Optional, Tuple from typing import Optional, Tuple, Type, Union
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -15,9 +20,11 @@ from torch.jit import Final
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn, LayerType
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import named_apply
from ._registry import generate_default_cfgs, register_model from ._registry import generate_default_cfgs, register_model
from .vision_transformer import get_init_weights_vit
__all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this __all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this
@ -215,59 +222,61 @@ class VisionTransformerRelPos(nn.Module):
def __init__( def __init__(
self, self,
img_size=224, img_size: Union[int, Tuple[int, int]] = 224,
patch_size=16, patch_size: Union[int, Tuple[int, int]] = 16,
in_chans=3, in_chans: int = 3,
num_classes=1000, num_classes: int = 1000,
global_pool='avg', global_pool: Literal['', 'avg', 'token', 'map'] = 'avg',
embed_dim=768, embed_dim: int = 768,
depth=12, depth: int = 12,
num_heads=12, num_heads: int = 12,
mlp_ratio=4., mlp_ratio: float = 4.,
qkv_bias=True, qkv_bias: bool = True,
qk_norm=False, qk_norm: bool = False,
init_values=1e-6, init_values: Optional[float] = 1e-6,
class_token=False, class_token: bool = False,
fc_norm=False, fc_norm: bool = False,
rel_pos_type='mlp', rel_pos_type: str = 'mlp',
rel_pos_dim=None, rel_pos_dim: Optional[int] = None,
shared_rel_pos=False, shared_rel_pos: bool = False,
drop_rate=0., drop_rate: float = 0.,
proj_drop_rate=0., proj_drop_rate: float = 0.,
attn_drop_rate=0., attn_drop_rate: float = 0.,
drop_path_rate=0., drop_path_rate: float = 0.,
weight_init='skip', weight_init: Literal['skip', 'jax', 'moco', ''] = 'skip',
embed_layer=PatchEmbed, fix_init: bool = False,
norm_layer=None, embed_layer: Type[nn.Module] = PatchEmbed,
act_layer=None, norm_layer: Optional[LayerType] = None,
block_fn=RelPosBlock act_layer: Optional[LayerType] = None,
block_fn: Type[nn.Module] = RelPosBlock
): ):
""" """
Args: Args:
img_size (int, tuple): input image size img_size: input image size
patch_size (int, tuple): patch size patch_size: patch size
in_chans (int): number of input channels in_chans: number of input channels
num_classes (int): number of classes for classification head num_classes: number of classes for classification head
global_pool (str): type of global pooling for final sequence (default: 'avg') global_pool: type of global pooling for final sequence (default: 'avg')
embed_dim (int): embedding dimension embed_dim: embedding dimension
depth (int): depth of transformer depth: depth of transformer
num_heads (int): number of attention heads num_heads: number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim mlp_ratio: ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True qkv_bias: enable bias for qkv if True
qk_norm (bool): Enable normalization of query and key in attention qk_norm: Enable normalization of query and key in attention
init_values: (float): layer-scale init values init_values: layer-scale init values
class_token (bool): use class token (default: False) class_token: use class token (default: False)
fc_norm (bool): use pre classifier norm instead of pre-pool fc_norm: use pre classifier norm instead of pre-pool
rel_pos_ty pe (str): type of relative position rel_pos_type: type of relative position
shared_rel_pos (bool): share relative pos across all blocks shared_rel_pos: share relative pos across all blocks
drop_rate (float): dropout rate drop_rate: dropout rate
proj_drop_rate (float): projection dropout rate proj_drop_rate: projection dropout rate
attn_drop_rate (float): attention dropout rate attn_drop_rate: attention dropout rate
drop_path_rate (float): stochastic depth rate drop_path_rate: stochastic depth rate
weight_init (str): weight init scheme weight_init: weight init scheme
embed_layer (nn.Module): patch embedding layer fix_init: apply weight initialization fix (scaling w/ layer index)
norm_layer: (nn.Module): normalization layer embed_layer: patch embedding layer
act_layer: (nn.Module): MLP activation layer norm_layer: normalization layer
act_layer: MLP activation layer
""" """
super().__init__() super().__init__()
assert global_pool in ('', 'avg', 'token') assert global_pool in ('', 'avg', 'token')
@ -332,13 +341,22 @@ class VisionTransformerRelPos(nn.Module):
if weight_init != 'skip': if weight_init != 'skip':
self.init_weights(weight_init) self.init_weights(weight_init)
if fix_init:
self.fix_init_weight()
def init_weights(self, mode=''): def init_weights(self, mode=''):
assert mode in ('jax', 'moco', '') assert mode in ('jax', 'moco', '')
if self.cls_token is not None: if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6) nn.init.normal_(self.cls_token, std=1e-6)
# FIXME weight init scheme using PyTorch defaults curently named_apply(get_init_weights_vit(mode), self)
#named_apply(get_init_weights_vit(mode, head_bias), self)
def fix_init_weight(self):
def rescale(param, _layer_id):
param.div_(math.sqrt(2.0 * _layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
@torch.jit.ignore @torch.jit.ignore
def no_weight_decay(self): def no_weight_decay(self):

View File

@ -10,6 +10,6 @@ from .log import setup_default_logging, FormatterNoInfo
from .metrics import AverageMeter, accuracy from .metrics import AverageMeter, accuracy
from .misc import natural_key, add_bool_arg, ParseKwargs from .misc import natural_key, add_bool_arg, ParseKwargs
from .model import unwrap_model, get_state_dict, freeze, unfreeze, reparameterize_model from .model import unwrap_model, get_state_dict, freeze, unfreeze, reparameterize_model
from .model_ema import ModelEma, ModelEmaV2 from .model_ema import ModelEma, ModelEmaV2, ModelEmaV3
from .random import random_seed from .random import random_seed
from .summary import update_summary, get_outdir from .summary import update_summary, get_outdir

View File

@ -2,18 +2,17 @@
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import logging
import os import os
from typing import Optional
import torch import torch
from torch import distributed as dist from torch import distributed as dist
try:
import horovod.torch as hvd
except ImportError:
hvd = None
from .model import unwrap_model from .model import unwrap_model
_logger = logging.getLogger(__name__)
def reduce_tensor(tensor, n): def reduce_tensor(tensor, n):
rt = tensor.clone() rt = tensor.clone()
@ -84,9 +83,39 @@ def init_distributed_device(args):
args.world_size = 1 args.world_size = 1
args.rank = 0 # global rank args.rank = 0 # global rank
args.local_rank = 0 args.local_rank = 0
result = init_distributed_device_so(
device=getattr(args, 'device', 'cuda'),
dist_backend=getattr(args, 'dist_backend', None),
dist_url=getattr(args, 'dist_url', None),
)
args.device = result['device']
args.world_size = result['world_size']
args.rank = result['global_rank']
args.local_rank = result['local_rank']
args.distributed = result['distributed']
device = torch.device(args.device)
return device
def init_distributed_device_so(
device: str = 'cuda',
dist_backend: Optional[str] = None,
dist_url: Optional[str] = None,
):
# Distributed training = training on more than one GPU.
# Works in both single and multi-node scenarios.
distributed = False
world_size = 1
global_rank = 0
local_rank = 0
if dist_backend is None:
# FIXME sane defaults for other device backends?
dist_backend = 'nccl' if 'cuda' in device else 'gloo'
dist_url = dist_url or 'env://'
# TBD, support horovod? # TBD, support horovod?
# if args.horovod: # if args.horovod:
# import horovod.torch as hvd
# assert hvd is not None, "Horovod is not installed" # assert hvd is not None, "Horovod is not installed"
# hvd.init() # hvd.init()
# args.local_rank = int(hvd.local_rank()) # args.local_rank = int(hvd.local_rank())
@ -96,42 +125,51 @@ def init_distributed_device(args):
# os.environ['LOCAL_RANK'] = str(args.local_rank) # os.environ['LOCAL_RANK'] = str(args.local_rank)
# os.environ['RANK'] = str(args.rank) # os.environ['RANK'] = str(args.rank)
# os.environ['WORLD_SIZE'] = str(args.world_size) # os.environ['WORLD_SIZE'] = str(args.world_size)
dist_backend = getattr(args, 'dist_backend', 'nccl')
dist_url = getattr(args, 'dist_url', 'env://')
if is_distributed_env(): if is_distributed_env():
if 'SLURM_PROCID' in os.environ: if 'SLURM_PROCID' in os.environ:
# DDP via SLURM # DDP via SLURM
args.local_rank, args.rank, args.world_size = world_info_from_env() local_rank, global_rank, world_size = world_info_from_env()
# SLURM var -> torch.distributed vars in case needed # SLURM var -> torch.distributed vars in case needed
os.environ['LOCAL_RANK'] = str(args.local_rank) os.environ['LOCAL_RANK'] = str(local_rank)
os.environ['RANK'] = str(args.rank) os.environ['RANK'] = str(global_rank)
os.environ['WORLD_SIZE'] = str(args.world_size) os.environ['WORLD_SIZE'] = str(world_size)
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=dist_backend, backend=dist_backend,
init_method=dist_url, init_method=dist_url,
world_size=args.world_size, world_size=world_size,
rank=args.rank, rank=global_rank,
) )
else: else:
# DDP via torchrun, torch.distributed.launch # DDP via torchrun, torch.distributed.launch
args.local_rank, _, _ = world_info_from_env() local_rank, _, _ = world_info_from_env()
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=dist_backend, backend=dist_backend,
init_method=dist_url, init_method=dist_url,
) )
args.world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank() global_rank = torch.distributed.get_rank()
args.distributed = True distributed = True
if torch.cuda.is_available(): if 'cuda' in device:
if args.distributed: assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
device = 'cuda:%d' % args.local_rank
else: if distributed and device != 'cpu':
device = 'cuda:0' device, *device_idx = device.split(':', maxsplit=1)
# Ignore manually specified device index in distributed mode and
# override with resolved local rank, fewer headaches in most setups.
if device_idx:
_logger.warning(f'device index {device_idx[0]} removed from specified ({device}).')
device = f'{device}:{local_rank}'
if device.startswith('cuda:'):
torch.cuda.set_device(device) torch.cuda.set_device(device)
else:
device = 'cpu'
args.device = device return dict(
device = torch.device(device) device=device,
return device global_rank=global_rank,
local_rank=local_rank,
world_size=world_size,
distributed=distributed,
)

View File

@ -5,6 +5,7 @@ Hacked together by / Copyright 2020 Ross Wightman
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -103,7 +104,7 @@ class ModelEmaV2(nn.Module):
GPU assignment and distributed training wrappers. GPU assignment and distributed training wrappers.
""" """
def __init__(self, model, decay=0.9999, device=None): def __init__(self, model, decay=0.9999, device=None):
super(ModelEmaV2, self).__init__() super().__init__()
# make a copy of the model for accumulating moving average of weights # make a copy of the model for accumulating moving average of weights
self.module = deepcopy(model) self.module = deepcopy(model)
self.module.eval() self.module.eval()
@ -124,3 +125,136 @@ class ModelEmaV2(nn.Module):
def set(self, model): def set(self, model):
self._update(model, update_fn=lambda e, m: m) self._update(model, update_fn=lambda e, m: m)
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
class ModelEmaV3(nn.Module):
""" Model Exponential Moving Average V3
Keep a moving average of everything in the model state_dict (parameters and buffers).
V3 of this module leverages for_each and in-place operations for faster performance.
Decay warmup based on code by @crowsonkb, her comments:
If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
good values for models you plan to train for a million or more steps (reaches decay
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
215.4k steps).
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
disable validation of the EMA weights. Validation will have to be done manually in a separate
process, or after the training stops converging.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
"""
def __init__(
self,
model,
decay: float = 0.9999,
min_decay: float = 0.0,
update_after_step: int = 0,
use_warmup: bool = False,
warmup_gamma: float = 1.0,
warmup_power: float = 2/3,
device: Optional[torch.device] = None,
foreach: bool = True,
exclude_buffers: bool = False,
):
super().__init__()
# make a copy of the model for accumulating moving average of weights
self.module = deepcopy(model)
self.module.eval()
self.decay = decay
self.min_decay = min_decay
self.update_after_step = update_after_step
self.use_warmup = use_warmup
self.warmup_gamma = warmup_gamma
self.warmup_power = warmup_power
self.foreach = foreach
self.device = device # perform ema on different device from model if set
self.exclude_buffers = exclude_buffers
if self.device is not None and device != next(model.parameters()).device:
self.foreach = False # cannot use foreach methods with different devices
self.module.to(device=device)
def get_decay(self, step: Optional[int] = None) -> float:
"""
Compute the decay factor for the exponential moving average.
"""
if step is None:
return self.decay
step = max(0, step - self.update_after_step - 1)
if step <= 0:
return 0.0
if self.use_warmup:
decay = 1 - (1 + step / self.warmup_gamma) ** -self.warmup_power
decay = max(min(decay, self.decay), self.min_decay)
else:
decay = self.decay
return decay
@torch.no_grad()
def update(self, model, step: Optional[int] = None):
decay = self.get_decay(step)
if self.exclude_buffers:
self.apply_update_no_buffers_(model, decay)
else:
self.apply_update_(model, decay)
def apply_update_(self, model, decay: float):
# interpolate parameters and buffers
if self.foreach:
ema_lerp_values = []
model_lerp_values = []
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
if ema_v.is_floating_point():
ema_lerp_values.append(ema_v)
model_lerp_values.append(model_v)
else:
ema_v.copy_(model_v)
if hasattr(torch, '_foreach_lerp_'):
torch._foreach_lerp_(ema_lerp_values, model_lerp_values, weight=1. - decay)
else:
torch._foreach_mul_(ema_lerp_values, scalar=decay)
torch._foreach_add_(ema_lerp_values, model_lerp_values, alpha=1. - decay)
else:
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
if ema_v.is_floating_point():
ema_v.lerp_(model_v, weight=1. - decay)
else:
ema_v.copy_(model_v)
def apply_update_no_buffers_(self, model, decay: float):
# interpolate parameters, copy buffers
ema_params = tuple(self.module.parameters())
model_params = tuple(model.parameters())
if self.foreach:
if hasattr(torch, '_foreach_lerp_'):
torch._foreach_lerp_(ema_params, model_params, weight=1. - decay)
else:
torch._foreach_mul_(ema_params, scalar=decay)
torch._foreach_add_(ema_params, model_params, alpha=1 - decay)
else:
for ema_p, model_p in zip(ema_params, model_params):
ema_p.lerp_(model_p, weight=1. - decay)
for ema_b, model_b in zip(self.module.buffers(), model.buffers()):
ema_b.copy_(model_b.to(device=self.device))
@torch.no_grad()
def set(self, model):
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
ema_v.copy_(model_v.to(device=self.device))
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)

View File

@ -15,6 +15,7 @@ NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
""" """
import argparse import argparse
import importlib
import json import json
import logging import logging
import os import os
@ -168,6 +169,24 @@ scripting_group.add_argument('--torchscript', dest='torchscript', action='store_
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor', scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
help="Enable compilation w/ specified backend (default: inductor).") help="Enable compilation w/ specified backend (default: inductor).")
# Device & distributed
group = parser.add_argument_group('Device parameters')
group.add_argument('--device', default='cuda', type=str,
help="Device (accelerator) to use.")
group.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
group.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)')
group.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.')
group.add_argument('--synchronize-step', action='store_true', default=False,
help='torch.cuda.synchronize() end of each step')
group.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--device-modules', default=None, type=str, nargs='+',
help="Python imports for device backend modules.")
# Optimizer parameters # Optimizer parameters
group = parser.add_argument_group('Optimizer parameters') group = parser.add_argument_group('Optimizer parameters')
group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
@ -330,11 +349,13 @@ group.add_argument('--split-bn', action='store_true',
# Model Exponential Moving Average # Model Exponential Moving Average
group = parser.add_argument_group('Model exponential moving average parameters') group = parser.add_argument_group('Model exponential moving average parameters')
group.add_argument('--model-ema', action='store_true', default=False, group.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights') help='Enable tracking moving average of model weights.')
group.add_argument('--model-ema-force-cpu', action='store_true', default=False, group.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
group.add_argument('--model-ema-decay', type=float, default=0.9998, group.add_argument('--model-ema-decay', type=float, default=0.9998,
help='decay factor for model weights moving average (default: 0.9998)') help='Decay factor for model weights moving average (default: 0.9998)')
group.add_argument('--model-ema-warmup', action='store_true',
help='Enable warmup for model EMA decay.')
# Misc # Misc
group = parser.add_argument_group('Miscellaneous parameters') group = parser.add_argument_group('Miscellaneous parameters')
@ -352,16 +373,6 @@ group.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 4)') help='how many training processes to use (default: 4)')
group.add_argument('--save-images', action='store_true', default=False, group.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging') help='save images of input bathes every log interval for debugging')
group.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
group.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)')
group.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.')
group.add_argument('--synchronize-step', action='store_true', default=False,
help='torch.cuda.synchronize() end of each step')
group.add_argument('--pin-mem', action='store_true', default=False, group.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
group.add_argument('--no-prefetcher', action='store_true', default=False, group.add_argument('--no-prefetcher', action='store_true', default=False,
@ -374,7 +385,6 @@ group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METR
help='Best metric (default: "top1"') help='Best metric (default: "top1"')
group.add_argument('--tta', type=int, default=0, metavar='N', group.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
group.add_argument("--local_rank", default=0, type=int)
group.add_argument('--use-multi-epochs-loader', action='store_true', default=False, group.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch') help='use the multi-epochs-loader to save time at the beginning of every epoch')
group.add_argument('--log-wandb', action='store_true', default=False, group.add_argument('--log-wandb', action='store_true', default=False,
@ -402,6 +412,10 @@ def main():
utils.setup_default_logging() utils.setup_default_logging()
args, args_text = _parse_args() args, args_text = _parse_args()
if args.device_modules:
for module in args.device_modules:
importlib.import_module(module)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
@ -586,10 +600,16 @@ def main():
model_ema = None model_ema = None
if args.model_ema: if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
model_ema = utils.ModelEmaV2( model_ema = utils.ModelEmaV3(
model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) model,
decay=args.model_ema_decay,
use_warmup=args.model_ema_warmup,
device='cpu' if args.model_ema_force_cpu else None,
)
if args.resume: if args.resume:
load_checkpoint(model_ema.module, args.resume, use_ema=True) load_checkpoint(model_ema.module, args.resume, use_ema=True)
if args.torchcompile:
model_ema = torch.compile(model_ema, backend=args.torchcompile)
# setup distributed training # setup distributed training
if args.distributed: if args.distributed:
@ -847,6 +867,7 @@ def main():
loss_scaler=loss_scaler, loss_scaler=loss_scaler,
model_ema=model_ema, model_ema=model_ema,
mixup_fn=mixup_fn, mixup_fn=mixup_fn,
num_updates_total=num_epochs * updates_per_epoch,
) )
if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
@ -860,6 +881,7 @@ def main():
loader_eval, loader_eval,
validate_loss_fn, validate_loss_fn,
args, args,
device=device,
amp_autocast=amp_autocast, amp_autocast=amp_autocast,
) )
@ -868,10 +890,11 @@ def main():
utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate( ema_eval_metrics = validate(
model_ema.module, model_ema,
loader_eval, loader_eval,
validate_loss_fn, validate_loss_fn,
args, args,
device=device,
amp_autocast=amp_autocast, amp_autocast=amp_autocast,
log_suffix=' (EMA)', log_suffix=' (EMA)',
) )
@ -935,6 +958,7 @@ def train_one_epoch(
loss_scaler=None, loss_scaler=None,
model_ema=None, model_ema=None,
mixup_fn=None, mixup_fn=None,
num_updates_total=None,
): ):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled: if args.prefetcher and loader.mixup_enabled:
@ -1026,7 +1050,7 @@ def train_one_epoch(
num_updates += 1 num_updates += 1
optimizer.zero_grad() optimizer.zero_grad()
if model_ema is not None: if model_ema is not None:
model_ema.update(model) model_ema.update(model, step=num_updates)
if args.synchronize_step and device.type == 'cuda': if args.synchronize_step and device.type == 'cuda':
torch.cuda.synchronize() torch.cuda.synchronize()