dbg
parent
b6a7c53182
commit
932e0eace1
ppcls/utils
|
@ -25,10 +25,11 @@ def build_gallery_layer(configs, feature_extractor):
|
|||
delimiter = configs["IndexProcess"]["delimiter"]
|
||||
gallery_images = []
|
||||
gallery_docs = []
|
||||
gallery_labels = []
|
||||
|
||||
with open(data_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
for _, ori_line in enumerate(lines):
|
||||
for ori_line in lines:
|
||||
line = ori_line.strip().split(delimiter)
|
||||
text_num = len(line)
|
||||
assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
|
||||
|
@ -36,6 +37,7 @@ def build_gallery_layer(configs, feature_extractor):
|
|||
|
||||
gallery_images.append(image_file)
|
||||
gallery_docs.append(ori_line.strip())
|
||||
gallery_labels.append(line[1].strip())
|
||||
batch_index = 0
|
||||
gallery_feature = paddle.zeros((len(gallery_images), embedding_size))
|
||||
for i, image_path in enumerate(gallery_images):
|
||||
|
@ -45,12 +47,13 @@ def build_gallery_layer(configs, feature_extractor):
|
|||
input_tensor[batch_index] = image
|
||||
batch_index += 1
|
||||
if batch_index == batch_size or i == len(gallery_images) - 1:
|
||||
batch_feature = feature_extractor(input_tensor)
|
||||
batch_feature = feature_extractor(input_tensor)["features"]
|
||||
for j in range(batch_index):
|
||||
feature = batch_feature[j]
|
||||
norm_feature = paddle.nn.functional.normalize(feature)
|
||||
gallery_feature[i + batch_index - j] = norm_feature
|
||||
gallery_layer = paddle.nn.Linear(embedding_size, len(gallery_images), weight_attr=gallery_feature, bias_attr=False)
|
||||
norm_feature = paddle.nn.functional.normalize(feature, axis=0)
|
||||
gallery_feature[i - batch_index + j] = norm_feature
|
||||
gallery_layer = paddle.nn.Linear(embedding_size, len(gallery_images), bias_attr=False)
|
||||
gallery_layer.set_state_dict({"weight": gallery_feature.T})
|
||||
return gallery_layer
|
||||
|
||||
|
||||
|
@ -71,11 +74,12 @@ class FuseModel(paddle.nn.Layer):
|
|||
super().__init__()
|
||||
self.feature_extractor = build_model(configs)
|
||||
load_dygraph_pretrain(self.feature_extractor, configs["Global"]["pretrained_model"])
|
||||
self.feature_extractor.eval()
|
||||
self.feature_extractor.head = IdentityHead()
|
||||
self.gallery_layer = build_gallery_layer(configs, self.feature_extractor)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.feature_model(x)["features"]
|
||||
x = self.feature_extractor(x)["features"]
|
||||
x = paddle.nn.functional.normalize(x)
|
||||
x = self.gallery_layer(x)
|
||||
return x
|
||||
|
@ -86,7 +90,7 @@ def main():
|
|||
configs = parse_config(args.config)
|
||||
init_logger(name='gallery2fc')
|
||||
fuse_model = FuseModel(configs)
|
||||
save_fuse_model(fuse_model)
|
||||
# save_fuse_model(fuse_model)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue