mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
inception_next dilation support, weights on hf hub, classifier reset / global pool / no head fixes
This commit is contained in:
parent
2d33b9df6c
commit
af9f56f3bf
@ -8,7 +8,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import trunc_normal_, DropPath, to_2tuple
|
||||
from timm.layers import trunc_normal_, DropPath, to_2tuple, create_conv2d, get_padding, SelectAdaptivePool2d
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
@ -23,16 +23,23 @@ class InceptionDWConv2d(nn.Module):
|
||||
in_chs,
|
||||
square_kernel_size=3,
|
||||
band_kernel_size=11,
|
||||
branch_ratio=0.125
|
||||
branch_ratio=0.125,
|
||||
dilation=1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
gc = int(in_chs * branch_ratio) # channel numbers of a convolution branch
|
||||
self.dwconv_hw = nn.Conv2d(gc, gc, square_kernel_size, padding=square_kernel_size // 2, groups=gc)
|
||||
square_padding = get_padding(square_kernel_size, dilation=dilation)
|
||||
band_padding = get_padding(band_kernel_size, dilation=dilation)
|
||||
self.dwconv_hw = nn.Conv2d(
|
||||
gc, gc, square_kernel_size,
|
||||
padding=square_padding, dilation=dilation, groups=gc)
|
||||
self.dwconv_w = nn.Conv2d(
|
||||
gc, gc, kernel_size=(1, band_kernel_size), padding=(0, band_kernel_size // 2), groups=gc)
|
||||
gc, gc, (1, band_kernel_size),
|
||||
padding=(0, band_padding), dilation=(1, dilation), groups=gc)
|
||||
self.dwconv_h = nn.Conv2d(
|
||||
gc, gc, kernel_size=(band_kernel_size, 1), padding=(band_kernel_size // 2, 0), groups=gc)
|
||||
gc, gc, (band_kernel_size, 1),
|
||||
padding=(band_padding, 0), dilation=(dilation, 1), groups=gc)
|
||||
self.split_indexes = (in_chs - 3 * gc, gc, gc, gc)
|
||||
|
||||
def forward(self, x):
|
||||
@ -89,6 +96,7 @@ class MlpClassifierHead(nn.Module):
|
||||
self,
|
||||
dim,
|
||||
num_classes=1000,
|
||||
pool_type='avg',
|
||||
mlp_ratio=3,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
@ -96,15 +104,17 @@ class MlpClassifierHead(nn.Module):
|
||||
bias=True
|
||||
):
|
||||
super().__init__()
|
||||
hidden_features = int(mlp_ratio * dim)
|
||||
self.fc1 = nn.Linear(dim, hidden_features, bias=bias)
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True)
|
||||
in_features = dim * self.global_pool.feat_mult()
|
||||
hidden_features = int(mlp_ratio * in_features)
|
||||
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
||||
self.act = act_layer()
|
||||
self.norm = norm_layer(hidden_features)
|
||||
self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.mean((2, 3)) # global average pooling
|
||||
x = self.global_pool(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.norm(x)
|
||||
@ -124,7 +134,8 @@ class MetaNeXtBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
token_mixer=nn.Identity,
|
||||
dilation=1,
|
||||
token_mixer=InceptionDWConv2d,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
mlp_layer=ConvMlp,
|
||||
mlp_ratio=4,
|
||||
@ -134,7 +145,7 @@ class MetaNeXtBlock(nn.Module):
|
||||
|
||||
):
|
||||
super().__init__()
|
||||
self.token_mixer = token_mixer(dim)
|
||||
self.token_mixer = token_mixer(dim, dilation=dilation)
|
||||
self.norm = norm_layer(dim)
|
||||
self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=act_layer)
|
||||
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value else None
|
||||
@ -156,21 +167,28 @@ class MetaNeXtStage(nn.Module):
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
ds_stride=2,
|
||||
stride=2,
|
||||
depth=2,
|
||||
dilation=(1, 1),
|
||||
drop_path_rates=None,
|
||||
ls_init_value=1.0,
|
||||
token_mixer=nn.Identity,
|
||||
token_mixer=InceptionDWConv2d,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=None,
|
||||
mlp_ratio=4,
|
||||
):
|
||||
super().__init__()
|
||||
self.grad_checkpointing = False
|
||||
if ds_stride > 1:
|
||||
if stride > 1 or dilation[0] != dilation[1]:
|
||||
self.downsample = nn.Sequential(
|
||||
norm_layer(in_chs),
|
||||
nn.Conv2d(in_chs, out_chs, kernel_size=ds_stride, stride=ds_stride),
|
||||
nn.Conv2d(
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=2,
|
||||
stride=stride,
|
||||
dilation=dilation[0],
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.downsample = nn.Identity()
|
||||
@ -180,6 +198,7 @@ class MetaNeXtStage(nn.Module):
|
||||
for i in range(depth):
|
||||
stage_blocks.append(MetaNeXtBlock(
|
||||
dim=out_chs,
|
||||
dilation=dilation[1],
|
||||
drop_path=drop_path_rates[i],
|
||||
ls_init_value=ls_init_value,
|
||||
token_mixer=token_mixer,
|
||||
@ -221,10 +240,11 @@ class MetaNeXt(nn.Module):
|
||||
self,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool='avg',
|
||||
output_stride=32,
|
||||
depths=(3, 3, 9, 3),
|
||||
dims=(96, 192, 384, 768),
|
||||
token_mixers=nn.Identity,
|
||||
token_mixers=InceptionDWConv2d,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
act_layer=nn.GELU,
|
||||
mlp_ratios=(4, 4, 4, 3),
|
||||
@ -241,6 +261,7 @@ class MetaNeXt(nn.Module):
|
||||
if not isinstance(mlp_ratios, (list, tuple)):
|
||||
mlp_ratios = [mlp_ratios] * num_stage
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.drop_rate = drop_rate
|
||||
self.feature_info = []
|
||||
|
||||
@ -266,7 +287,8 @@ class MetaNeXt(nn.Module):
|
||||
self.stages.append(MetaNeXtStage(
|
||||
prev_chs,
|
||||
out_chs,
|
||||
ds_stride=2 if i > 0 else 1,
|
||||
stride=stride if i > 0 else 1,
|
||||
dilation=(first_dilation, dilation),
|
||||
depth=depths[i],
|
||||
drop_path_rates=dp_rates[i],
|
||||
ls_init_value=ls_init_value,
|
||||
@ -278,7 +300,15 @@ class MetaNeXt(nn.Module):
|
||||
prev_chs = out_chs
|
||||
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
|
||||
self.num_features = prev_chs
|
||||
self.head = head_fn(self.num_features, num_classes, drop=drop_rate)
|
||||
if self.num_classes > 0:
|
||||
if issubclass(head_fn, MlpClassifierHead):
|
||||
assert self.global_pool, 'Cannot disable global pooling with MLP head present.'
|
||||
self.head = head_fn(self.num_features, num_classes, pool_type=self.global_pool, drop=drop_rate)
|
||||
else:
|
||||
if self.global_pool:
|
||||
self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True)
|
||||
else:
|
||||
self.head = nn.Identity()
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
@ -301,9 +331,18 @@ class MetaNeXt(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.fc2
|
||||
|
||||
def reset_classifier(self, num_classes=0, global_pool=None):
|
||||
# FIXME
|
||||
self.head.reset(num_classes, global_pool)
|
||||
def reset_classifier(self, num_classes=0, global_pool=None, head_fn=MlpClassifierHead):
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
if num_classes > 0:
|
||||
if issubclass(head_fn, MlpClassifierHead):
|
||||
assert self.global_pool, 'Cannot disable global pooling with MLP head present.'
|
||||
self.head = head_fn(self.num_features, num_classes, pool_type=self.global_pool, drop=self.drop_rate)
|
||||
else:
|
||||
if self.global_pool:
|
||||
self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True)
|
||||
else:
|
||||
self.head = nn.Identity()
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
@ -319,9 +358,12 @@ class MetaNeXt(nn.Module):
|
||||
x = self.stages(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x):
|
||||
x = self.head(x)
|
||||
return x
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if pre_logits:
|
||||
if hasattr(self.head, 'global_pool'):
|
||||
x = self.head.global_pool(x)
|
||||
return x
|
||||
return self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
@ -342,18 +384,22 @@ def _cfg(url='', **kwargs):
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'inception_next_tiny.sail_in1k': _cfg(
|
||||
url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_tiny.pth',
|
||||
hf_hub_id='timm/',
|
||||
# url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_tiny.pth',
|
||||
),
|
||||
'inception_next_small.sail_in1k': _cfg(
|
||||
url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_small.pth',
|
||||
hf_hub_id='timm/',
|
||||
# url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_small.pth',
|
||||
),
|
||||
'inception_next_base.sail_in1k': _cfg(
|
||||
url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base.pth',
|
||||
hf_hub_id='timm/',
|
||||
# url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base.pth',
|
||||
crop_pct=0.95,
|
||||
),
|
||||
'inception_next_base.sail_in1k_384': _cfg(
|
||||
url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base_384.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0,
|
||||
hf_hub_id='timm/',
|
||||
# url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base_384.pth',
|
||||
input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0,
|
||||
),
|
||||
})
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user