mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Make channels for classic resnet configurable
This commit is contained in:
parent
9b2b8014e8
commit
2b3f1a4633
@ -399,6 +399,7 @@ class ResNet(nn.Module):
|
||||
block_reduce_first: int = 1,
|
||||
down_kernel_size: int = 1,
|
||||
avg_down: bool = False,
|
||||
channels: Optional[Tuple[int, ...]] = (64, 128, 256, 512),
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
norm_layer: LayerType = nn.BatchNorm2d,
|
||||
aa_layer: Optional[Type[nn.Module]] = None,
|
||||
@ -489,7 +490,6 @@ class ResNet(nn.Module):
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
# Feature Blocks
|
||||
channels = (64, 128, 256, 512)
|
||||
stage_modules, stage_feature_info = make_blocks(
|
||||
block,
|
||||
channels,
|
||||
@ -513,7 +513,7 @@ class ResNet(nn.Module):
|
||||
self.feature_info.extend(stage_feature_info)
|
||||
|
||||
# Head (Pooling and Classifier)
|
||||
self.num_features = self.head_hidden_size = 512 * block.expansion
|
||||
self.num_features = self.head_hidden_size = channels[-1] * block.expansion
|
||||
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
||||
|
||||
self.init_weights(zero_init_last=zero_init_last)
|
||||
|
Loading…
x
Reference in New Issue
Block a user