micro_block -> layer_type
parent
81de331e00
commit
f6df698c4f
|
@ -164,18 +164,18 @@ class BottleneckBlock(TheseusLayer):
|
|||
stride,
|
||||
shortcut=True,
|
||||
if_first=False,
|
||||
micro_block=ConvBNLayer,
|
||||
layer=ConvBNLayer,
|
||||
lr_mult=1.0,
|
||||
data_format="NCHW"):
|
||||
super().__init__()
|
||||
self.conv0 = micro_block(
|
||||
self.conv0 = layer(
|
||||
num_channels=num_channels,
|
||||
num_filters=num_filters,
|
||||
filter_size=1,
|
||||
act="relu",
|
||||
lr_mult=lr_mult,
|
||||
data_format=data_format)
|
||||
self.conv1 = micro_block(
|
||||
self.conv1 = layer(
|
||||
num_channels=num_filters,
|
||||
num_filters=num_filters,
|
||||
filter_size=3,
|
||||
|
@ -183,7 +183,7 @@ class BottleneckBlock(TheseusLayer):
|
|||
act="relu",
|
||||
lr_mult=lr_mult,
|
||||
data_format=data_format)
|
||||
self.conv2 = micro_block(
|
||||
self.conv2 = layer(
|
||||
num_channels=num_filters,
|
||||
num_filters=num_filters * 4,
|
||||
filter_size=1,
|
||||
|
@ -226,13 +226,13 @@ class BasicBlock(TheseusLayer):
|
|||
stride,
|
||||
shortcut=True,
|
||||
if_first=False,
|
||||
micro_block=ConvBNLayer,
|
||||
layer=ConvBNLayer,
|
||||
lr_mult=1.0,
|
||||
data_format="NCHW"):
|
||||
super().__init__()
|
||||
|
||||
self.stride = stride
|
||||
self.conv0 = micro_block(
|
||||
self.conv0 = layer(
|
||||
num_channels=num_channels,
|
||||
num_filters=num_filters,
|
||||
filter_size=3,
|
||||
|
@ -240,7 +240,7 @@ class BasicBlock(TheseusLayer):
|
|||
act="relu",
|
||||
lr_mult=lr_mult,
|
||||
data_format=data_format)
|
||||
self.conv1 = micro_block(
|
||||
self.conv1 = layer(
|
||||
num_channels=num_filters,
|
||||
num_filters=num_filters,
|
||||
filter_size=3,
|
||||
|
@ -296,7 +296,7 @@ class ResNet(TheseusLayer):
|
|||
input_image_channel=3,
|
||||
return_patterns=None,
|
||||
return_stages=None,
|
||||
micro_block="ConvBNLayer",
|
||||
layer_type="ConvBNLayer",
|
||||
use_first_short_conv=True,
|
||||
**kargs):
|
||||
super().__init__()
|
||||
|
@ -312,10 +312,10 @@ class ResNet(TheseusLayer):
|
|||
self.num_channels = self.cfg["num_channels"]
|
||||
self.channels_mult = 1 if self.num_channels[-1] == 256 else 4
|
||||
|
||||
if micro_block == "ConvBNLayer":
|
||||
micro_block = ConvBNLayer
|
||||
elif micro_block == "DiverseBranchBlock":
|
||||
micro_block = DiverseBranchBlock
|
||||
if layer_type == "ConvBNLayer":
|
||||
layer = ConvBNLayer
|
||||
elif layer_type == "DiverseBranchBlock":
|
||||
layer = DiverseBranchBlock
|
||||
else:
|
||||
raise Exception()
|
||||
|
||||
|
@ -377,7 +377,7 @@ class ResNet(TheseusLayer):
|
|||
if i == 0 and block_idx != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
if_first=block_idx == i == 0 if version == "vd" else True,
|
||||
micro_block=micro_block,
|
||||
layer=layer,
|
||||
lr_mult=self.lr_mult_list[block_idx + 1],
|
||||
data_format=data_format))
|
||||
shortcut = True
|
||||
|
|
|
@ -19,7 +19,7 @@ Global:
|
|||
Arch:
|
||||
name: ResNet18
|
||||
class_num: 1000
|
||||
micro_block: DiverseBranchBlock
|
||||
layer_type: DiverseBranchBlock
|
||||
use_first_short_conv: False
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue