diff --git a/timm/data/transforms.py b/timm/data/transforms.py index 822983fe..02a069bd 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -32,16 +32,12 @@ class ToNumpy: class ToTensor: - + """ ToTensor with no rescaling of values""" def __init__(self, dtype=torch.float32): self.dtype = dtype def __call__(self, pil_img): - np_img = np.array(pil_img, dtype=np.uint8) - if np_img.ndim < 3: - np_img = np.expand_dims(np_img, axis=-1) - np_img = np.rollaxis(np_img, 2) # HWC to CHW - return torch.from_numpy(np_img).to(dtype=self.dtype) + return F.pil_to_tensor(pil_img).to(dtype=self.dtype) # Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in diff --git a/timm/layers/classifier.py b/timm/layers/classifier.py index 2eb4ec2e..71e45c87 100644 --- a/timm/layers/classifier.py +++ b/timm/layers/classifier.py @@ -180,10 +180,10 @@ class NormMlpClassifierHead(nn.Module): self.drop = nn.Dropout(drop_rate) self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - def reset(self, num_classes, global_pool=None): - if global_pool is not None: - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.flatten = nn.Flatten(1) if global_pool else nn.Identity() + def reset(self, num_classes, pool_type=None): + if pool_type is not None: + self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) + self.flatten = nn.Flatten(1) if pool_type else nn.Identity() self.use_conv = self.global_pool.is_identity() linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear if self.hidden_size: diff --git a/timm/layers/create_act.py b/timm/layers/create_act.py index c473c5a9..93bcbf0e 100644 --- a/timm/layers/create_act.py +++ b/timm/layers/create_act.py @@ -148,7 +148,7 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'): return _ACT_LAYER_DEFAULT[name] -def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs): +def create_act_layer(name: Union[Type[nn.Module], str], inplace=None, **kwargs): act_layer = get_act_layer(name) if act_layer is None: return None diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 0eb9561d..6b6963dc 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -39,6 +39,7 @@ from .mobilevit import * from .mvitv2 import * from .nasnet import * from .nest import * +from .nextvit import * from .nfnet import * from .pit import * from .pnasnet import * diff --git a/timm/models/davit.py b/timm/models/davit.py index f00cf733..d4d6ad69 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -547,6 +547,17 @@ class DaVit(nn.Module): if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', # stem and embed + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+).downsample', (0,)), + (r'^stages\.(\d+)\.blocks\.(\d+)', None), + (r'^norm_pre', (99999,)), + ] + ) + @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @@ -558,7 +569,7 @@ class DaVit(nn.Module): return self.head.fc def reset_classifier(self, num_classes, global_pool=None): - self.head.reset(num_classes, global_pool=global_pool) + self.head.reset(num_classes, global_pool) def forward_features(self, x): x = self.stem(x) diff --git a/timm/models/nextvit.py b/timm/models/nextvit.py new file mode 100644 index 00000000..7ef56a38 --- /dev/null +++ b/timm/models/nextvit.py @@ -0,0 +1,685 @@ +""" Next-ViT + +As described in https://arxiv.org/abs/2207.05501 + +Next-ViT model defs and weights adapted from https://github.com/bytedance/Next-ViT, original copyright below +""" +# Copyright (c) ByteDance Inc. All rights reserved. +from functools import partial + +import torch +import torch.nn.functional as F +from torch import nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import DropPath, trunc_normal_, ConvMlp, get_norm_layer, get_act_layer, use_fused_attn +from timm.layers import ClassifierHead +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import checkpoint_seq +from ._registry import generate_default_cfgs, register_model + + +def merge_pre_bn(module, pre_bn_1, pre_bn_2=None): + """ Merge pre BN to reduce inference runtime. + """ + weight = module.weight.data + if module.bias is None: + zeros = torch.zeros(module.out_chs, device=weight.device).type(weight.type()) + module.bias = nn.Parameter(zeros) + bias = module.bias.data + if pre_bn_2 is None: + assert pre_bn_1.track_running_stats is True, "Unsupported bn_module.track_running_stats is False" + assert pre_bn_1.affine is True, "Unsupported bn_module.affine is False" + + scale_invstd = pre_bn_1.running_var.add(pre_bn_1.eps).pow(-0.5) + extra_weight = scale_invstd * pre_bn_1.weight + extra_bias = pre_bn_1.bias - pre_bn_1.weight * pre_bn_1.running_mean * scale_invstd + else: + assert pre_bn_1.track_running_stats is True, "Unsupported bn_module.track_running_stats is False" + assert pre_bn_1.affine is True, "Unsupported bn_module.affine is False" + + assert pre_bn_2.track_running_stats is True, "Unsupported bn_module.track_running_stats is False" + assert pre_bn_2.affine is True, "Unsupported bn_module.affine is False" + + scale_invstd_1 = pre_bn_1.running_var.add(pre_bn_1.eps).pow(-0.5) + scale_invstd_2 = pre_bn_2.running_var.add(pre_bn_2.eps).pow(-0.5) + + extra_weight = scale_invstd_1 * pre_bn_1.weight * scale_invstd_2 * pre_bn_2.weight + extra_bias = ( + scale_invstd_2 * pre_bn_2.weight + * (pre_bn_1.bias - pre_bn_1.weight * pre_bn_1.running_mean * scale_invstd_1 - pre_bn_2.running_mean) + + pre_bn_2.bias + ) + + if isinstance(module, nn.Linear): + extra_bias = weight @ extra_bias + weight.mul_(extra_weight.view(1, weight.size(1)).expand_as(weight)) + elif isinstance(module, nn.Conv2d): + assert weight.shape[2] == 1 and weight.shape[3] == 1 + weight = weight.reshape(weight.shape[0], weight.shape[1]) + extra_bias = weight @ extra_bias + weight.mul_(extra_weight.view(1, weight.size(1)).expand_as(weight)) + weight = weight.reshape(weight.shape[0], weight.shape[1], 1, 1) + bias.add_(extra_bias) + + module.weight.data = weight + module.bias.data = bias + + +class ConvNormAct(nn.Module): + def __init__( + self, + in_chs, + out_chs, + kernel_size=3, + stride=1, + groups=1, + norm_layer=nn.BatchNorm2d, + act_layer=nn.ReLU, + ): + super(ConvNormAct, self).__init__() + self.conv = nn.Conv2d( + in_chs, out_chs, kernel_size=kernel_size, stride=stride, + padding=1, groups=groups, bias=False) + self.norm = norm_layer(out_chs) + self.act = act_layer() + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + x = self.act(x) + return x + + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class PatchEmbed(nn.Module): + def __init__(self, + in_chs, + out_chs, + stride=1, + norm_layer = nn.BatchNorm2d, + ): + super(PatchEmbed, self).__init__() + + if stride == 2: + self.pool = nn.AvgPool2d((2, 2), stride=2, ceil_mode=True, count_include_pad=False) + self.conv = nn.Conv2d(in_chs, out_chs, kernel_size=1, stride=1, bias=False) + self.norm = norm_layer(out_chs) + elif in_chs != out_chs: + self.pool = nn.Identity() + self.conv = nn.Conv2d(in_chs, out_chs, kernel_size=1, stride=1, bias=False) + self.norm = norm_layer(out_chs) + else: + self.pool = nn.Identity() + self.conv = nn.Identity() + self.norm = nn.Identity() + + def forward(self, x): + return self.norm(self.conv(self.pool(x))) + + +class ConvAttention(nn.Module): + """ + Multi-Head Convolutional Attention + """ + + def __init__(self, out_chs, head_dim, norm_layer = nn.BatchNorm2d, act_layer = nn.ReLU): + super(ConvAttention, self).__init__() + self.group_conv3x3 = nn.Conv2d( + out_chs, out_chs, + kernel_size=3, stride=1, padding=1, groups=out_chs // head_dim, bias=False + ) + self.norm = norm_layer(out_chs) + self.act = act_layer() + self.projection = nn.Conv2d(out_chs, out_chs, kernel_size=1, bias=False) + + def forward(self, x): + out = self.group_conv3x3(x) + out = self.norm(out) + out = self.act(out) + out = self.projection(out) + return out + +class NextConvBlock(nn.Module): + """ + Next Convolution Block + """ + + def __init__( + self, + in_chs, + out_chs, + stride=1, + drop_path=0., + drop=0., + head_dim=32, + mlp_ratio=3., + norm_layer=nn.BatchNorm2d, + act_layer=nn.ReLU + ): + super(NextConvBlock, self).__init__() + self.in_chs = in_chs + self.out_chs = out_chs + assert out_chs % head_dim == 0 + + self.patch_embed = PatchEmbed(in_chs, out_chs, stride, norm_layer=norm_layer) + self.mhca = ConvAttention( + out_chs, + head_dim, + norm_layer=norm_layer, + act_layer=act_layer, + ) + self.attn_drop_path = DropPath(drop_path) + + self.norm = norm_layer(out_chs) + self.mlp = ConvMlp( + out_chs, + hidden_features=int(out_chs * mlp_ratio), + drop=drop, + bias=True, + act_layer=act_layer, + ) + self.mlp_drop_path = DropPath(drop_path) + self.is_fused = False + + @torch.no_grad() + def reparameterize(self): + if not self.is_fused: + merge_pre_bn(self.mlp.fc1, self.norm) + self.norm = None + self.is_fused = True + + def forward(self, x): + x = self.patch_embed(x) + x = x + self.attn_drop_path(self.mhca(x)) + + out = self.norm(x) + x = x + self.mlp_drop_path(self.mlp(out)) + return x + + +class EfficientAttention(nn.Module): + """ + Efficient Multi-Head Self Attention + """ + fused_attn: torch.jit.Final[bool] + + def __init__( + self, + dim, + out_dim=None, + head_dim=32, + qkv_bias=True, + attn_drop=0., + proj_drop=0., + sr_ratio=1, + norm_layer=nn.BatchNorm1d, + ): + super().__init__() + self.dim = dim + self.out_dim = out_dim if out_dim is not None else dim + self.num_heads = self.dim // head_dim + self.head_dim = head_dim + self.scale = head_dim ** -0.5 + self.fused_attn = use_fused_attn() + + self.q = nn.Linear(dim, self.dim, bias=qkv_bias) + self.k = nn.Linear(dim, self.dim, bias=qkv_bias) + self.v = nn.Linear(dim, self.dim, bias=qkv_bias) + self.proj = nn.Linear(self.dim, self.out_dim) + self.attn_drop = nn.Dropout(attn_drop) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + self.N_ratio = sr_ratio ** 2 + if sr_ratio > 1: + self.sr = nn.AvgPool1d(kernel_size=self.N_ratio, stride=self.N_ratio) + self.norm = norm_layer(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x): + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + + if self.sr is not None: + x = self.sr(x.transpose(1, 2)) + x = self.norm(x).transpose(1, 2) + + k = self.k(x).reshape(B, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = self.v(x).reshape(B, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + q = q * self.scale + attn = q @ k.transpose(-1, -2) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class NextTransformerBlock(nn.Module): + """ + Next Transformer Block + """ + + def __init__( + self, + in_chs, + out_chs, + drop_path, + stride=1, + sr_ratio=1, + mlp_ratio=2, + head_dim=32, + mix_block_ratio=0.75, + attn_drop=0., + drop=0., + norm_layer=nn.BatchNorm2d, + act_layer=nn.ReLU, + ): + super(NextTransformerBlock, self).__init__() + self.in_chs = in_chs + self.out_chs = out_chs + self.mix_block_ratio = mix_block_ratio + + self.mhsa_out_chs = _make_divisible(int(out_chs * mix_block_ratio), 32) + self.mhca_out_chs = out_chs - self.mhsa_out_chs + + self.patch_embed = PatchEmbed(in_chs, self.mhsa_out_chs, stride) + self.norm1 = norm_layer(self.mhsa_out_chs) + self.e_mhsa = EfficientAttention( + self.mhsa_out_chs, + head_dim=head_dim, + sr_ratio=sr_ratio, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.mhsa_drop_path = DropPath(drop_path * mix_block_ratio) + + self.projection = PatchEmbed(self.mhsa_out_chs, self.mhca_out_chs, stride=1, norm_layer=norm_layer) + self.mhca = ConvAttention( + self.mhca_out_chs, + head_dim=head_dim, + norm_layer=norm_layer, + act_layer=act_layer, + ) + self.mhca_drop_path = DropPath(drop_path * (1 - mix_block_ratio)) + + self.norm2 = norm_layer(out_chs) + self.mlp = ConvMlp( + out_chs, + hidden_features=int(out_chs * mlp_ratio), + act_layer=act_layer, + drop=drop, + ) + self.mlp_drop_path = DropPath(drop_path) + self.is_fused = False + + @torch.no_grad() + def reparameterize(self): + if not self.is_fused: + merge_pre_bn(self.e_mhsa.q, self.norm1) + if self.e_mhsa.norm is not None: + merge_pre_bn(self.e_mhsa.k, self.norm1, self.e_mhsa.norm) + merge_pre_bn(self.e_mhsa.v, self.norm1, self.e_mhsa.norm) + self.e_mhsa.norm = nn.Identity() + else: + merge_pre_bn(self.e_mhsa.k, self.norm1) + merge_pre_bn(self.e_mhsa.v, self.norm1) + self.norm1 = nn.Identity() + + merge_pre_bn(self.mlp.fc1, self.norm2) + self.norm2 = nn.Identity() + self.is_fused = True + + def forward(self, x): + x = self.patch_embed(x) + B, C, H, W = x.shape + + out = self.norm1(x) + out = out.reshape(B, C, -1).transpose(-1, -2) + out = self.mhsa_drop_path(self.e_mhsa(out)) + x = x + out.transpose(-1, -2).reshape(B, C, H, W) + + out = self.projection(x) + out = out + self.mhca_drop_path(self.mhca(out)) + x = torch.cat([x, out], dim=1) + + out = self.norm2(x) + x = x + self.mlp_drop_path(self.mlp(out)) + return x + + +class NextStage(nn.Module): + + def __init__( + self, + in_chs, + block_chs, + block_types, + stride=2, + sr_ratio=1, + mix_block_ratio=1.0, + drop=0., + attn_drop=0., + drop_path=0., + head_dim=32, + norm_layer=nn.BatchNorm2d, + act_layer=nn.ReLU, + ): + super().__init__() + self.grad_checkpointing = False + + blocks = [] + for block_idx, block_type in enumerate(block_types): + stride = stride if block_idx == 0 else 1 + out_chs = block_chs[block_idx] + block_type = block_types[block_idx] + dpr = drop_path[block_idx] if isinstance(drop_path, (list, tuple)) else drop_path + if block_type is NextConvBlock: + layer = NextConvBlock( + in_chs, + out_chs, + stride=stride, + drop_path=dpr, + drop=drop, + head_dim=head_dim, + norm_layer=norm_layer, + act_layer=act_layer, + ) + blocks.append(layer) + elif block_type is NextTransformerBlock: + layer = NextTransformerBlock( + in_chs, + out_chs, + drop_path=dpr, + stride=stride, + sr_ratio=sr_ratio, + head_dim=head_dim, + mix_block_ratio=mix_block_ratio, + attn_drop=attn_drop, + drop=drop, + norm_layer=norm_layer, + act_layer=act_layer, + ) + blocks.append(layer) + in_chs = out_chs + + self.blocks = nn.Sequential(*blocks) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + def forward(self, x): + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + +class NextViT(nn.Module): + def __init__( + self, + in_chans, + num_classes=1000, + global_pool='avg', + stem_chs=(64, 32, 64), + depths=(3, 4, 10, 3), + strides=(1, 2, 2, 2), + sr_ratios=(8, 4, 2, 1), + drop_path_rate=0.1, + attn_drop_rate=0., + drop_rate=0., + head_dim=32, + mix_block_ratio=0.75, + norm_layer=nn.BatchNorm2d, + act_layer=None, + ): + super(NextViT, self).__init__() + self.grad_checkpointing = False + self.num_classes = num_classes + norm_layer = get_norm_layer(norm_layer) + if act_layer is None: + act_layer = partial(nn.ReLU, inplace=True) + else: + act_layer = get_act_layer(act_layer) + + self.stage_out_chs = [ + [96] * (depths[0]), + [192] * (depths[1] - 1) + [256], + [384, 384, 384, 384, 512] * (depths[2] // 5), + [768] * (depths[3] - 1) + [1024] + ] + self.feature_info = [dict( + num_chs=sc[-1], + reduction=2**(i + 2), + module=f'stages.{i}' + ) for i, sc in enumerate(self.stage_out_chs)] + + # Next Hybrid Strategy + self.stage_block_types = [ + [NextConvBlock] * depths[0], + [NextConvBlock] * (depths[1] - 1) + [NextTransformerBlock], + [NextConvBlock, NextConvBlock, NextConvBlock, NextConvBlock, NextTransformerBlock] * (depths[2] // 5), + [NextConvBlock] * (depths[3] - 1) + [NextTransformerBlock]] + + self.stem = nn.Sequential( + ConvNormAct(in_chans, stem_chs[0], kernel_size=3, stride=2, norm_layer=norm_layer, act_layer=act_layer), + ConvNormAct(stem_chs[0], stem_chs[1], kernel_size=3, stride=1, norm_layer=norm_layer, act_layer=act_layer), + ConvNormAct(stem_chs[1], stem_chs[2], kernel_size=3, stride=1, norm_layer=norm_layer, act_layer=act_layer), + ConvNormAct(stem_chs[2], stem_chs[2], kernel_size=3, stride=2, norm_layer=norm_layer, act_layer=act_layer), + ) + in_chs = out_chs = stem_chs[-1] + stages = [] + idx = 0 + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + for stage_idx in range(len(depths)): + stage = NextStage( + in_chs=in_chs, + block_chs=self.stage_out_chs[stage_idx], + block_types=self.stage_block_types[stage_idx], + stride=strides[stage_idx], + sr_ratio=sr_ratios[stage_idx], + mix_block_ratio=mix_block_ratio, + head_dim=head_dim, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[stage_idx], + norm_layer=norm_layer, + act_layer=act_layer, + ) + in_chs = out_chs = self.stage_out_chs[stage_idx][-1] + stages += [stage] + idx += depths[stage_idx] + self.num_features = out_chs + self.stages = nn.Sequential(*stages) + self.norm = norm_layer(out_chs) + self.head = ClassifierHead(pool_type=global_pool, in_features=out_chs, num_classes=num_classes) + + self.stage_out_idx = [sum(depths[:idx + 1]) - 1 for idx in range(len(depths))] + self._initialize_weights() + + def _initialize_weights(self): + for n, m in self.named_modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=.02) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', # stem and embed + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+)\.blocks\.(\d+)', None), + (r'^norm', (99999,)), + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + for stage in self.stages: + stage.set_grad_checkpointing(enable=enable) + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool=None): + self.head.reset(num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.stem(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x) + else: + x = self.stages(x) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + """ Remap original checkpoints -> timm """ + if 'head.fc.weight' in state_dict: + return state_dict # non-original + + D = model.state_dict() + out_dict = {} + # remap originals based on order + for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()): + out_dict[ka] = vb + + return out_dict + + +def _create_nextvit(variant, pretrained=False, **kwargs): + default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) + out_indices = kwargs.pop('out_indices', default_out_indices) + + model = build_model_with_cfg( + NextViT, + variant, + pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs) + + return model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.95, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0.conv', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'nextvit_small.bd_in1k': _cfg( + hf_hub_id='timm/', + ), + 'nextvit_base.bd_in1k': _cfg( + hf_hub_id='timm/', + ), + 'nextvit_large.bd_in1k': _cfg( + hf_hub_id='timm/', + ), + 'nextvit_small.bd_in1k_384': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, + ), + 'nextvit_base.bd_in1k_384': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, + ), + 'nextvit_large.bd_in1k_384': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, + ), + + 'nextvit_small.bd_ssld_6m_in1k': _cfg( + hf_hub_id='timm/', + ), + 'nextvit_base.bd_ssld_6m_in1k': _cfg( + hf_hub_id='timm/', + ), + 'nextvit_large.bd_ssld_6m_in1k': _cfg( + hf_hub_id='timm/', + ), + 'nextvit_small.bd_ssld_6m_in1k_384': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, + ), + 'nextvit_base.bd_ssld_6m_in1k_384': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, + ), + 'nextvit_large.bd_ssld_6m_in1k_384': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, + ), +}) + + +@register_model +def nextvit_small(pretrained=False, **kwargs): + model_args = dict(depths=(3, 4, 10, 3), drop_path_rate=0.1) + model = _create_nextvit( + 'nextvit_small', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def nextvit_base(pretrained=False, **kwargs): + model_args = dict(depths=(3, 4, 20, 3), drop_path_rate=0.2) + model = _create_nextvit( + 'nextvit_base', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def nextvit_large(pretrained=False, **kwargs): + model_args = dict(depths=(3, 4, 30, 3), drop_path_rate=0.2) + model = _create_nextvit( + 'nextvit_large', pretrained=pretrained, **dict(model_args, **kwargs)) + return model diff --git a/timm/models/tiny_vit.py b/timm/models/tiny_vit.py index 96a88db7..b4b29648 100644 --- a/timm/models/tiny_vit.py +++ b/timm/models/tiny_vit.py @@ -535,7 +535,7 @@ class TinyVit(nn.Module): def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes - self.head.reset(num_classes, global_pool=global_pool) + self.head.reset(num_classes, pool_type=global_pool) def forward_features(self, x): x = self.patch_embed(x) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 380e3a64..70f91d58 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -421,6 +421,7 @@ class VisionTransformer(nn.Module): attn_drop_rate: float = 0., drop_path_rate: float = 0., weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '', + fix_init: bool = False, embed_layer: Callable = PatchEmbed, norm_layer: Optional[LayerType] = None, act_layer: Optional[LayerType] = None, @@ -449,6 +450,7 @@ class VisionTransformer(nn.Module): attn_drop_rate: Attention dropout rate. drop_path_rate: Stochastic depth rate. weight_init: Weight initialization scheme. + fix_init: Apply weight initialization fix (scaling w/ layer index). embed_layer: Patch embedding layer. norm_layer: Normalization layer. act_layer: MLP activation layer. @@ -536,8 +538,18 @@ class VisionTransformer(nn.Module): if weight_init != 'skip': self.init_weights(weight_init) + if fix_init: + self.fix_init_weight() - def init_weights(self, mode: Literal['jax', 'jax_nlhb', 'moco', ''] = '') -> None: + def fix_init_weight(self): + def rescale(param, _layer_id): + param.div_(math.sqrt(2.0 * _layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def init_weights(self, mode: str = '') -> None: assert mode in ('jax', 'jax_nlhb', 'moco', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. trunc_normal_(self.pos_embed, std=.02) @@ -737,7 +749,7 @@ def init_weights_vit_moco(module: nn.Module, name: str = '') -> None: module.init_weights() -def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> None: +def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable: if 'jax' in mode: return partial(init_weights_vit_jax, head_bias=head_bias) elif 'moco' in mode: @@ -1723,7 +1735,12 @@ default_cfgs = { input_size=(3, 256, 256)), 'vit_medium_patch16_reg4_gap_256': _cfg( input_size=(3, 256, 256)), - 'vit_base_patch16_reg8_gap_256': _cfg(input_size=(3, 256, 256)), + 'vit_base_patch16_reg4_gap_256': _cfg( + input_size=(3, 256, 256)), + 'vit_so150m_patch16_reg4_gap_256': _cfg( + input_size=(3, 256, 256)), + 'vit_so150m_patch16_reg4_map_256': _cfg( + input_size=(3, 256, 256)), } _quick_gelu_cfgs = [ @@ -2623,13 +2640,35 @@ def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio @register_model -def vit_base_patch16_reg8_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: +def vit_base_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, - no_embed_class=True, global_pool='avg', reg_tokens=8, + no_embed_class=True, global_pool='avg', reg_tokens=4, ) model = _create_vision_transformer( - 'vit_base_patch16_reg8_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + 'vit_base_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so150m_patch16_reg4_map_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572, + class_token=False, reg_tokens=4, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_so150m_patch16_reg4_map_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572, + class_token=False, reg_tokens=4, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_so150m_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) return model diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index 2cd37cfe..ea8cf0ea 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -7,7 +7,12 @@ Hacked together by / Copyright 2022, Ross Wightman import logging import math from functools import partial -from typing import Optional, Tuple +from typing import Optional, Tuple, Type, Union + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal import torch import torch.nn as nn @@ -15,9 +20,11 @@ from torch.jit import Final from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn +from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn, LayerType from ._builder import build_model_with_cfg +from ._manipulate import named_apply from ._registry import generate_default_cfgs, register_model +from .vision_transformer import get_init_weights_vit __all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this @@ -215,59 +222,61 @@ class VisionTransformerRelPos(nn.Module): def __init__( self, - img_size=224, - patch_size=16, - in_chans=3, - num_classes=1000, - global_pool='avg', - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4., - qkv_bias=True, - qk_norm=False, - init_values=1e-6, - class_token=False, - fc_norm=False, - rel_pos_type='mlp', - rel_pos_dim=None, - shared_rel_pos=False, - drop_rate=0., - proj_drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - weight_init='skip', - embed_layer=PatchEmbed, - norm_layer=None, - act_layer=None, - block_fn=RelPosBlock + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: Literal['', 'avg', 'token', 'map'] = 'avg', + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_norm: bool = False, + init_values: Optional[float] = 1e-6, + class_token: bool = False, + fc_norm: bool = False, + rel_pos_type: str = 'mlp', + rel_pos_dim: Optional[int] = None, + shared_rel_pos: bool = False, + drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + weight_init: Literal['skip', 'jax', 'moco', ''] = 'skip', + fix_init: bool = False, + embed_layer: Type[nn.Module] = PatchEmbed, + norm_layer: Optional[LayerType] = None, + act_layer: Optional[LayerType] = None, + block_fn: Type[nn.Module] = RelPosBlock ): """ Args: - img_size (int, tuple): input image size - patch_size (int, tuple): patch size - in_chans (int): number of input channels - num_classes (int): number of classes for classification head - global_pool (str): type of global pooling for final sequence (default: 'avg') - embed_dim (int): embedding dimension - depth (int): depth of transformer - num_heads (int): number of attention heads - mlp_ratio (int): ratio of mlp hidden dim to embedding dim - qkv_bias (bool): enable bias for qkv if True - qk_norm (bool): Enable normalization of query and key in attention - init_values: (float): layer-scale init values - class_token (bool): use class token (default: False) - fc_norm (bool): use pre classifier norm instead of pre-pool - rel_pos_ty pe (str): type of relative position - shared_rel_pos (bool): share relative pos across all blocks - drop_rate (float): dropout rate - proj_drop_rate (float): projection dropout rate - attn_drop_rate (float): attention dropout rate - drop_path_rate (float): stochastic depth rate - weight_init (str): weight init scheme - embed_layer (nn.Module): patch embedding layer - norm_layer: (nn.Module): normalization layer - act_layer: (nn.Module): MLP activation layer + img_size: input image size + patch_size: patch size + in_chans: number of input channels + num_classes: number of classes for classification head + global_pool: type of global pooling for final sequence (default: 'avg') + embed_dim: embedding dimension + depth: depth of transformer + num_heads: number of attention heads + mlp_ratio: ratio of mlp hidden dim to embedding dim + qkv_bias: enable bias for qkv if True + qk_norm: Enable normalization of query and key in attention + init_values: layer-scale init values + class_token: use class token (default: False) + fc_norm: use pre classifier norm instead of pre-pool + rel_pos_type: type of relative position + shared_rel_pos: share relative pos across all blocks + drop_rate: dropout rate + proj_drop_rate: projection dropout rate + attn_drop_rate: attention dropout rate + drop_path_rate: stochastic depth rate + weight_init: weight init scheme + fix_init: apply weight initialization fix (scaling w/ layer index) + embed_layer: patch embedding layer + norm_layer: normalization layer + act_layer: MLP activation layer """ super().__init__() assert global_pool in ('', 'avg', 'token') @@ -332,13 +341,22 @@ class VisionTransformerRelPos(nn.Module): if weight_init != 'skip': self.init_weights(weight_init) + if fix_init: + self.fix_init_weight() def init_weights(self, mode=''): assert mode in ('jax', 'moco', '') if self.cls_token is not None: nn.init.normal_(self.cls_token, std=1e-6) - # FIXME weight init scheme using PyTorch defaults curently - #named_apply(get_init_weights_vit(mode, head_bias), self) + named_apply(get_init_weights_vit(mode), self) + + def fix_init_weight(self): + def rescale(param, _layer_id): + param.div_(math.sqrt(2.0 * _layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) @torch.jit.ignore def no_weight_decay(self): diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index 63fcf4c5..4c6a00ca 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -10,6 +10,6 @@ from .log import setup_default_logging, FormatterNoInfo from .metrics import AverageMeter, accuracy from .misc import natural_key, add_bool_arg, ParseKwargs from .model import unwrap_model, get_state_dict, freeze, unfreeze, reparameterize_model -from .model_ema import ModelEma, ModelEmaV2 +from .model_ema import ModelEma, ModelEmaV2, ModelEmaV3 from .random import random_seed from .summary import update_summary, get_outdir diff --git a/timm/utils/distributed.py b/timm/utils/distributed.py index ee9a358c..286e8ba4 100644 --- a/timm/utils/distributed.py +++ b/timm/utils/distributed.py @@ -2,18 +2,17 @@ Hacked together by / Copyright 2020 Ross Wightman """ +import logging import os +from typing import Optional import torch from torch import distributed as dist -try: - import horovod.torch as hvd -except ImportError: - hvd = None - from .model import unwrap_model +_logger = logging.getLogger(__name__) + def reduce_tensor(tensor, n): rt = tensor.clone() @@ -84,9 +83,39 @@ def init_distributed_device(args): args.world_size = 1 args.rank = 0 # global rank args.local_rank = 0 + result = init_distributed_device_so( + device=getattr(args, 'device', 'cuda'), + dist_backend=getattr(args, 'dist_backend', None), + dist_url=getattr(args, 'dist_url', None), + ) + args.device = result['device'] + args.world_size = result['world_size'] + args.rank = result['global_rank'] + args.local_rank = result['local_rank'] + args.distributed = result['distributed'] + device = torch.device(args.device) + return device + + +def init_distributed_device_so( + device: str = 'cuda', + dist_backend: Optional[str] = None, + dist_url: Optional[str] = None, +): + # Distributed training = training on more than one GPU. + # Works in both single and multi-node scenarios. + distributed = False + world_size = 1 + global_rank = 0 + local_rank = 0 + if dist_backend is None: + # FIXME sane defaults for other device backends? + dist_backend = 'nccl' if 'cuda' in device else 'gloo' + dist_url = dist_url or 'env://' # TBD, support horovod? # if args.horovod: + # import horovod.torch as hvd # assert hvd is not None, "Horovod is not installed" # hvd.init() # args.local_rank = int(hvd.local_rank()) @@ -96,42 +125,51 @@ def init_distributed_device(args): # os.environ['LOCAL_RANK'] = str(args.local_rank) # os.environ['RANK'] = str(args.rank) # os.environ['WORLD_SIZE'] = str(args.world_size) - dist_backend = getattr(args, 'dist_backend', 'nccl') - dist_url = getattr(args, 'dist_url', 'env://') if is_distributed_env(): if 'SLURM_PROCID' in os.environ: # DDP via SLURM - args.local_rank, args.rank, args.world_size = world_info_from_env() + local_rank, global_rank, world_size = world_info_from_env() # SLURM var -> torch.distributed vars in case needed - os.environ['LOCAL_RANK'] = str(args.local_rank) - os.environ['RANK'] = str(args.rank) - os.environ['WORLD_SIZE'] = str(args.world_size) + os.environ['LOCAL_RANK'] = str(local_rank) + os.environ['RANK'] = str(global_rank) + os.environ['WORLD_SIZE'] = str(world_size) torch.distributed.init_process_group( backend=dist_backend, init_method=dist_url, - world_size=args.world_size, - rank=args.rank, + world_size=world_size, + rank=global_rank, ) else: # DDP via torchrun, torch.distributed.launch - args.local_rank, _, _ = world_info_from_env() + local_rank, _, _ = world_info_from_env() torch.distributed.init_process_group( backend=dist_backend, init_method=dist_url, ) - args.world_size = torch.distributed.get_world_size() - args.rank = torch.distributed.get_rank() - args.distributed = True + world_size = torch.distributed.get_world_size() + global_rank = torch.distributed.get_rank() + distributed = True - if torch.cuda.is_available(): - if args.distributed: - device = 'cuda:%d' % args.local_rank - else: - device = 'cuda:0' + if 'cuda' in device: + assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.' + + if distributed and device != 'cpu': + device, *device_idx = device.split(':', maxsplit=1) + + # Ignore manually specified device index in distributed mode and + # override with resolved local rank, fewer headaches in most setups. + if device_idx: + _logger.warning(f'device index {device_idx[0]} removed from specified ({device}).') + + device = f'{device}:{local_rank}' + + if device.startswith('cuda:'): torch.cuda.set_device(device) - else: - device = 'cpu' - args.device = device - device = torch.device(device) - return device + return dict( + device=device, + global_rank=global_rank, + local_rank=local_rank, + world_size=world_size, + distributed=distributed, + ) diff --git a/timm/utils/model_ema.py b/timm/utils/model_ema.py index 073d5c5e..3e491675 100644 --- a/timm/utils/model_ema.py +++ b/timm/utils/model_ema.py @@ -5,6 +5,7 @@ Hacked together by / Copyright 2020 Ross Wightman import logging from collections import OrderedDict from copy import deepcopy +from typing import Optional import torch import torch.nn as nn @@ -103,7 +104,7 @@ class ModelEmaV2(nn.Module): GPU assignment and distributed training wrappers. """ def __init__(self, model, decay=0.9999, device=None): - super(ModelEmaV2, self).__init__() + super().__init__() # make a copy of the model for accumulating moving average of weights self.module = deepcopy(model) self.module.eval() @@ -124,3 +125,136 @@ class ModelEmaV2(nn.Module): def set(self, model): self._update(model, update_fn=lambda e, m: m) + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) + + +class ModelEmaV3(nn.Module): + """ Model Exponential Moving Average V3 + + Keep a moving average of everything in the model state_dict (parameters and buffers). + V3 of this module leverages for_each and in-place operations for faster performance. + + Decay warmup based on code by @crowsonkb, her comments: + If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are + good values for models you plan to train for a million or more steps (reaches decay + factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models + you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at + 215.4k steps). + + This is intended to allow functionality like + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + + To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but + disable validation of the EMA weights. Validation will have to be done manually in a separate + process, or after the training stops converging. + + This class is sensitive where it is initialized in the sequence of model init, + GPU assignment and distributed training wrappers. + """ + def __init__( + self, + model, + decay: float = 0.9999, + min_decay: float = 0.0, + update_after_step: int = 0, + use_warmup: bool = False, + warmup_gamma: float = 1.0, + warmup_power: float = 2/3, + device: Optional[torch.device] = None, + foreach: bool = True, + exclude_buffers: bool = False, + ): + super().__init__() + # make a copy of the model for accumulating moving average of weights + self.module = deepcopy(model) + self.module.eval() + self.decay = decay + self.min_decay = min_decay + self.update_after_step = update_after_step + self.use_warmup = use_warmup + self.warmup_gamma = warmup_gamma + self.warmup_power = warmup_power + self.foreach = foreach + self.device = device # perform ema on different device from model if set + self.exclude_buffers = exclude_buffers + if self.device is not None and device != next(model.parameters()).device: + self.foreach = False # cannot use foreach methods with different devices + self.module.to(device=device) + + def get_decay(self, step: Optional[int] = None) -> float: + """ + Compute the decay factor for the exponential moving average. + """ + if step is None: + return self.decay + + step = max(0, step - self.update_after_step - 1) + if step <= 0: + return 0.0 + + if self.use_warmup: + decay = 1 - (1 + step / self.warmup_gamma) ** -self.warmup_power + decay = max(min(decay, self.decay), self.min_decay) + else: + decay = self.decay + + return decay + + @torch.no_grad() + def update(self, model, step: Optional[int] = None): + decay = self.get_decay(step) + if self.exclude_buffers: + self.apply_update_no_buffers_(model, decay) + else: + self.apply_update_(model, decay) + + def apply_update_(self, model, decay: float): + # interpolate parameters and buffers + if self.foreach: + ema_lerp_values = [] + model_lerp_values = [] + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + if ema_v.is_floating_point(): + ema_lerp_values.append(ema_v) + model_lerp_values.append(model_v) + else: + ema_v.copy_(model_v) + + if hasattr(torch, '_foreach_lerp_'): + torch._foreach_lerp_(ema_lerp_values, model_lerp_values, weight=1. - decay) + else: + torch._foreach_mul_(ema_lerp_values, scalar=decay) + torch._foreach_add_(ema_lerp_values, model_lerp_values, alpha=1. - decay) + else: + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + if ema_v.is_floating_point(): + ema_v.lerp_(model_v, weight=1. - decay) + else: + ema_v.copy_(model_v) + + def apply_update_no_buffers_(self, model, decay: float): + # interpolate parameters, copy buffers + ema_params = tuple(self.module.parameters()) + model_params = tuple(model.parameters()) + if self.foreach: + if hasattr(torch, '_foreach_lerp_'): + torch._foreach_lerp_(ema_params, model_params, weight=1. - decay) + else: + torch._foreach_mul_(ema_params, scalar=decay) + torch._foreach_add_(ema_params, model_params, alpha=1 - decay) + else: + for ema_p, model_p in zip(ema_params, model_params): + ema_p.lerp_(model_p, weight=1. - decay) + + for ema_b, model_b in zip(self.module.buffers(), model.buffers()): + ema_b.copy_(model_b.to(device=self.device)) + + @torch.no_grad() + def set(self, model): + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + ema_v.copy_(model_v.to(device=self.device)) + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) \ No newline at end of file diff --git a/train.py b/train.py index ba917773..539dff3d 100755 --- a/train.py +++ b/train.py @@ -15,6 +15,7 @@ NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) """ import argparse +import importlib import json import logging import os @@ -168,6 +169,24 @@ scripting_group.add_argument('--torchscript', dest='torchscript', action='store_ scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor', help="Enable compilation w/ specified backend (default: inductor).") +# Device & distributed +group = parser.add_argument_group('Device parameters') +group.add_argument('--device', default='cuda', type=str, + help="Device (accelerator) to use.") +group.add_argument('--amp', action='store_true', default=False, + help='use NVIDIA Apex AMP or Native AMP for mixed precision training') +group.add_argument('--amp-dtype', default='float16', type=str, + help='lower precision AMP dtype (default: float16)') +group.add_argument('--amp-impl', default='native', type=str, + help='AMP impl to use, "native" or "apex" (default: native)') +group.add_argument('--no-ddp-bb', action='store_true', default=False, + help='Force broadcast buffers for native DDP to off.') +group.add_argument('--synchronize-step', action='store_true', default=False, + help='torch.cuda.synchronize() end of each step') +group.add_argument("--local_rank", default=0, type=int) +parser.add_argument('--device-modules', default=None, type=str, nargs='+', + help="Python imports for device backend modules.") + # Optimizer parameters group = parser.add_argument_group('Optimizer parameters') group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', @@ -330,11 +349,13 @@ group.add_argument('--split-bn', action='store_true', # Model Exponential Moving Average group = parser.add_argument_group('Model exponential moving average parameters') group.add_argument('--model-ema', action='store_true', default=False, - help='Enable tracking moving average of model weights') + help='Enable tracking moving average of model weights.') group.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') group.add_argument('--model-ema-decay', type=float, default=0.9998, - help='decay factor for model weights moving average (default: 0.9998)') + help='Decay factor for model weights moving average (default: 0.9998)') +group.add_argument('--model-ema-warmup', action='store_true', + help='Enable warmup for model EMA decay.') # Misc group = parser.add_argument_group('Miscellaneous parameters') @@ -352,16 +373,6 @@ group.add_argument('-j', '--workers', type=int, default=4, metavar='N', help='how many training processes to use (default: 4)') group.add_argument('--save-images', action='store_true', default=False, help='save images of input bathes every log interval for debugging') -group.add_argument('--amp', action='store_true', default=False, - help='use NVIDIA Apex AMP or Native AMP for mixed precision training') -group.add_argument('--amp-dtype', default='float16', type=str, - help='lower precision AMP dtype (default: float16)') -group.add_argument('--amp-impl', default='native', type=str, - help='AMP impl to use, "native" or "apex" (default: native)') -group.add_argument('--no-ddp-bb', action='store_true', default=False, - help='Force broadcast buffers for native DDP to off.') -group.add_argument('--synchronize-step', action='store_true', default=False, - help='torch.cuda.synchronize() end of each step') group.add_argument('--pin-mem', action='store_true', default=False, help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') group.add_argument('--no-prefetcher', action='store_true', default=False, @@ -374,7 +385,6 @@ group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METR help='Best metric (default: "top1"') group.add_argument('--tta', type=int, default=0, metavar='N', help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') -group.add_argument("--local_rank", default=0, type=int) group.add_argument('--use-multi-epochs-loader', action='store_true', default=False, help='use the multi-epochs-loader to save time at the beginning of every epoch') group.add_argument('--log-wandb', action='store_true', default=False, @@ -402,6 +412,10 @@ def main(): utils.setup_default_logging() args, args_text = _parse_args() + if args.device_modules: + for module in args.device_modules: + importlib.import_module(module) + if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True @@ -586,10 +600,16 @@ def main(): model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper - model_ema = utils.ModelEmaV2( - model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) + model_ema = utils.ModelEmaV3( + model, + decay=args.model_ema_decay, + use_warmup=args.model_ema_warmup, + device='cpu' if args.model_ema_force_cpu else None, + ) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) + if args.torchcompile: + model_ema = torch.compile(model_ema, backend=args.torchcompile) # setup distributed training if args.distributed: @@ -847,6 +867,7 @@ def main(): loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, + num_updates_total=num_epochs * updates_per_epoch, ) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): @@ -860,6 +881,7 @@ def main(): loader_eval, validate_loss_fn, args, + device=device, amp_autocast=amp_autocast, ) @@ -868,10 +890,11 @@ def main(): utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate( - model_ema.module, + model_ema, loader_eval, validate_loss_fn, args, + device=device, amp_autocast=amp_autocast, log_suffix=' (EMA)', ) @@ -935,6 +958,7 @@ def train_one_epoch( loss_scaler=None, model_ema=None, mixup_fn=None, + num_updates_total=None, ): if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.prefetcher and loader.mixup_enabled: @@ -1026,7 +1050,7 @@ def train_one_epoch( num_updates += 1 optimizer.zero_grad() if model_ema is not None: - model_ema.update(model) + model_ema.update(model, step=num_updates) if args.synchronize_step and device.type == 'cuda': torch.cuda.synchronize()