add compute_mean_std tool
parent
307e38e519
commit
86012a4571
|
@ -0,0 +1,60 @@
|
|||
import torchreid
|
||||
import argparse
|
||||
|
||||
|
||||
"""
|
||||
Compute channel-wise mean and standard deviation of a dataset.
|
||||
|
||||
Usage:
|
||||
$ python compute_mean_std.py DATASET_ROOT DATASET_KEY
|
||||
|
||||
- The first argument points to the root path where you put the datasets.
|
||||
- The second argument means the specific dataset key.
|
||||
|
||||
For instance, your datasets are put under $DATA and you wanna
|
||||
compute the statistics of Market1501, do
|
||||
$ python compute_mean_std.py $DATA market1501
|
||||
"""
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('root', type=str)
|
||||
parser.add_argument('sources', type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
datamanager = torchreid.data.ImageDataManager(
|
||||
root=args.root,
|
||||
sources=args.sources,
|
||||
targets=None,
|
||||
height=256,
|
||||
width=128,
|
||||
batch_size_train=100,
|
||||
batch_size_test=100,
|
||||
transforms=None,
|
||||
norm_mean=[0., 0., 0.],
|
||||
norm_std=[1., 1., 1.],
|
||||
train_sampler='SequentialSampler'
|
||||
)
|
||||
train_loader = datamanager.train_loader
|
||||
|
||||
print('Computing mean and std ...')
|
||||
mean = 0.
|
||||
std = 0.
|
||||
n_samples = 0.
|
||||
for data in train_loader:
|
||||
data = data[0]
|
||||
batch_size = data.size(0)
|
||||
data = data.view(batch_size, data.size(1), -1)
|
||||
mean += data.mean(2).sum(0)
|
||||
std += data.std(2).sum(0)
|
||||
n_samples += batch_size
|
||||
|
||||
mean /= n_samples
|
||||
std /= n_samples
|
||||
print('Mean: {}'.format(mean))
|
||||
print('Std: {}'.format(std))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -100,7 +100,7 @@ class ImageDataManager(DataManager):
|
|||
workers (int, optional): number of workers. Default is 4.
|
||||
num_instances (int, optional): number of instances per identity in a batch.
|
||||
Default is 4.
|
||||
train_sampler (str, optional): sampler. Default is empty (``RandomSampler``).
|
||||
train_sampler (str, optional): sampler. Default is RandomSampler.
|
||||
cuhk03_labeled (bool, optional): use cuhk03 labeled images.
|
||||
Default is False (defaul is to use detected images).
|
||||
cuhk03_classic_split (bool, optional): use the classic split in cuhk03.
|
||||
|
@ -148,7 +148,7 @@ class ImageDataManager(DataManager):
|
|||
batch_size_test=32,
|
||||
workers=4,
|
||||
num_instances=4,
|
||||
train_sampler='',
|
||||
train_sampler='RandomSampler',
|
||||
cuhk03_labeled=False,
|
||||
cuhk03_classic_split=False,
|
||||
market1501_500k=False
|
||||
|
@ -321,7 +321,7 @@ class VideoDataManager(DataManager):
|
|||
workers (int, optional): number of workers. Default is 4.
|
||||
num_instances (int, optional): number of instances per identity in a batch.
|
||||
Default is 4.
|
||||
train_sampler (str, optional): sampler. Default is empty (``RandomSampler``).
|
||||
train_sampler (str, optional): sampler. Default is RandomSampler.
|
||||
seq_len (int, optional): how many images to sample in a tracklet. Default is 15.
|
||||
sample_method (str, optional): how to sample images in a tracklet. Default is "evenly".
|
||||
Choices are ["evenly", "random", "all"]. "evenly" and "random" will sample ``seq_len``
|
||||
|
@ -372,7 +372,7 @@ class VideoDataManager(DataManager):
|
|||
batch_size_test=3,
|
||||
workers=4,
|
||||
num_instances=4,
|
||||
train_sampler=None,
|
||||
train_sampler='RandomSampler',
|
||||
seq_len=15,
|
||||
sample_method='evenly'
|
||||
):
|
||||
|
|
|
@ -95,7 +95,7 @@ def build_train_sampler(data_source, train_sampler, batch_size=32, num_instances
|
|||
elif train_sampler == 'SequentialSampler':
|
||||
sampler = SequentialSampler(data_source)
|
||||
|
||||
else:
|
||||
elif train_sampler == 'RandomSampler':
|
||||
sampler = RandomSampler(data_source)
|
||||
|
||||
return sampler
|
Loading…
Reference in New Issue