mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Make channels for classic resnet configurable
This commit is contained in:
parent
9b2b8014e8
commit
2b3f1a4633
@ -399,6 +399,7 @@ class ResNet(nn.Module):
|
|||||||
block_reduce_first: int = 1,
|
block_reduce_first: int = 1,
|
||||||
down_kernel_size: int = 1,
|
down_kernel_size: int = 1,
|
||||||
avg_down: bool = False,
|
avg_down: bool = False,
|
||||||
|
channels: Optional[Tuple[int, ...]] = (64, 128, 256, 512),
|
||||||
act_layer: LayerType = nn.ReLU,
|
act_layer: LayerType = nn.ReLU,
|
||||||
norm_layer: LayerType = nn.BatchNorm2d,
|
norm_layer: LayerType = nn.BatchNorm2d,
|
||||||
aa_layer: Optional[Type[nn.Module]] = None,
|
aa_layer: Optional[Type[nn.Module]] = None,
|
||||||
@ -489,7 +490,6 @@ class ResNet(nn.Module):
|
|||||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
|
|
||||||
# Feature Blocks
|
# Feature Blocks
|
||||||
channels = (64, 128, 256, 512)
|
|
||||||
stage_modules, stage_feature_info = make_blocks(
|
stage_modules, stage_feature_info = make_blocks(
|
||||||
block,
|
block,
|
||||||
channels,
|
channels,
|
||||||
@ -513,7 +513,7 @@ class ResNet(nn.Module):
|
|||||||
self.feature_info.extend(stage_feature_info)
|
self.feature_info.extend(stage_feature_info)
|
||||||
|
|
||||||
# Head (Pooling and Classifier)
|
# Head (Pooling and Classifier)
|
||||||
self.num_features = self.head_hidden_size = 512 * block.expansion
|
self.num_features = self.head_hidden_size = channels[-1] * block.expansion
|
||||||
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
||||||
|
|
||||||
self.init_weights(zero_init_last=zero_init_last)
|
self.init_weights(zero_init_last=zero_init_last)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user