mirror of
https://github.com/KaiyangZhou/deep-person-reid.git
synced 2025-06-03 14:53:23 +08:00
When passing a model_path to the FeatureExtractor, weights are loaded twice. One the pretrained once and then the given ones. Avoid the unnecessary load.
153 lines
4.5 KiB
Python
153 lines
4.5 KiB
Python
from __future__ import absolute_import
|
|
import numpy as np
|
|
import torch
|
|
import torchvision.transforms as T
|
|
from PIL import Image
|
|
|
|
from torchreid.utils import (
|
|
check_isfile, load_pretrained_weights, compute_model_complexity
|
|
)
|
|
from torchreid.models import build_model
|
|
|
|
|
|
class FeatureExtractor(object):
|
|
"""A simple API for feature extraction.
|
|
|
|
FeatureExtractor can be used like a python function, which
|
|
accepts input of the following types:
|
|
- a list of strings (image paths)
|
|
- a list of numpy.ndarray each with shape (H, W, C)
|
|
- a single string (image path)
|
|
- a single numpy.ndarray with shape (H, W, C)
|
|
- a torch.Tensor with shape (B, C, H, W) or (C, H, W)
|
|
|
|
Returned is a torch tensor with shape (B, D) where D is the
|
|
feature dimension.
|
|
|
|
Args:
|
|
model_name (str): model name.
|
|
model_path (str): path to model weights.
|
|
image_size (sequence or int): image height and width.
|
|
pixel_mean (list): pixel mean for normalization.
|
|
pixel_std (list): pixel std for normalization.
|
|
pixel_norm (bool): whether to normalize pixels.
|
|
device (str): 'cpu' or 'cuda' (could be specific gpu devices).
|
|
verbose (bool): show model details.
|
|
|
|
Examples::
|
|
|
|
from torchreid.utils import FeatureExtractor
|
|
|
|
extractor = FeatureExtractor(
|
|
model_name='osnet_x1_0',
|
|
model_path='a/b/c/model.pth.tar',
|
|
device='cuda'
|
|
)
|
|
|
|
image_list = [
|
|
'a/b/c/image001.jpg',
|
|
'a/b/c/image002.jpg',
|
|
'a/b/c/image003.jpg',
|
|
'a/b/c/image004.jpg',
|
|
'a/b/c/image005.jpg'
|
|
]
|
|
|
|
features = extractor(image_list)
|
|
print(features.shape) # output (5, 512)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name='',
|
|
model_path='',
|
|
image_size=(256, 128),
|
|
pixel_mean=[0.485, 0.456, 0.406],
|
|
pixel_std=[0.229, 0.224, 0.225],
|
|
pixel_norm=True,
|
|
device='cuda',
|
|
verbose=True
|
|
):
|
|
# Build model
|
|
model = build_model(
|
|
model_name,
|
|
num_classes=1,
|
|
pretrained=not (model_path and check_isfile(model_path)),
|
|
use_gpu=device.startswith('cuda')
|
|
)
|
|
model.eval()
|
|
|
|
if verbose:
|
|
num_params, flops = compute_model_complexity(
|
|
model, (1, 3, image_size[0], image_size[1])
|
|
)
|
|
print('Model: {}'.format(model_name))
|
|
print('- params: {:,}'.format(num_params))
|
|
print('- flops: {:,}'.format(flops))
|
|
|
|
if model_path and check_isfile(model_path):
|
|
load_pretrained_weights(model, model_path)
|
|
|
|
# Build transform functions
|
|
transforms = []
|
|
transforms += [T.Resize(image_size)]
|
|
transforms += [T.ToTensor()]
|
|
if pixel_norm:
|
|
transforms += [T.Normalize(mean=pixel_mean, std=pixel_std)]
|
|
preprocess = T.Compose(transforms)
|
|
|
|
to_pil = T.ToPILImage()
|
|
|
|
device = torch.device(device)
|
|
model.to(device)
|
|
|
|
# Class attributes
|
|
self.model = model
|
|
self.preprocess = preprocess
|
|
self.to_pil = to_pil
|
|
self.device = device
|
|
|
|
def __call__(self, input):
|
|
if isinstance(input, list):
|
|
images = []
|
|
|
|
for element in input:
|
|
if isinstance(element, str):
|
|
image = Image.open(element).convert('RGB')
|
|
|
|
elif isinstance(element, np.ndarray):
|
|
image = self.to_pil(element)
|
|
|
|
else:
|
|
raise TypeError(
|
|
'Type of each element must belong to [str | numpy.ndarray]'
|
|
)
|
|
|
|
image = self.preprocess(image)
|
|
images.append(image)
|
|
|
|
images = torch.stack(images, dim=0)
|
|
images = images.to(self.device)
|
|
|
|
elif isinstance(input, str):
|
|
image = Image.open(input).convert('RGB')
|
|
image = self.preprocess(image)
|
|
images = image.unsqueeze(0).to(self.device)
|
|
|
|
elif isinstance(input, np.ndarray):
|
|
image = self.to_pil(input)
|
|
image = self.preprocess(image)
|
|
images = image.unsqueeze(0).to(self.device)
|
|
|
|
elif isinstance(input, torch.Tensor):
|
|
if input.dim() == 3:
|
|
input = input.unsqueeze(0)
|
|
images = input.to(self.device)
|
|
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
with torch.no_grad():
|
|
features = self.model(images)
|
|
|
|
return features
|