Add commonly used datasets

pull/2/head
yanglei 2020-07-01 16:09:06 +08:00 committed by chenkai
parent a48ffaaa0a
commit ba789da5c8
17 changed files with 1405 additions and 16 deletions

View File

@ -1,9 +1,15 @@
from .base_dataset import BaseDataset
from .builder import build_dataloader, build_dataset
from .pipelines import Compose
from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
from .cifar import CIFAR10, CIFAR100
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
RepeatDataset)
from .imagenet import ImageNet
from .mnist import MNIST, FashionMNIST
from .samplers import DistributedSampler
__all__ = [
'BaseDataset', 'build_dataloader', 'build_dataset', 'Compose',
'DistributedSampler'
'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST',
'build_dataloader', 'build_dataset', 'Compose', 'DistributedSampler',
'ConcatDataset', 'RepeatDataset', 'ClassBalancedDataset', 'DATASETS',
'PIPELINES'
]

View File

@ -1,14 +1,26 @@
import copy
from abc import ABCMeta, abstractmethod
import numpy as np
from torch.utils.data import Dataset
from .pipelines import Compose
class BaseDataset(Dataset, metaclass=ABCMeta):
"""Base dataset.
def __init__(self, ann_file, pipeline, data_prefix, test_mode):
Args:
data_prefix (str): the prefix of data path
pipeline (list): a list of dict, where each element represents
a operation defined in `mmcls.datasets.pipelines`
ann_file (str | None): the annotation file. When ann_file is str,
the subclass is expected to read from the ann_file. When ann_file
is None, the subclass is expected to read according to data_prefix
test_mode (bool): in train mode or test mode
"""
def __init__(self, data_prefix, pipeline, ann_file=None, test_mode=False):
super(BaseDataset, self).__init__()
self.ann_file = ann_file
@ -21,11 +33,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
def load_annotations(self):
pass
def prepare_train_data(self, idx):
results = copy.deepcopy(self.data_infos[idx])
return self.pipeline(results)
def prepare_test_data(self, idx):
def prepare_data(self, idx):
results = copy.deepcopy(self.data_infos[idx])
return self.pipeline(results)
@ -33,7 +41,35 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
return len(self.data_infos)
def __getitem__(self, idx):
if self.test_mode:
return self.prepare_train_data(idx)
else:
return self.prepare_test_data(idx)
return self.prepare_data(idx)
def evaluate(self, results, metric='accuracy', logger=None):
"""Evaluate the dataset.
Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
Default value is `accuracy`.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None.
Returns:
dict: evaluation results
"""
if not isinstance(metric, str):
assert len(metric) == 1
metric = metric[0]
allowed_metrics = ['accuracy']
if metric not in allowed_metrics:
raise KeyError(f'metric {metric} is not supported')
eval_results = {}
if metric == 'accuracy':
nums = []
for result in results:
nums.append(result['num_samples'].item())
for topk, v in result['accuracy'].items():
if topk not in eval_results:
eval_results[topk] = []
eval_results[topk].append(v.item())
for topk, accs in eval_results.items():
eval_results[topk] = np.average(accs, weights=nums)
return eval_results

View File

@ -23,12 +23,16 @@ PIPELINES = Registry('pipeline')
def build_dataset(cfg, default_args=None):
from .dataset_wrappers import ConcatDataset, RepeatDataset
from .dataset_wrappers import (ConcatDataset, RepeatDataset,
ClassBalancedDataset)
if isinstance(cfg, (list, tuple)):
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
elif cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(
build_dataset(cfg['dataset'], default_args), cfg['times'])
elif cfg['type'] == 'ClassBalancedDataset':
dataset = ClassBalancedDataset(
build_dataset(cfg['dataset'], default_args), cfg['oversample_thr'])
else:
dataset = build_from_cfg(cfg, DATASETS, default_args)
@ -85,6 +89,7 @@ def build_dataloader(dataset,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
pin_memory=False,
shuffle=shuffle,
worker_init_fn=init_fn,
**kwargs)

View File

@ -0,0 +1,127 @@
import os
import os.path
import pickle
import numpy as np
from .base_dataset import BaseDataset
from .builder import DATASETS
from .utils import check_integrity, download_and_extract_archive
@DATASETS.register_module()
class CIFAR10(BaseDataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
This implementation is modified from
https://github.com/pytorch/vision/blob/master/torchvision/datasets/cifar.py # noqa: E501
"""
base_folder = 'cifar-10-batches-py'
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = "cifar-10-python.tar.gz"
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
train_list = [
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
['data_batch_4', '634d18415352ddfa80567beed471001a'],
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
]
test_list = [
['test_batch', '40351d587109b95175f43aff81a1287e'],
]
meta = {
'filename': 'batches.meta',
'key': 'label_names',
'md5': '5ff9c542aee3614f3951f8cda6e48888',
}
def load_annotations(self):
if not self._check_integrity():
download_and_extract_archive(
self.url,
self.data_prefix,
filename=self.filename,
md5=self.tgz_md5)
if not self.test_mode:
downloaded_list = self.train_list
else:
downloaded_list = self.test_list
self.imgs = []
self.gt_labels = []
# load the picked numpy arrays
for file_name, checksum in downloaded_list:
file_path = os.path.join(self.data_prefix, self.base_folder,
file_name)
with open(file_path, 'rb') as f:
entry = pickle.load(f, encoding='latin1')
self.imgs.append(entry['data'])
if 'labels' in entry:
self.gt_labels.extend(entry['labels'])
else:
self.gt_labels.extend(entry['fine_labels'])
self.imgs = np.vstack(self.imgs).reshape(-1, 3, 32, 32)
self.imgs = self.imgs.transpose((0, 2, 3, 1)) # convert to HWC
self._load_meta()
data_infos = []
for img, gt_label in zip(self.imgs, self.gt_labels):
gt_label = np.array(gt_label, dtype=np.int64)
info = {'img': img, 'gt_labels': gt_label}
data_infos.append(info)
return data_infos
def _load_meta(self):
path = os.path.join(self.data_prefix, self.base_folder,
self.meta['filename'])
if not check_integrity(path, self.meta['md5']):
raise RuntimeError(
'Dataset metadata file not found or corrupted.' +
' You can use download=True to download it')
with open(path, 'rb') as infile:
data = pickle.load(infile, encoding='latin1')
self.classes = data[self.meta['key']]
self.class_to_idx = {
_class: i
for i, _class in enumerate(self.classes)
}
def _check_integrity(self):
root = self.data_prefix
for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, self.base_folder, filename)
if not check_integrity(fpath, md5):
return False
return True
@DATASETS.register_module()
class CIFAR100(CIFAR10):
"""`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
"""
base_folder = 'cifar-100-python'
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
filename = "cifar-100-python.tar.gz"
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
train_list = [
['train', '16019d7e3df5f24257cddd939b257f8d'],
]
test_list = [
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
]
meta = {
'filename': 'meta',
'key': 'fine_label_names',
'md5': '7973b15100ade9c7d40fb424638fde48',
}

View File

@ -0,0 +1,159 @@
import bisect
import math
from collections import defaultdict
import numpy as np
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
from .builder import DATASETS
@DATASETS.register_module()
class ConcatDataset(_ConcatDataset):
"""A wrapper of concatenated dataset.
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
add `get_cat_ids` function.
Args:
datasets (list[:obj:`Dataset`]): A list of datasets.
"""
def __init__(self, datasets):
super(ConcatDataset, self).__init__(datasets)
def get_cat_ids(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError(
'absolute value of index should not exceed dataset length')
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx].get_cat_ids(sample_idx)
@DATASETS.register_module()
class RepeatDataset(object):
"""A wrapper of repeated dataset.
The length of repeated dataset will be `times` larger than the original
dataset. This is useful when the data loading time is long but the dataset
is small. Using RepeatDataset can reduce the data loading time between
epochs.
Args:
dataset (:obj:`Dataset`): The dataset to be repeated.
times (int): Repeat times.
"""
def __init__(self, dataset, times):
self.dataset = dataset
self.times = times
self._ori_len = len(self.dataset)
def __getitem__(self, idx):
return self.dataset[idx % self._ori_len]
def get_cat_ids(self, idx):
return self.dataset.get_cat_ids(idx % self._ori_len)
def __len__(self):
return self.times * self._ori_len
# Modified from https://github.com/facebookresearch/detectron2/blob/41d475b75a230221e21d9cac5d69655e3415e3a4/detectron2/data/samplers/distributed_sampler.py#L57 # noqa
@DATASETS.register_module()
class ClassBalancedDataset(object):
"""A wrapper of repeated dataset with repeat factor.
Suitable for training on class imbalanced datasets like LVIS. Following
the sampling strategy in [1], in each epoch, an image may appear multiple
times based on its "repeat factor".
The repeat factor for an image is a function of the frequency the rarest
category labeled in that image. The "frequency of category c" in [0, 1]
is defined by the fraction of images in the training set (without repeats)
in which category c appears.
The dataset needs to instantiate :func:`self.get_cat_ids(idx)` to support
ClassBalancedDataset.
The repeat factor is computed as followed.
1. For each category c, compute the fraction # of images
that contain it: f(c)
2. For each category c, compute the category-level repeat factor:
r(c) = max(1, sqrt(t/f(c)))
3. For each image I and its labels L(I), compute the image-level repeat
factor:
r(I) = max_{c in L(I)} r(c)
References:
.. [1] https://arxiv.org/pdf/1908.03195.pdf
Args:
dataset (:obj:`CustomDataset`): The dataset to be repeated.
oversample_thr (float): frequency threshold below which data is
repeated. For categories with `f_c` >= `oversample_thr`, there is
no oversampling. For categories with `f_c` < `oversample_thr`, the
degree of oversampling following the square-root inverse frequency
heuristic above.
"""
def __init__(self, dataset, oversample_thr):
self.dataset = dataset
self.oversample_thr = oversample_thr
repeat_factors = self._get_repeat_factors(dataset, oversample_thr)
repeat_indices = []
for dataset_index, repeat_factor in enumerate(repeat_factors):
repeat_indices.extend([dataset_index] * math.ceil(repeat_factor))
self.repeat_indices = repeat_indices
flags = []
if hasattr(self.dataset, 'flag'):
for flag, repeat_factor in zip(self.dataset.flag, repeat_factors):
flags.extend([flag] * int(math.ceil(repeat_factor)))
assert len(flags) == len(repeat_indices)
self.flag = np.asarray(flags, dtype=np.uint8)
def _get_repeat_factors(self, dataset, repeat_thr):
# 1. For each category c, compute the fraction # of images
# that contain it: f(c)
category_freq = defaultdict(int)
num_images = len(dataset)
for idx in range(num_images):
cat_ids = set(self.dataset.get_cat_ids(idx))
for cat_id in cat_ids:
category_freq[cat_id] += 1
for k, v in category_freq.items():
assert v > 0, f'caterogy {k} does not contain any images'
category_freq[k] = v / num_images
# 2. For each category c, compute the category-level repeat factor:
# r(c) = max(1, sqrt(t/f(c)))
category_repeat = {
cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
for cat_id, cat_freq in category_freq.items()
}
# 3. For each image I and its labels L(I), compute the image-level
# repeat factor:
# r(I) = max_{c in L(I)} r(c)
repeat_factors = []
for idx in range(num_images):
cat_ids = set(self.dataset.get_cat_ids(idx))
repeat_factor = max(
{category_repeat[cat_id]
for cat_id in cat_ids})
repeat_factors.append(repeat_factor)
return repeat_factors
def __getitem__(self, idx):
ori_index = self.repeat_indices[idx]
return self.dataset[ori_index]
def __len__(self):
return len(self.repeat_indices)

View File

@ -0,0 +1,104 @@
import os
import numpy as np
from .base_dataset import BaseDataset
from .builder import DATASETS
@DATASETS.register_module()
class ImageNet(BaseDataset):
"""`ImageNet <http://www.image-net.org>`_ Dataset.
This implementation is modified from
https://github.com/pytorch/vision/blob/master/torchvision/datasets/imagenet.py # noqa: E501
"""
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
def load_annotations(self):
if self.ann_file is None:
classes, class_to_idx = find_classes(self.data_prefix)
samples = make_dataset(
self.data_prefix, class_to_idx, extensions=self.IMG_EXTENSIONS)
if len(samples) == 0:
raise (RuntimeError('Found 0 files in subfolders of: '
f'{self.data_prefix}. '
'Supported extensions are: '
f'{",".join(self.IMG_EXTENSIONS)}'))
self.classes = classes
self.class_to_idx = class_to_idx
elif isinstance(self.ann_file, str):
with open(self.ann_file) as f:
samples = [x.strip().split(' ') for x in f.readlines()]
else:
raise TypeError('ann_file must be a str or None')
self.samples = samples
data_infos = []
for filename, gt_label in self.samples:
info = {'img_prefix': self.data_prefix}
info['img_info'] = {'filename': filename}
info['gt_labels'] = np.array(gt_label, dtype=np.int64)
data_infos.append(info)
return data_infos
def has_file_allowed_extension(filename, extensions):
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
filename_lower = filename.lower()
return any(filename_lower.endswith(ext) for ext in extensions)
def find_classes(root):
"""Find classes by folders under a root.
Args:
root (string): root directory of folders
Returns:
classes (list): a list of class names
class_to_idx (dict): the map from class name to class idx
"""
classes = [
d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))
]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def make_dataset(root, class_to_idx, extensions):
"""Make dataset by walking all images under a root.
Args:
root (string): root directory of folders
class_to_idx (dict): the map from class name to class idx
extensions (tuple): allowed extensions
Returns:
images (list): a list of tuple where each element is (image, label)
"""
images = []
root = os.path.expanduser(root)
for class_name in sorted(os.listdir(root)):
_dir = os.path.join(root, class_name)
if not os.path.isdir(_dir):
continue
for _, _, fns in sorted(os.walk(_dir)):
for fn in sorted(fns):
if has_file_allowed_extension(fn, extensions):
path = os.path.join(class_name, fn)
item = (path, class_to_idx[class_name])
images.append(item)
return images

View File

@ -0,0 +1,175 @@
import codecs
import os
import os.path as osp
import numpy as np
import torch
from .base_dataset import BaseDataset
from .builder import DATASETS
from .utils import download_and_extract_archive, rm_suffix
@DATASETS.register_module()
class MNIST(BaseDataset):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
This implementation is modified from
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py # noqa: E501
"""
resource_prefix = 'http://yann.lecun.com/exdb/mnist/'
resources = {
'train_image_file':
('train-images-idx3-ubyte.gz', 'f68b3c2dcbeaaa9fbdd348bbdeb94873'),
'train_label_file':
('train-labels-idx1-ubyte.gz', 'd53e105ee54ea40749a09fcbcd1e9432'),
'test_image_file':
('t10k-images-idx3-ubyte.gz', '9fb629c4189551a2d022fa330f9573f3'),
'test_label_file':
('t10k-labels-idx1-ubyte.gz', 'ec29112dd5afa0611ce80d1b7f02629c')
}
classes = [
'0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five',
'6 - six', '7 - seven', '8 - eight', '9 - nine'
]
@property
def class_to_idx(self):
return {_class: i for i, _class in enumerate(self.classes)}
def load_annotations(self):
train_image_file = osp.join(
self.data_prefix, rm_suffix(self.resources['train_image_file'][0]))
train_label_file = osp.join(
self.data_prefix, rm_suffix(self.resources['train_label_file'][0]))
test_image_file = osp.join(
self.data_prefix, rm_suffix(self.resources['test_image_file'][0]))
test_label_file = osp.join(
self.data_prefix, rm_suffix(self.resources['test_label_file'][0]))
if not osp.exists(train_image_file) or not osp.exists(
train_label_file) or not osp.exists(
test_image_file) or not osp.exists(test_label_file):
self.download()
train_set = (read_image_file(train_image_file),
read_label_file(train_label_file))
test_set = (read_image_file(test_image_file),
read_label_file(test_label_file))
if not self.test_mode:
imgs, gt_labels = train_set
else:
imgs, gt_labels = test_set
data_infos = []
for img, gt_label in zip(imgs, gt_labels):
gt_label = np.array(gt_label, dtype=np.int64)
info = {'img': img.numpy(), 'gt_labels': gt_label}
data_infos.append(info)
return data_infos
def download(self):
os.makedirs(self.data_prefix, exist_ok=True)
# download files
for url, md5 in self.resources.values():
url = osp.join(self.resource_prefix, url)
filename = url.rpartition('/')[2]
download_and_extract_archive(
url,
download_root=self.data_prefix,
filename=filename,
md5=md5)
@DATASETS.register_module()
class FashionMNIST(MNIST):
"""`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_
Dataset.
"""
resource_prefix = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/' # noqa: E501
resources = {
'train_image_file':
('train-images-idx3-ubyte.gz', '8d4fb7e6c68d591d4c3dfef9ec88bf0d'),
'train_label_file':
('train-labels-idx1-ubyte.gz', '25c81989df183df01b3e8a0aad5dffbe'),
'test_image_file':
('t10k-images-idx3-ubyte.gz', 'bef4ecab320f06d8554ea6380940ec79'),
'test_label_file':
('t10k-labels-idx1-ubyte.gz', 'bb300cfdad3c16e7a12a480ee83cd310')
}
classes = [
'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
]
def get_int(b):
return int(codecs.encode(b, 'hex'), 16)
def open_maybe_compressed_file(path):
"""Return a file object that possibly decompresses 'path' on the fly.
Decompression occurs when argument `path` is a string
and ends with '.gz' or '.xz'.
"""
if not isinstance(path, torch._six.string_classes):
return path
if path.endswith('.gz'):
import gzip
return gzip.open(path, 'rb')
if path.endswith('.xz'):
import lzma
return lzma.open(path, 'rb')
return open(path, 'rb')
def read_sn3_pascalvincent_tensor(path, strict=True):
"""Read a SN3 file in "Pascal Vincent" format
(Lush file 'libidx/idx-io.lsh').
Argument may be a filename, compressed filename, or file object.
"""
# typemap
if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'):
read_sn3_pascalvincent_tensor.typemap = {
8: (torch.uint8, np.uint8, np.uint8),
9: (torch.int8, np.int8, np.int8),
11: (torch.int16, np.dtype('>i2'), 'i2'),
12: (torch.int32, np.dtype('>i4'), 'i4'),
13: (torch.float32, np.dtype('>f4'), 'f4'),
14: (torch.float64, np.dtype('>f8'), 'f8')
}
# read
with open_maybe_compressed_file(path) as f:
data = f.read()
# parse
magic = get_int(data[0:4])
nd = magic % 256
ty = magic // 256
assert nd >= 1 and nd <= 3
assert ty >= 8 and ty <= 14
m = read_sn3_pascalvincent_tensor.typemap[ty]
s = [get_int(data[4 * (i + 1):4 * (i + 2)]) for i in range(nd)]
parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
assert parsed.shape[0] == np.prod(s) or not strict
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
def read_label_file(path):
with open(path, 'rb') as f:
x = read_sn3_pascalvincent_tensor(f, strict=False)
assert (x.dtype == torch.uint8)
assert (x.ndimension() == 1)
return x.long()
def read_image_file(path):
with open(path, 'rb') as f:
x = read_sn3_pascalvincent_tensor(f, strict=False)
assert (x.dtype == torch.uint8)
assert (x.ndimension() == 3)
return x

View File

@ -1,3 +1,13 @@
from .compose import Compose
from .formating import (Collect, ImageToTensor, ToDataContainer, ToNumpy,
ToPIL, ToTensor, Transpose, to_tensor)
from .loading import LoadImageFromFile
from .transforms import (CenterCrop, Normalize, RandomHorizontalFlip,
RandomResizedCrop, Resize)
__all__ = ['Compose']
__all__ = [
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
'ToPIL', 'ToNumpy', 'Transpose', 'Collect', 'LoadImageFromFile',
'RandomResizedCrop', 'RandomHorizontalFlip', 'Resize', 'CenterCrop',
'Normalize'
]

View File

@ -0,0 +1,198 @@
from collections.abc import Sequence
import mmcv
import numpy as np
import torch
from mmcv.parallel import DataContainer as DC
from PIL import Image
from ..builder import PIPELINES
def to_tensor(data):
"""Convert objects of various python types to :obj:`torch.Tensor`.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
"""
if isinstance(data, torch.Tensor):
return data
elif isinstance(data, np.ndarray):
return torch.from_numpy(data)
elif isinstance(data, Sequence) and not mmcv.is_str(data):
return torch.tensor(data)
elif isinstance(data, int):
return torch.LongTensor([data])
elif isinstance(data, float):
return torch.FloatTensor([data])
else:
raise TypeError(
f'Type {type(data)} cannot be converted to tensor.'
'Supported types are: `numpy.ndarray`, `torch.Tensor`, '
'`Sequence`, `int` and `float`')
@PIPELINES.register_module()
class ToTensor(object):
def __init__(self, keys):
self.keys = keys
def __call__(self, results):
for key in self.keys:
results[key] = to_tensor(results[key])
return results
def __repr__(self):
return self.__class__.__name__ + f'(keys={self.keys})'
@PIPELINES.register_module()
class ImageToTensor(object):
def __init__(self, keys):
self.keys = keys
def __call__(self, results):
for key in self.keys:
img = results[key]
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
results[key] = to_tensor(img.transpose(2, 0, 1))
return results
def __repr__(self):
return self.__class__.__name__ + f'(keys={self.keys})'
@PIPELINES.register_module()
class Transpose(object):
def __init__(self, keys, order):
self.keys = keys
self.order = order
def __call__(self, results):
for key in self.keys:
results[key] = results[key].transpose(self.order)
return results
def __repr__(self):
return self.__class__.__name__ + \
f'(keys={self.keys}, order={self.order})'
@PIPELINES.register_module()
class ToDataContainer(object):
def __init__(self,
fields=(dict(key='img', stack=True), dict(key='gt_labels'))):
self.fields = fields
def __call__(self, results):
for field in self.fields:
field = field.copy()
key = field.pop('key')
results[key] = DC(results[key], **field)
return results
def __repr__(self):
return self.__class__.__name__ + f'(fields={self.fields})'
@PIPELINES.register_module()
class ToPIL(object):
def __init__(self):
pass
def __call__(self, results):
results['img'] = Image.fromarray(results['img'])
return results
@PIPELINES.register_module()
class ToNumpy(object):
def __init__(self):
pass
def __call__(self, results):
results['img'] = np.array(results['img'], dtype=np.float32)
# results['img'] = np.array(results['img'])
return results
@PIPELINES.register_module()
class Collect(object):
"""
Collect data from the loader relevant to the specific task.
This is usually the last stage of the data loader pipeline. Typically keys
is set to some subset of "img" and "gt_labels".
The "img_meta" item is always populated. The contents of the "img_meta"
dictionary depends on "meta_keys". By default this includes:
- "img_shape": shape of the image input to the network as a tuple
(h, w, c). Note that images may be zero padded on the bottom/right
if the batch tensor is larger than this shape.
- "filename": path to the image file
- "ori_shape": original shape of the image as a tuple (h, w, c)
- "img_norm_cfg": a dict of normalization information:
- mean - per channel mean subtraction
- std - per channel std divisor
- to_rgb - bool indicating if bgr was converted to rgb
"""
def __init__(self, keys, meta_keys=('filename', 'ori_shape', 'img_shape')):
self.keys = keys
self.meta_keys = meta_keys
def __call__(self, results):
data = {}
img_meta = {}
for key in self.meta_keys:
img_meta[key] = results[key]
# data['img_metas'] = DC(img_meta, cpu_only=True)
# data['img_metas'] = img_meta
for key in self.keys:
data[key] = results[key]
return data
def __repr__(self):
return self.__class__.__name__ + \
f'(keys={self.keys}, meta_keys={self.meta_keys})'
@PIPELINES.register_module()
class WrapFieldsToLists(object):
"""Wrap fields of the data dictionary into lists for evaluation.
This class can be used as a last step of a test or validation
pipeline for single image evaluation or inference.
Example:
>>> test_pipeline = [
>>> dict(type='LoadImageFromFile'),
>>> dict(type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
>>> dict(type='ImageToTensor', keys=['img']),
>>> dict(type='Collect', keys=['img']),
>>> dict(type='WrapIntoLists')
>>> ]
"""
def __call__(self, results):
# Wrap dict fields into lists
for key, val in results.items():
results[key] = [val]
return results
def __repr__(self):
return f'{self.__class__.__name__}()'

View File

@ -0,0 +1,68 @@
import os.path as osp
import mmcv
import numpy as np
from ..builder import PIPELINES
@PIPELINES.register_module()
class LoadImageFromFile(object):
"""Load an image from file.
Required keys are "img_prefix" and "img_info" (a dict that must contain the
key "filename"). Added or updated keys are "filename", "img", "img_shape",
"ori_shape" (same as `img_shape`) and "img_norm_cfg" (means=0 and stds=1).
Args:
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
color_type (str): The flag argument for :func:`mmcv.imfrombytes()`.
Defaults to 'color'.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details.
Defaults to ``dict(backend='disk')``.
"""
def __init__(self,
to_float32=False,
color_type='color',
file_client_args=dict(backend='disk')):
self.to_float32 = to_float32
self.color_type = color_type
self.file_client_args = file_client_args.copy()
self.file_client = None
def __call__(self, results):
if self.file_client is None:
self.file_client = mmcv.FileClient(**self.file_client_args)
if results['img_prefix'] is not None:
filename = osp.join(results['img_prefix'],
results['img_info']['filename'])
else:
filename = results['img_info']['filename']
img_bytes = self.file_client.get(filename)
img = mmcv.imfrombytes(img_bytes, flag=self.color_type)
if self.to_float32:
img = img.astype(np.float32)
results['filename'] = filename
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
num_channels = 1 if len(img.shape) < 3 else img.shape[2]
results['img_norm_cfg'] = dict(
mean=np.zeros(num_channels, dtype=np.float32),
std=np.ones(num_channels, dtype=np.float32),
to_rgb=False)
return results
def __repr__(self):
repr_str = (f'{self.__class__.__name__}('
f'to_float32={self.to_float32}, '
f"color_type='{self.color_type}', "
f'file_client_args={self.file_client_args})')
return repr_str

View File

@ -0,0 +1,115 @@
import mmcv
import numpy as np
from torchvision import transforms
from ..builder import PIPELINES
@PIPELINES.register_module()
class RandomCrop(transforms.RandomCrop):
"""
"""
def __init__(self, *args, **kwargs):
super(RandomCrop, self).__init__(*args, **kwargs)
def __call__(self, results):
results['img'] = super(RandomCrop, self).__call__(results['img'])
return results
@PIPELINES.register_module()
class RandomResizedCrop(transforms.RandomResizedCrop):
"""
"""
def __init__(self, *args, **kwargs):
super(RandomResizedCrop, self).__init__(*args, **kwargs)
def __call__(self, results):
results['img'] = super(RandomResizedCrop,
self).__call__(results['img'])
return results
@PIPELINES.register_module()
class RandomHorizontalFlip(transforms.RandomHorizontalFlip):
"""
"""
def __init__(self, *args, **kwargs):
super(RandomHorizontalFlip, self).__init__(*args, **kwargs)
def __call__(self, results):
results['img'] = super(RandomHorizontalFlip,
self).__call__(results['img'])
return results
@PIPELINES.register_module()
class Resize(transforms.Resize):
"""
"""
def __init__(self, *args, **kwargs):
super(Resize, self).__init__(*args, **kwargs)
def __call__(self, results):
results['img'] = super(Resize, self).__call__(results['img'])
return results
@PIPELINES.register_module()
class CenterCrop(transforms.CenterCrop):
"""
"""
def __init__(self, *args, **kwargs):
super(CenterCrop, self).__init__(*args, **kwargs)
def __call__(self, results):
results['img'] = super(CenterCrop, self).__call__(results['img'])
return results
@PIPELINES.register_module()
class ColorJitter(transforms.ColorJitter):
"""
"""
def __init__(self, *args, **kwargs):
super(ColorJitter, self).__init__(*args, **kwargs)
def __call__(self, results):
results['img'] = super(ColorJitter, self).__call__(results['img'])
return results
@PIPELINES.register_module()
class Normalize(object):
"""Normalize the image.
Args:
mean (sequence): Mean values of 3 channels.
std (sequence): Std values of 3 channels.
to_rgb (bool): Whether to convert the image from BGR to RGB,
default is true.
"""
def __init__(self, mean, std, to_rgb=True):
self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32)
self.to_rgb = to_rgb
def __call__(self, results):
for key in results.get('img_fields', ['img']):
results[key] = mmcv.imnormalize(results[key], self.mean, self.std,
self.to_rgb)
results['img_norm_cfg'] = dict(
mean=self.mean, std=self.std, to_rgb=self.to_rgb)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})'
return repr_str

View File

@ -0,0 +1,163 @@
import gzip
import hashlib
import os
import os.path
import tarfile
import zipfile
from tqdm import tqdm
__all__ = ['rm_suffix', 'check_integrity', 'download_and_extract_archive']
def rm_suffix(s, suffix=None):
if suffix is None:
return s[:s.rfind('.')]
else:
return s[:s.rfind(suffix)]
def gen_bar_updater():
pbar = tqdm(total=None)
def bar_update(count, block_size, total_size):
if pbar.total is None and total_size:
pbar.total = total_size
progress_bytes = count * block_size
pbar.update(progress_bytes - pbar.n)
return bar_update
def calculate_md5(fpath, chunk_size=1024 * 1024):
md5 = hashlib.md5()
with open(fpath, 'rb') as f:
for chunk in iter(lambda: f.read(chunk_size), b''):
md5.update(chunk)
return md5.hexdigest()
def check_md5(fpath, md5, **kwargs):
return md5 == calculate_md5(fpath, **kwargs)
def check_integrity(fpath, md5=None):
if not os.path.isfile(fpath):
return False
if md5 is None:
return True
return check_md5(fpath, md5)
def download_url(url, root, filename=None, md5=None):
"""Download a file from a url and place it in root.
Args:
url (str): URL to download file from.
root (str): Directory to place downloaded file in.
filename (str | None): Name to save the file under.
If filename is None, use the basename of the URL.
md5 (str | None): MD5 checksum of the download.
If md5 is None, download without md5 check.
"""
import urllib
root = os.path.expanduser(root)
if not filename:
filename = os.path.basename(url)
fpath = os.path.join(root, filename)
os.makedirs(root, exist_ok=True)
# check if file is already present locally
if check_integrity(fpath, md5):
print(f'Using downloaded and verified file: {fpath}')
else: # download the file
try:
print('Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(
url, fpath, reporthook=gen_bar_updater())
except (urllib.error.URLError, IOError) as e:
if url[:5] == 'https':
url = url.replace('https:', 'http:')
print('Failed download. Trying https -> http instead.'
f' Downloading {url} to {fpath}')
urllib.request.urlretrieve(
url, fpath, reporthook=gen_bar_updater())
else:
raise e
# check integrity of downloaded file
if not check_integrity(fpath, md5):
raise RuntimeError('File not found or corrupted.')
def _is_tarxz(filename):
return filename.endswith('.tar.xz')
def _is_tar(filename):
return filename.endswith('.tar')
def _is_targz(filename):
return filename.endswith('.tar.gz')
def _is_tgz(filename):
return filename.endswith('.tgz')
def _is_gzip(filename):
return filename.endswith('.gz') and not filename.endswith('.tar.gz')
def _is_zip(filename):
return filename.endswith('.zip')
def extract_archive(from_path, to_path=None, remove_finished=False):
if to_path is None:
to_path = os.path.dirname(from_path)
if _is_tar(from_path):
with tarfile.open(from_path, 'r') as tar:
tar.extractall(path=to_path)
elif _is_targz(from_path) or _is_tgz(from_path):
with tarfile.open(from_path, 'r:gz') as tar:
tar.extractall(path=to_path)
elif _is_tarxz(from_path):
with tarfile.open(from_path, 'r:xz') as tar:
tar.extractall(path=to_path)
elif _is_gzip(from_path):
to_path = os.path.join(
to_path,
os.path.splitext(os.path.basename(from_path))[0])
with open(to_path, 'wb') as out_f, gzip.GzipFile(from_path) as zip_f:
out_f.write(zip_f.read())
elif _is_zip(from_path):
with zipfile.ZipFile(from_path, 'r') as z:
z.extractall(to_path)
else:
raise ValueError(f'Extraction of {from_path} not supported')
if remove_finished:
os.remove(from_path)
def download_and_extract_archive(url,
download_root,
extract_root=None,
filename=None,
md5=None,
remove_finished=False):
download_root = os.path.expanduser(download_root)
if extract_root is None:
extract_root = download_root
if not filename:
filename = os.path.basename(url)
download_url(url, download_root, filename, md5)
archive = os.path.join(download_root, filename)
print(f'Extracting {archive} to {extract_root}')
extract_archive(archive, extract_root, remove_finished)

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

BIN
tests/data/gray.jpg 100644

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

View File

@ -0,0 +1,122 @@
import bisect
import math
import random
import string
import tempfile
from collections import defaultdict
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from mmcls.datasets import (DATASETS, BaseDataset, ClassBalancedDataset,
ConcatDataset, RepeatDataset)
from mmcls.datasets.utils import (check_integrity,
download_and_extract_archive, rm_suffix)
@pytest.mark.parametrize(
'dataset_name',
['MNIST', 'FashionMNIST', 'CIFAR10', 'CIFAR100', 'ImageNet'])
def test_datasets_override_default(dataset_name):
dataset_class = DATASETS.get(dataset_name)
dataset_class.load_annotations = MagicMock()
# Test default behavior
dataset = dataset_class(data_prefix='', pipeline=[])
assert dataset.data_prefix == ''
assert not dataset.test_mode
assert dataset.ann_file is None
@patch.multiple(BaseDataset, __abstractmethods__=set())
def test_dataset_wrapper():
BaseDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
dataset_a = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
len_a = 10
cat_ids_list_a = [
np.random.randint(0, 80, num).tolist()
for num in np.random.randint(1, 20, len_a)
]
dataset_a.data_infos = MagicMock()
dataset_a.data_infos.__len__.return_value = len_a
dataset_a.get_cat_ids = MagicMock(
side_effect=lambda idx: cat_ids_list_a[idx])
dataset_b = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
len_b = 20
cat_ids_list_b = [
np.random.randint(0, 80, num).tolist()
for num in np.random.randint(1, 20, len_b)
]
dataset_b.data_infos = MagicMock()
dataset_b.data_infos.__len__.return_value = len_b
dataset_b.get_cat_ids = MagicMock(
side_effect=lambda idx: cat_ids_list_b[idx])
concat_dataset = ConcatDataset([dataset_a, dataset_b])
assert concat_dataset[5] == 5
assert concat_dataset[25] == 15
assert concat_dataset.get_cat_ids(5) == cat_ids_list_a[5]
assert concat_dataset.get_cat_ids(25) == cat_ids_list_b[15]
assert len(concat_dataset) == len(dataset_a) + len(dataset_b)
repeat_dataset = RepeatDataset(dataset_a, 10)
assert repeat_dataset[5] == 5
assert repeat_dataset[15] == 5
assert repeat_dataset[27] == 7
assert repeat_dataset.get_cat_ids(5) == cat_ids_list_a[5]
assert repeat_dataset.get_cat_ids(15) == cat_ids_list_a[5]
assert repeat_dataset.get_cat_ids(27) == cat_ids_list_a[7]
assert len(repeat_dataset) == 10 * len(dataset_a)
category_freq = defaultdict(int)
for cat_ids in cat_ids_list_a:
cat_ids = set(cat_ids)
for cat_id in cat_ids:
category_freq[cat_id] += 1
for k, v in category_freq.items():
category_freq[k] = v / len(cat_ids_list_a)
mean_freq = np.mean(list(category_freq.values()))
repeat_thr = mean_freq
category_repeat = {
cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
for cat_id, cat_freq in category_freq.items()
}
repeat_factors = []
for cat_ids in cat_ids_list_a:
cat_ids = set(cat_ids)
repeat_factor = max({category_repeat[cat_id] for cat_id in cat_ids})
repeat_factors.append(math.ceil(repeat_factor))
repeat_factors_cumsum = np.cumsum(repeat_factors)
repeat_factor_dataset = ClassBalancedDataset(dataset_a, repeat_thr)
assert len(repeat_factor_dataset) == repeat_factors_cumsum[-1]
for idx in np.random.randint(0, len(repeat_factor_dataset), 3):
assert repeat_factor_dataset[idx] == bisect.bisect_right(
repeat_factors_cumsum, idx)
def test_dataset_utils():
# test rm_suffix
assert rm_suffix('a.jpg') == 'a'
assert rm_suffix('a.bak.jpg') == 'a.bak'
assert rm_suffix('a.bak.jpg', suffix='.jpg') == 'a.bak'
assert rm_suffix('a.bak.jpg', suffix='.bak.jpg') == 'a'
# test check_integrity
rand_file = ''.join(random.sample(string.ascii_letters, 10))
assert not check_integrity(rand_file, md5=None)
assert not check_integrity(rand_file, md5=2333)
tmp_file = tempfile.NamedTemporaryFile()
assert check_integrity(tmp_file.name, md5=None)
assert not check_integrity(tmp_file.name, md5=2333)
# test download_and_extract_archive
url = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'
md5 = 'd53e105ee54ea40749a09fcbcd1e9432'
tmp_dir = tempfile.TemporaryDirectory()
download_and_extract_archive(
url, download_root=tmp_dir.name, md5=md5, remove_finished=True)

View File

@ -0,0 +1,57 @@
import copy
import os.path as osp
import numpy as np
from mmcls.datasets.pipelines import LoadImageFromFile
class TestLoading(object):
@classmethod
def setup_class(cls):
cls.data_prefix = osp.join(osp.dirname(__file__), '../data')
def test_load_img(self):
results = dict(
img_prefix=self.data_prefix, img_info=dict(filename='color.jpg'))
transform = LoadImageFromFile()
results = transform(copy.deepcopy(results))
assert results['filename'] == osp.join(self.data_prefix, 'color.jpg')
assert results['img'].shape == (288, 512, 3)
assert results['img'].dtype == np.uint8
assert results['img_shape'] == (288, 512, 3)
assert results['ori_shape'] == (288, 512, 3)
np.testing.assert_equal(results['img_norm_cfg']['mean'],
np.zeros(3, dtype=np.float32))
assert repr(transform) == transform.__class__.__name__ + \
"(to_float32=False, color_type='color', " + \
"file_client_args={'backend': 'disk'})"
# no img_prefix
results = dict(
img_prefix=None, img_info=dict(filename='tests/data/color.jpg'))
transform = LoadImageFromFile()
results = transform(copy.deepcopy(results))
assert results['filename'] == 'tests/data/color.jpg'
assert results['img'].shape == (288, 512, 3)
# to_float32
transform = LoadImageFromFile(to_float32=True)
results = transform(copy.deepcopy(results))
assert results['img'].dtype == np.float32
# gray image
results = dict(
img_prefix=self.data_prefix, img_info=dict(filename='gray.jpg'))
transform = LoadImageFromFile()
results = transform(copy.deepcopy(results))
assert results['img'].shape == (288, 512, 3)
assert results['img'].dtype == np.uint8
transform = LoadImageFromFile(color_type='unchanged')
results = transform(copy.deepcopy(results))
assert results['img'].shape == (288, 512)
assert results['img'].dtype == np.uint8
np.testing.assert_equal(results['img_norm_cfg']['mean'],
np.zeros(1, dtype=np.float32))

View File

@ -0,0 +1,44 @@
import copy
import os.path as osp
import mmcv
import numpy as np
import torch
from mmcv.utils import build_from_cfg
from torchvision import transforms
from mmcls.datasets.builder import PIPELINES
def test_normalize():
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
transform = dict(type='Normalize', **img_norm_cfg)
transform = build_from_cfg(transform, PIPELINES)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
original_img = copy.deepcopy(img)
results['img'] = img
results['img2'] = copy.deepcopy(img)
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
results['img_fields'] = ['img', 'img2']
norm_results = transform(results)
assert np.equal(norm_results['img'], norm_results['img2']).all()
mean = np.array(img_norm_cfg['mean'])
std = np.array(img_norm_cfg['std'])
normalized_img = (original_img[..., ::-1] - mean) / std
assert np.allclose(norm_results['img'], normalized_img)
# compare results with torchvision
normalize_module = transforms.Normalize(mean=mean, std=std)
tensor_img = original_img[..., ::-1].copy()
tensor_img = torch.Tensor(tensor_img.transpose(2, 0, 1))
normalized_img = normalize_module(tensor_img)
normalized_img = np.array(normalized_img).transpose(1, 2, 0)
assert np.equal(norm_results['img'], normalized_img).all()