mirror of https://github.com/alibaba/EasyCV.git
126 lines
4.2 KiB
Python
126 lines
4.2 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import os
|
|
import unittest
|
|
|
|
import torch
|
|
from PIL import Image
|
|
from tests.ut_config import IMG_NORM_CFG, SMALL_IMAGENET_RAW_LOCAL
|
|
|
|
from easycv.datasets.builder import build_dataset
|
|
|
|
|
|
class MultiViewDatasetTest(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
|
|
def test_multiview_dataset(self):
|
|
data_train_list = os.path.join(SMALL_IMAGENET_RAW_LOCAL,
|
|
'meta/train_200.txt')
|
|
data_train_root = SMALL_IMAGENET_RAW_LOCAL
|
|
pipeline1 = [
|
|
dict(
|
|
type='RandomResizedCrop',
|
|
size=224,
|
|
scale=(0.4, 1.),
|
|
interpolation=Image.BICUBIC),
|
|
dict(type='RandomHorizontalFlip', p=0.5),
|
|
dict(
|
|
type='RandomAppliedTrans',
|
|
transforms=[
|
|
dict(
|
|
type='ColorJitter',
|
|
brightness=0.4,
|
|
contrast=0.4,
|
|
saturation=0.2,
|
|
hue=0.1)
|
|
],
|
|
p=0.8),
|
|
dict(type='RandomGrayscale', p=0.2),
|
|
dict(
|
|
type='RandomAppliedTrans',
|
|
transforms=[dict(type='GaussianBlur', kernel_size=23)],
|
|
p=1.0),
|
|
dict(type='ToTensor'),
|
|
dict(type='Normalize', **IMG_NORM_CFG),
|
|
dict(type='Collect', keys=['img'])
|
|
]
|
|
pipeline2 = [
|
|
dict(
|
|
type='RandomResizedCrop',
|
|
size=224,
|
|
scale=(0.4, 1.),
|
|
interpolation=Image.BICUBIC),
|
|
dict(type='RandomHorizontalFlip', p=0.5),
|
|
dict(
|
|
type='RandomAppliedTrans',
|
|
transforms=[
|
|
dict(
|
|
type='ColorJitter',
|
|
brightness=0.4,
|
|
contrast=0.4,
|
|
saturation=0.2,
|
|
hue=0.1)
|
|
],
|
|
p=0.8),
|
|
dict(type='RandomGrayscale', p=0.2),
|
|
dict(
|
|
type='RandomAppliedTrans',
|
|
transforms=[dict(type='GaussianBlur', kernel_size=23)],
|
|
p=0.1),
|
|
dict(
|
|
type='RandomAppliedTrans',
|
|
transforms=[dict(type='Solarization', threshold=130)],
|
|
p=0.2),
|
|
dict(type='ToTensor'),
|
|
dict(type='Normalize', **IMG_NORM_CFG),
|
|
dict(type='Collect', keys=['img'])
|
|
]
|
|
pipeline3 = [
|
|
dict(
|
|
type='RandomResizedCrop',
|
|
size=96,
|
|
scale=(0.05, 0.4),
|
|
interpolation=Image.BICUBIC),
|
|
dict(type='RandomHorizontalFlip', p=0.5),
|
|
dict(
|
|
type='RandomAppliedTrans',
|
|
transforms=[
|
|
dict(
|
|
type='ColorJitter',
|
|
brightness=0.4,
|
|
contrast=0.4,
|
|
saturation=0.2,
|
|
hue=0.1)
|
|
],
|
|
p=0.8),
|
|
dict(type='RandomGrayscale', p=0.2),
|
|
dict(
|
|
type='RandomAppliedTrans',
|
|
transforms=[dict(type='GaussianBlur', kernel_size=23)],
|
|
p=0.5),
|
|
dict(type='ToTensor'),
|
|
dict(type='Normalize', **IMG_NORM_CFG),
|
|
dict(type='Collect', keys=['img'])
|
|
]
|
|
|
|
train_data = dict(
|
|
type='MultiViewDataset',
|
|
data_source=dict(
|
|
type='SSLSourceImageList',
|
|
list_file=data_train_list,
|
|
root=data_train_root),
|
|
num_views=[1, 1, 8],
|
|
pipelines=[pipeline1, pipeline2, pipeline3])
|
|
|
|
dataset = build_dataset(train_data)
|
|
for _, batch in enumerate(dataset):
|
|
self.assertEqual(len(batch['img']), 10)
|
|
self.assertEqual(batch['img'][0].shape, torch.Size([3, 224, 224]))
|
|
self.assertEqual(batch['img'][-1].shape, torch.Size([3, 96, 96]))
|
|
break
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|