Merge branch 'main' of https://github.com/jameslahm/pytorch-image-models into jameslahm-main
commit
5309424d5e
|
@ -82,19 +82,30 @@ class NormLinear(nn.Sequential):
|
|||
|
||||
|
||||
class RepVggDw(nn.Module):
|
||||
def __init__(self, ed, kernel_size):
|
||||
def __init__(self, ed, kernel_size, legacy=False):
|
||||
super().__init__()
|
||||
self.conv = ConvNorm(ed, ed, kernel_size, 1, (kernel_size - 1) // 2, groups=ed)
|
||||
self.conv1 = ConvNorm(ed, ed, 1, 1, 0, groups=ed)
|
||||
if legacy:
|
||||
self.conv1 = ConvNorm(ed, ed, 1, 1, 0, groups=ed)
|
||||
# Make torchscript happy.
|
||||
self.bn = nn.Identity()
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(ed, ed, 1, 1, 0, groups=ed)
|
||||
self.bn = nn.BatchNorm2d(ed)
|
||||
self.dim = ed
|
||||
self.legacy = legacy
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x) + self.conv1(x) + x
|
||||
return self.bn(self.conv(x) + self.conv1(x) + x)
|
||||
|
||||
@torch.no_grad()
|
||||
def fuse(self):
|
||||
conv = self.conv.fuse()
|
||||
conv1 = self.conv1.fuse()
|
||||
|
||||
if self.legacy:
|
||||
conv1 = self.conv1.fuse()
|
||||
else:
|
||||
conv1 = self.conv1
|
||||
|
||||
conv_w = conv.weight
|
||||
conv_b = conv.bias
|
||||
|
@ -112,6 +123,14 @@ class RepVggDw(nn.Module):
|
|||
|
||||
conv.weight.data.copy_(final_conv_w)
|
||||
conv.bias.data.copy_(final_conv_b)
|
||||
|
||||
if not self.legacy:
|
||||
bn = self.bn
|
||||
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
|
||||
w = conv.weight * w[:, None, None, None]
|
||||
b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / (bn.running_var + bn.eps) ** 0.5
|
||||
conv.weight.data.copy_(w)
|
||||
conv.bias.data.copy_(b)
|
||||
return conv
|
||||
|
||||
|
||||
|
@ -127,10 +146,10 @@ class RepVitMlp(nn.Module):
|
|||
|
||||
|
||||
class RepViTBlock(nn.Module):
|
||||
def __init__(self, in_dim, mlp_ratio, kernel_size, use_se, act_layer):
|
||||
def __init__(self, in_dim, mlp_ratio, kernel_size, use_se, act_layer, legacy=False):
|
||||
super(RepViTBlock, self).__init__()
|
||||
|
||||
self.token_mixer = RepVggDw(in_dim, kernel_size)
|
||||
self.token_mixer = RepVggDw(in_dim, kernel_size, legacy)
|
||||
self.se = SqueezeExcite(in_dim, 0.25) if use_se else nn.Identity()
|
||||
self.channel_mixer = RepVitMlp(in_dim, in_dim * mlp_ratio, act_layer)
|
||||
|
||||
|
@ -155,9 +174,9 @@ class RepVitStem(nn.Module):
|
|||
|
||||
|
||||
class RepVitDownsample(nn.Module):
|
||||
def __init__(self, in_dim, mlp_ratio, out_dim, kernel_size, act_layer):
|
||||
def __init__(self, in_dim, mlp_ratio, out_dim, kernel_size, act_layer, legacy=False):
|
||||
super().__init__()
|
||||
self.pre_block = RepViTBlock(in_dim, mlp_ratio, kernel_size, use_se=False, act_layer=act_layer)
|
||||
self.pre_block = RepViTBlock(in_dim, mlp_ratio, kernel_size, use_se=False, act_layer=act_layer, legacy=legacy)
|
||||
self.spatial_downsample = ConvNorm(in_dim, in_dim, kernel_size, 2, (kernel_size - 1) // 2, groups=in_dim)
|
||||
self.channel_downsample = ConvNorm(in_dim, out_dim, 1, 1)
|
||||
self.ffn = RepVitMlp(out_dim, out_dim * mlp_ratio, act_layer)
|
||||
|
@ -172,7 +191,7 @@ class RepVitDownsample(nn.Module):
|
|||
|
||||
|
||||
class RepVitClassifier(nn.Module):
|
||||
def __init__(self, dim, num_classes, distillation=False, drop=0.):
|
||||
def __init__(self, dim, num_classes, distillation=False, drop=0.0):
|
||||
super().__init__()
|
||||
self.head_drop = nn.Dropout(drop)
|
||||
self.head = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
@ -211,10 +230,10 @@ class RepVitClassifier(nn.Module):
|
|||
|
||||
|
||||
class RepVitStage(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3, downsample=True):
|
||||
def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3, downsample=True, legacy=False):
|
||||
super().__init__()
|
||||
if downsample:
|
||||
self.downsample = RepVitDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer)
|
||||
self.downsample = RepVitDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer, legacy)
|
||||
else:
|
||||
assert in_dim == out_dim
|
||||
self.downsample = nn.Identity()
|
||||
|
@ -222,7 +241,7 @@ class RepVitStage(nn.Module):
|
|||
blocks = []
|
||||
use_se = True
|
||||
for _ in range(depth):
|
||||
blocks.append(RepViTBlock(out_dim, mlp_ratio, kernel_size, use_se, act_layer))
|
||||
blocks.append(RepViTBlock(out_dim, mlp_ratio, kernel_size, use_se, act_layer, legacy))
|
||||
use_se = not use_se
|
||||
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
@ -246,7 +265,8 @@ class RepVit(nn.Module):
|
|||
num_classes=1000,
|
||||
act_layer=nn.GELU,
|
||||
distillation=True,
|
||||
drop_rate=0.,
|
||||
drop_rate=0.0,
|
||||
legacy=False,
|
||||
):
|
||||
super(RepVit, self).__init__()
|
||||
self.grad_checkpointing = False
|
||||
|
@ -275,6 +295,7 @@ class RepVit(nn.Module):
|
|||
act_layer=act_layer,
|
||||
kernel_size=kernel_size,
|
||||
downsample=downsample,
|
||||
legacy=legacy,
|
||||
)
|
||||
)
|
||||
stage_stride = 2 if downsample else 1
|
||||
|
@ -290,12 +311,9 @@ class RepVit(nn.Module):
|
|||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
matcher = dict(
|
||||
stem=r'^stem', # stem and embed
|
||||
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
|
||||
)
|
||||
matcher = dict(stem=r'^stem', blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]) # stem and embed
|
||||
return matcher
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.grad_checkpointing = enable
|
||||
|
@ -369,15 +387,42 @@ default_cfgs = generate_default_cfgs(
|
|||
{
|
||||
'repvit_m1.dist_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
# url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_distill_300_timm.pth'
|
||||
),
|
||||
'repvit_m2.dist_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
# url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_distill_300_timm.pth'
|
||||
),
|
||||
'repvit_m3.dist_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
# url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m3_distill_300_timm.pth'
|
||||
),
|
||||
'repvit_m0_9.dist_in1k_300e': _cfg(
|
||||
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m0_9_distill_300e_timm.pth'
|
||||
),
|
||||
'repvit_m0_9.dist_in1k_450e': _cfg(
|
||||
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m0_9_distill_450e_timm.pth'
|
||||
),
|
||||
'repvit_m1_0.dist_in1k_300e': _cfg(
|
||||
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_0_distill_300e_timm.pth'
|
||||
),
|
||||
'repvit_m1_0.dist_in1k_450e': _cfg(
|
||||
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_0_distill_450e_timm.pth'
|
||||
),
|
||||
'repvit_m1_1.dist_in1k_300e': _cfg(
|
||||
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_1_distill_300e_timm.pth'
|
||||
),
|
||||
'repvit_m1_1.dist_in1k_450e': _cfg(
|
||||
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_1_distill_450e_timm.pth'
|
||||
),
|
||||
'repvit_m1_5.dist_in1k_300e': _cfg(
|
||||
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_5_distill_300e_timm.pth'
|
||||
),
|
||||
'repvit_m1_5.dist_in1k_450e': _cfg(
|
||||
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_5_distill_450e_timm.pth'
|
||||
),
|
||||
'repvit_m2_3.dist_in1k_300e': _cfg(
|
||||
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_3_distill_300e_timm.pth'
|
||||
),
|
||||
'repvit_m2_3.dist_in1k_450e': _cfg(
|
||||
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_3_distill_450e_timm.pth'
|
||||
),
|
||||
}
|
||||
)
|
||||
|
@ -386,7 +431,9 @@ default_cfgs = generate_default_cfgs(
|
|||
def _create_repvit(variant, pretrained=False, **kwargs):
|
||||
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
|
||||
model = build_model_with_cfg(
|
||||
RepVit, variant, pretrained,
|
||||
RepVit,
|
||||
variant,
|
||||
pretrained,
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -398,7 +445,7 @@ def repvit_m1(pretrained=False, **kwargs):
|
|||
"""
|
||||
Constructs a RepViT-M1 model
|
||||
"""
|
||||
model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2))
|
||||
model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2), legacy=True)
|
||||
return _create_repvit('repvit_m1', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -407,7 +454,7 @@ def repvit_m2(pretrained=False, **kwargs):
|
|||
"""
|
||||
Constructs a RepViT-M2 model
|
||||
"""
|
||||
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2))
|
||||
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2), legacy=True)
|
||||
return _create_repvit('repvit_m2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -416,5 +463,50 @@ def repvit_m3(pretrained=False, **kwargs):
|
|||
"""
|
||||
Constructs a RepViT-M3 model
|
||||
"""
|
||||
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 18, 2))
|
||||
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 18, 2), legacy=True)
|
||||
return _create_repvit('repvit_m3', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def repvit_m0_9(pretrained=False, **kwargs):
|
||||
"""
|
||||
Constructs a RepViT-M0.9 model
|
||||
"""
|
||||
model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2))
|
||||
return _create_repvit('repvit_m0_9', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def repvit_m1_0(pretrained=False, **kwargs):
|
||||
"""
|
||||
Constructs a RepViT-M1.0 model
|
||||
"""
|
||||
model_args = dict(embed_dim=(56, 112, 224, 448), depth=(2, 2, 14, 2))
|
||||
return _create_repvit('repvit_m1_0', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def repvit_m1_1(pretrained=False, **kwargs):
|
||||
"""
|
||||
Constructs a RepViT-M1.1 model
|
||||
"""
|
||||
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2))
|
||||
return _create_repvit('repvit_m1_1', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def repvit_m1_5(pretrained=False, **kwargs):
|
||||
"""
|
||||
Constructs a RepViT-M1.5 model
|
||||
"""
|
||||
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 24, 4))
|
||||
return _create_repvit('repvit_m1_5', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def repvit_m2_3(pretrained=False, **kwargs):
|
||||
"""
|
||||
Constructs a RepViT-M2.3 model
|
||||
"""
|
||||
model_args = dict(embed_dim=(80, 160, 320, 640), depth=(6, 6, 34, 2))
|
||||
return _create_repvit('repvit_m2_3', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
|
Loading…
Reference in New Issue