mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Some RepVit tweaks
* add head dropout to RepVit as all models have that arg * default train to non-distilled head output via distilled_training flag (set_distilled_training) so fine-tune works by default w/o distillation script * camel case naming tweaks to match other models
This commit is contained in:
parent
f6771909ff
commit
c692715388
@ -15,7 +15,7 @@ Paper: `RepViT: Revisiting Mobile CNN From ViT Perspective`
|
||||
Adapted from official impl at https://github.com/jameslahm/RepViT
|
||||
"""
|
||||
|
||||
__all__ = ['RepViT']
|
||||
__all__ = ['RepVit']
|
||||
|
||||
import torch.nn as nn
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
@ -81,7 +81,7 @@ class NormLinear(nn.Sequential):
|
||||
return m
|
||||
|
||||
|
||||
class RepVGGDW(nn.Module):
|
||||
class RepVggDw(nn.Module):
|
||||
def __init__(self, ed, kernel_size):
|
||||
super().__init__()
|
||||
self.conv = ConvNorm(ed, ed, kernel_size, 1, (kernel_size - 1) // 2, groups=ed)
|
||||
@ -115,7 +115,7 @@ class RepVGGDW(nn.Module):
|
||||
return conv
|
||||
|
||||
|
||||
class RepViTMlp(nn.Module):
|
||||
class RepVitMlp(nn.Module):
|
||||
def __init__(self, in_dim, hidden_dim, act_layer):
|
||||
super().__init__()
|
||||
self.conv1 = ConvNorm(in_dim, hidden_dim, 1, 1, 0)
|
||||
@ -130,9 +130,9 @@ class RepViTBlock(nn.Module):
|
||||
def __init__(self, in_dim, mlp_ratio, kernel_size, use_se, act_layer):
|
||||
super(RepViTBlock, self).__init__()
|
||||
|
||||
self.token_mixer = RepVGGDW(in_dim, kernel_size)
|
||||
self.token_mixer = RepVggDw(in_dim, kernel_size)
|
||||
self.se = SqueezeExcite(in_dim, 0.25) if use_se else nn.Identity()
|
||||
self.channel_mixer = RepViTMlp(in_dim, in_dim * mlp_ratio, act_layer)
|
||||
self.channel_mixer = RepVitMlp(in_dim, in_dim * mlp_ratio, act_layer)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.token_mixer(x)
|
||||
@ -142,7 +142,7 @@ class RepViTBlock(nn.Module):
|
||||
return identity + x
|
||||
|
||||
|
||||
class RepViTStem(nn.Module):
|
||||
class RepVitStem(nn.Module):
|
||||
def __init__(self, in_chs, out_chs, act_layer):
|
||||
super().__init__()
|
||||
self.conv1 = ConvNorm(in_chs, out_chs // 2, 3, 2, 1)
|
||||
@ -154,13 +154,13 @@ class RepViTStem(nn.Module):
|
||||
return self.conv2(self.act1(self.conv1(x)))
|
||||
|
||||
|
||||
class RepViTDownsample(nn.Module):
|
||||
class RepVitDownsample(nn.Module):
|
||||
def __init__(self, in_dim, mlp_ratio, out_dim, kernel_size, act_layer):
|
||||
super().__init__()
|
||||
self.pre_block = RepViTBlock(in_dim, mlp_ratio, kernel_size, use_se=False, act_layer=act_layer)
|
||||
self.spatial_downsample = ConvNorm(in_dim, in_dim, kernel_size, 2, (kernel_size - 1) // 2, groups=in_dim)
|
||||
self.channel_downsample = ConvNorm(in_dim, out_dim, 1, 1)
|
||||
self.ffn = RepViTMlp(out_dim, out_dim * mlp_ratio, act_layer)
|
||||
self.ffn = RepVitMlp(out_dim, out_dim * mlp_ratio, act_layer)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pre_block(x)
|
||||
@ -171,22 +171,25 @@ class RepViTDownsample(nn.Module):
|
||||
return x + identity
|
||||
|
||||
|
||||
class RepViTClassifier(nn.Module):
|
||||
def __init__(self, dim, num_classes, distillation=False):
|
||||
class RepVitClassifier(nn.Module):
|
||||
def __init__(self, dim, num_classes, distillation=False, drop=0.):
|
||||
super().__init__()
|
||||
self.head_drop = nn.Dropout(drop)
|
||||
self.head = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.distillation = distillation
|
||||
self.num_classes=num_classes
|
||||
self.distilled_training = False
|
||||
self.num_classes = num_classes
|
||||
if distillation:
|
||||
self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.head_drop(x)
|
||||
if self.distillation:
|
||||
x1, x2 = self.head(x), self.head_dist(x)
|
||||
if (not self.training) or torch.jit.is_scripting():
|
||||
return (x1 + x2) / 2
|
||||
else:
|
||||
if self.training and self.distilled_training and not torch.jit.is_scripting():
|
||||
return x1, x2
|
||||
else:
|
||||
return (x1 + x2) / 2
|
||||
else:
|
||||
x = self.head(x)
|
||||
return x
|
||||
@ -207,11 +210,11 @@ class RepViTClassifier(nn.Module):
|
||||
return head
|
||||
|
||||
|
||||
class RepViTStage(nn.Module):
|
||||
class RepVitStage(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3, downsample=True):
|
||||
super().__init__()
|
||||
if downsample:
|
||||
self.downsample = RepViTDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer)
|
||||
self.downsample = RepVitDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer)
|
||||
else:
|
||||
assert in_dim == out_dim
|
||||
self.downsample = nn.Identity()
|
||||
@ -230,7 +233,7 @@ class RepViTStage(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class RepViT(nn.Module):
|
||||
class RepVit(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chans=3,
|
||||
@ -243,15 +246,16 @@ class RepViT(nn.Module):
|
||||
num_classes=1000,
|
||||
act_layer=nn.GELU,
|
||||
distillation=True,
|
||||
drop_rate=0.,
|
||||
):
|
||||
super(RepViT, self).__init__()
|
||||
super(RepVit, self).__init__()
|
||||
self.grad_checkpointing = False
|
||||
self.global_pool = global_pool
|
||||
self.embed_dim = embed_dim
|
||||
self.num_classes = num_classes
|
||||
|
||||
in_dim = embed_dim[0]
|
||||
self.stem = RepViTStem(in_chans, in_dim, act_layer)
|
||||
self.stem = RepVitStem(in_chans, in_dim, act_layer)
|
||||
stride = self.stem.stride
|
||||
resolution = tuple([i // p for i, p in zip(to_2tuple(img_size), to_2tuple(stride))])
|
||||
|
||||
@ -263,7 +267,7 @@ class RepViT(nn.Module):
|
||||
for i in range(num_stages):
|
||||
downsample = True if i != 0 else False
|
||||
stages.append(
|
||||
RepViTStage(
|
||||
RepVitStage(
|
||||
in_dim,
|
||||
embed_dim[i],
|
||||
depth[i],
|
||||
@ -281,7 +285,8 @@ class RepViT(nn.Module):
|
||||
self.stages = nn.Sequential(*stages)
|
||||
|
||||
self.num_features = embed_dim[-1]
|
||||
self.head = RepViTClassifier(embed_dim[-1], num_classes, distillation)
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = RepVitClassifier(embed_dim[-1], num_classes, distillation)
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
@ -304,9 +309,13 @@ class RepViT(nn.Module):
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
self.head = (
|
||||
RepViTClassifier(self.embed_dim[-1], num_classes, distillation) if num_classes > 0 else nn.Identity()
|
||||
RepVitClassifier(self.embed_dim[-1], num_classes, distillation) if num_classes > 0 else nn.Identity()
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_distilled_training(self, enable=True):
|
||||
self.head.distilled_training = enable
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
@ -317,8 +326,9 @@ class RepViT(nn.Module):
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool == 'avg':
|
||||
x = nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)
|
||||
return x if pre_logits else self.head(x)
|
||||
x = x.mean((2, 3), keepdim=False)
|
||||
x = self.head_drop(x)
|
||||
return self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
@ -373,7 +383,9 @@ default_cfgs = generate_default_cfgs(
|
||||
def _create_repvit(variant, pretrained=False, **kwargs):
|
||||
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
|
||||
model = build_model_with_cfg(
|
||||
RepViT, variant, pretrained, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs
|
||||
RepVit, variant, pretrained,
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user