pull/1599/head
weishengyu 2021-12-23 20:29:05 +08:00
parent b6a7c53182
commit 932e0eace1
1 changed files with 11 additions and 7 deletions

View File

@ -25,10 +25,11 @@ def build_gallery_layer(configs, feature_extractor):
delimiter = configs["IndexProcess"]["delimiter"] delimiter = configs["IndexProcess"]["delimiter"]
gallery_images = [] gallery_images = []
gallery_docs = [] gallery_docs = []
gallery_labels = []
with open(data_file, 'r', encoding='utf-8') as f: with open(data_file, 'r', encoding='utf-8') as f:
lines = f.readlines() lines = f.readlines()
for _, ori_line in enumerate(lines): for ori_line in lines:
line = ori_line.strip().split(delimiter) line = ori_line.strip().split(delimiter)
text_num = len(line) text_num = len(line)
assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}" assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
@ -36,6 +37,7 @@ def build_gallery_layer(configs, feature_extractor):
gallery_images.append(image_file) gallery_images.append(image_file)
gallery_docs.append(ori_line.strip()) gallery_docs.append(ori_line.strip())
gallery_labels.append(line[1].strip())
batch_index = 0 batch_index = 0
gallery_feature = paddle.zeros((len(gallery_images), embedding_size)) gallery_feature = paddle.zeros((len(gallery_images), embedding_size))
for i, image_path in enumerate(gallery_images): for i, image_path in enumerate(gallery_images):
@ -45,12 +47,13 @@ def build_gallery_layer(configs, feature_extractor):
input_tensor[batch_index] = image input_tensor[batch_index] = image
batch_index += 1 batch_index += 1
if batch_index == batch_size or i == len(gallery_images) - 1: if batch_index == batch_size or i == len(gallery_images) - 1:
batch_feature = feature_extractor(input_tensor) batch_feature = feature_extractor(input_tensor)["features"]
for j in range(batch_index): for j in range(batch_index):
feature = batch_feature[j] feature = batch_feature[j]
norm_feature = paddle.nn.functional.normalize(feature) norm_feature = paddle.nn.functional.normalize(feature, axis=0)
gallery_feature[i + batch_index - j] = norm_feature gallery_feature[i - batch_index + j] = norm_feature
gallery_layer = paddle.nn.Linear(embedding_size, len(gallery_images), weight_attr=gallery_feature, bias_attr=False) gallery_layer = paddle.nn.Linear(embedding_size, len(gallery_images), bias_attr=False)
gallery_layer.set_state_dict({"weight": gallery_feature.T})
return gallery_layer return gallery_layer
@ -71,11 +74,12 @@ class FuseModel(paddle.nn.Layer):
super().__init__() super().__init__()
self.feature_extractor = build_model(configs) self.feature_extractor = build_model(configs)
load_dygraph_pretrain(self.feature_extractor, configs["Global"]["pretrained_model"]) load_dygraph_pretrain(self.feature_extractor, configs["Global"]["pretrained_model"])
self.feature_extractor.eval()
self.feature_extractor.head = IdentityHead() self.feature_extractor.head = IdentityHead()
self.gallery_layer = build_gallery_layer(configs, self.feature_extractor) self.gallery_layer = build_gallery_layer(configs, self.feature_extractor)
def forward(self, x): def forward(self, x):
x = self.feature_model(x)["features"] x = self.feature_extractor(x)["features"]
x = paddle.nn.functional.normalize(x) x = paddle.nn.functional.normalize(x)
x = self.gallery_layer(x) x = self.gallery_layer(x)
return x return x
@ -86,7 +90,7 @@ def main():
configs = parse_config(args.config) configs = parse_config(args.config)
init_logger(name='gallery2fc') init_logger(name='gallery2fc')
fuse_model = FuseModel(configs) fuse_model = FuseModel(configs)
save_fuse_model(fuse_model) # save_fuse_model(fuse_model)
if __name__ == '__main__': if __name__ == '__main__':