Merge pull request #2092 from huggingface/mesa_ema

ModelEMAV3 + MESA experiments
This commit is contained in:
Ross Wightman 2024-02-10 23:10:27 -08:00 committed by GitHub
commit 1b50b15145
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1065 additions and 119 deletions

View File

@ -32,16 +32,12 @@ class ToNumpy:
class ToTensor:
""" ToTensor with no rescaling of values"""
def __init__(self, dtype=torch.float32):
self.dtype = dtype
def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return torch.from_numpy(np_img).to(dtype=self.dtype)
return F.pil_to_tensor(pil_img).to(dtype=self.dtype)
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in

View File

@ -180,10 +180,10 @@ class NormMlpClassifierHead(nn.Module):
self.drop = nn.Dropout(drop_rate)
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def reset(self, num_classes, global_pool=None):
if global_pool is not None:
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
def reset(self, num_classes, pool_type=None):
if pool_type is not None:
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
self.use_conv = self.global_pool.is_identity()
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
if self.hidden_size:

View File

@ -148,7 +148,7 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
return _ACT_LAYER_DEFAULT[name]
def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs):
def create_act_layer(name: Union[Type[nn.Module], str], inplace=None, **kwargs):
act_layer = get_act_layer(name)
if act_layer is None:
return None

View File

@ -39,6 +39,7 @@ from .mobilevit import *
from .mvitv2 import *
from .nasnet import *
from .nest import *
from .nextvit import *
from .nfnet import *
from .pit import *
from .pnasnet import *

View File

@ -547,6 +547,17 @@ class DaVit(nn.Module):
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem', # stem and embed
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+).downsample', (0,)),
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
(r'^norm_pre', (99999,)),
]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@ -558,7 +569,7 @@ class DaVit(nn.Module):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
self.head.reset(num_classes, global_pool=global_pool)
self.head.reset(num_classes, global_pool)
def forward_features(self, x):
x = self.stem(x)

685
timm/models/nextvit.py Normal file
View File

@ -0,0 +1,685 @@
""" Next-ViT
As described in https://arxiv.org/abs/2207.05501
Next-ViT model defs and weights adapted from https://github.com/bytedance/Next-ViT, original copyright below
"""
# Copyright (c) ByteDance Inc. All rights reserved.
from functools import partial
import torch
import torch.nn.functional as F
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, ConvMlp, get_norm_layer, get_act_layer, use_fused_attn
from timm.layers import ClassifierHead
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model
def merge_pre_bn(module, pre_bn_1, pre_bn_2=None):
""" Merge pre BN to reduce inference runtime.
"""
weight = module.weight.data
if module.bias is None:
zeros = torch.zeros(module.out_chs, device=weight.device).type(weight.type())
module.bias = nn.Parameter(zeros)
bias = module.bias.data
if pre_bn_2 is None:
assert pre_bn_1.track_running_stats is True, "Unsupported bn_module.track_running_stats is False"
assert pre_bn_1.affine is True, "Unsupported bn_module.affine is False"
scale_invstd = pre_bn_1.running_var.add(pre_bn_1.eps).pow(-0.5)
extra_weight = scale_invstd * pre_bn_1.weight
extra_bias = pre_bn_1.bias - pre_bn_1.weight * pre_bn_1.running_mean * scale_invstd
else:
assert pre_bn_1.track_running_stats is True, "Unsupported bn_module.track_running_stats is False"
assert pre_bn_1.affine is True, "Unsupported bn_module.affine is False"
assert pre_bn_2.track_running_stats is True, "Unsupported bn_module.track_running_stats is False"
assert pre_bn_2.affine is True, "Unsupported bn_module.affine is False"
scale_invstd_1 = pre_bn_1.running_var.add(pre_bn_1.eps).pow(-0.5)
scale_invstd_2 = pre_bn_2.running_var.add(pre_bn_2.eps).pow(-0.5)
extra_weight = scale_invstd_1 * pre_bn_1.weight * scale_invstd_2 * pre_bn_2.weight
extra_bias = (
scale_invstd_2 * pre_bn_2.weight
* (pre_bn_1.bias - pre_bn_1.weight * pre_bn_1.running_mean * scale_invstd_1 - pre_bn_2.running_mean)
+ pre_bn_2.bias
)
if isinstance(module, nn.Linear):
extra_bias = weight @ extra_bias
weight.mul_(extra_weight.view(1, weight.size(1)).expand_as(weight))
elif isinstance(module, nn.Conv2d):
assert weight.shape[2] == 1 and weight.shape[3] == 1
weight = weight.reshape(weight.shape[0], weight.shape[1])
extra_bias = weight @ extra_bias
weight.mul_(extra_weight.view(1, weight.size(1)).expand_as(weight))
weight = weight.reshape(weight.shape[0], weight.shape[1], 1, 1)
bias.add_(extra_bias)
module.weight.data = weight
module.bias.data = bias
class ConvNormAct(nn.Module):
def __init__(
self,
in_chs,
out_chs,
kernel_size=3,
stride=1,
groups=1,
norm_layer=nn.BatchNorm2d,
act_layer=nn.ReLU,
):
super(ConvNormAct, self).__init__()
self.conv = nn.Conv2d(
in_chs, out_chs, kernel_size=kernel_size, stride=stride,
padding=1, groups=groups, bias=False)
self.norm = norm_layer(out_chs)
self.act = act_layer()
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
x = self.act(x)
return x
def _make_divisible(v, divisor, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class PatchEmbed(nn.Module):
def __init__(self,
in_chs,
out_chs,
stride=1,
norm_layer = nn.BatchNorm2d,
):
super(PatchEmbed, self).__init__()
if stride == 2:
self.pool = nn.AvgPool2d((2, 2), stride=2, ceil_mode=True, count_include_pad=False)
self.conv = nn.Conv2d(in_chs, out_chs, kernel_size=1, stride=1, bias=False)
self.norm = norm_layer(out_chs)
elif in_chs != out_chs:
self.pool = nn.Identity()
self.conv = nn.Conv2d(in_chs, out_chs, kernel_size=1, stride=1, bias=False)
self.norm = norm_layer(out_chs)
else:
self.pool = nn.Identity()
self.conv = nn.Identity()
self.norm = nn.Identity()
def forward(self, x):
return self.norm(self.conv(self.pool(x)))
class ConvAttention(nn.Module):
"""
Multi-Head Convolutional Attention
"""
def __init__(self, out_chs, head_dim, norm_layer = nn.BatchNorm2d, act_layer = nn.ReLU):
super(ConvAttention, self).__init__()
self.group_conv3x3 = nn.Conv2d(
out_chs, out_chs,
kernel_size=3, stride=1, padding=1, groups=out_chs // head_dim, bias=False
)
self.norm = norm_layer(out_chs)
self.act = act_layer()
self.projection = nn.Conv2d(out_chs, out_chs, kernel_size=1, bias=False)
def forward(self, x):
out = self.group_conv3x3(x)
out = self.norm(out)
out = self.act(out)
out = self.projection(out)
return out
class NextConvBlock(nn.Module):
"""
Next Convolution Block
"""
def __init__(
self,
in_chs,
out_chs,
stride=1,
drop_path=0.,
drop=0.,
head_dim=32,
mlp_ratio=3.,
norm_layer=nn.BatchNorm2d,
act_layer=nn.ReLU
):
super(NextConvBlock, self).__init__()
self.in_chs = in_chs
self.out_chs = out_chs
assert out_chs % head_dim == 0
self.patch_embed = PatchEmbed(in_chs, out_chs, stride, norm_layer=norm_layer)
self.mhca = ConvAttention(
out_chs,
head_dim,
norm_layer=norm_layer,
act_layer=act_layer,
)
self.attn_drop_path = DropPath(drop_path)
self.norm = norm_layer(out_chs)
self.mlp = ConvMlp(
out_chs,
hidden_features=int(out_chs * mlp_ratio),
drop=drop,
bias=True,
act_layer=act_layer,
)
self.mlp_drop_path = DropPath(drop_path)
self.is_fused = False
@torch.no_grad()
def reparameterize(self):
if not self.is_fused:
merge_pre_bn(self.mlp.fc1, self.norm)
self.norm = None
self.is_fused = True
def forward(self, x):
x = self.patch_embed(x)
x = x + self.attn_drop_path(self.mhca(x))
out = self.norm(x)
x = x + self.mlp_drop_path(self.mlp(out))
return x
class EfficientAttention(nn.Module):
"""
Efficient Multi-Head Self Attention
"""
fused_attn: torch.jit.Final[bool]
def __init__(
self,
dim,
out_dim=None,
head_dim=32,
qkv_bias=True,
attn_drop=0.,
proj_drop=0.,
sr_ratio=1,
norm_layer=nn.BatchNorm1d,
):
super().__init__()
self.dim = dim
self.out_dim = out_dim if out_dim is not None else dim
self.num_heads = self.dim // head_dim
self.head_dim = head_dim
self.scale = head_dim ** -0.5
self.fused_attn = use_fused_attn()
self.q = nn.Linear(dim, self.dim, bias=qkv_bias)
self.k = nn.Linear(dim, self.dim, bias=qkv_bias)
self.v = nn.Linear(dim, self.dim, bias=qkv_bias)
self.proj = nn.Linear(self.dim, self.out_dim)
self.attn_drop = nn.Dropout(attn_drop)
self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio
self.N_ratio = sr_ratio ** 2
if sr_ratio > 1:
self.sr = nn.AvgPool1d(kernel_size=self.N_ratio, stride=self.N_ratio)
self.norm = norm_layer(dim)
else:
self.sr = None
self.norm = None
def forward(self, x):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
if self.sr is not None:
x = self.sr(x.transpose(1, 2))
x = self.norm(x).transpose(1, 2)
k = self.k(x).reshape(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v(x).reshape(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
attn = q @ k.transpose(-1, -2)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class NextTransformerBlock(nn.Module):
"""
Next Transformer Block
"""
def __init__(
self,
in_chs,
out_chs,
drop_path,
stride=1,
sr_ratio=1,
mlp_ratio=2,
head_dim=32,
mix_block_ratio=0.75,
attn_drop=0.,
drop=0.,
norm_layer=nn.BatchNorm2d,
act_layer=nn.ReLU,
):
super(NextTransformerBlock, self).__init__()
self.in_chs = in_chs
self.out_chs = out_chs
self.mix_block_ratio = mix_block_ratio
self.mhsa_out_chs = _make_divisible(int(out_chs * mix_block_ratio), 32)
self.mhca_out_chs = out_chs - self.mhsa_out_chs
self.patch_embed = PatchEmbed(in_chs, self.mhsa_out_chs, stride)
self.norm1 = norm_layer(self.mhsa_out_chs)
self.e_mhsa = EfficientAttention(
self.mhsa_out_chs,
head_dim=head_dim,
sr_ratio=sr_ratio,
attn_drop=attn_drop,
proj_drop=drop,
)
self.mhsa_drop_path = DropPath(drop_path * mix_block_ratio)
self.projection = PatchEmbed(self.mhsa_out_chs, self.mhca_out_chs, stride=1, norm_layer=norm_layer)
self.mhca = ConvAttention(
self.mhca_out_chs,
head_dim=head_dim,
norm_layer=norm_layer,
act_layer=act_layer,
)
self.mhca_drop_path = DropPath(drop_path * (1 - mix_block_ratio))
self.norm2 = norm_layer(out_chs)
self.mlp = ConvMlp(
out_chs,
hidden_features=int(out_chs * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
self.mlp_drop_path = DropPath(drop_path)
self.is_fused = False
@torch.no_grad()
def reparameterize(self):
if not self.is_fused:
merge_pre_bn(self.e_mhsa.q, self.norm1)
if self.e_mhsa.norm is not None:
merge_pre_bn(self.e_mhsa.k, self.norm1, self.e_mhsa.norm)
merge_pre_bn(self.e_mhsa.v, self.norm1, self.e_mhsa.norm)
self.e_mhsa.norm = nn.Identity()
else:
merge_pre_bn(self.e_mhsa.k, self.norm1)
merge_pre_bn(self.e_mhsa.v, self.norm1)
self.norm1 = nn.Identity()
merge_pre_bn(self.mlp.fc1, self.norm2)
self.norm2 = nn.Identity()
self.is_fused = True
def forward(self, x):
x = self.patch_embed(x)
B, C, H, W = x.shape
out = self.norm1(x)
out = out.reshape(B, C, -1).transpose(-1, -2)
out = self.mhsa_drop_path(self.e_mhsa(out))
x = x + out.transpose(-1, -2).reshape(B, C, H, W)
out = self.projection(x)
out = out + self.mhca_drop_path(self.mhca(out))
x = torch.cat([x, out], dim=1)
out = self.norm2(x)
x = x + self.mlp_drop_path(self.mlp(out))
return x
class NextStage(nn.Module):
def __init__(
self,
in_chs,
block_chs,
block_types,
stride=2,
sr_ratio=1,
mix_block_ratio=1.0,
drop=0.,
attn_drop=0.,
drop_path=0.,
head_dim=32,
norm_layer=nn.BatchNorm2d,
act_layer=nn.ReLU,
):
super().__init__()
self.grad_checkpointing = False
blocks = []
for block_idx, block_type in enumerate(block_types):
stride = stride if block_idx == 0 else 1
out_chs = block_chs[block_idx]
block_type = block_types[block_idx]
dpr = drop_path[block_idx] if isinstance(drop_path, (list, tuple)) else drop_path
if block_type is NextConvBlock:
layer = NextConvBlock(
in_chs,
out_chs,
stride=stride,
drop_path=dpr,
drop=drop,
head_dim=head_dim,
norm_layer=norm_layer,
act_layer=act_layer,
)
blocks.append(layer)
elif block_type is NextTransformerBlock:
layer = NextTransformerBlock(
in_chs,
out_chs,
drop_path=dpr,
stride=stride,
sr_ratio=sr_ratio,
head_dim=head_dim,
mix_block_ratio=mix_block_ratio,
attn_drop=attn_drop,
drop=drop,
norm_layer=norm_layer,
act_layer=act_layer,
)
blocks.append(layer)
in_chs = out_chs
self.blocks = nn.Sequential(*blocks)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
def forward(self, x):
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x
class NextViT(nn.Module):
def __init__(
self,
in_chans,
num_classes=1000,
global_pool='avg',
stem_chs=(64, 32, 64),
depths=(3, 4, 10, 3),
strides=(1, 2, 2, 2),
sr_ratios=(8, 4, 2, 1),
drop_path_rate=0.1,
attn_drop_rate=0.,
drop_rate=0.,
head_dim=32,
mix_block_ratio=0.75,
norm_layer=nn.BatchNorm2d,
act_layer=None,
):
super(NextViT, self).__init__()
self.grad_checkpointing = False
self.num_classes = num_classes
norm_layer = get_norm_layer(norm_layer)
if act_layer is None:
act_layer = partial(nn.ReLU, inplace=True)
else:
act_layer = get_act_layer(act_layer)
self.stage_out_chs = [
[96] * (depths[0]),
[192] * (depths[1] - 1) + [256],
[384, 384, 384, 384, 512] * (depths[2] // 5),
[768] * (depths[3] - 1) + [1024]
]
self.feature_info = [dict(
num_chs=sc[-1],
reduction=2**(i + 2),
module=f'stages.{i}'
) for i, sc in enumerate(self.stage_out_chs)]
# Next Hybrid Strategy
self.stage_block_types = [
[NextConvBlock] * depths[0],
[NextConvBlock] * (depths[1] - 1) + [NextTransformerBlock],
[NextConvBlock, NextConvBlock, NextConvBlock, NextConvBlock, NextTransformerBlock] * (depths[2] // 5),
[NextConvBlock] * (depths[3] - 1) + [NextTransformerBlock]]
self.stem = nn.Sequential(
ConvNormAct(in_chans, stem_chs[0], kernel_size=3, stride=2, norm_layer=norm_layer, act_layer=act_layer),
ConvNormAct(stem_chs[0], stem_chs[1], kernel_size=3, stride=1, norm_layer=norm_layer, act_layer=act_layer),
ConvNormAct(stem_chs[1], stem_chs[2], kernel_size=3, stride=1, norm_layer=norm_layer, act_layer=act_layer),
ConvNormAct(stem_chs[2], stem_chs[2], kernel_size=3, stride=2, norm_layer=norm_layer, act_layer=act_layer),
)
in_chs = out_chs = stem_chs[-1]
stages = []
idx = 0
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
for stage_idx in range(len(depths)):
stage = NextStage(
in_chs=in_chs,
block_chs=self.stage_out_chs[stage_idx],
block_types=self.stage_block_types[stage_idx],
stride=strides[stage_idx],
sr_ratio=sr_ratios[stage_idx],
mix_block_ratio=mix_block_ratio,
head_dim=head_dim,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[stage_idx],
norm_layer=norm_layer,
act_layer=act_layer,
)
in_chs = out_chs = self.stage_out_chs[stage_idx][-1]
stages += [stage]
idx += depths[stage_idx]
self.num_features = out_chs
self.stages = nn.Sequential(*stages)
self.norm = norm_layer(out_chs)
self.head = ClassifierHead(pool_type=global_pool, in_features=out_chs, num_classes=num_classes)
self.stage_out_idx = [sum(depths[:idx + 1]) - 1 for idx in range(len(depths))]
self._initialize_weights()
def _initialize_weights(self):
for n, m in self.named_modules():
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=.02)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0)
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem', # stem and embed
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
(r'^norm', (99999,)),
]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
for stage in self.stages:
stage.set_grad_checkpointing(enable=enable)
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
self.head.reset(num_classes, pool_type=global_pool)
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stages, x)
else:
x = self.stages(x)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def checkpoint_filter_fn(state_dict, model):
""" Remap original checkpoints -> timm """
if 'head.fc.weight' in state_dict:
return state_dict # non-original
D = model.state_dict()
out_dict = {}
# remap originals based on order
for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()):
out_dict[ka] = vb
return out_dict
def _create_nextvit(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
model = build_model_with_cfg(
NextViT,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs)
return model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.95, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.0.conv', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = generate_default_cfgs({
'nextvit_small.bd_in1k': _cfg(
hf_hub_id='timm/',
),
'nextvit_base.bd_in1k': _cfg(
hf_hub_id='timm/',
),
'nextvit_large.bd_in1k': _cfg(
hf_hub_id='timm/',
),
'nextvit_small.bd_in1k_384': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
),
'nextvit_base.bd_in1k_384': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
),
'nextvit_large.bd_in1k_384': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
),
'nextvit_small.bd_ssld_6m_in1k': _cfg(
hf_hub_id='timm/',
),
'nextvit_base.bd_ssld_6m_in1k': _cfg(
hf_hub_id='timm/',
),
'nextvit_large.bd_ssld_6m_in1k': _cfg(
hf_hub_id='timm/',
),
'nextvit_small.bd_ssld_6m_in1k_384': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
),
'nextvit_base.bd_ssld_6m_in1k_384': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
),
'nextvit_large.bd_ssld_6m_in1k_384': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
),
})
@register_model
def nextvit_small(pretrained=False, **kwargs):
model_args = dict(depths=(3, 4, 10, 3), drop_path_rate=0.1)
model = _create_nextvit(
'nextvit_small', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def nextvit_base(pretrained=False, **kwargs):
model_args = dict(depths=(3, 4, 20, 3), drop_path_rate=0.2)
model = _create_nextvit(
'nextvit_base', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def nextvit_large(pretrained=False, **kwargs):
model_args = dict(depths=(3, 4, 30, 3), drop_path_rate=0.2)
model = _create_nextvit(
'nextvit_large', pretrained=pretrained, **dict(model_args, **kwargs))
return model

View File

@ -535,7 +535,7 @@ class TinyVit(nn.Module):
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool=global_pool)
self.head.reset(num_classes, pool_type=global_pool)
def forward_features(self, x):
x = self.patch_embed(x)

View File

@ -421,6 +421,7 @@ class VisionTransformer(nn.Module):
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '',
fix_init: bool = False,
embed_layer: Callable = PatchEmbed,
norm_layer: Optional[LayerType] = None,
act_layer: Optional[LayerType] = None,
@ -449,6 +450,7 @@ class VisionTransformer(nn.Module):
attn_drop_rate: Attention dropout rate.
drop_path_rate: Stochastic depth rate.
weight_init: Weight initialization scheme.
fix_init: Apply weight initialization fix (scaling w/ layer index).
embed_layer: Patch embedding layer.
norm_layer: Normalization layer.
act_layer: MLP activation layer.
@ -536,8 +538,18 @@ class VisionTransformer(nn.Module):
if weight_init != 'skip':
self.init_weights(weight_init)
if fix_init:
self.fix_init_weight()
def init_weights(self, mode: Literal['jax', 'jax_nlhb', 'moco', ''] = '') -> None:
def fix_init_weight(self):
def rescale(param, _layer_id):
param.div_(math.sqrt(2.0 * _layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def init_weights(self, mode: str = '') -> None:
assert mode in ('jax', 'jax_nlhb', 'moco', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
trunc_normal_(self.pos_embed, std=.02)
@ -737,7 +749,7 @@ def init_weights_vit_moco(module: nn.Module, name: str = '') -> None:
module.init_weights()
def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> None:
def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable:
if 'jax' in mode:
return partial(init_weights_vit_jax, head_bias=head_bias)
elif 'moco' in mode:
@ -1723,7 +1735,12 @@ default_cfgs = {
input_size=(3, 256, 256)),
'vit_medium_patch16_reg4_gap_256': _cfg(
input_size=(3, 256, 256)),
'vit_base_patch16_reg8_gap_256': _cfg(input_size=(3, 256, 256)),
'vit_base_patch16_reg4_gap_256': _cfg(
input_size=(3, 256, 256)),
'vit_so150m_patch16_reg4_gap_256': _cfg(
input_size=(3, 256, 256)),
'vit_so150m_patch16_reg4_map_256': _cfg(
input_size=(3, 256, 256)),
}
_quick_gelu_cfgs = [
@ -2623,13 +2640,35 @@ def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
@register_model
def vit_base_patch16_reg8_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
def vit_base_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False,
no_embed_class=True, global_pool='avg', reg_tokens=8,
no_embed_class=True, global_pool='avg', reg_tokens=4,
)
model = _create_vision_transformer(
'vit_base_patch16_reg8_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
'vit_base_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_so150m_patch16_reg4_map_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
class_token=False, reg_tokens=4, global_pool='map',
)
model = _create_vision_transformer(
'vit_so150m_patch16_reg4_map_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
class_token=False, reg_tokens=4, global_pool='avg', fc_norm=False,
)
model = _create_vision_transformer(
'vit_so150m_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model

View File

@ -7,7 +7,12 @@ Hacked together by / Copyright 2022, Ross Wightman
import logging
import math
from functools import partial
from typing import Optional, Tuple
from typing import Optional, Tuple, Type, Union
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
import torch
import torch.nn as nn
@ -15,9 +20,11 @@ from torch.jit import Final
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn, LayerType
from ._builder import build_model_with_cfg
from ._manipulate import named_apply
from ._registry import generate_default_cfgs, register_model
from .vision_transformer import get_init_weights_vit
__all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this
@ -215,59 +222,61 @@ class VisionTransformerRelPos(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
global_pool='avg',
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
qk_norm=False,
init_values=1e-6,
class_token=False,
fc_norm=False,
rel_pos_type='mlp',
rel_pos_dim=None,
shared_rel_pos=False,
drop_rate=0.,
proj_drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
weight_init='skip',
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
block_fn=RelPosBlock
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: Literal['', 'avg', 'token', 'map'] = 'avg',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
qk_norm: bool = False,
init_values: Optional[float] = 1e-6,
class_token: bool = False,
fc_norm: bool = False,
rel_pos_type: str = 'mlp',
rel_pos_dim: Optional[int] = None,
shared_rel_pos: bool = False,
drop_rate: float = 0.,
proj_drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
weight_init: Literal['skip', 'jax', 'moco', ''] = 'skip',
fix_init: bool = False,
embed_layer: Type[nn.Module] = PatchEmbed,
norm_layer: Optional[LayerType] = None,
act_layer: Optional[LayerType] = None,
block_fn: Type[nn.Module] = RelPosBlock
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
global_pool (str): type of global pooling for final sequence (default: 'avg')
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
qk_norm (bool): Enable normalization of query and key in attention
init_values: (float): layer-scale init values
class_token (bool): use class token (default: False)
fc_norm (bool): use pre classifier norm instead of pre-pool
rel_pos_ty pe (str): type of relative position
shared_rel_pos (bool): share relative pos across all blocks
drop_rate (float): dropout rate
proj_drop_rate (float): projection dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
weight_init (str): weight init scheme
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
act_layer: (nn.Module): MLP activation layer
img_size: input image size
patch_size: patch size
in_chans: number of input channels
num_classes: number of classes for classification head
global_pool: type of global pooling for final sequence (default: 'avg')
embed_dim: embedding dimension
depth: depth of transformer
num_heads: number of attention heads
mlp_ratio: ratio of mlp hidden dim to embedding dim
qkv_bias: enable bias for qkv if True
qk_norm: Enable normalization of query and key in attention
init_values: layer-scale init values
class_token: use class token (default: False)
fc_norm: use pre classifier norm instead of pre-pool
rel_pos_type: type of relative position
shared_rel_pos: share relative pos across all blocks
drop_rate: dropout rate
proj_drop_rate: projection dropout rate
attn_drop_rate: attention dropout rate
drop_path_rate: stochastic depth rate
weight_init: weight init scheme
fix_init: apply weight initialization fix (scaling w/ layer index)
embed_layer: patch embedding layer
norm_layer: normalization layer
act_layer: MLP activation layer
"""
super().__init__()
assert global_pool in ('', 'avg', 'token')
@ -332,13 +341,22 @@ class VisionTransformerRelPos(nn.Module):
if weight_init != 'skip':
self.init_weights(weight_init)
if fix_init:
self.fix_init_weight()
def init_weights(self, mode=''):
assert mode in ('jax', 'moco', '')
if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6)
# FIXME weight init scheme using PyTorch defaults curently
#named_apply(get_init_weights_vit(mode, head_bias), self)
named_apply(get_init_weights_vit(mode), self)
def fix_init_weight(self):
def rescale(param, _layer_id):
param.div_(math.sqrt(2.0 * _layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
@torch.jit.ignore
def no_weight_decay(self):

View File

@ -10,6 +10,6 @@ from .log import setup_default_logging, FormatterNoInfo
from .metrics import AverageMeter, accuracy
from .misc import natural_key, add_bool_arg, ParseKwargs
from .model import unwrap_model, get_state_dict, freeze, unfreeze, reparameterize_model
from .model_ema import ModelEma, ModelEmaV2
from .model_ema import ModelEma, ModelEmaV2, ModelEmaV3
from .random import random_seed
from .summary import update_summary, get_outdir

View File

@ -2,18 +2,17 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
import os
from typing import Optional
import torch
from torch import distributed as dist
try:
import horovod.torch as hvd
except ImportError:
hvd = None
from .model import unwrap_model
_logger = logging.getLogger(__name__)
def reduce_tensor(tensor, n):
rt = tensor.clone()
@ -84,9 +83,39 @@ def init_distributed_device(args):
args.world_size = 1
args.rank = 0 # global rank
args.local_rank = 0
result = init_distributed_device_so(
device=getattr(args, 'device', 'cuda'),
dist_backend=getattr(args, 'dist_backend', None),
dist_url=getattr(args, 'dist_url', None),
)
args.device = result['device']
args.world_size = result['world_size']
args.rank = result['global_rank']
args.local_rank = result['local_rank']
args.distributed = result['distributed']
device = torch.device(args.device)
return device
def init_distributed_device_so(
device: str = 'cuda',
dist_backend: Optional[str] = None,
dist_url: Optional[str] = None,
):
# Distributed training = training on more than one GPU.
# Works in both single and multi-node scenarios.
distributed = False
world_size = 1
global_rank = 0
local_rank = 0
if dist_backend is None:
# FIXME sane defaults for other device backends?
dist_backend = 'nccl' if 'cuda' in device else 'gloo'
dist_url = dist_url or 'env://'
# TBD, support horovod?
# if args.horovod:
# import horovod.torch as hvd
# assert hvd is not None, "Horovod is not installed"
# hvd.init()
# args.local_rank = int(hvd.local_rank())
@ -96,42 +125,51 @@ def init_distributed_device(args):
# os.environ['LOCAL_RANK'] = str(args.local_rank)
# os.environ['RANK'] = str(args.rank)
# os.environ['WORLD_SIZE'] = str(args.world_size)
dist_backend = getattr(args, 'dist_backend', 'nccl')
dist_url = getattr(args, 'dist_url', 'env://')
if is_distributed_env():
if 'SLURM_PROCID' in os.environ:
# DDP via SLURM
args.local_rank, args.rank, args.world_size = world_info_from_env()
local_rank, global_rank, world_size = world_info_from_env()
# SLURM var -> torch.distributed vars in case needed
os.environ['LOCAL_RANK'] = str(args.local_rank)
os.environ['RANK'] = str(args.rank)
os.environ['WORLD_SIZE'] = str(args.world_size)
os.environ['LOCAL_RANK'] = str(local_rank)
os.environ['RANK'] = str(global_rank)
os.environ['WORLD_SIZE'] = str(world_size)
torch.distributed.init_process_group(
backend=dist_backend,
init_method=dist_url,
world_size=args.world_size,
rank=args.rank,
world_size=world_size,
rank=global_rank,
)
else:
# DDP via torchrun, torch.distributed.launch
args.local_rank, _, _ = world_info_from_env()
local_rank, _, _ = world_info_from_env()
torch.distributed.init_process_group(
backend=dist_backend,
init_method=dist_url,
)
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
args.distributed = True
world_size = torch.distributed.get_world_size()
global_rank = torch.distributed.get_rank()
distributed = True
if torch.cuda.is_available():
if args.distributed:
device = 'cuda:%d' % args.local_rank
else:
device = 'cuda:0'
if 'cuda' in device:
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
if distributed and device != 'cpu':
device, *device_idx = device.split(':', maxsplit=1)
# Ignore manually specified device index in distributed mode and
# override with resolved local rank, fewer headaches in most setups.
if device_idx:
_logger.warning(f'device index {device_idx[0]} removed from specified ({device}).')
device = f'{device}:{local_rank}'
if device.startswith('cuda:'):
torch.cuda.set_device(device)
else:
device = 'cpu'
args.device = device
device = torch.device(device)
return device
return dict(
device=device,
global_rank=global_rank,
local_rank=local_rank,
world_size=world_size,
distributed=distributed,
)

View File

@ -5,6 +5,7 @@ Hacked together by / Copyright 2020 Ross Wightman
import logging
from collections import OrderedDict
from copy import deepcopy
from typing import Optional
import torch
import torch.nn as nn
@ -103,7 +104,7 @@ class ModelEmaV2(nn.Module):
GPU assignment and distributed training wrappers.
"""
def __init__(self, model, decay=0.9999, device=None):
super(ModelEmaV2, self).__init__()
super().__init__()
# make a copy of the model for accumulating moving average of weights
self.module = deepcopy(model)
self.module.eval()
@ -124,3 +125,136 @@ class ModelEmaV2(nn.Module):
def set(self, model):
self._update(model, update_fn=lambda e, m: m)
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
class ModelEmaV3(nn.Module):
""" Model Exponential Moving Average V3
Keep a moving average of everything in the model state_dict (parameters and buffers).
V3 of this module leverages for_each and in-place operations for faster performance.
Decay warmup based on code by @crowsonkb, her comments:
If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
good values for models you plan to train for a million or more steps (reaches decay
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
215.4k steps).
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
disable validation of the EMA weights. Validation will have to be done manually in a separate
process, or after the training stops converging.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
"""
def __init__(
self,
model,
decay: float = 0.9999,
min_decay: float = 0.0,
update_after_step: int = 0,
use_warmup: bool = False,
warmup_gamma: float = 1.0,
warmup_power: float = 2/3,
device: Optional[torch.device] = None,
foreach: bool = True,
exclude_buffers: bool = False,
):
super().__init__()
# make a copy of the model for accumulating moving average of weights
self.module = deepcopy(model)
self.module.eval()
self.decay = decay
self.min_decay = min_decay
self.update_after_step = update_after_step
self.use_warmup = use_warmup
self.warmup_gamma = warmup_gamma
self.warmup_power = warmup_power
self.foreach = foreach
self.device = device # perform ema on different device from model if set
self.exclude_buffers = exclude_buffers
if self.device is not None and device != next(model.parameters()).device:
self.foreach = False # cannot use foreach methods with different devices
self.module.to(device=device)
def get_decay(self, step: Optional[int] = None) -> float:
"""
Compute the decay factor for the exponential moving average.
"""
if step is None:
return self.decay
step = max(0, step - self.update_after_step - 1)
if step <= 0:
return 0.0
if self.use_warmup:
decay = 1 - (1 + step / self.warmup_gamma) ** -self.warmup_power
decay = max(min(decay, self.decay), self.min_decay)
else:
decay = self.decay
return decay
@torch.no_grad()
def update(self, model, step: Optional[int] = None):
decay = self.get_decay(step)
if self.exclude_buffers:
self.apply_update_no_buffers_(model, decay)
else:
self.apply_update_(model, decay)
def apply_update_(self, model, decay: float):
# interpolate parameters and buffers
if self.foreach:
ema_lerp_values = []
model_lerp_values = []
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
if ema_v.is_floating_point():
ema_lerp_values.append(ema_v)
model_lerp_values.append(model_v)
else:
ema_v.copy_(model_v)
if hasattr(torch, '_foreach_lerp_'):
torch._foreach_lerp_(ema_lerp_values, model_lerp_values, weight=1. - decay)
else:
torch._foreach_mul_(ema_lerp_values, scalar=decay)
torch._foreach_add_(ema_lerp_values, model_lerp_values, alpha=1. - decay)
else:
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
if ema_v.is_floating_point():
ema_v.lerp_(model_v, weight=1. - decay)
else:
ema_v.copy_(model_v)
def apply_update_no_buffers_(self, model, decay: float):
# interpolate parameters, copy buffers
ema_params = tuple(self.module.parameters())
model_params = tuple(model.parameters())
if self.foreach:
if hasattr(torch, '_foreach_lerp_'):
torch._foreach_lerp_(ema_params, model_params, weight=1. - decay)
else:
torch._foreach_mul_(ema_params, scalar=decay)
torch._foreach_add_(ema_params, model_params, alpha=1 - decay)
else:
for ema_p, model_p in zip(ema_params, model_params):
ema_p.lerp_(model_p, weight=1. - decay)
for ema_b, model_b in zip(self.module.buffers(), model.buffers()):
ema_b.copy_(model_b.to(device=self.device))
@torch.no_grad()
def set(self, model):
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
ema_v.copy_(model_v.to(device=self.device))
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)

View File

@ -15,6 +15,7 @@ NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
"""
import argparse
import importlib
import json
import logging
import os
@ -168,6 +169,24 @@ scripting_group.add_argument('--torchscript', dest='torchscript', action='store_
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
help="Enable compilation w/ specified backend (default: inductor).")
# Device & distributed
group = parser.add_argument_group('Device parameters')
group.add_argument('--device', default='cuda', type=str,
help="Device (accelerator) to use.")
group.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
group.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)')
group.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.')
group.add_argument('--synchronize-step', action='store_true', default=False,
help='torch.cuda.synchronize() end of each step')
group.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--device-modules', default=None, type=str, nargs='+',
help="Python imports for device backend modules.")
# Optimizer parameters
group = parser.add_argument_group('Optimizer parameters')
group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
@ -330,11 +349,13 @@ group.add_argument('--split-bn', action='store_true',
# Model Exponential Moving Average
group = parser.add_argument_group('Model exponential moving average parameters')
group.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
help='Enable tracking moving average of model weights.')
group.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
group.add_argument('--model-ema-decay', type=float, default=0.9998,
help='decay factor for model weights moving average (default: 0.9998)')
help='Decay factor for model weights moving average (default: 0.9998)')
group.add_argument('--model-ema-warmup', action='store_true',
help='Enable warmup for model EMA decay.')
# Misc
group = parser.add_argument_group('Miscellaneous parameters')
@ -352,16 +373,6 @@ group.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 4)')
group.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging')
group.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
group.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)')
group.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.')
group.add_argument('--synchronize-step', action='store_true', default=False,
help='torch.cuda.synchronize() end of each step')
group.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
group.add_argument('--no-prefetcher', action='store_true', default=False,
@ -374,7 +385,6 @@ group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METR
help='Best metric (default: "top1"')
group.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
group.add_argument("--local_rank", default=0, type=int)
group.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch')
group.add_argument('--log-wandb', action='store_true', default=False,
@ -402,6 +412,10 @@ def main():
utils.setup_default_logging()
args, args_text = _parse_args()
if args.device_modules:
for module in args.device_modules:
importlib.import_module(module)
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
@ -586,10 +600,16 @@ def main():
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
model_ema = utils.ModelEmaV2(
model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
model_ema = utils.ModelEmaV3(
model,
decay=args.model_ema_decay,
use_warmup=args.model_ema_warmup,
device='cpu' if args.model_ema_force_cpu else None,
)
if args.resume:
load_checkpoint(model_ema.module, args.resume, use_ema=True)
if args.torchcompile:
model_ema = torch.compile(model_ema, backend=args.torchcompile)
# setup distributed training
if args.distributed:
@ -847,6 +867,7 @@ def main():
loss_scaler=loss_scaler,
model_ema=model_ema,
mixup_fn=mixup_fn,
num_updates_total=num_epochs * updates_per_epoch,
)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
@ -860,6 +881,7 @@ def main():
loader_eval,
validate_loss_fn,
args,
device=device,
amp_autocast=amp_autocast,
)
@ -868,10 +890,11 @@ def main():
utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate(
model_ema.module,
model_ema,
loader_eval,
validate_loss_fn,
args,
device=device,
amp_autocast=amp_autocast,
log_suffix=' (EMA)',
)
@ -935,6 +958,7 @@ def train_one_epoch(
loss_scaler=None,
model_ema=None,
mixup_fn=None,
num_updates_total=None,
):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
@ -1026,7 +1050,7 @@ def train_one_epoch(
num_updates += 1
optimizer.zero_grad()
if model_ema is not None:
model_ema.update(model)
model_ema.update(model, step=num_updates)
if args.synchronize_step and device.type == 'cuda':
torch.cuda.synchronize()