pull/1599/head
weishengyu 2021-12-23 20:29:05 +08:00
parent b6a7c53182
commit 932e0eace1
1 changed files with 11 additions and 7 deletions
ppcls/utils

View File

@ -25,10 +25,11 @@ def build_gallery_layer(configs, feature_extractor):
delimiter = configs["IndexProcess"]["delimiter"]
gallery_images = []
gallery_docs = []
gallery_labels = []
with open(data_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
for _, ori_line in enumerate(lines):
for ori_line in lines:
line = ori_line.strip().split(delimiter)
text_num = len(line)
assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
@ -36,6 +37,7 @@ def build_gallery_layer(configs, feature_extractor):
gallery_images.append(image_file)
gallery_docs.append(ori_line.strip())
gallery_labels.append(line[1].strip())
batch_index = 0
gallery_feature = paddle.zeros((len(gallery_images), embedding_size))
for i, image_path in enumerate(gallery_images):
@ -45,12 +47,13 @@ def build_gallery_layer(configs, feature_extractor):
input_tensor[batch_index] = image
batch_index += 1
if batch_index == batch_size or i == len(gallery_images) - 1:
batch_feature = feature_extractor(input_tensor)
batch_feature = feature_extractor(input_tensor)["features"]
for j in range(batch_index):
feature = batch_feature[j]
norm_feature = paddle.nn.functional.normalize(feature)
gallery_feature[i + batch_index - j] = norm_feature
gallery_layer = paddle.nn.Linear(embedding_size, len(gallery_images), weight_attr=gallery_feature, bias_attr=False)
norm_feature = paddle.nn.functional.normalize(feature, axis=0)
gallery_feature[i - batch_index + j] = norm_feature
gallery_layer = paddle.nn.Linear(embedding_size, len(gallery_images), bias_attr=False)
gallery_layer.set_state_dict({"weight": gallery_feature.T})
return gallery_layer
@ -71,11 +74,12 @@ class FuseModel(paddle.nn.Layer):
super().__init__()
self.feature_extractor = build_model(configs)
load_dygraph_pretrain(self.feature_extractor, configs["Global"]["pretrained_model"])
self.feature_extractor.eval()
self.feature_extractor.head = IdentityHead()
self.gallery_layer = build_gallery_layer(configs, self.feature_extractor)
def forward(self, x):
x = self.feature_model(x)["features"]
x = self.feature_extractor(x)["features"]
x = paddle.nn.functional.normalize(x)
x = self.gallery_layer(x)
return x
@ -86,7 +90,7 @@ def main():
configs = parse_config(args.config)
init_logger(name='gallery2fc')
fuse_model = FuseModel(configs)
save_fuse_model(fuse_model)
# save_fuse_model(fuse_model)
if __name__ == '__main__':