2022-04-07 18:31:45 +08:00
|
|
|
import paddle
|
|
|
|
|
|
|
|
|
|
|
|
class BNNeck(paddle.nn.Layer):
|
2022-04-07 16:59:09 +08:00
|
|
|
def __init__(self, num_filters, trainable=False):
|
|
|
|
super(BNNeck, self).__init__()
|
|
|
|
self.num_filters = num_filters
|
|
|
|
|
2022-04-08 10:58:23 +08:00
|
|
|
self.bn = paddle.nn.BatchNorm1D(
|
2022-04-07 16:59:09 +08:00
|
|
|
self.num_filters)
|
|
|
|
if not trainable:
|
|
|
|
self.bn.bias.trainable = False
|
|
|
|
|
|
|
|
def forward(self, input, label=None):
|
|
|
|
out = self.bn(input)
|
|
|
|
return out
|