mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update EdgeNeXt to use ClassifierHead as per ConvNeXt (#2051)
* Update edgenext.py
This commit is contained in:
parent
711c5dee6d
commit
bbe798317f
@ -18,7 +18,7 @@ from torch import nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d, \
|
||||
use_fused_attn
|
||||
use_fused_attn, NormMlpClassifierHead, ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_module
|
||||
from ._manipulate import named_apply, checkpoint_seq
|
||||
@ -375,13 +375,23 @@ class EdgeNeXt(nn.Module):
|
||||
self.stages = nn.Sequential(*stages)
|
||||
|
||||
self.num_features = dims[-1]
|
||||
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
|
||||
self.head = nn.Sequential(OrderedDict([
|
||||
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
|
||||
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
|
||||
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
|
||||
('drop', nn.Dropout(self.drop_rate)),
|
||||
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
|
||||
if head_norm_first:
|
||||
self.norm_pre = norm_layer(self.num_features)
|
||||
self.head = ClassifierHead(
|
||||
self.num_features,
|
||||
num_classes,
|
||||
pool_type=global_pool,
|
||||
drop_rate=self.drop_rate,
|
||||
)
|
||||
else:
|
||||
self.norm_pre = nn.Identity()
|
||||
self.head = NormMlpClassifierHead(
|
||||
self.num_features,
|
||||
num_classes,
|
||||
pool_type=global_pool,
|
||||
drop_rate=self.drop_rate,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
|
||||
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
|
||||
|
||||
@ -406,10 +416,7 @@ class EdgeNeXt(nn.Module):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes=0, global_pool=None):
|
||||
if global_pool is not None:
|
||||
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
||||
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
@ -418,12 +425,7 @@ class EdgeNeXt(nn.Module):
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
|
||||
x = self.head.global_pool(x)
|
||||
x = self.head.norm(x)
|
||||
x = self.head.flatten(x)
|
||||
x = self.head.drop(x)
|
||||
return x if pre_logits else self.head.fc(x)
|
||||
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user