Release the inference code of ISE (ReID-CVPR2022)
parent
bf12dffcd3
commit
12d1b5f9a9
docs
en/algorithm_introduction
images
zh_CN/algorithm_introduction
ppcls
arch/gears
configs
data
dataloader
|
@ -0,0 +1,60 @@
|
|||
# ISE
|
||||
---
|
||||
## Catalogue
|
||||
|
||||
- [1. Introduction](#1)
|
||||
- [2. Performance on Market1501 and MSMT17](#2)
|
||||
- [3. Test](#3)
|
||||
- [4. Reference](#4)
|
||||
|
||||
<a name='1'></a>
|
||||
## 1. Introduction
|
||||
|
||||
ISE (Implicit Sample Extension) is a simple, efficient, and effective learning algorithm for unsupervised person Re-ID. ISE generates what we call support samples around the cluster boundaries. The sample generation process in ISE depends on two critical mechanisms, i.e., a progressive linear interpolation strategy and a label-preserving loss function. The generated support samples from ISE provide complementary information, which can nicely handle the "sub and mixed" clustering errors. ISE achieves superior performance than other unsupervised methods on Market1501 and MSMT17 datasets.
|
||||
|
||||
> [**Implicit Sample Extension for Unsupervised Person Re-Identification**](https://arxiv.org/abs/2204.06892v1)<br>
|
||||
> Xinyu Zhang, Dongdong Li, Zhigang Wang, Jian Wang, Errui Ding, Javen Qinfeng Shi, Zhaoxiang Zhang, Jingdong Wang<br>
|
||||
> CVPR2022
|
||||
|
||||

|
||||
|
||||
<a name='2'></a>
|
||||
## 2. Performance on Market1501 and MSMT17
|
||||
|
||||
The main results on Market1501 (M) and MSMT17 (MS). PIL denotes the progressive linear interpolation strategy. LP represents the label-preserving loss function.
|
||||
|
||||
| Methods | M | Link | MS | Link |
|
||||
| --- | -- | -- | -- | - |
|
||||
| Baseline | 82.5 (92.5) | - | 30.1 (58.6) | - |
|
||||
| ISE (+PIL) | 83.9 (93.9) | - | 33.5 (63.9) | - |
|
||||
| ISE (+LP) | 83.6 (92.7) | - | 31.4 (59.9) | - |
|
||||
| ISE (Ours) (+PIL+LP) | **84.7 (94.0)** | [ISE_M](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ISE_M_model.pdparams) | **35.0 (64.7)** | [ISE_MS](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ISE_MS_model.pdparams) |
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. Test
|
||||
The training code is coming soon. We first release the test code with the pretrained models.
|
||||
|
||||
**Test:** You can simply run the following script for the evaluation.
|
||||
|
||||
```
|
||||
python tools/eval.py -c ./ppcls/configs/Person/ResNet50_UReID_infer.yaml
|
||||
```
|
||||
**Steps:**
|
||||
1. Download the pretrained model first, and put the model into: ```./pd_model_trace/ISE/```.
|
||||
2. Change the dataset name in: ```./ppcls/configs/Person/ResNet50_UReID_infer.yaml```.
|
||||
3. Run the above script.
|
||||
|
||||
|
||||
<a name="4"></a>
|
||||
## 4. Reference
|
||||
|
||||
If you find ISE useful in your research, please kindly consider citing our paper:
|
||||
|
||||
```
|
||||
@inproceedings{zhang2022Implicit,
|
||||
title={Implicit Sample Extension for Unsupervised Person Re-Identification},
|
||||
author={Xinyu Zhang, Dongdong Li, Zhigang Wang, Jian Wang, Errui Ding, Javen Qinfeng Shi, Zhaoxiang Zhang, Jingdong Wang},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||
year={2022}
|
||||
}
|
||||
```
|
Binary file not shown.
After Width: | Height: | Size: 178 KiB |
|
@ -0,0 +1,62 @@
|
|||
# ISE
|
||||
---
|
||||
## 目录
|
||||
|
||||
- [1. 介绍](#1)
|
||||
- [2. 在Market1501和MSMT17上的结果](#2)
|
||||
- [3. 测试](#3)
|
||||
- [4. 引用](#4)
|
||||
|
||||
<a name='1'></a>
|
||||
## 1. 介绍
|
||||
|
||||
ISE (Implicit Sample Extension)是一种简单、高效、有效的无监督行人再识别学习算法。ISE在聚类蔟边界周围生成样本,我们称之为支持样本。ISE的样本生成过程依赖于两个关键机制,即渐进线性插值策略(progressive linear interpolation)和标签保留的损失函数(label-preserving loss function)。ISE生成的支持样本提供了额外补充信息,可以很好地处理“子类和混合”的聚类错误。ISE在Market1501和MSMT17数据集上取得了优于其他无监督方法的性能。
|
||||
|
||||
> [**Implicit Sample Extension for Unsupervised Person Re-Identification**](https://arxiv.org/abs/2204.06892v1)<br>
|
||||
> Xinyu Zhang, Dongdong Li, Zhigang Wang, Jian Wang, Errui Ding, Javen Qinfeng Shi, Zhaoxiang Zhang, Jingdong Wang<br>
|
||||
> CVPR2022
|
||||
|
||||

|
||||
|
||||
|
||||
<a name='2'></a>
|
||||
## 2. 在Market1501和MSMT17上的结果
|
||||
|
||||
在Market1501和MSMT17上的主要结果。“PIL”表示渐进线性插值策略。“LP”表示标签保留的损失函数。
|
||||
|
||||
| 方法 | Market1501 | 下载链接 | MSMT17 | 下载链接 |
|
||||
| --- | -- | -- | -- | - |
|
||||
| Baseline | 82.5 (92.5) | - | 30.1 (58.6) | - |
|
||||
| ISE (+PIL) | 83.9 (93.9) | - | 33.5 (63.9) | - |
|
||||
| ISE (+LP) | 83.6 (92.7) | - | 31.4 (59.9) | - |
|
||||
| ISE (Ours) (+PIL+LP) | **84.7 (94.0)** | [ISE_M](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ISE_M_model.pdparams) | **35.0 (64.7)** | [ISE_MS](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ISE_MS_model.pdparams) |
|
||||
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. 测试
|
||||
我们很快会提供训练代码,首先我们提供了测试代码和模型。
|
||||
|
||||
**测试:** 可简使用如下脚本进行模型评估。
|
||||
|
||||
```
|
||||
python tools/eval.py -c ./ppcls/configs/Person/ResNet50_UReID_infer.yaml
|
||||
```
|
||||
**步骤:**
|
||||
1. 首先下载模型,并放入:```./pd_model_trace/ISE/```。
|
||||
2. 改变```./ppcls/configs/Person/ResNet50_UReID_infer.yaml```中的数据集名称。
|
||||
3. 运行上述脚本。
|
||||
|
||||
|
||||
<a name="4"></a>
|
||||
## 4. 引用
|
||||
|
||||
如果ISE在您的研究中有启发,请考虑引用我们的论文:
|
||||
|
||||
```
|
||||
@inproceedings{zhang2022Implicit,
|
||||
title={Implicit Sample Extension for Unsupervised Person Re-Identification},
|
||||
author={Xinyu Zhang, Dongdong Li, Zhigang Wang, Jian Wang, Errui Ding, Javen Qinfeng Shi, Zhaoxiang Zhang, Jingdong Wang},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||
year={2022}
|
||||
}
|
||||
```
|
|
@ -18,13 +18,15 @@ from .circlemargin import CircleMargin
|
|||
from .fc import FC
|
||||
from .vehicle_neck import VehicleNeck
|
||||
from paddle.nn import Tanh
|
||||
from .bnneck import BNNeck
|
||||
|
||||
__all__ = ['build_gear']
|
||||
|
||||
|
||||
def build_gear(config):
|
||||
support_dict = [
|
||||
'ArcMargin', 'CosMargin', 'CircleMargin', 'FC', 'VehicleNeck', 'Tanh'
|
||||
'ArcMargin', 'CosMargin', 'CircleMargin', 'FC', 'VehicleNeck', 'Tanh',
|
||||
'BNNeck'
|
||||
]
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception(
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
|
||||
class BNNeck(nn.Layer):
|
||||
def __init__(self, num_features):
|
||||
super().__init__()
|
||||
weight_attr = paddle.ParamAttr(
|
||||
initializer=paddle.nn.initializer.Constant(value=1.0))
|
||||
bias_attr = paddle.ParamAttr(
|
||||
initializer=paddle.nn.initializer.Constant(value=0.0),
|
||||
trainable=False)
|
||||
self.feat_bn = nn.BatchNorm1D(
|
||||
num_features,
|
||||
momentum=0.9,
|
||||
epsilon=1e-05,
|
||||
weight_attr=weight_attr,
|
||||
bias_attr=bias_attr)
|
||||
self.flatten = nn.Flatten()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.flatten(x)
|
||||
x = self.feat_bn(x)
|
||||
return x
|
|
@ -0,0 +1,152 @@
|
|||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
# pretrained_model: "./pd_model_trace/ISE/ISE_M_model" # pretrained ISE model for Market1501
|
||||
# pretrained_model: "./pd_model_trace/ISE/ISE_MS_model" # pretrained ISE model for MSMT17
|
||||
output_dir: "./output/"
|
||||
device: "gpu"
|
||||
save_interval: 1
|
||||
eval_during_train: True
|
||||
eval_interval: 1
|
||||
epochs: 120
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 128, 256]
|
||||
save_inference_dir: "./inference"
|
||||
eval_mode: "retrieval"
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: "RecModel"
|
||||
infer_output_key: "features"
|
||||
infer_add_softmax: False
|
||||
Backbone:
|
||||
name: "ResNet50_last_stage_stride1"
|
||||
pretrained: True
|
||||
BackboneStopLayer:
|
||||
name: "avg_pool"
|
||||
Neck:
|
||||
name: "BNNeck"
|
||||
num_features: 2048
|
||||
Head:
|
||||
name: "FC"
|
||||
embedding_size: 2048
|
||||
class_num: 751
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
- SupConLoss:
|
||||
weight: 1.0
|
||||
views: 2
|
||||
Eval:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
|
||||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.04
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
coeff: 0.0005
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: "Market1501" # ["Market1501", "MSMT17"]
|
||||
image_root: "./dataset"
|
||||
cls_label_path: "bounding_box_train"
|
||||
transform_ops:
|
||||
- ResizeImage:
|
||||
size: [128, 256]
|
||||
interpolation: 'bicubic'
|
||||
backend: 'pil'
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- Pad:
|
||||
padding: 10
|
||||
fill: 0
|
||||
- RandomCrop:
|
||||
size: [128, 256]
|
||||
pad_if_needed: False
|
||||
- NormalizeImage:
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- RandomErasing:
|
||||
EPSILON: 0.5
|
||||
sl: 0.02
|
||||
sh: 0.4
|
||||
r1: 0.3
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
|
||||
sampler:
|
||||
name: PKSampler
|
||||
batch_size: 16
|
||||
sample_per_id: 4
|
||||
drop_last: True
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 6
|
||||
use_shared_memory: True
|
||||
Eval:
|
||||
Query:
|
||||
dataset:
|
||||
name: "Market1501" # ["Market1501", "MSMT17"]
|
||||
image_root: "./dataset"
|
||||
cls_label_path: "query"
|
||||
transform_ops:
|
||||
- ResizeImage:
|
||||
size: [128, 256]
|
||||
interpolation: 'bicubic'
|
||||
backend: 'pil'
|
||||
- NormalizeImage:
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 6
|
||||
use_shared_memory: True
|
||||
|
||||
Gallery:
|
||||
dataset:
|
||||
name: "Market1501" # ["Market1501", "MSMT17"]
|
||||
image_root: "./dataset"
|
||||
cls_label_path: "bounding_box_test"
|
||||
transform_ops:
|
||||
- ResizeImage:
|
||||
size: [128, 256]
|
||||
interpolation: 'bicubic'
|
||||
backend: 'pil'
|
||||
- NormalizeImage:
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 6
|
||||
use_shared_memory: True
|
||||
|
||||
Metric:
|
||||
Eval:
|
||||
- Recallk:
|
||||
topk: [1, 5]
|
||||
- mAP: {}
|
||||
|
|
@ -28,6 +28,7 @@ from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
|
|||
from ppcls.data.dataloader.logo_dataset import LogoDataset
|
||||
from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
|
||||
from ppcls.data.dataloader.mix_dataset import MixDataset
|
||||
from ppcls.data.dataloader.person_dataset import Market1501, MSMT17
|
||||
|
||||
# sampler
|
||||
from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler
|
||||
|
|
|
@ -7,3 +7,4 @@ from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
|
|||
from ppcls.data.dataloader.mix_dataset import MixDataset
|
||||
from ppcls.data.dataloader.mix_sampler import MixSampler
|
||||
from ppcls.data.dataloader.pk_sampler import PKSampler
|
||||
from ppcls.data.dataloader.person_dataset import Market1501, MSMT17
|
||||
|
|
|
@ -0,0 +1,217 @@
|
|||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle.io import Dataset
|
||||
import os
|
||||
import cv2
|
||||
|
||||
from ppcls.data import preprocess
|
||||
from ppcls.data.preprocess import transform
|
||||
from ppcls.utils import logger
|
||||
from .common_dataset import create_operators
|
||||
import os.path as osp
|
||||
import glob
|
||||
import re
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class Market1501(Dataset):
|
||||
"""
|
||||
Market1501
|
||||
Reference:
|
||||
Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
|
||||
URL: http://www.liangzheng.org/Project/project_reid.html
|
||||
|
||||
Dataset statistics:
|
||||
# identities: 1501 (+1 for background)
|
||||
# images: 12936 (train) + 3368 (query) + 15913 (gallery)
|
||||
"""
|
||||
_dataset_dir = 'market1501/Market-1501-v15.09.15'
|
||||
|
||||
def __init__(self, image_root, cls_label_path, transform_ops=None):
|
||||
self._img_root = image_root
|
||||
self._cls_path = cls_label_path # the sub folder in the dataset
|
||||
self._dataset_dir = osp.join(image_root, self._dataset_dir,
|
||||
self._cls_path)
|
||||
self._check_before_run()
|
||||
if transform_ops:
|
||||
self._transform_ops = create_operators(transform_ops)
|
||||
self._dtype = paddle.get_default_dtype()
|
||||
self._load_anno(relabel=True if 'train' in self._cls_path else False)
|
||||
|
||||
def _check_before_run(self):
|
||||
"""Check if the file is available before going deeper"""
|
||||
if not osp.exists(self._dataset_dir):
|
||||
raise RuntimeError("'{}' is not available".format(
|
||||
self._dataset_dir))
|
||||
|
||||
def _load_anno(self, relabel=False):
|
||||
img_paths = glob.glob(osp.join(self._dataset_dir, '*.jpg'))
|
||||
pattern = re.compile(r'([-\d]+)_c(\d)')
|
||||
|
||||
self.images = []
|
||||
self.labels = []
|
||||
self.cameras = []
|
||||
pid_container = set()
|
||||
|
||||
for img_path in sorted(img_paths):
|
||||
pid, _ = map(int, pattern.search(img_path).groups())
|
||||
if pid == -1: continue # junk images are just ignored
|
||||
pid_container.add(pid)
|
||||
pid2label = {pid: label for label, pid in enumerate(pid_container)}
|
||||
|
||||
for img_path in sorted(img_paths):
|
||||
pid, camid = map(int, pattern.search(img_path).groups())
|
||||
if pid == -1: continue # junk images are just ignored
|
||||
assert 0 <= pid <= 1501 # pid == 0 means background
|
||||
assert 1 <= camid <= 6
|
||||
camid -= 1 # index starts from 0
|
||||
if relabel: pid = pid2label[pid]
|
||||
self.images.append(img_path)
|
||||
self.labels.append(pid)
|
||||
self.cameras.append(camid)
|
||||
|
||||
self.num_pids, self.num_imgs, self.num_cams = get_imagedata_info(
|
||||
self.images, self.labels, self.cameras, subfolder=self._cls_path)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
try:
|
||||
img = Image.open(self.images[idx]).convert('RGB')
|
||||
img = np.array(img, dtype="float32").astype(np.uint8)
|
||||
if self._transform_ops:
|
||||
img = transform(img, self._transform_ops)
|
||||
img = img.transpose((2, 0, 1))
|
||||
return (img, self.labels[idx], self.cameras[idx])
|
||||
except Exception as ex:
|
||||
logger.error("Exception occured when parse line: {} with msg: {}".
|
||||
format(self.images[idx], ex))
|
||||
rnd_idx = np.random.randint(self.__len__())
|
||||
return self.__getitem__(rnd_idx)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
@property
|
||||
def class_num(self):
|
||||
return len(set(self.labels))
|
||||
|
||||
|
||||
class MSMT17(Dataset):
|
||||
"""
|
||||
MSMT17
|
||||
|
||||
Reference:
|
||||
Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018.
|
||||
|
||||
URL: http://www.pkuvmc.com/publications/msmt17.html
|
||||
|
||||
Dataset statistics:
|
||||
# identities: 4101
|
||||
# images: 32621 (train) + 11659 (query) + 82161 (gallery)
|
||||
# cameras: 15
|
||||
"""
|
||||
_dataset_dir = 'msmt17/MSMT17_V1'
|
||||
|
||||
def __init__(self, image_root, cls_label_path, transform_ops=None):
|
||||
self._img_root = image_root
|
||||
self._cls_path = cls_label_path # the sub folder in the dataset
|
||||
self._dataset_dir = osp.join(image_root, self._dataset_dir,
|
||||
self._cls_path)
|
||||
self._check_before_run()
|
||||
if transform_ops:
|
||||
self._transform_ops = create_operators(transform_ops)
|
||||
self._dtype = paddle.get_default_dtype()
|
||||
self._load_anno(relabel=True if 'train' in self._cls_path else False)
|
||||
|
||||
def _check_before_run(self):
|
||||
"""Check if the file is available before going deeper"""
|
||||
if not osp.exists(self._dataset_dir):
|
||||
raise RuntimeError("'{}' is not available".format(
|
||||
self._dataset_dir))
|
||||
|
||||
def _load_anno(self, relabel=False):
|
||||
img_paths = glob.glob(osp.join(self._dataset_dir, '*.jpg'))
|
||||
pattern = re.compile(r'([-\d]+)_c(\d+)')
|
||||
|
||||
self.images = []
|
||||
self.labels = []
|
||||
self.cameras = []
|
||||
pid_container = set()
|
||||
|
||||
for img_path in img_paths:
|
||||
pid, _ = map(int, pattern.search(img_path).groups())
|
||||
if pid == -1:
|
||||
continue # junk images are just ignored
|
||||
pid_container.add(pid)
|
||||
pid2label = {pid: label for label, pid in enumerate(pid_container)}
|
||||
|
||||
for img_path in img_paths:
|
||||
pid, camid = map(int, pattern.search(img_path).groups())
|
||||
if pid == -1:
|
||||
continue # junk images are just ignored
|
||||
assert 1 <= camid <= 15
|
||||
camid -= 1 # index starts from 0
|
||||
if relabel:
|
||||
pid = pid2label[pid]
|
||||
self.images.append(img_path)
|
||||
self.labels.append(pid)
|
||||
self.cameras.append(camid)
|
||||
|
||||
self.num_pids, self.num_imgs, self.num_cams = get_imagedata_info(
|
||||
self.images, self.labels, self.cameras, subfolder=self._cls_path)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
try:
|
||||
img = Image.open(self.images[idx]).convert('RGB')
|
||||
img = np.array(img, dtype="float32").astype(np.uint8)
|
||||
if self._transform_ops:
|
||||
img = transform(img, self._transform_ops)
|
||||
img = img.transpose((2, 0, 1))
|
||||
return (img, self.labels[idx], self.cameras[idx])
|
||||
except Exception as ex:
|
||||
logger.error("Exception occured when parse line: {} with msg: {}".
|
||||
format(self.images[idx], ex))
|
||||
rnd_idx = np.random.randint(self.__len__())
|
||||
return self.__getitem__(rnd_idx)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
@property
|
||||
def class_num(self):
|
||||
return len(set(self.labels))
|
||||
|
||||
|
||||
def get_imagedata_info(data, labels, cameras, subfolder='train'):
|
||||
pids, cams = [], []
|
||||
for _, pid, camid in zip(data, labels, cameras):
|
||||
pids += [pid]
|
||||
cams += [camid]
|
||||
pids = set(pids)
|
||||
cams = set(cams)
|
||||
num_pids = len(pids)
|
||||
num_cams = len(cams)
|
||||
num_imgs = len(data)
|
||||
print("Dataset statistics:")
|
||||
print(" ----------------------------------------")
|
||||
print(" subset | # ids | # images | # cameras")
|
||||
print(" ----------------------------------------")
|
||||
print(" {} | {:5d} | {:8d} | {:9d}".format(subfolder, num_pids,
|
||||
num_imgs, num_cams))
|
||||
print(" ----------------------------------------")
|
||||
return num_pids, num_imgs, num_cams
|
Loading…
Reference in New Issue