deep-person-reid/torchreid/utils/feature_extractor.py

153 lines
4.5 KiB
Python

from __future__ import absolute_import
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from torchreid.utils import (
check_isfile, load_pretrained_weights, compute_model_complexity
)
from torchreid.models import build_model
class FeatureExtractor(object):
"""A simple API for feature extraction.
FeatureExtractor can be used like a python function, which
accepts input of the following types:
- a list of strings (image paths)
- a list of numpy.ndarray each with shape (H, W, C)
- a single string (image path)
- a single numpy.ndarray with shape (H, W, C)
- a torch.Tensor with shape (B, C, H, W) or (C, H, W)
Returned is a torch tensor with shape (B, D) where D is the
feature dimension.
Args:
model_name (str): model name.
model_path (str): path to model weights.
image_size (sequence or int): image height and width.
pixel_mean (list): pixel mean for normalization.
pixel_std (list): pixel std for normalization.
pixel_norm (bool): whether to normalize pixels.
device (str): 'cpu' or 'cuda' (could be specific gpu devices).
verbose (bool): show model details.
Examples::
from torchreid.utils import FeatureExtractor
extractor = FeatureExtractor(
model_name='osnet_x1_0',
model_path='a/b/c/model.pth.tar',
device='cuda'
)
image_list = [
'a/b/c/image001.jpg',
'a/b/c/image002.jpg',
'a/b/c/image003.jpg',
'a/b/c/image004.jpg',
'a/b/c/image005.jpg'
]
features = extractor(image_list)
print(features.shape) # output (5, 512)
"""
def __init__(
self,
model_name='',
model_path='',
image_size=(256, 128),
pixel_mean=[0.485, 0.456, 0.406],
pixel_std=[0.229, 0.224, 0.225],
pixel_norm=True,
device='cuda',
verbose=True
):
# Build model
model = build_model(
model_name,
num_classes=1,
pretrained=not (model_path and check_isfile(model_path)),
use_gpu=device.startswith('cuda')
)
model.eval()
if verbose:
num_params, flops = compute_model_complexity(
model, (1, 3, image_size[0], image_size[1])
)
print('Model: {}'.format(model_name))
print('- params: {:,}'.format(num_params))
print('- flops: {:,}'.format(flops))
if model_path and check_isfile(model_path):
load_pretrained_weights(model, model_path)
# Build transform functions
transforms = []
transforms += [T.Resize(image_size)]
transforms += [T.ToTensor()]
if pixel_norm:
transforms += [T.Normalize(mean=pixel_mean, std=pixel_std)]
preprocess = T.Compose(transforms)
to_pil = T.ToPILImage()
device = torch.device(device)
model.to(device)
# Class attributes
self.model = model
self.preprocess = preprocess
self.to_pil = to_pil
self.device = device
def __call__(self, input):
if isinstance(input, list):
images = []
for element in input:
if isinstance(element, str):
image = Image.open(element).convert('RGB')
elif isinstance(element, np.ndarray):
image = self.to_pil(element)
else:
raise TypeError(
'Type of each element must belong to [str | numpy.ndarray]'
)
image = self.preprocess(image)
images.append(image)
images = torch.stack(images, dim=0)
images = images.to(self.device)
elif isinstance(input, str):
image = Image.open(input).convert('RGB')
image = self.preprocess(image)
images = image.unsqueeze(0).to(self.device)
elif isinstance(input, np.ndarray):
image = self.to_pil(input)
image = self.preprocess(image)
images = image.unsqueeze(0).to(self.device)
elif isinstance(input, torch.Tensor):
if input.dim() == 3:
input = input.unsqueeze(0)
images = input.to(self.device)
else:
raise NotImplementedError
with torch.no_grad():
features = self.model(images)
return features