mirror of https://github.com/alibaba/EasyCV.git
143 lines
5.0 KiB
Python
143 lines
5.0 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import math
|
|
import os
|
|
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image, ImageFile
|
|
|
|
from easycv.file import io
|
|
from easycv.framework.errors import ValueError
|
|
from easycv.utils.misc import deprecated
|
|
from .base import InputProcessor, OutputProcessor, Predictor, PredictorV2
|
|
from .builder import PREDICTORS
|
|
from .classifier import ClassificationPredictor
|
|
|
|
|
|
@PREDICTORS.register_module()
|
|
class ReIDPredictor(ClassificationPredictor):
|
|
"""Predictor for reid.
|
|
Args:
|
|
model_path (str): Path of model path.
|
|
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
|
|
batch_size (int): batch size for forward.
|
|
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
|
|
save_results (bool): Whether to save predict results.
|
|
save_path (str): File path for saving results, only valid when `save_results` is True.
|
|
pipelines (list[dict]): Data pipeline configs.
|
|
topk (int): Return top-k results. Default: 1.
|
|
pil_input (bool): Whether use PIL image. If processor need PIL input, set true, default false.
|
|
label_map_path (str): File path of saving labels list.
|
|
input_processor_threads (int): Number of processes to process inputs.
|
|
mode (str): The image mode into the model.
|
|
"""
|
|
|
|
def __init__(self,
|
|
model_path,
|
|
config_file=None,
|
|
batch_size=1,
|
|
device=None,
|
|
save_results=False,
|
|
save_path=None,
|
|
pipelines=None,
|
|
topk=1,
|
|
pil_input=True,
|
|
label_map_path=None,
|
|
input_processor_threads=8,
|
|
mode='BGR',
|
|
*args,
|
|
**kwargs):
|
|
|
|
if pil_input:
|
|
mode = 'RGB'
|
|
super(ReIDPredictor, self).__init__(
|
|
model_path,
|
|
config_file=config_file,
|
|
batch_size=batch_size,
|
|
device=device,
|
|
save_results=save_results,
|
|
save_path=save_path,
|
|
pipelines=pipelines,
|
|
input_processor_threads=input_processor_threads,
|
|
mode=mode,
|
|
topk=topk,
|
|
pil_input=pil_input,
|
|
label_map_path=label_map_path,
|
|
*args,
|
|
**kwargs)
|
|
|
|
def model_forward(self, inputs):
|
|
"""Model forward.
|
|
If you need refactor model forward, you need to reimplement it.
|
|
"""
|
|
with torch.no_grad():
|
|
outputs = self.model(**inputs, mode='extract')
|
|
return outputs
|
|
|
|
def get_id(self, img_path):
|
|
camera_id = []
|
|
labels = []
|
|
for path in img_path:
|
|
filename = os.path.basename(path)
|
|
label = filename[0:4]
|
|
camera = filename.split('c')[1]
|
|
if label[0:2] == '-1':
|
|
labels.append(-1)
|
|
else:
|
|
labels.append(int(label))
|
|
camera_id.append(int(camera[0]))
|
|
return camera_id, labels
|
|
|
|
def file_name_walk(self, file_dir):
|
|
image_path_list = []
|
|
for root, dirs, files in os.walk(file_dir):
|
|
for name in files:
|
|
if name.endswith('.jpg'):
|
|
image_path_list.append(os.path.join(root, name))
|
|
return image_path_list
|
|
|
|
def __call__(self, inputs, keep_inputs=False):
|
|
if not isinstance(inputs, list):
|
|
inputs = self.file_name_walk(inputs)
|
|
img_cam, img_label = self.get_id(inputs)
|
|
else:
|
|
img_cam, img_label = None, None
|
|
|
|
if self.input_processor is None:
|
|
self.input_processor = self.get_input_processor()
|
|
|
|
# TODO: fault tolerance
|
|
if isinstance(inputs, (str, np.ndarray, ImageFile.ImageFile)):
|
|
inputs = [inputs]
|
|
|
|
results_list = []
|
|
for i in range(0, len(inputs), self.batch_size):
|
|
batch_inputs = inputs[i:min(len(inputs), i + self.batch_size)]
|
|
batch_outputs = self.input_processor(batch_inputs)
|
|
batch_outputs = self._to_device(batch_outputs)
|
|
batch_outputs = self.model_forward(batch_outputs)
|
|
results = batch_outputs['neck']
|
|
if keep_inputs:
|
|
for i in range(len(batch_inputs)):
|
|
results[i].update({'inputs': batch_inputs[i]})
|
|
if isinstance(results, list):
|
|
results_list.extend(results)
|
|
else:
|
|
results_list.append(results)
|
|
|
|
image_feature = torch.cat(results_list, 0)
|
|
image_feature_norm = torch.norm(
|
|
image_feature, p=2, dim=1, keepdim=True)
|
|
image_feature = image_feature.div(
|
|
image_feature_norm.expand_as(image_feature))
|
|
|
|
# TODO: support append to file
|
|
if self.save_results:
|
|
self.dump(image_feature, self.save_path)
|
|
|
|
return {
|
|
'img_feature': image_feature,
|
|
'img_cam': img_cam,
|
|
'img_label': img_label
|
|
}
|