micro_block -> layer_type

pull/2587/head
gaotingquan 2023-01-06 10:05:43 +00:00 committed by Tingquan Gao
parent 81de331e00
commit f6df698c4f
2 changed files with 14 additions and 14 deletions

View File

@ -164,18 +164,18 @@ class BottleneckBlock(TheseusLayer):
stride,
shortcut=True,
if_first=False,
micro_block=ConvBNLayer,
layer=ConvBNLayer,
lr_mult=1.0,
data_format="NCHW"):
super().__init__()
self.conv0 = micro_block(
self.conv0 = layer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act="relu",
lr_mult=lr_mult,
data_format=data_format)
self.conv1 = micro_block(
self.conv1 = layer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
@ -183,7 +183,7 @@ class BottleneckBlock(TheseusLayer):
act="relu",
lr_mult=lr_mult,
data_format=data_format)
self.conv2 = micro_block(
self.conv2 = layer(
num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
@ -226,13 +226,13 @@ class BasicBlock(TheseusLayer):
stride,
shortcut=True,
if_first=False,
micro_block=ConvBNLayer,
layer=ConvBNLayer,
lr_mult=1.0,
data_format="NCHW"):
super().__init__()
self.stride = stride
self.conv0 = micro_block(
self.conv0 = layer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=3,
@ -240,7 +240,7 @@ class BasicBlock(TheseusLayer):
act="relu",
lr_mult=lr_mult,
data_format=data_format)
self.conv1 = micro_block(
self.conv1 = layer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
@ -296,7 +296,7 @@ class ResNet(TheseusLayer):
input_image_channel=3,
return_patterns=None,
return_stages=None,
micro_block="ConvBNLayer",
layer_type="ConvBNLayer",
use_first_short_conv=True,
**kargs):
super().__init__()
@ -312,10 +312,10 @@ class ResNet(TheseusLayer):
self.num_channels = self.cfg["num_channels"]
self.channels_mult = 1 if self.num_channels[-1] == 256 else 4
if micro_block == "ConvBNLayer":
micro_block = ConvBNLayer
elif micro_block == "DiverseBranchBlock":
micro_block = DiverseBranchBlock
if layer_type == "ConvBNLayer":
layer = ConvBNLayer
elif layer_type == "DiverseBranchBlock":
layer = DiverseBranchBlock
else:
raise Exception()
@ -377,7 +377,7 @@ class ResNet(TheseusLayer):
if i == 0 and block_idx != 0 else 1,
shortcut=shortcut,
if_first=block_idx == i == 0 if version == "vd" else True,
micro_block=micro_block,
layer=layer,
lr_mult=self.lr_mult_list[block_idx + 1],
data_format=data_format))
shortcut = True

View File

@ -19,7 +19,7 @@ Global:
Arch:
name: ResNet18
class_num: 1000
micro_block: DiverseBranchBlock
layer_type: DiverseBranchBlock
use_first_short_conv: False