From 932e0eace1b5838c209dcda9ad22511ccf665040 Mon Sep 17 00:00:00 2001 From: weishengyu Date: Thu, 23 Dec 2021 20:29:05 +0800 Subject: [PATCH] dbg --- ppcls/utils/gallery2fc.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/ppcls/utils/gallery2fc.py b/ppcls/utils/gallery2fc.py index e2f49de3c..224a10786 100644 --- a/ppcls/utils/gallery2fc.py +++ b/ppcls/utils/gallery2fc.py @@ -25,10 +25,11 @@ def build_gallery_layer(configs, feature_extractor): delimiter = configs["IndexProcess"]["delimiter"] gallery_images = [] gallery_docs = [] + gallery_labels = [] with open(data_file, 'r', encoding='utf-8') as f: lines = f.readlines() - for _, ori_line in enumerate(lines): + for ori_line in lines: line = ori_line.strip().split(delimiter) text_num = len(line) assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}" @@ -36,6 +37,7 @@ def build_gallery_layer(configs, feature_extractor): gallery_images.append(image_file) gallery_docs.append(ori_line.strip()) + gallery_labels.append(line[1].strip()) batch_index = 0 gallery_feature = paddle.zeros((len(gallery_images), embedding_size)) for i, image_path in enumerate(gallery_images): @@ -45,12 +47,13 @@ def build_gallery_layer(configs, feature_extractor): input_tensor[batch_index] = image batch_index += 1 if batch_index == batch_size or i == len(gallery_images) - 1: - batch_feature = feature_extractor(input_tensor) + batch_feature = feature_extractor(input_tensor)["features"] for j in range(batch_index): feature = batch_feature[j] - norm_feature = paddle.nn.functional.normalize(feature) - gallery_feature[i + batch_index - j] = norm_feature - gallery_layer = paddle.nn.Linear(embedding_size, len(gallery_images), weight_attr=gallery_feature, bias_attr=False) + norm_feature = paddle.nn.functional.normalize(feature, axis=0) + gallery_feature[i - batch_index + j] = norm_feature + gallery_layer = paddle.nn.Linear(embedding_size, len(gallery_images), bias_attr=False) + gallery_layer.set_state_dict({"weight": gallery_feature.T}) return gallery_layer @@ -71,11 +74,12 @@ class FuseModel(paddle.nn.Layer): super().__init__() self.feature_extractor = build_model(configs) load_dygraph_pretrain(self.feature_extractor, configs["Global"]["pretrained_model"]) + self.feature_extractor.eval() self.feature_extractor.head = IdentityHead() self.gallery_layer = build_gallery_layer(configs, self.feature_extractor) def forward(self, x): - x = self.feature_model(x)["features"] + x = self.feature_extractor(x)["features"] x = paddle.nn.functional.normalize(x) x = self.gallery_layer(x) return x @@ -86,7 +90,7 @@ def main(): configs = parse_config(args.config) init_logger(name='gallery2fc') fuse_model = FuseModel(configs) - save_fuse_model(fuse_model) + # save_fuse_model(fuse_model) if __name__ == '__main__':