Merge branch 'seefun-efficientvit'

This commit is contained in:
Ross Wightman 2023-11-21 14:06:27 -08:00
commit fa06f6c481

View File

@ -8,13 +8,15 @@ Adapted from official impl at https://github.com/mit-han-lab/efficientvit
__all__ = ['EfficientVit']
from typing import Optional
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import SelectAdaptivePool2d, create_conv2d
from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_module
from ._manipulate import checkpoint_seq
@ -70,7 +72,7 @@ class ConvNormAct(nn.Module):
bias=bias,
)
self.norm = norm_layer(num_features=out_channels) if norm_layer else nn.Identity()
self.act = act_layer(inplace=True) if act_layer else nn.Identity()
self.act = act_layer(inplace=True) if act_layer is not None else nn.Identity()
def forward(self, x):
x = self.dropout(x)
@ -121,6 +123,50 @@ class DSConv(nn.Module):
return x
class ConvBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3,
stride=1,
mid_channels=None,
expand_ratio=1,
use_bias=False,
norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d),
act_layer=(nn.ReLU6, None),
):
super(ConvBlock, self).__init__()
use_bias = val2tuple(use_bias, 2)
norm_layer = val2tuple(norm_layer, 2)
act_layer = val2tuple(act_layer, 2)
mid_channels = mid_channels or round(in_channels * expand_ratio)
self.conv1 = ConvNormAct(
in_channels,
mid_channels,
kernel_size,
stride,
norm_layer=norm_layer[0],
act_layer=act_layer[0],
bias=use_bias[0],
)
self.conv2 = ConvNormAct(
mid_channels,
out_channels,
kernel_size,
1,
norm_layer=norm_layer[1],
act_layer=act_layer[1],
bias=use_bias[1],
)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
class MBConv(nn.Module):
def __init__(
self,
@ -175,8 +221,53 @@ class MBConv(nn.Module):
return x
class LiteMSA(nn.Module):
"""Lightweight multi-scale attention"""
class FusedMBConv(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3,
stride=1,
mid_channels=None,
expand_ratio=6,
groups=1,
use_bias=False,
norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d),
act_layer=(nn.ReLU6, None),
):
super(FusedMBConv, self).__init__()
use_bias = val2tuple(use_bias, 2)
norm_layer = val2tuple(norm_layer, 2)
act_layer = val2tuple(act_layer, 2)
mid_channels = mid_channels or round(in_channels * expand_ratio)
self.spatial_conv = ConvNormAct(
in_channels,
mid_channels,
kernel_size,
stride=stride,
groups=groups,
norm_layer=norm_layer[0],
act_layer=act_layer[0],
bias=use_bias[0],
)
self.point_conv = ConvNormAct(
mid_channels,
out_channels,
1,
norm_layer=norm_layer[1],
act_layer=act_layer[1],
bias=use_bias[1],
)
def forward(self, x):
x = self.spatial_conv(x)
x = self.point_conv(x)
return x
class LiteMLA(nn.Module):
"""Lightweight multi-scale linear attention"""
def __init__(
self,
@ -192,7 +283,7 @@ class LiteMSA(nn.Module):
scales=(5,),
eps=1e-5,
):
super(LiteMSA, self).__init__()
super(LiteMLA, self).__init__()
self.eps = eps
heads = heads or int(in_channels // dim * heads_ratio)
total_dim = heads * dim
@ -271,7 +362,7 @@ class LiteMSA(nn.Module):
return out
register_notrace_module(LiteMSA)
register_notrace_module(LiteMLA)
class EfficientVitBlock(nn.Module):
@ -286,7 +377,7 @@ class EfficientVitBlock(nn.Module):
):
super(EfficientVitBlock, self).__init__()
self.context_module = ResidualBlock(
LiteMSA(
LiteMLA(
in_channels=in_channels,
out_channels=in_channels,
heads_ratio=heads_ratio,
@ -340,31 +431,54 @@ def build_local_block(
norm_layer: str,
act_layer: str,
fewer_norm: bool = False,
block_type: str = "default",
):
assert block_type in ["default", "large", "fused"]
if expand_ratio == 1:
block = DSConv(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
use_bias=(True, False) if fewer_norm else False,
norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
act_layer=(act_layer, None),
)
if block_type == "default":
block = DSConv(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
use_bias=(True, False) if fewer_norm else False,
norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
act_layer=(act_layer, None),
)
else:
block = ConvBlock(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
use_bias=(True, False) if fewer_norm else False,
norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
act_layer=(act_layer, None),
)
else:
block = MBConv(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
expand_ratio=expand_ratio,
use_bias=(True, True, False) if fewer_norm else False,
norm_layer=(None, None, norm_layer) if fewer_norm else norm_layer,
act_layer=(act_layer, act_layer, None),
)
if block_type == "default":
block = MBConv(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
expand_ratio=expand_ratio,
use_bias=(True, True, False) if fewer_norm else False,
norm_layer=(None, None, norm_layer) if fewer_norm else norm_layer,
act_layer=(act_layer, act_layer, None),
)
else:
block = FusedMBConv(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
expand_ratio=expand_ratio,
use_bias=(True, False) if fewer_norm else False,
norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
act_layer=(act_layer, None),
)
return block
class Stem(nn.Sequential):
def __init__(self, in_chs, out_chs, depth, norm_layer, act_layer):
def __init__(self, in_chs, out_chs, depth, norm_layer, act_layer, block_type='default'):
super().__init__()
self.stride = 2
@ -385,6 +499,7 @@ class Stem(nn.Sequential):
expand_ratio=1,
norm_layer=norm_layer,
act_layer=act_layer,
block_type=block_type,
),
nn.Identity(),
))
@ -451,6 +566,69 @@ class EfficientVitStage(nn.Module):
return self.blocks(x)
class EfficientVitLargeStage(nn.Module):
def __init__(
self,
in_chs,
out_chs,
depth,
norm_layer,
act_layer,
head_dim,
vit_stage=False,
fewer_norm=False,
):
super(EfficientVitLargeStage, self).__init__()
blocks = [ResidualBlock(
build_local_block(
in_channels=in_chs,
out_channels=out_chs,
stride=2,
expand_ratio=24 if vit_stage else 16,
norm_layer=norm_layer,
act_layer=act_layer,
fewer_norm=vit_stage or fewer_norm,
block_type='default' if fewer_norm else 'fused',
),
None,
)]
in_chs = out_chs
if vit_stage:
# for stage 4
for _ in range(depth):
blocks.append(
EfficientVitBlock(
in_channels=in_chs,
head_dim=head_dim,
expand_ratio=6,
norm_layer=norm_layer,
act_layer=act_layer,
)
)
else:
# for stage 1, 2, 3
for i in range(depth):
blocks.append(ResidualBlock(
build_local_block(
in_channels=in_chs,
out_channels=out_chs,
stride=1,
expand_ratio=4,
norm_layer=norm_layer,
act_layer=act_layer,
fewer_norm=fewer_norm,
block_type='default' if fewer_norm else 'fused',
),
nn.Identity(),
))
self.blocks = nn.Sequential(*blocks)
def forward(self, x):
return self.blocks(x)
class ClassifierHead(nn.Module):
def __init__(
self,
@ -461,14 +639,15 @@ class ClassifierHead(nn.Module):
norm_layer=nn.BatchNorm2d,
act_layer=nn.Hardswish,
global_pool='avg',
norm_eps=1e-5,
):
super(ClassifierHead, self).__init__()
self.in_conv = ConvNormAct(in_channels, widths[0], 1, norm_layer=norm_layer, act_layer=act_layer)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW')
self.classifier = nn.Sequential(
nn.Linear(widths[0], widths[1], bias=False),
nn.LayerNorm(widths[1]),
act_layer(inplace=True),
nn.LayerNorm(widths[1], eps=norm_eps),
act_layer(inplace=True) if act_layer is not None else nn.Identity(),
nn.Dropout(dropout, inplace=False),
nn.Linear(widths[1], n_classes, bias=True),
)
@ -596,6 +775,125 @@ class EfficientVit(nn.Module):
return x
class EfficientVitLarge(nn.Module):
def __init__(
self,
in_chans=3,
widths=(),
depths=(),
head_dim=32,
norm_layer=nn.BatchNorm2d,
act_layer=GELUTanh,
global_pool='avg',
head_widths=(),
drop_rate=0.0,
num_classes=1000,
norm_eps=1e-7,
):
super(EfficientVitLarge, self).__init__()
self.grad_checkpointing = False
self.global_pool = global_pool
self.num_classes = num_classes
self.norm_eps = norm_eps
norm_layer = partial(norm_layer, eps=self.norm_eps)
# input stem
self.stem = Stem(in_chans, widths[0], depths[0], norm_layer, act_layer, block_type='large')
stride = self.stem.stride
# stages
self.feature_info = []
self.stages = nn.Sequential()
in_channels = widths[0]
for i, (w, d) in enumerate(zip(widths[1:], depths[1:])):
self.stages.append(EfficientVitLargeStage(
in_channels,
w,
depth=d,
norm_layer=norm_layer,
act_layer=act_layer,
head_dim=head_dim,
vit_stage=i >= 3,
fewer_norm=i >= 2,
))
stride *= 2
in_channels = w
self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{i}')]
self.num_features = in_channels
self.head_widths = head_widths
self.head_dropout = drop_rate
if num_classes > 0:
self.head = ClassifierHead(
self.num_features,
self.head_widths,
n_classes=num_classes,
dropout=self.head_dropout,
global_pool=self.global_pool,
act_layer=act_layer,
norm_eps=self.norm_eps,
)
else:
if self.global_pool == 'avg':
self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
else:
self.head = nn.Identity()
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem',
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+).downsample', (0,)),
(r'^stages\.(\d+)\.\w+\.(\d+)', None),
]
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head.classifier[-1]
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
if num_classes > 0:
self.head = ClassifierHead(
self.num_features,
self.head_widths,
n_classes=num_classes,
dropout=self.head_dropout,
global_pool=self.global_pool,
norm_eps=self.norm_eps
)
else:
if self.global_pool == 'avg':
self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True)
else:
self.head = nn.Identity()
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stages, x)
else:
x = self.stages(x)
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def _cfg(url='', **kwargs):
return {
'url': url,
@ -648,6 +946,57 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/',
input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
),
'efficientvit_l1.r224_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=1.0,
),
'efficientvit_l2.r224_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=1.0,
),
'efficientvit_l2.r256_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
),
'efficientvit_l2.r288_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
),
'efficientvit_l2.r384_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
),
'efficientvit_l3.r224_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=1.0,
),
'efficientvit_l3.r256_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
),
'efficientvit_l3.r320_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0,
),
'efficientvit_l3.r384_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
),
# 'efficientvit_l0_sam.sam': _cfg(
# # hf_hub_id='timm/',
# input_size=(3, 512, 512), crop_pct=1.0,
# num_classes=0,
# ),
# 'efficientvit_l1_sam.sam': _cfg(
# # hf_hub_id='timm/',
# input_size=(3, 512, 512), crop_pct=1.0,
# num_classes=0,
# ),
# 'efficientvit_l2_sam.sam': _cfg(
# # hf_hub_id='timm/',f
# input_size=(3, 512, 512), crop_pct=1.0,
# num_classes=0,
# ),
})
@ -663,6 +1012,18 @@ def _create_efficientvit(variant, pretrained=False, **kwargs):
return model
def _create_efficientvit_large(variant, pretrained=False, **kwargs):
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
model = build_model_with_cfg(
EfficientVitLarge,
variant,
pretrained,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs
)
return model
@register_model
def efficientvit_b0(pretrained=False, **kwargs):
model_args = dict(
@ -689,3 +1050,49 @@ def efficientvit_b3(pretrained=False, **kwargs):
model_args = dict(
widths=(32, 64, 128, 256, 512), depths=(1, 4, 6, 6, 9), head_dim=32, head_widths=(2304, 2560))
return _create_efficientvit('efficientvit_b3', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def efficientvit_l1(pretrained=False, **kwargs):
model_args = dict(
widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 6, 6), head_dim=32, head_widths=(3072, 3200))
return _create_efficientvit_large('efficientvit_l1', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def efficientvit_l2(pretrained=False, **kwargs):
model_args = dict(
widths=(32, 64, 128, 256, 512), depths=(1, 2, 2, 8, 8), head_dim=32, head_widths=(3072, 3200))
return _create_efficientvit_large('efficientvit_l2', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def efficientvit_l3(pretrained=False, **kwargs):
model_args = dict(
widths=(64, 128, 256, 512, 1024), depths=(1, 2, 2, 8, 8), head_dim=32, head_widths=(6144, 6400))
return _create_efficientvit_large('efficientvit_l3', pretrained=pretrained, **dict(model_args, **kwargs))
# FIXME will wait for v2 SAM models which are pending
# @register_model
# def efficientvit_l0_sam(pretrained=False, **kwargs):
# # only backbone for segment-anything-model weights
# model_args = dict(
# widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 4, 4), head_dim=32, num_classes=0, norm_eps=1e-6)
# return _create_efficientvit_large('efficientvit_l0_sam', pretrained=pretrained, **dict(model_args, **kwargs))
#
#
# @register_model
# def efficientvit_l1_sam(pretrained=False, **kwargs):
# # only backbone for segment-anything-model weights
# model_args = dict(
# widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 6, 6), head_dim=32, num_classes=0, norm_eps=1e-6)
# return _create_efficientvit_large('efficientvit_l1_sam', pretrained=pretrained, **dict(model_args, **kwargs))
#
#
# @register_model
# def efficientvit_l2_sam(pretrained=False, **kwargs):
# # only backbone for segment-anything-model weights
# model_args = dict(
# widths=(32, 64, 128, 256, 512), depths=(1, 2, 2, 8, 8), head_dim=32, num_classes=0, norm_eps=1e-6)
# return _create_efficientvit_large('efficientvit_l2_sam', pretrained=pretrained, **dict(model_args, **kwargs))