dbg
parent
b6a7c53182
commit
932e0eace1
|
@ -25,10 +25,11 @@ def build_gallery_layer(configs, feature_extractor):
|
||||||
delimiter = configs["IndexProcess"]["delimiter"]
|
delimiter = configs["IndexProcess"]["delimiter"]
|
||||||
gallery_images = []
|
gallery_images = []
|
||||||
gallery_docs = []
|
gallery_docs = []
|
||||||
|
gallery_labels = []
|
||||||
|
|
||||||
with open(data_file, 'r', encoding='utf-8') as f:
|
with open(data_file, 'r', encoding='utf-8') as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
for _, ori_line in enumerate(lines):
|
for ori_line in lines:
|
||||||
line = ori_line.strip().split(delimiter)
|
line = ori_line.strip().split(delimiter)
|
||||||
text_num = len(line)
|
text_num = len(line)
|
||||||
assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
|
assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
|
||||||
|
@ -36,6 +37,7 @@ def build_gallery_layer(configs, feature_extractor):
|
||||||
|
|
||||||
gallery_images.append(image_file)
|
gallery_images.append(image_file)
|
||||||
gallery_docs.append(ori_line.strip())
|
gallery_docs.append(ori_line.strip())
|
||||||
|
gallery_labels.append(line[1].strip())
|
||||||
batch_index = 0
|
batch_index = 0
|
||||||
gallery_feature = paddle.zeros((len(gallery_images), embedding_size))
|
gallery_feature = paddle.zeros((len(gallery_images), embedding_size))
|
||||||
for i, image_path in enumerate(gallery_images):
|
for i, image_path in enumerate(gallery_images):
|
||||||
|
@ -45,12 +47,13 @@ def build_gallery_layer(configs, feature_extractor):
|
||||||
input_tensor[batch_index] = image
|
input_tensor[batch_index] = image
|
||||||
batch_index += 1
|
batch_index += 1
|
||||||
if batch_index == batch_size or i == len(gallery_images) - 1:
|
if batch_index == batch_size or i == len(gallery_images) - 1:
|
||||||
batch_feature = feature_extractor(input_tensor)
|
batch_feature = feature_extractor(input_tensor)["features"]
|
||||||
for j in range(batch_index):
|
for j in range(batch_index):
|
||||||
feature = batch_feature[j]
|
feature = batch_feature[j]
|
||||||
norm_feature = paddle.nn.functional.normalize(feature)
|
norm_feature = paddle.nn.functional.normalize(feature, axis=0)
|
||||||
gallery_feature[i + batch_index - j] = norm_feature
|
gallery_feature[i - batch_index + j] = norm_feature
|
||||||
gallery_layer = paddle.nn.Linear(embedding_size, len(gallery_images), weight_attr=gallery_feature, bias_attr=False)
|
gallery_layer = paddle.nn.Linear(embedding_size, len(gallery_images), bias_attr=False)
|
||||||
|
gallery_layer.set_state_dict({"weight": gallery_feature.T})
|
||||||
return gallery_layer
|
return gallery_layer
|
||||||
|
|
||||||
|
|
||||||
|
@ -71,11 +74,12 @@ class FuseModel(paddle.nn.Layer):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.feature_extractor = build_model(configs)
|
self.feature_extractor = build_model(configs)
|
||||||
load_dygraph_pretrain(self.feature_extractor, configs["Global"]["pretrained_model"])
|
load_dygraph_pretrain(self.feature_extractor, configs["Global"]["pretrained_model"])
|
||||||
|
self.feature_extractor.eval()
|
||||||
self.feature_extractor.head = IdentityHead()
|
self.feature_extractor.head = IdentityHead()
|
||||||
self.gallery_layer = build_gallery_layer(configs, self.feature_extractor)
|
self.gallery_layer = build_gallery_layer(configs, self.feature_extractor)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.feature_model(x)["features"]
|
x = self.feature_extractor(x)["features"]
|
||||||
x = paddle.nn.functional.normalize(x)
|
x = paddle.nn.functional.normalize(x)
|
||||||
x = self.gallery_layer(x)
|
x = self.gallery_layer(x)
|
||||||
return x
|
return x
|
||||||
|
@ -86,7 +90,7 @@ def main():
|
||||||
configs = parse_config(args.config)
|
configs = parse_config(args.config)
|
||||||
init_logger(name='gallery2fc')
|
init_logger(name='gallery2fc')
|
||||||
fuse_model = FuseModel(configs)
|
fuse_model = FuseModel(configs)
|
||||||
save_fuse_model(fuse_model)
|
# save_fuse_model(fuse_model)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue