mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
update pp_hgnet.py
This commit is contained in:
parent
8a760fb85f
commit
713dd6f9eb
@ -90,7 +90,7 @@ class ESEModule(TheseusLayer):
|
|||||||
return paddle.multiply(x=identity, y=x)
|
return paddle.multiply(x=identity, y=x)
|
||||||
|
|
||||||
|
|
||||||
class _HG_Block(TheseusLayer):
|
class HG_Block(TheseusLayer):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels,
|
in_channels,
|
||||||
@ -140,7 +140,7 @@ class _HG_Block(TheseusLayer):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class _HG_Stage(TheseusLayer):
|
class HG_Stage(TheseusLayer):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_channels,
|
in_channels,
|
||||||
mid_channels,
|
mid_channels,
|
||||||
@ -161,7 +161,7 @@ class _HG_Stage(TheseusLayer):
|
|||||||
|
|
||||||
blocks_list = []
|
blocks_list = []
|
||||||
blocks_list.append(
|
blocks_list.append(
|
||||||
_HG_Block(
|
HG_Block(
|
||||||
in_channels,
|
in_channels,
|
||||||
mid_channels,
|
mid_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
@ -169,7 +169,7 @@ class _HG_Stage(TheseusLayer):
|
|||||||
identity=False))
|
identity=False))
|
||||||
for _ in range(block_num - 1):
|
for _ in range(block_num - 1):
|
||||||
blocks_list.append(
|
blocks_list.append(
|
||||||
_HG_Block(
|
HG_Block(
|
||||||
out_channels,
|
out_channels,
|
||||||
mid_channels,
|
mid_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
@ -228,7 +228,7 @@ class PPHGNet(TheseusLayer):
|
|||||||
in_channels, mid_channels, out_channels, block_num, downsample = stage_config[
|
in_channels, mid_channels, out_channels, block_num, downsample = stage_config[
|
||||||
k]
|
k]
|
||||||
self.stages.append(
|
self.stages.append(
|
||||||
_HG_Stage(in_channels, mid_channels, out_channels, block_num,
|
HG_Stage(in_channels, mid_channels, out_channels, block_num,
|
||||||
layer_num, downsample))
|
layer_num, downsample))
|
||||||
|
|
||||||
self.avg_pool = AdaptiveAvgPool2D(1)
|
self.avg_pool = AdaptiveAvgPool2D(1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user