PaddleClas/ppcls/arch/gears/bnneck.py

17 lines
411 B
Python
Raw Normal View History

2022-04-07 18:31:45 +08:00
import paddle
class BNNeck(paddle.nn.Layer):
2022-04-07 16:59:09 +08:00
def __init__(self, num_filters, trainable=False):
super(BNNeck, self).__init__()
self.num_filters = num_filters
2022-04-08 10:58:23 +08:00
self.bn = paddle.nn.BatchNorm1D(
2022-04-07 16:59:09 +08:00
self.num_filters)
if not trainable:
self.bn.bias.trainable = False
def forward(self, input, label=None):
out = self.bn(input)
return out