From 86012a4571d327b7211e0c5191d3e97564ba75ac Mon Sep 17 00:00:00 2001 From: KaiyangZhou Date: Wed, 27 Nov 2019 16:49:29 +0000 Subject: [PATCH] add compute_mean_std tool --- tools/compute_mean_std.py | 60 +++++++++++++++++++++++++++++++++++ torchreid/data/datamanager.py | 8 ++--- torchreid/data/sampler.py | 2 +- 3 files changed, 65 insertions(+), 5 deletions(-) create mode 100644 tools/compute_mean_std.py diff --git a/tools/compute_mean_std.py b/tools/compute_mean_std.py new file mode 100644 index 0000000..e805113 --- /dev/null +++ b/tools/compute_mean_std.py @@ -0,0 +1,60 @@ +import torchreid +import argparse + + +""" +Compute channel-wise mean and standard deviation of a dataset. + +Usage: +$ python compute_mean_std.py DATASET_ROOT DATASET_KEY + +- The first argument points to the root path where you put the datasets. +- The second argument means the specific dataset key. + +For instance, your datasets are put under $DATA and you wanna +compute the statistics of Market1501, do +$ python compute_mean_std.py $DATA market1501 +""" + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('root', type=str) + parser.add_argument('sources', type=str) + args = parser.parse_args() + + datamanager = torchreid.data.ImageDataManager( + root=args.root, + sources=args.sources, + targets=None, + height=256, + width=128, + batch_size_train=100, + batch_size_test=100, + transforms=None, + norm_mean=[0., 0., 0.], + norm_std=[1., 1., 1.], + train_sampler='SequentialSampler' + ) + train_loader = datamanager.train_loader + + print('Computing mean and std ...') + mean = 0. + std = 0. + n_samples = 0. + for data in train_loader: + data = data[0] + batch_size = data.size(0) + data = data.view(batch_size, data.size(1), -1) + mean += data.mean(2).sum(0) + std += data.std(2).sum(0) + n_samples += batch_size + + mean /= n_samples + std /= n_samples + print('Mean: {}'.format(mean)) + print('Std: {}'.format(std)) + + +if __name__ == '__main__': + main() diff --git a/torchreid/data/datamanager.py b/torchreid/data/datamanager.py index c97e107..d5d377c 100644 --- a/torchreid/data/datamanager.py +++ b/torchreid/data/datamanager.py @@ -100,7 +100,7 @@ class ImageDataManager(DataManager): workers (int, optional): number of workers. Default is 4. num_instances (int, optional): number of instances per identity in a batch. Default is 4. - train_sampler (str, optional): sampler. Default is empty (``RandomSampler``). + train_sampler (str, optional): sampler. Default is RandomSampler. cuhk03_labeled (bool, optional): use cuhk03 labeled images. Default is False (defaul is to use detected images). cuhk03_classic_split (bool, optional): use the classic split in cuhk03. @@ -148,7 +148,7 @@ class ImageDataManager(DataManager): batch_size_test=32, workers=4, num_instances=4, - train_sampler='', + train_sampler='RandomSampler', cuhk03_labeled=False, cuhk03_classic_split=False, market1501_500k=False @@ -321,7 +321,7 @@ class VideoDataManager(DataManager): workers (int, optional): number of workers. Default is 4. num_instances (int, optional): number of instances per identity in a batch. Default is 4. - train_sampler (str, optional): sampler. Default is empty (``RandomSampler``). + train_sampler (str, optional): sampler. Default is RandomSampler. seq_len (int, optional): how many images to sample in a tracklet. Default is 15. sample_method (str, optional): how to sample images in a tracklet. Default is "evenly". Choices are ["evenly", "random", "all"]. "evenly" and "random" will sample ``seq_len`` @@ -372,7 +372,7 @@ class VideoDataManager(DataManager): batch_size_test=3, workers=4, num_instances=4, - train_sampler=None, + train_sampler='RandomSampler', seq_len=15, sample_method='evenly' ): diff --git a/torchreid/data/sampler.py b/torchreid/data/sampler.py index 773f07d..42488a2 100644 --- a/torchreid/data/sampler.py +++ b/torchreid/data/sampler.py @@ -95,7 +95,7 @@ def build_train_sampler(data_source, train_sampler, batch_size=32, num_instances elif train_sampler == 'SequentialSampler': sampler = SequentialSampler(data_source) - else: + elif train_sampler == 'RandomSampler': sampler = RandomSampler(data_source) return sampler \ No newline at end of file