add feature extract
parent
c941144568
commit
b6a7c53182
ppcls/utils
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import paddle
|
||||
import cv2
|
||||
|
||||
from ppcls.arch import build_model
|
||||
from ppcls.arch.gears.identity_head import IdentityHead
|
||||
|
@ -9,8 +10,8 @@ from ppcls.utils.logger import init_logger
|
|||
from ppcls.data import transform, create_operators
|
||||
|
||||
|
||||
def build_gallery_feature(configs, feature_extractor):
|
||||
transform_configs = configs["Infer"]["transforms"]
|
||||
def build_gallery_layer(configs, feature_extractor):
|
||||
transform_configs = configs["IndexProcess"]["transform_ops"]
|
||||
preprocess_ops = create_operators(transform_configs)
|
||||
|
||||
embedding_size = configs["Arch"]["Head"]["embedding_size"]
|
||||
|
@ -35,13 +36,34 @@ def build_gallery_feature(configs, feature_extractor):
|
|||
|
||||
gallery_images.append(image_file)
|
||||
gallery_docs.append(ori_line.strip())
|
||||
batch_index = 0
|
||||
gallery_feature = paddle.zeros((len(gallery_images), embedding_size))
|
||||
for i, image_path in enumerate(gallery_images):
|
||||
image = cv2.imread(image_path)
|
||||
for op in preprocess_ops:
|
||||
image = op(image)
|
||||
input_tensor[batch_index] = image
|
||||
batch_index += 1
|
||||
if batch_index == batch_size or i == len(gallery_images) - 1:
|
||||
batch_feature = feature_extractor(input_tensor)
|
||||
for j in range(batch_index):
|
||||
feature = batch_feature[j]
|
||||
norm_feature = paddle.nn.functional.normalize(feature)
|
||||
gallery_feature[i + batch_index - j] = norm_feature
|
||||
gallery_layer = paddle.nn.Linear(embedding_size, len(gallery_images), weight_attr=gallery_feature, bias_attr=False)
|
||||
return gallery_layer
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def save_fuse_model(fuse_model):
|
||||
pass
|
||||
def export_fuse_model(model, config):
|
||||
model.eval()
|
||||
model.quanter.save_quantized_model(
|
||||
model.base_model,
|
||||
save_path,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None] + config["Global"]["image_shape"],
|
||||
dtype='float32')
|
||||
])
|
||||
|
||||
|
||||
class FuseModel(paddle.nn.Layer):
|
||||
|
@ -50,11 +72,11 @@ class FuseModel(paddle.nn.Layer):
|
|||
self.feature_extractor = build_model(configs)
|
||||
load_dygraph_pretrain(self.feature_extractor, configs["Global"]["pretrained_model"])
|
||||
self.feature_extractor.head = IdentityHead()
|
||||
self.gallery_layer = build_gallery_feature(configs, self.feature_extractor)
|
||||
self.gallery_layer = build_gallery_layer(configs, self.feature_extractor)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.feature_model(x)["features"]
|
||||
x = paddle.norm(x)
|
||||
x = paddle.nn.functional.normalize(x)
|
||||
x = self.gallery_layer(x)
|
||||
return x
|
||||
|
||||
|
|
Loading…
Reference in New Issue