mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove experimental downsample in block support in ConvNeXt. Experiment further before keeping it in.
This commit is contained in:
parent
bfc0dccb0e
commit
06307b8b41
@ -17,7 +17,6 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .fx_features import register_notrace_module
|
||||
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
|
||||
from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, create_conv2d
|
||||
from .registry import register_model
|
||||
@ -124,7 +123,6 @@ class ConvNeXtBlock(nn.Module):
|
||||
norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
|
||||
mlp_layer = ConvMlp if conv_mlp else Mlp
|
||||
self.use_conv_mlp = conv_mlp
|
||||
self.shortcut_after_dw = stride > 1
|
||||
|
||||
self.conv_dw = create_conv2d(dim, dim_out, kernel_size=7, stride=stride, depthwise=True, bias=conv_bias)
|
||||
self.norm = norm_layer(dim_out)
|
||||
@ -135,9 +133,6 @@ class ConvNeXtBlock(nn.Module):
|
||||
def forward(self, x):
|
||||
shortcut = x
|
||||
x = self.conv_dw(x)
|
||||
if self.shortcut_after_dw:
|
||||
shortcut = x
|
||||
|
||||
if self.use_conv_mlp:
|
||||
x = self.norm(x)
|
||||
x = self.mlp(x)
|
||||
@ -150,7 +145,6 @@ class ConvNeXtBlock(nn.Module):
|
||||
x = x.mul(self.gamma.reshape(1, -1, 1, 1))
|
||||
|
||||
x = self.drop_path(x) + shortcut
|
||||
#print('b', x.shape)
|
||||
return x
|
||||
|
||||
|
||||
@ -164,7 +158,6 @@ class ConvNeXtStage(nn.Module):
|
||||
depth=2,
|
||||
drop_path_rates=None,
|
||||
ls_init_value=1.0,
|
||||
downsample_block=False,
|
||||
conv_mlp=False,
|
||||
conv_bias=True,
|
||||
norm_layer=None,
|
||||
@ -173,14 +166,14 @@ class ConvNeXtStage(nn.Module):
|
||||
super().__init__()
|
||||
self.grad_checkpointing = False
|
||||
|
||||
if downsample_block or (in_chs == out_chs and stride == 1):
|
||||
self.downsample = nn.Identity()
|
||||
else:
|
||||
if in_chs != out_chs or stride > 1:
|
||||
self.downsample = nn.Sequential(
|
||||
norm_layer(in_chs),
|
||||
nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride, bias=conv_bias),
|
||||
)
|
||||
in_chs = out_chs
|
||||
else:
|
||||
self.downsample = nn.Identity()
|
||||
|
||||
drop_path_rates = drop_path_rates or [0.] * depth
|
||||
stage_blocks = []
|
||||
@ -188,7 +181,6 @@ class ConvNeXtStage(nn.Module):
|
||||
stage_blocks.append(ConvNeXtBlock(
|
||||
dim=in_chs,
|
||||
dim_out=out_chs,
|
||||
stride=stride if downsample_block and i == 0 else 1,
|
||||
drop_path=drop_path_rates[i],
|
||||
ls_init_value=ls_init_value,
|
||||
conv_mlp=conv_mlp,
|
||||
@ -236,7 +228,6 @@ class ConvNeXt(nn.Module):
|
||||
stem_stride=4,
|
||||
head_init_scale=1.,
|
||||
head_norm_first=False,
|
||||
downsample_block=False,
|
||||
conv_mlp=False,
|
||||
conv_bias=True,
|
||||
norm_layer=None,
|
||||
@ -291,7 +282,6 @@ class ConvNeXt(nn.Module):
|
||||
depth=depths[i],
|
||||
drop_path_rates=dp_rates[i],
|
||||
ls_init_value=ls_init_value,
|
||||
downsample_block=downsample_block,
|
||||
conv_mlp=conv_mlp,
|
||||
conv_bias=conv_bias,
|
||||
norm_layer=norm_layer,
|
||||
@ -418,7 +408,7 @@ def convnext_nano_hnf(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def convnext_nano_ols(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), downsample_block=True,
|
||||
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True,
|
||||
conv_bias=False, stem_type='overlap', stem_kernel_size=9, **kwargs)
|
||||
model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
@ -426,7 +416,8 @@ def convnext_nano_ols(pretrained=False, **kwargs):
|
||||
|
||||
@register_model
|
||||
def convnext_tiny_hnf(pretrained=False, **kwargs):
|
||||
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs)
|
||||
model_args = dict(
|
||||
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs)
|
||||
model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user