Update EdgeNeXt to use ClassifierHead as per ConvNeXt (#2051)

* Update edgenext.py
This commit is contained in:
Fredo Guan 2023-12-11 12:17:19 -08:00 committed by GitHub
parent 711c5dee6d
commit bbe798317f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -18,7 +18,7 @@ from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d, \
use_fused_attn
use_fused_attn, NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_module
from ._manipulate import named_apply, checkpoint_seq
@ -375,13 +375,23 @@ class EdgeNeXt(nn.Module):
self.stages = nn.Sequential(*stages)
self.num_features = dims[-1]
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
self.head = nn.Sequential(OrderedDict([
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
if head_norm_first:
self.norm_pre = norm_layer(self.num_features)
self.head = ClassifierHead(
self.num_features,
num_classes,
pool_type=global_pool,
drop_rate=self.drop_rate,
)
else:
self.norm_pre = nn.Identity()
self.head = NormMlpClassifierHead(
self.num_features,
num_classes,
pool_type=global_pool,
drop_rate=self.drop_rate,
norm_layer=norm_layer,
)
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
@ -406,10 +416,7 @@ class EdgeNeXt(nn.Module):
return self.head.fc
def reset_classifier(self, num_classes=0, global_pool=None):
if global_pool is not None:
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head.reset(num_classes, global_pool)
def forward_features(self, x):
x = self.stem(x)
@ -418,12 +425,7 @@ class EdgeNeXt(nn.Module):
return x
def forward_head(self, x, pre_logits: bool = False):
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
x = self.head.global_pool(x)
x = self.head.norm(x)
x = self.head.flatten(x)
x = self.head.drop(x)
return x if pre_logits else self.head.fc(x)
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)