pull/1599/head
weishengyu 2021-12-27 19:37:00 +08:00
parent b662ed34ac
commit 87f508e9f9
1 changed files with 3 additions and 2 deletions

View File

@ -39,6 +39,7 @@ class GalleryLayer(paddle.nn.Layer):
gallery_docs.append(ori_line.strip())
gallery_labels.append(line[1].strip())
self.gallery_layer = paddle.nn.Linear(embedding_size, len(self.gallery_images), bias_attr=False)
self.gallery_layer.skip_quant = True
def forward(self, x, label=None):
x = paddle.nn.functional.normalize(x)
@ -63,8 +64,8 @@ class GalleryLayer(paddle.nn.Layer):
for j in range(batch_index):
feature = batch_feature[j]
norm_feature = paddle.nn.functional.normalize(feature, axis=0)
gallery_feature[i - batch_index + j] = norm_feature
self.gallery_layer.set_state_dict({"weight": gallery_feature.T})
gallery_feature[i - batch_index + j + 1] = norm_feature
self.gallery_layer.set_state_dict({"_layer.weight": gallery_feature.T})
def export_fuse_model(configs):