EasyCV/easycv/predictors/reid_predictor.py

143 lines
5.0 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
import numpy as np
import torch
from PIL import Image, ImageFile
from easycv.file import io
from easycv.framework.errors import ValueError
from easycv.utils.misc import deprecated
from .base import InputProcessor, OutputProcessor, Predictor, PredictorV2
from .builder import PREDICTORS
from .classifier import ClassificationPredictor
@PREDICTORS.register_module()
class ReIDPredictor(ClassificationPredictor):
"""Predictor for reid.
Args:
model_path (str): Path of model path.
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
batch_size (int): batch size for forward.
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs.
topk (int): Return top-k results. Default: 1.
pil_input (bool): Whether use PIL image. If processor need PIL input, set true, default false.
label_map_path (str): File path of saving labels list.
input_processor_threads (int): Number of processes to process inputs.
mode (str): The image mode into the model.
"""
def __init__(self,
model_path,
config_file=None,
batch_size=1,
device=None,
save_results=False,
save_path=None,
pipelines=None,
topk=1,
pil_input=True,
label_map_path=None,
input_processor_threads=8,
mode='BGR',
*args,
**kwargs):
if pil_input:
mode = 'RGB'
super(ReIDPredictor, self).__init__(
model_path,
config_file=config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
input_processor_threads=input_processor_threads,
mode=mode,
topk=topk,
pil_input=pil_input,
label_map_path=label_map_path,
*args,
**kwargs)
def model_forward(self, inputs):
"""Model forward.
If you need refactor model forward, you need to reimplement it.
"""
with torch.no_grad():
outputs = self.model(**inputs, mode='extract')
return outputs
def get_id(self, img_path):
camera_id = []
labels = []
for path in img_path:
filename = os.path.basename(path)
label = filename[0:4]
camera = filename.split('c')[1]
if label[0:2] == '-1':
labels.append(-1)
else:
labels.append(int(label))
camera_id.append(int(camera[0]))
return camera_id, labels
def file_name_walk(self, file_dir):
image_path_list = []
for root, dirs, files in os.walk(file_dir):
for name in files:
if name.endswith('.jpg'):
image_path_list.append(os.path.join(root, name))
return image_path_list
def __call__(self, inputs, keep_inputs=False):
if not isinstance(inputs, list):
inputs = self.file_name_walk(inputs)
img_cam, img_label = self.get_id(inputs)
else:
img_cam, img_label = None, None
if self.input_processor is None:
self.input_processor = self.get_input_processor()
# TODO: fault tolerance
if isinstance(inputs, (str, np.ndarray, ImageFile.ImageFile)):
inputs = [inputs]
results_list = []
for i in range(0, len(inputs), self.batch_size):
batch_inputs = inputs[i:min(len(inputs), i + self.batch_size)]
batch_outputs = self.input_processor(batch_inputs)
batch_outputs = self._to_device(batch_outputs)
batch_outputs = self.model_forward(batch_outputs)
results = batch_outputs['neck']
if keep_inputs:
for i in range(len(batch_inputs)):
results[i].update({'inputs': batch_inputs[i]})
if isinstance(results, list):
results_list.extend(results)
else:
results_list.append(results)
image_feature = torch.cat(results_list, 0)
image_feature_norm = torch.norm(
image_feature, p=2, dim=1, keepdim=True)
image_feature = image_feature.div(
image_feature_norm.expand_as(image_feature))
# TODO: support append to file
if self.save_results:
self.dump(image_feature, self.save_path)
return {
'img_feature': image_feature,
'img_cam': img_cam,
'img_label': img_label
}