diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 6893bc5b..7eabc1fd 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -399,6 +399,7 @@ class ResNet(nn.Module): block_reduce_first: int = 1, down_kernel_size: int = 1, avg_down: bool = False, + channels: Optional[Tuple[int, ...]] = (64, 128, 256, 512), act_layer: LayerType = nn.ReLU, norm_layer: LayerType = nn.BatchNorm2d, aa_layer: Optional[Type[nn.Module]] = None, @@ -489,7 +490,6 @@ class ResNet(nn.Module): self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Feature Blocks - channels = (64, 128, 256, 512) stage_modules, stage_feature_info = make_blocks( block, channels, @@ -513,7 +513,7 @@ class ResNet(nn.Module): self.feature_info.extend(stage_feature_info) # Head (Pooling and Classifier) - self.num_features = self.head_hidden_size = 512 * block.expansion + self.num_features = self.head_hidden_size = channels[-1] * block.expansion self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) self.init_weights(zero_init_last=zero_init_last)