Make channels for classic resnet configurable

This commit is contained in:
Ross Wightman 2024-07-22 10:47:40 -07:00
parent 9b2b8014e8
commit 2b3f1a4633

View File

@ -399,6 +399,7 @@ class ResNet(nn.Module):
block_reduce_first: int = 1,
down_kernel_size: int = 1,
avg_down: bool = False,
channels: Optional[Tuple[int, ...]] = (64, 128, 256, 512),
act_layer: LayerType = nn.ReLU,
norm_layer: LayerType = nn.BatchNorm2d,
aa_layer: Optional[Type[nn.Module]] = None,
@ -489,7 +490,6 @@ class ResNet(nn.Module):
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# Feature Blocks
channels = (64, 128, 256, 512)
stage_modules, stage_feature_info = make_blocks(
block,
channels,
@ -513,7 +513,7 @@ class ResNet(nn.Module):
self.feature_info.extend(stage_feature_info)
# Head (Pooling and Classifier)
self.num_features = self.head_hidden_size = 512 * block.expansion
self.num_features = self.head_hidden_size = channels[-1] * block.expansion
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
self.init_weights(zero_init_last=zero_init_last)