add compute_mean_std tool

pull/294/head
KaiyangZhou 2019-11-27 16:49:29 +00:00
parent 307e38e519
commit 86012a4571
3 changed files with 65 additions and 5 deletions

View File

@ -0,0 +1,60 @@
import torchreid
import argparse
"""
Compute channel-wise mean and standard deviation of a dataset.
Usage:
$ python compute_mean_std.py DATASET_ROOT DATASET_KEY
- The first argument points to the root path where you put the datasets.
- The second argument means the specific dataset key.
For instance, your datasets are put under $DATA and you wanna
compute the statistics of Market1501, do
$ python compute_mean_std.py $DATA market1501
"""
def main():
parser = argparse.ArgumentParser()
parser.add_argument('root', type=str)
parser.add_argument('sources', type=str)
args = parser.parse_args()
datamanager = torchreid.data.ImageDataManager(
root=args.root,
sources=args.sources,
targets=None,
height=256,
width=128,
batch_size_train=100,
batch_size_test=100,
transforms=None,
norm_mean=[0., 0., 0.],
norm_std=[1., 1., 1.],
train_sampler='SequentialSampler'
)
train_loader = datamanager.train_loader
print('Computing mean and std ...')
mean = 0.
std = 0.
n_samples = 0.
for data in train_loader:
data = data[0]
batch_size = data.size(0)
data = data.view(batch_size, data.size(1), -1)
mean += data.mean(2).sum(0)
std += data.std(2).sum(0)
n_samples += batch_size
mean /= n_samples
std /= n_samples
print('Mean: {}'.format(mean))
print('Std: {}'.format(std))
if __name__ == '__main__':
main()

View File

@ -100,7 +100,7 @@ class ImageDataManager(DataManager):
workers (int, optional): number of workers. Default is 4.
num_instances (int, optional): number of instances per identity in a batch.
Default is 4.
train_sampler (str, optional): sampler. Default is empty (``RandomSampler``).
train_sampler (str, optional): sampler. Default is RandomSampler.
cuhk03_labeled (bool, optional): use cuhk03 labeled images.
Default is False (defaul is to use detected images).
cuhk03_classic_split (bool, optional): use the classic split in cuhk03.
@ -148,7 +148,7 @@ class ImageDataManager(DataManager):
batch_size_test=32,
workers=4,
num_instances=4,
train_sampler='',
train_sampler='RandomSampler',
cuhk03_labeled=False,
cuhk03_classic_split=False,
market1501_500k=False
@ -321,7 +321,7 @@ class VideoDataManager(DataManager):
workers (int, optional): number of workers. Default is 4.
num_instances (int, optional): number of instances per identity in a batch.
Default is 4.
train_sampler (str, optional): sampler. Default is empty (``RandomSampler``).
train_sampler (str, optional): sampler. Default is RandomSampler.
seq_len (int, optional): how many images to sample in a tracklet. Default is 15.
sample_method (str, optional): how to sample images in a tracklet. Default is "evenly".
Choices are ["evenly", "random", "all"]. "evenly" and "random" will sample ``seq_len``
@ -372,7 +372,7 @@ class VideoDataManager(DataManager):
batch_size_test=3,
workers=4,
num_instances=4,
train_sampler=None,
train_sampler='RandomSampler',
seq_len=15,
sample_method='evenly'
):

View File

@ -95,7 +95,7 @@ def build_train_sampler(data_source, train_sampler, batch_size=32, num_instances
elif train_sampler == 'SequentialSampler':
sampler = SequentialSampler(data_source)
else:
elif train_sampler == 'RandomSampler':
sampler = RandomSampler(data_source)
return sampler