Fix two default args in DenseNet blocks... fix #1427

more_vit
Ross Wightman 2022-08-25 15:00:35 -07:00
parent 527f9a4cb2
commit 1d8d6f6072
1 changed files with 2 additions and 2 deletions

View File

@ -115,7 +115,7 @@ class DenseBlock(nn.ModuleDict):
_version = 2
def __init__(
self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU,
self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=BatchNormAct2d,
drop_rate=0., memory_efficient=False):
super(DenseBlock, self).__init__()
for i in range(num_layers):
@ -138,7 +138,7 @@ class DenseBlock(nn.ModuleDict):
class DenseTransition(nn.Sequential):
def __init__(self, num_input_features, num_output_features, norm_layer=nn.BatchNorm2d, aa_layer=None):
def __init__(self, num_input_features, num_output_features, norm_layer=BatchNormAct2d, aa_layer=None):
super(DenseTransition, self).__init__()
self.add_module('norm', norm_layer(num_input_features))
self.add_module('conv', nn.Conv2d(