PaddleClas/ppcls/utils/gallery2fc.py

96 lines
3.3 KiB
Python
Raw Normal View History

2021-12-14 10:34:40 +08:00
import os
2021-12-08 14:40:03 +08:00
import paddle
2021-12-15 10:39:26 +08:00
import cv2
2021-12-14 10:34:40 +08:00
2021-12-08 19:47:43 +08:00
from ppcls.arch import build_model
2021-12-14 10:34:40 +08:00
from ppcls.arch.gears.identity_head import IdentityHead
2021-12-08 19:48:59 +08:00
from ppcls.utils.config import parse_config, parse_args
2021-12-08 19:56:30 +08:00
from ppcls.utils.save_load import load_dygraph_pretrain
2021-12-08 20:02:13 +08:00
from ppcls.utils.logger import init_logger
2021-12-14 10:34:40 +08:00
from ppcls.data import transform, create_operators
2021-12-08 19:56:30 +08:00
2021-12-15 10:39:26 +08:00
def build_gallery_layer(configs, feature_extractor):
transform_configs = configs["IndexProcess"]["transform_ops"]
2021-12-14 10:34:40 +08:00
preprocess_ops = create_operators(transform_configs)
embedding_size = configs["Arch"]["Head"]["embedding_size"]
batch_size = configs["IndexProcess"]["batch_size"]
2021-12-23 20:51:15 +08:00
image_shape = configs["Global"]["image_shape"].copy()
2021-12-14 10:34:40 +08:00
image_shape.insert(0, batch_size)
input_tensor = paddle.zeros(image_shape)
image_root = configs["IndexProcess"]["image_root"]
data_file = configs["IndexProcess"]["data_file"]
delimiter = configs["IndexProcess"]["delimiter"]
gallery_images = []
gallery_docs = []
2021-12-23 20:29:05 +08:00
gallery_labels = []
2021-12-14 10:34:40 +08:00
with open(data_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
2021-12-23 20:29:05 +08:00
for ori_line in lines:
2021-12-14 10:34:40 +08:00
line = ori_line.strip().split(delimiter)
text_num = len(line)
assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
image_file = os.path.join(image_root, line[0])
gallery_images.append(image_file)
gallery_docs.append(ori_line.strip())
2021-12-23 20:29:05 +08:00
gallery_labels.append(line[1].strip())
2021-12-15 10:39:26 +08:00
batch_index = 0
gallery_feature = paddle.zeros((len(gallery_images), embedding_size))
for i, image_path in enumerate(gallery_images):
image = cv2.imread(image_path)
for op in preprocess_ops:
image = op(image)
input_tensor[batch_index] = image
batch_index += 1
if batch_index == batch_size or i == len(gallery_images) - 1:
2021-12-23 20:29:05 +08:00
batch_feature = feature_extractor(input_tensor)["features"]
2021-12-15 10:39:26 +08:00
for j in range(batch_index):
feature = batch_feature[j]
2021-12-23 20:29:05 +08:00
norm_feature = paddle.nn.functional.normalize(feature, axis=0)
gallery_feature[i - batch_index + j] = norm_feature
gallery_layer = paddle.nn.Linear(embedding_size, len(gallery_images), bias_attr=False)
gallery_layer.set_state_dict({"weight": gallery_feature.T})
2021-12-15 10:39:26 +08:00
return gallery_layer
2021-12-23 20:51:15 +08:00
class GalleryLayer(paddle.nn.Layer):
def __init__(self, configs, feature_extractor):
2021-12-08 14:40:03 +08:00
super().__init__()
2021-12-23 20:51:15 +08:00
self.gallery_layer = build_gallery_layer(configs, feature_extractor)
2021-12-08 14:40:03 +08:00
def forward(self, x):
2021-12-15 10:39:26 +08:00
x = paddle.nn.functional.normalize(x)
2021-12-08 14:40:03 +08:00
x = self.gallery_layer(x)
return x
2021-12-23 20:35:10 +08:00
def export_fuse_model(configs):
2021-12-23 20:51:15 +08:00
fuse_model = build_model(configs)
load_dygraph_pretrain(fuse_model, configs["Global"]["pretrained_model"])
2021-12-23 20:35:10 +08:00
fuse_model.eval()
2021-12-23 20:51:15 +08:00
fuse_model.head = GalleryLayer(configs, fuse_model)
2021-12-23 20:35:10 +08:00
save_path = configs["Global"]["save_inference_dir"]
fuse_model.quanter.save_quantized_model(
fuse_model,
save_path,
input_spec=[
paddle.static.InputSpec(
shape=[None] + configs["Global"]["image_shape"],
dtype='float32')
])
2021-12-08 14:40:03 +08:00
def main():
args = parse_args()
configs = parse_config(args.config)
2021-12-08 20:02:13 +08:00
init_logger(name='gallery2fc')
2021-12-23 20:35:10 +08:00
export_fuse_model(configs)
2021-12-08 14:40:03 +08:00
if __name__ == '__main__':
main()