mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
remove transition
This commit is contained in:
parent
b38b0f388f
commit
9ab8639c3d
@ -94,47 +94,6 @@ class Layer1(TheseusLayer):
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
class TransitionLayer(TheseusLayer):
|
|
||||||
def __init__(self, in_channels, out_channels, name=None):
|
|
||||||
super(TransitionLayer, self).__init__()
|
|
||||||
|
|
||||||
num_in = len(in_channels)
|
|
||||||
num_out = len(out_channels)
|
|
||||||
out = []
|
|
||||||
self.conv_bn_func_list = []
|
|
||||||
for i in range(num_out):
|
|
||||||
residual = None
|
|
||||||
if i < num_in:
|
|
||||||
if in_channels[i] != out_channels[i]:
|
|
||||||
residual = self.add_sublayer(
|
|
||||||
"transition_{}_layer_{}".format(name, i + 1),
|
|
||||||
ConvBNLayer(
|
|
||||||
num_channels=in_channels[i],
|
|
||||||
num_filters=out_channels[i],
|
|
||||||
filter_size=3))
|
|
||||||
else:
|
|
||||||
residual = self.add_sublayer(
|
|
||||||
"transition_{}_layer_{}".format(name, i + 1),
|
|
||||||
ConvBNLayer(
|
|
||||||
num_channels=in_channels[-1],
|
|
||||||
num_filters=out_channels[i],
|
|
||||||
filter_size=3,
|
|
||||||
stride=2))
|
|
||||||
self.conv_bn_func_list.append(residual)
|
|
||||||
|
|
||||||
def forward(self, x, res_dict=None):
|
|
||||||
outs = []
|
|
||||||
for idx, conv_bn_func in enumerate(self.conv_bn_func_list):
|
|
||||||
if conv_bn_func is None:
|
|
||||||
outs.append(x[idx])
|
|
||||||
else:
|
|
||||||
if idx < len(x):
|
|
||||||
outs.append(conv_bn_func(x[idx]))
|
|
||||||
else:
|
|
||||||
outs.append(conv_bn_func(x[-1]))
|
|
||||||
return outs
|
|
||||||
|
|
||||||
|
|
||||||
class Branches(TheseusLayer):
|
class Branches(TheseusLayer):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
block_num,
|
block_num,
|
||||||
@ -537,8 +496,16 @@ class HRNet(TheseusLayer):
|
|||||||
|
|
||||||
self.la1 = Layer1(num_channels=64, has_se=has_se, name="layer2")
|
self.la1 = Layer1(num_channels=64, has_se=has_se, name="layer2")
|
||||||
|
|
||||||
self.tr1 = TransitionLayer(
|
self.tr1_1 = BasicBlock(
|
||||||
in_channels=[256], out_channels=channels_2, name="tr1")
|
num_channels=256,
|
||||||
|
num_filters=width,
|
||||||
|
has_se=has_se,
|
||||||
|
name="tr1_1")
|
||||||
|
self.tr1_2 = BasicBlock(
|
||||||
|
num_channels=width,
|
||||||
|
num_filters=width * 2,
|
||||||
|
has_se=has_se,
|
||||||
|
name="tr1_2")
|
||||||
|
|
||||||
self.st2 = Stage(
|
self.st2 = Stage(
|
||||||
num_channels=channels_2,
|
num_channels=channels_2,
|
||||||
@ -547,8 +514,11 @@ class HRNet(TheseusLayer):
|
|||||||
has_se=self.has_se,
|
has_se=self.has_se,
|
||||||
name="st2")
|
name="st2")
|
||||||
|
|
||||||
self.tr2 = TransitionLayer(
|
self.tr2 = BasicBlock(
|
||||||
in_channels=channels_2, out_channels=channels_3, name="tr2")
|
num_channels=width * 2,
|
||||||
|
num_filters=width * 4,
|
||||||
|
has_se=has_se,
|
||||||
|
name="tr2")
|
||||||
self.st3 = Stage(
|
self.st3 = Stage(
|
||||||
num_channels=channels_3,
|
num_channels=channels_3,
|
||||||
num_modules=num_modules_3,
|
num_modules=num_modules_3,
|
||||||
@ -556,8 +526,12 @@ class HRNet(TheseusLayer):
|
|||||||
has_se=self.has_se,
|
has_se=self.has_se,
|
||||||
name="st3")
|
name="st3")
|
||||||
|
|
||||||
self.tr3 = TransitionLayer(
|
self.tr3 = BasicBlock(
|
||||||
in_channels=channels_3, out_channels=channels_4, name="tr3")
|
num_channels=width * 4,
|
||||||
|
num_filters=width * 8,
|
||||||
|
has_se=has_se,
|
||||||
|
name="tr3")
|
||||||
|
|
||||||
self.st4 = Stage(
|
self.st4 = Stage(
|
||||||
num_channels=channels_4,
|
num_channels=channels_4,
|
||||||
num_modules=num_modules_4,
|
num_modules=num_modules_4,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user