More fastvit & mobileone updates, ready for weight upload

dynamic_resize_vit
Ross Wightman 2023-08-22 14:42:43 -07:00 committed by Ross Wightman
parent 8474508d07
commit 8470eb1cb5
3 changed files with 180 additions and 167 deletions

View File

@ -20,6 +20,7 @@ from .efficientnet import *
from .efficientvit_mit import *
from .efficientvit_msra import *
from .eva import *
from .fastvit import *
from .focalnet import *
from .gcvit import *
from .ghostnet import *

View File

@ -589,7 +589,7 @@ class MobileOneBlock(nn.Module):
self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None
# Re-parameterizable conv branches
convs = list()
convs = []
for _ in range(self.num_conv_branches):
convs.append(layers.conv_norm_act(
in_chs, out_chs, kernel_size=kernel_size,
@ -624,8 +624,8 @@ class MobileOneBlock(nn.Module):
# Other branches
out = scale_out + identity_out
for ix in range(self.num_conv_branches):
out += self.conv_kxk[ix](x)
for ck in self.conv_kxk:
out += ck(x)
return self.act(self.attn(out))
@ -1643,11 +1643,15 @@ model_cfgs = dict(
),
)
# FIXME temporary for mobileone remap
from .fastvit import checkpoint_filter_fn
def _create_byobnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
ByobNet, variant, pretrained,
model_cfg=model_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True),
**kwargs)
@ -1805,6 +1809,27 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_d8_evos_ch-2bc12646.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0),
'mobileone_s0': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s0_unfused.pth.tar',
crop_pct=0.875,
),
'mobileone_s1': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s1_unfused.pth.tar',
crop_pct=0.9,
),
'mobileone_s2': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s2_unfused.pth.tar',
crop_pct=0.9,
),
'mobileone_s3': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s3_unfused.pth.tar',
crop_pct=0.9,
),
'mobileone_s4': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s4_unfused.pth.tar',
crop_pct=0.9,
),
})

View File

@ -13,7 +13,8 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn
from ._registry import register_model
from ._builder import build_model_with_cfg
from ._registry import register_model, generate_default_cfgs
def num_groups(group_size, channels):
@ -58,7 +59,7 @@ class MobileOneBlock(nn.Module):
kernel_size: Size of the convolution kernel.
stride: Stride size.
dilation: Kernel dilation factor.
groups: Group number.
group_size: Convolution group size.
inference_mode: If True, instantiates model in inference mode.
use_se: Whether to use SE-ReLU activations.
use_act: Whether to use activation. Default: ``True``
@ -76,7 +77,7 @@ class MobileOneBlock(nn.Module):
self.num_conv_branches = num_conv_branches
# Check if SE-ReLU is requested
self.se = SqueezeExcite(out_chs) if use_se else nn.Identity()
self.se = SqueezeExcite(out_chs, rd_divisor=1) if use_se else nn.Identity()
if inference_mode:
self.reparam_conv = create_conv2d(
@ -100,17 +101,16 @@ class MobileOneBlock(nn.Module):
# Re-parameterizable conv branches
if num_conv_branches > 0:
rbr_conv = list()
for _ in range(self.num_conv_branches):
rbr_conv.append(ConvNormAct(
self.rbr_conv = nn.ModuleList([
ConvNormAct(
self.in_chs,
self.out_chs,
kernel_size=kernel_size,
stride=self.stride,
groups=self.groups,
apply_act=False,
))
self.rbr_conv = nn.ModuleList(rbr_conv)
) for _ in range(self.num_conv_branches)
])
else:
self.rbr_conv = None
@ -148,8 +148,8 @@ class MobileOneBlock(nn.Module):
# Other branches
out = scale_out + identity_out
if self.rbr_conv is not None:
for ix in range(self.num_conv_branches):
out += self.rbr_conv[ix](x)
for rc in self.rbr_conv:
out += rc(x)
return self.act(self.se(out))
@ -159,8 +159,9 @@ class MobileOneBlock(nn.Module):
architecture used at training time to obtain a plain CNN-like structure
for inference.
"""
if self.inference_mode:
if self.reparam_conv is not None:
return
kernel, bias = self._get_kernel_bias()
self.reparam_conv = create_conv2d(
in_channels=self.in_chs,
@ -177,6 +178,7 @@ class MobileOneBlock(nn.Module):
# Delete un-used branches
for para in self.parameters():
para.detach_()
self.__delattr__("rbr_conv")
self.__delattr__("rbr_scale")
if hasattr(self, "rbr_skip"):
@ -339,7 +341,6 @@ class ReparamLargeKernelConv(nn.Module):
self.act = act_layer() if act_layer is not None else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply forward pass."""
if self.lkb_reparam is not None:
out = self.lkb_reparam(x)
else:
@ -1076,14 +1077,17 @@ def basic_blocks(
class FastVit(nn.Module):
fork_feat: torch.jit.Final[bool]
"""
This class implements `FastViT architecture <https://arxiv.org/pdf/2303.14189.pdf>`_
"""
def __init__(
self,
layers,
token_mixers: Tuple[str, ...],
in_chans=3,
layers=(2, 2, 6, 2),
token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"),
embed_dims=None,
mlp_ratios=None,
downsamples=None,
@ -1099,24 +1103,19 @@ class FastVit(nn.Module):
use_layer_scale=True,
layer_scale_init_value=1e-5,
fork_feat=False,
init_cfg=None,
pretrained=None,
cls_ratio=2.0,
inference_mode=False,
**kwargs,
) -> None:
super().__init__()
if not fork_feat:
self.num_classes = num_classes
self.num_classes = 0 if fork_feat else num_classes
self.fork_feat = fork_feat
if pos_embs is None:
pos_embs = [None] * len(layers)
# Convolutional stem
self.patch_embed = convolutional_stem(3, embed_dims[0], inference_mode)
self.patch_embed = convolutional_stem(
in_chans, embed_dims[0], inference_mode)
# Build the main stages of the network architecture
network = []
@ -1192,14 +1191,9 @@ class FastVit(nn.Module):
else nn.Identity()
)
self.apply(self.cls_init_weights)
self.init_cfg = copy.deepcopy(init_cfg)
self.apply(self._init_weights)
# load pre-trained model
if self.fork_feat and (self.init_cfg is not None or pretrained is not None):
self.init_weights()
def cls_init_weights(self, m: nn.Module) -> None:
def _init_weights(self, m: nn.Module) -> None:
"""Init. for classification"""
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
@ -1224,10 +1218,11 @@ class FastVit(nn.Module):
outs = []
for idx, block in enumerate(self.network):
x = block(x)
if self.fork_feat and idx in self.out_indices:
norm_layer = getattr(self, f"norm{idx}")
x_out = norm_layer(x)
outs.append(x_out)
if self.fork_feat:
if idx in self.out_indices:
norm_layer = getattr(self, f"norm{idx}")
x_out = norm_layer(x)
outs.append(x_out)
if self.fork_feat:
# output the features of four stages for dense prediction
return outs
@ -1256,7 +1251,7 @@ def _cfg(url="", **kwargs):
"num_classes": 1000,
"input_size": (3, 256, 256),
"pool_size": None,
"crop_pct": 0.95,
"crop_pct": 0.9,
"interpolation": "bicubic",
"mean": IMAGENET_DEFAULT_MEAN,
"std": IMAGENET_DEFAULT_STD,
@ -1265,174 +1260,166 @@ def _cfg(url="", **kwargs):
}
default_cfgs = {
"fastvit_t": _cfg(crop_pct=0.9),
"fastvit_s": _cfg(crop_pct=0.9),
"fastvit_m": _cfg(crop_pct=0.95),
'fastvit_t8': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t8.pth.tar')
}
default_cfgs = generate_default_cfgs({
"fastvit_t8.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t8.pth.tar'
),
"fastvit_t12.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t12.pth.tar'
),
"fastvit_s12.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_s12.pth.tar'),
"fastvit_sa12.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_sa12.pth.tar'),
"fastvit_sa24.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_sa24.pth.tar'),
"fastvit_sa36.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_sa36.pth.tar'),
"fastvit_ma36.apple_in1k": _cfg(
url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_ma36.pth.tar',
crop_pct=0.95
),
# "fastvit_t8.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_t8.pth.tar'
# ),
# "fastvit_t12.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_t12.pth.tar'
# ),
#
# "fastvit_s12.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_s12.pth.tar'),
# "fastvit_sa12.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_sa12.pth.tar'),
# "fastvit_sa24.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_sa24.pth.tar'),
# "fastvit_sa36.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_sa36.pth.tar'),
#
# "fastvit_ma36.apple_dist_in1k": _cfg(
# url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_ma36.pth.tar',
# crop_pct=0.95
# ),
})
def checkpoint_filter_fn(state_dict, model):
# FIXME temporary for remapping
state_dict = state_dict.get('state_dict', state_dict)
msd = model.state_dict()
out_dict = {}
for ka, kb, va, vb in zip(msd.keys(), state_dict.keys(), msd.values(), state_dict.values()):
if va.ndim == 4 and vb.ndim == 2:
vb = vb[:, :, None, None]
out_dict[ka] = vb
return out_dict
def _create_fastvit(variant, pretrained=False, **kwargs):
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
model = build_model_with_cfg(
FastVit,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs
)
return model
@register_model
def fastvit_t8(pretrained=False, **kwargs):
"""Instantiate FastViT-T8 model variant."""
layers = [2, 2, 4, 2]
embed_dims = [48, 96, 192, 384]
mlp_ratios = [3, 3, 3, 3]
downsamples = [True, True, True, True]
token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer")
model = FastVit(
layers,
token_mixers=token_mixers,
embed_dims=embed_dims,
mlp_ratios=mlp_ratios,
downsamples=downsamples,
**kwargs,
model_args = dict(
layers=(2, 2, 4, 2),
embed_dims=(48, 96, 192, 384),
mlp_ratios=(3, 3, 3, 3),
downsamples=(True, True, True, True),
token_mixers=("repmixer", "repmixer", "repmixer", "repmixer")
)
model.default_cfg = default_cfgs["fastvit_t"]
if pretrained:
raise ValueError("Functionality not implemented.")
return model
return _create_fastvit('fastvit_t8', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def fastvit_t12(pretrained=False, **kwargs):
"""Instantiate FastViT-T12 model variant."""
layers = [2, 2, 6, 2]
embed_dims = [64, 128, 256, 512]
mlp_ratios = [3, 3, 3, 3]
downsamples = [True, True, True, True]
token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer")
model = FastVit(
layers,
token_mixers=token_mixers,
embed_dims=embed_dims,
mlp_ratios=mlp_ratios,
downsamples=downsamples,
**kwargs,
model_args = dict(
layers=(2, 2, 6, 2),
embed_dims=(64, 128, 256, 512),
mlp_ratios=(3, 3, 3, 3),
downsamples=(True, True, True, True),
token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer"),
)
model.default_cfg = default_cfgs["fastvit_t"]
if pretrained:
raise ValueError("Functionality not implemented.")
return model
return _create_fastvit('fastvit_t12', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def fastvit_s12(pretrained=False, **kwargs):
"""Instantiate FastViT-S12 model variant."""
layers = [2, 2, 6, 2]
embed_dims = [64, 128, 256, 512]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer")
model = FastVit(
layers,
token_mixers=token_mixers,
embed_dims=embed_dims,
mlp_ratios=mlp_ratios,
downsamples=downsamples,
**kwargs,
model_args = dict(
layers=(2, 2, 6, 2),
embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4),
downsamples=(True, True, True, True),
token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"),
)
model.default_cfg = default_cfgs["fastvit_s"]
if pretrained:
raise ValueError("Functionality not implemented.")
return model
return _create_fastvit('fastvit_s12', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def fastvit_sa12(pretrained=False, **kwargs):
"""Instantiate FastViT-SA12 model variant."""
layers = [2, 2, 6, 2]
embed_dims = [64, 128, 256, 512]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
model = FastVit(
layers,
token_mixers=token_mixers,
embed_dims=embed_dims,
pos_embs=pos_embs,
mlp_ratios=mlp_ratios,
downsamples=downsamples,
**kwargs,
model_args = dict(
layers=(2, 2, 6, 2),
embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4),
downsamples=(True, True, True, True),
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
)
model.default_cfg = default_cfgs["fastvit_s"]
if pretrained:
raise ValueError("Functionality not implemented.")
return model
return _create_fastvit('fastvit_sa12', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def fastvit_sa24(pretrained=False, **kwargs):
"""Instantiate FastViT-SA24 model variant."""
layers = [4, 4, 12, 4]
embed_dims = [64, 128, 256, 512]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
model = FastVit(
layers,
token_mixers=token_mixers,
embed_dims=embed_dims,
pos_embs=pos_embs,
mlp_ratios=mlp_ratios,
downsamples=downsamples,
**kwargs,
model_args = dict(
layers=(4, 4, 12, 4),
embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4),
downsamples=(True, True, True, True),
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
)
model.default_cfg = default_cfgs["fastvit_s"]
if pretrained:
raise ValueError("Functionality not implemented.")
return model
return _create_fastvit('fastvit_sa24', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def fastvit_sa36(pretrained=False, **kwargs):
"""Instantiate FastViT-SA36 model variant."""
layers = [6, 6, 18, 6]
embed_dims = [64, 128, 256, 512]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
model = FastVit(
layers,
embed_dims=embed_dims,
token_mixers=token_mixers,
pos_embs=pos_embs,
mlp_ratios=mlp_ratios,
downsamples=downsamples,
layer_scale_init_value=1e-6,
**kwargs,
model_args = dict(
layers=(6, 6, 18, 6),
embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4),
downsamples=(True, True, True, True),
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
)
model.default_cfg = default_cfgs["fastvit_m"]
if pretrained:
raise ValueError("Functionality not implemented.")
return model
return _create_fastvit('fastvit_sa36', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def fastvit_ma36(pretrained=False, **kwargs):
"""Instantiate FastViT-MA36 model variant."""
layers = [6, 6, 18, 6]
embed_dims = [76, 152, 304, 608]
mlp_ratios = [4, 4, 4, 4]
downsamples = [True, True, True, True]
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
model = FastVit(
layers,
embed_dims=embed_dims,
token_mixers=token_mixers,
pos_embs=pos_embs,
mlp_ratios=mlp_ratios,
downsamples=downsamples,
layer_scale_init_value=1e-6,
**kwargs,
model_args = dict(
layers=(6, 6, 18, 6),
embed_dims=(76, 152, 304, 608),
mlp_ratios=(4, 4, 4, 4),
downsamples=(True, True, True, True),
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention")
)
model.default_cfg = default_cfgs["fastvit_m"]
if pretrained:
raise ValueError("Functionality not implemented.")
return model
return _create_fastvit('fastvit_ma36', pretrained=pretrained, **dict(model_args, **kwargs))