Some RepVit tweaks

* add head dropout to RepVit as all models have that arg
* default train to non-distilled head output via distilled_training flag (set_distilled_training) so fine-tune works by default w/o distillation script
* camel case naming tweaks to match other models
This commit is contained in:
Ross Wightman 2023-08-09 12:40:26 -07:00
parent f6771909ff
commit c692715388

View File

@ -15,7 +15,7 @@ Paper: `RepViT: Revisiting Mobile CNN From ViT Perspective`
Adapted from official impl at https://github.com/jameslahm/RepViT Adapted from official impl at https://github.com/jameslahm/RepViT
""" """
__all__ = ['RepViT'] __all__ = ['RepVit']
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
@ -81,7 +81,7 @@ class NormLinear(nn.Sequential):
return m return m
class RepVGGDW(nn.Module): class RepVggDw(nn.Module):
def __init__(self, ed, kernel_size): def __init__(self, ed, kernel_size):
super().__init__() super().__init__()
self.conv = ConvNorm(ed, ed, kernel_size, 1, (kernel_size - 1) // 2, groups=ed) self.conv = ConvNorm(ed, ed, kernel_size, 1, (kernel_size - 1) // 2, groups=ed)
@ -115,7 +115,7 @@ class RepVGGDW(nn.Module):
return conv return conv
class RepViTMlp(nn.Module): class RepVitMlp(nn.Module):
def __init__(self, in_dim, hidden_dim, act_layer): def __init__(self, in_dim, hidden_dim, act_layer):
super().__init__() super().__init__()
self.conv1 = ConvNorm(in_dim, hidden_dim, 1, 1, 0) self.conv1 = ConvNorm(in_dim, hidden_dim, 1, 1, 0)
@ -130,9 +130,9 @@ class RepViTBlock(nn.Module):
def __init__(self, in_dim, mlp_ratio, kernel_size, use_se, act_layer): def __init__(self, in_dim, mlp_ratio, kernel_size, use_se, act_layer):
super(RepViTBlock, self).__init__() super(RepViTBlock, self).__init__()
self.token_mixer = RepVGGDW(in_dim, kernel_size) self.token_mixer = RepVggDw(in_dim, kernel_size)
self.se = SqueezeExcite(in_dim, 0.25) if use_se else nn.Identity() self.se = SqueezeExcite(in_dim, 0.25) if use_se else nn.Identity()
self.channel_mixer = RepViTMlp(in_dim, in_dim * mlp_ratio, act_layer) self.channel_mixer = RepVitMlp(in_dim, in_dim * mlp_ratio, act_layer)
def forward(self, x): def forward(self, x):
x = self.token_mixer(x) x = self.token_mixer(x)
@ -142,7 +142,7 @@ class RepViTBlock(nn.Module):
return identity + x return identity + x
class RepViTStem(nn.Module): class RepVitStem(nn.Module):
def __init__(self, in_chs, out_chs, act_layer): def __init__(self, in_chs, out_chs, act_layer):
super().__init__() super().__init__()
self.conv1 = ConvNorm(in_chs, out_chs // 2, 3, 2, 1) self.conv1 = ConvNorm(in_chs, out_chs // 2, 3, 2, 1)
@ -154,13 +154,13 @@ class RepViTStem(nn.Module):
return self.conv2(self.act1(self.conv1(x))) return self.conv2(self.act1(self.conv1(x)))
class RepViTDownsample(nn.Module): class RepVitDownsample(nn.Module):
def __init__(self, in_dim, mlp_ratio, out_dim, kernel_size, act_layer): def __init__(self, in_dim, mlp_ratio, out_dim, kernel_size, act_layer):
super().__init__() super().__init__()
self.pre_block = RepViTBlock(in_dim, mlp_ratio, kernel_size, use_se=False, act_layer=act_layer) self.pre_block = RepViTBlock(in_dim, mlp_ratio, kernel_size, use_se=False, act_layer=act_layer)
self.spatial_downsample = ConvNorm(in_dim, in_dim, kernel_size, 2, (kernel_size - 1) // 2, groups=in_dim) self.spatial_downsample = ConvNorm(in_dim, in_dim, kernel_size, 2, (kernel_size - 1) // 2, groups=in_dim)
self.channel_downsample = ConvNorm(in_dim, out_dim, 1, 1) self.channel_downsample = ConvNorm(in_dim, out_dim, 1, 1)
self.ffn = RepViTMlp(out_dim, out_dim * mlp_ratio, act_layer) self.ffn = RepVitMlp(out_dim, out_dim * mlp_ratio, act_layer)
def forward(self, x): def forward(self, x):
x = self.pre_block(x) x = self.pre_block(x)
@ -171,22 +171,25 @@ class RepViTDownsample(nn.Module):
return x + identity return x + identity
class RepViTClassifier(nn.Module): class RepVitClassifier(nn.Module):
def __init__(self, dim, num_classes, distillation=False): def __init__(self, dim, num_classes, distillation=False, drop=0.):
super().__init__() super().__init__()
self.head_drop = nn.Dropout(drop)
self.head = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity() self.head = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
self.distillation = distillation self.distillation = distillation
self.num_classes=num_classes self.distilled_training = False
self.num_classes = num_classes
if distillation: if distillation:
self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity() self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x): def forward(self, x):
x = self.head_drop(x)
if self.distillation: if self.distillation:
x1, x2 = self.head(x), self.head_dist(x) x1, x2 = self.head(x), self.head_dist(x)
if (not self.training) or torch.jit.is_scripting(): if self.training and self.distilled_training and not torch.jit.is_scripting():
return (x1 + x2) / 2
else:
return x1, x2 return x1, x2
else:
return (x1 + x2) / 2
else: else:
x = self.head(x) x = self.head(x)
return x return x
@ -207,11 +210,11 @@ class RepViTClassifier(nn.Module):
return head return head
class RepViTStage(nn.Module): class RepVitStage(nn.Module):
def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3, downsample=True): def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3, downsample=True):
super().__init__() super().__init__()
if downsample: if downsample:
self.downsample = RepViTDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer) self.downsample = RepVitDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer)
else: else:
assert in_dim == out_dim assert in_dim == out_dim
self.downsample = nn.Identity() self.downsample = nn.Identity()
@ -230,7 +233,7 @@ class RepViTStage(nn.Module):
return x return x
class RepViT(nn.Module): class RepVit(nn.Module):
def __init__( def __init__(
self, self,
in_chans=3, in_chans=3,
@ -243,15 +246,16 @@ class RepViT(nn.Module):
num_classes=1000, num_classes=1000,
act_layer=nn.GELU, act_layer=nn.GELU,
distillation=True, distillation=True,
drop_rate=0.,
): ):
super(RepViT, self).__init__() super(RepVit, self).__init__()
self.grad_checkpointing = False self.grad_checkpointing = False
self.global_pool = global_pool self.global_pool = global_pool
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_classes = num_classes self.num_classes = num_classes
in_dim = embed_dim[0] in_dim = embed_dim[0]
self.stem = RepViTStem(in_chans, in_dim, act_layer) self.stem = RepVitStem(in_chans, in_dim, act_layer)
stride = self.stem.stride stride = self.stem.stride
resolution = tuple([i // p for i, p in zip(to_2tuple(img_size), to_2tuple(stride))]) resolution = tuple([i // p for i, p in zip(to_2tuple(img_size), to_2tuple(stride))])
@ -263,7 +267,7 @@ class RepViT(nn.Module):
for i in range(num_stages): for i in range(num_stages):
downsample = True if i != 0 else False downsample = True if i != 0 else False
stages.append( stages.append(
RepViTStage( RepVitStage(
in_dim, in_dim,
embed_dim[i], embed_dim[i],
depth[i], depth[i],
@ -281,7 +285,8 @@ class RepViT(nn.Module):
self.stages = nn.Sequential(*stages) self.stages = nn.Sequential(*stages)
self.num_features = embed_dim[-1] self.num_features = embed_dim[-1]
self.head = RepViTClassifier(embed_dim[-1], num_classes, distillation) self.head_drop = nn.Dropout(drop_rate)
self.head = RepVitClassifier(embed_dim[-1], num_classes, distillation)
@torch.jit.ignore @torch.jit.ignore
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
@ -304,9 +309,13 @@ class RepViT(nn.Module):
if global_pool is not None: if global_pool is not None:
self.global_pool = global_pool self.global_pool = global_pool
self.head = ( self.head = (
RepViTClassifier(self.embed_dim[-1], num_classes, distillation) if num_classes > 0 else nn.Identity() RepVitClassifier(self.embed_dim[-1], num_classes, distillation) if num_classes > 0 else nn.Identity()
) )
@torch.jit.ignore
def set_distilled_training(self, enable=True):
self.head.distilled_training = enable
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():
@ -317,8 +326,9 @@ class RepViT(nn.Module):
def forward_head(self, x, pre_logits: bool = False): def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg': if self.global_pool == 'avg':
x = nn.functional.adaptive_avg_pool2d(x, 1).flatten(1) x = x.mean((2, 3), keepdim=False)
return x if pre_logits else self.head(x) x = self.head_drop(x)
return self.head(x)
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
@ -373,7 +383,9 @@ default_cfgs = generate_default_cfgs(
def _create_repvit(variant, pretrained=False, **kwargs): def _create_repvit(variant, pretrained=False, **kwargs):
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3)) out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
model = build_model_with_cfg( model = build_model_with_cfg(
RepViT, variant, pretrained, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs RepVit, variant, pretrained,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs,
) )
return model return model