diff --git a/timm/models/inception_next.py b/timm/models/inception_next.py index cd34953d..df6bb448 100644 --- a/timm/models/inception_next.py +++ b/timm/models/inception_next.py @@ -221,6 +221,7 @@ class MetaNeXt(nn.Module): self, in_chans=3, num_classes=1000, + output_stride=32, depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), token_mixers=nn.Identity, @@ -239,22 +240,30 @@ class MetaNeXt(nn.Module): token_mixers = [token_mixers] * num_stage if not isinstance(mlp_ratios, (list, tuple)): mlp_ratios = [mlp_ratios] * num_stage - self.num_classes = num_classes self.drop_rate = drop_rate + self.feature_info = [] + self.stem = nn.Sequential( nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), norm_layer(dims[0]) ) - self.stages = nn.Sequential() dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] - stages = [] prev_chs = dims[0] + curr_stride = 4 + dilation = 1 # feature resolution stages, each consisting of multiple residual blocks + self.stages = nn.Sequential() for i in range(num_stage): + stride = 2 if curr_stride == 2 or i > 0 else 1 + if curr_stride >= output_stride and stride > 1: + dilation *= stride + stride = 1 + curr_stride *= stride + first_dilation = 1 if dilation in (1, 2) else 2 out_chs = dims[i] - stages.append(MetaNeXtStage( + self.stages.append(MetaNeXtStage( prev_chs, out_chs, ds_stride=2 if i > 0 else 1, @@ -267,7 +276,7 @@ class MetaNeXt(nn.Module): mlp_ratio=mlp_ratios[i], )) prev_chs = out_chs - self.stages = nn.Sequential(*stages) + self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] self.num_features = prev_chs self.head = head_fn(self.num_features, num_classes, drop=drop_rate) self.apply(self._init_weights) @@ -353,7 +362,8 @@ def _create_inception_next(variant, pretrained=False, **kwargs): model = build_model_with_cfg( MetaNeXt, variant, pretrained, feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), - **kwargs) + **kwargs, + ) return model