mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add features_only support to inception_next
This commit is contained in:
parent
3d8d7450ad
commit
2d33b9df6c
@ -221,6 +221,7 @@ class MetaNeXt(nn.Module):
|
||||
self,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
output_stride=32,
|
||||
depths=(3, 3, 9, 3),
|
||||
dims=(96, 192, 384, 768),
|
||||
token_mixers=nn.Identity,
|
||||
@ -239,22 +240,30 @@ class MetaNeXt(nn.Module):
|
||||
token_mixers = [token_mixers] * num_stage
|
||||
if not isinstance(mlp_ratios, (list, tuple)):
|
||||
mlp_ratios = [mlp_ratios] * num_stage
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
self.feature_info = []
|
||||
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
||||
norm_layer(dims[0])
|
||||
)
|
||||
|
||||
self.stages = nn.Sequential()
|
||||
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||
stages = []
|
||||
prev_chs = dims[0]
|
||||
curr_stride = 4
|
||||
dilation = 1
|
||||
# feature resolution stages, each consisting of multiple residual blocks
|
||||
self.stages = nn.Sequential()
|
||||
for i in range(num_stage):
|
||||
stride = 2 if curr_stride == 2 or i > 0 else 1
|
||||
if curr_stride >= output_stride and stride > 1:
|
||||
dilation *= stride
|
||||
stride = 1
|
||||
curr_stride *= stride
|
||||
first_dilation = 1 if dilation in (1, 2) else 2
|
||||
out_chs = dims[i]
|
||||
stages.append(MetaNeXtStage(
|
||||
self.stages.append(MetaNeXtStage(
|
||||
prev_chs,
|
||||
out_chs,
|
||||
ds_stride=2 if i > 0 else 1,
|
||||
@ -267,7 +276,7 @@ class MetaNeXt(nn.Module):
|
||||
mlp_ratio=mlp_ratios[i],
|
||||
))
|
||||
prev_chs = out_chs
|
||||
self.stages = nn.Sequential(*stages)
|
||||
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
|
||||
self.num_features = prev_chs
|
||||
self.head = head_fn(self.num_features, num_classes, drop=drop_rate)
|
||||
self.apply(self._init_weights)
|
||||
@ -353,7 +362,8 @@ def _create_inception_next(variant, pretrained=False, **kwargs):
|
||||
model = build_model_with_cfg(
|
||||
MetaNeXt, variant, pretrained,
|
||||
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user