Some RepVit tweaks

* add head dropout to RepVit as all models have that arg
* default train to non-distilled head output via distilled_training flag (set_distilled_training) so fine-tune works by default w/o distillation script
* camel case naming tweaks to match other models
This commit is contained in:
Ross Wightman 2023-08-09 12:40:26 -07:00
parent f6771909ff
commit c692715388

View File

@ -15,7 +15,7 @@ Paper: `RepViT: Revisiting Mobile CNN From ViT Perspective`
Adapted from official impl at https://github.com/jameslahm/RepViT
"""
__all__ = ['RepViT']
__all__ = ['RepVit']
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
@ -81,7 +81,7 @@ class NormLinear(nn.Sequential):
return m
class RepVGGDW(nn.Module):
class RepVggDw(nn.Module):
def __init__(self, ed, kernel_size):
super().__init__()
self.conv = ConvNorm(ed, ed, kernel_size, 1, (kernel_size - 1) // 2, groups=ed)
@ -115,7 +115,7 @@ class RepVGGDW(nn.Module):
return conv
class RepViTMlp(nn.Module):
class RepVitMlp(nn.Module):
def __init__(self, in_dim, hidden_dim, act_layer):
super().__init__()
self.conv1 = ConvNorm(in_dim, hidden_dim, 1, 1, 0)
@ -130,9 +130,9 @@ class RepViTBlock(nn.Module):
def __init__(self, in_dim, mlp_ratio, kernel_size, use_se, act_layer):
super(RepViTBlock, self).__init__()
self.token_mixer = RepVGGDW(in_dim, kernel_size)
self.token_mixer = RepVggDw(in_dim, kernel_size)
self.se = SqueezeExcite(in_dim, 0.25) if use_se else nn.Identity()
self.channel_mixer = RepViTMlp(in_dim, in_dim * mlp_ratio, act_layer)
self.channel_mixer = RepVitMlp(in_dim, in_dim * mlp_ratio, act_layer)
def forward(self, x):
x = self.token_mixer(x)
@ -142,7 +142,7 @@ class RepViTBlock(nn.Module):
return identity + x
class RepViTStem(nn.Module):
class RepVitStem(nn.Module):
def __init__(self, in_chs, out_chs, act_layer):
super().__init__()
self.conv1 = ConvNorm(in_chs, out_chs // 2, 3, 2, 1)
@ -154,13 +154,13 @@ class RepViTStem(nn.Module):
return self.conv2(self.act1(self.conv1(x)))
class RepViTDownsample(nn.Module):
class RepVitDownsample(nn.Module):
def __init__(self, in_dim, mlp_ratio, out_dim, kernel_size, act_layer):
super().__init__()
self.pre_block = RepViTBlock(in_dim, mlp_ratio, kernel_size, use_se=False, act_layer=act_layer)
self.spatial_downsample = ConvNorm(in_dim, in_dim, kernel_size, 2, (kernel_size - 1) // 2, groups=in_dim)
self.channel_downsample = ConvNorm(in_dim, out_dim, 1, 1)
self.ffn = RepViTMlp(out_dim, out_dim * mlp_ratio, act_layer)
self.ffn = RepVitMlp(out_dim, out_dim * mlp_ratio, act_layer)
def forward(self, x):
x = self.pre_block(x)
@ -171,22 +171,25 @@ class RepViTDownsample(nn.Module):
return x + identity
class RepViTClassifier(nn.Module):
def __init__(self, dim, num_classes, distillation=False):
class RepVitClassifier(nn.Module):
def __init__(self, dim, num_classes, distillation=False, drop=0.):
super().__init__()
self.head_drop = nn.Dropout(drop)
self.head = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
self.distillation = distillation
self.num_classes=num_classes
self.distilled_training = False
self.num_classes = num_classes
if distillation:
self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x):
x = self.head_drop(x)
if self.distillation:
x1, x2 = self.head(x), self.head_dist(x)
if (not self.training) or torch.jit.is_scripting():
return (x1 + x2) / 2
else:
if self.training and self.distilled_training and not torch.jit.is_scripting():
return x1, x2
else:
return (x1 + x2) / 2
else:
x = self.head(x)
return x
@ -207,11 +210,11 @@ class RepViTClassifier(nn.Module):
return head
class RepViTStage(nn.Module):
class RepVitStage(nn.Module):
def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3, downsample=True):
super().__init__()
if downsample:
self.downsample = RepViTDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer)
self.downsample = RepVitDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer)
else:
assert in_dim == out_dim
self.downsample = nn.Identity()
@ -230,7 +233,7 @@ class RepViTStage(nn.Module):
return x
class RepViT(nn.Module):
class RepVit(nn.Module):
def __init__(
self,
in_chans=3,
@ -243,15 +246,16 @@ class RepViT(nn.Module):
num_classes=1000,
act_layer=nn.GELU,
distillation=True,
drop_rate=0.,
):
super(RepViT, self).__init__()
super(RepVit, self).__init__()
self.grad_checkpointing = False
self.global_pool = global_pool
self.embed_dim = embed_dim
self.num_classes = num_classes
in_dim = embed_dim[0]
self.stem = RepViTStem(in_chans, in_dim, act_layer)
self.stem = RepVitStem(in_chans, in_dim, act_layer)
stride = self.stem.stride
resolution = tuple([i // p for i, p in zip(to_2tuple(img_size), to_2tuple(stride))])
@ -263,7 +267,7 @@ class RepViT(nn.Module):
for i in range(num_stages):
downsample = True if i != 0 else False
stages.append(
RepViTStage(
RepVitStage(
in_dim,
embed_dim[i],
depth[i],
@ -281,7 +285,8 @@ class RepViT(nn.Module):
self.stages = nn.Sequential(*stages)
self.num_features = embed_dim[-1]
self.head = RepViTClassifier(embed_dim[-1], num_classes, distillation)
self.head_drop = nn.Dropout(drop_rate)
self.head = RepVitClassifier(embed_dim[-1], num_classes, distillation)
@torch.jit.ignore
def group_matcher(self, coarse=False):
@ -304,9 +309,13 @@ class RepViT(nn.Module):
if global_pool is not None:
self.global_pool = global_pool
self.head = (
RepViTClassifier(self.embed_dim[-1], num_classes, distillation) if num_classes > 0 else nn.Identity()
RepVitClassifier(self.embed_dim[-1], num_classes, distillation) if num_classes > 0 else nn.Identity()
)
@torch.jit.ignore
def set_distilled_training(self, enable=True):
self.head.distilled_training = enable
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
@ -317,8 +326,9 @@ class RepViT(nn.Module):
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg':
x = nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)
return x if pre_logits else self.head(x)
x = x.mean((2, 3), keepdim=False)
x = self.head_drop(x)
return self.head(x)
def forward(self, x):
x = self.forward_features(x)
@ -373,7 +383,9 @@ default_cfgs = generate_default_cfgs(
def _create_repvit(variant, pretrained=False, **kwargs):
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
model = build_model_with_cfg(
RepViT, variant, pretrained, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs
RepVit, variant, pretrained,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs,
)
return model