Add commonly used datasets
parent
a48ffaaa0a
commit
ba789da5c8
|
@ -1,9 +1,15 @@
|
|||
from .base_dataset import BaseDataset
|
||||
from .builder import build_dataloader, build_dataset
|
||||
from .pipelines import Compose
|
||||
from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
|
||||
from .cifar import CIFAR10, CIFAR100
|
||||
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
|
||||
RepeatDataset)
|
||||
from .imagenet import ImageNet
|
||||
from .mnist import MNIST, FashionMNIST
|
||||
from .samplers import DistributedSampler
|
||||
|
||||
__all__ = [
|
||||
'BaseDataset', 'build_dataloader', 'build_dataset', 'Compose',
|
||||
'DistributedSampler'
|
||||
'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST',
|
||||
'build_dataloader', 'build_dataset', 'Compose', 'DistributedSampler',
|
||||
'ConcatDataset', 'RepeatDataset', 'ClassBalancedDataset', 'DATASETS',
|
||||
'PIPELINES'
|
||||
]
|
||||
|
|
|
@ -1,14 +1,26 @@
|
|||
import copy
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from .pipelines import Compose
|
||||
|
||||
|
||||
class BaseDataset(Dataset, metaclass=ABCMeta):
|
||||
"""Base dataset.
|
||||
|
||||
def __init__(self, ann_file, pipeline, data_prefix, test_mode):
|
||||
Args:
|
||||
data_prefix (str): the prefix of data path
|
||||
pipeline (list): a list of dict, where each element represents
|
||||
a operation defined in `mmcls.datasets.pipelines`
|
||||
ann_file (str | None): the annotation file. When ann_file is str,
|
||||
the subclass is expected to read from the ann_file. When ann_file
|
||||
is None, the subclass is expected to read according to data_prefix
|
||||
test_mode (bool): in train mode or test mode
|
||||
"""
|
||||
|
||||
def __init__(self, data_prefix, pipeline, ann_file=None, test_mode=False):
|
||||
super(BaseDataset, self).__init__()
|
||||
|
||||
self.ann_file = ann_file
|
||||
|
@ -21,11 +33,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
|||
def load_annotations(self):
|
||||
pass
|
||||
|
||||
def prepare_train_data(self, idx):
|
||||
results = copy.deepcopy(self.data_infos[idx])
|
||||
return self.pipeline(results)
|
||||
|
||||
def prepare_test_data(self, idx):
|
||||
def prepare_data(self, idx):
|
||||
results = copy.deepcopy(self.data_infos[idx])
|
||||
return self.pipeline(results)
|
||||
|
||||
|
@ -33,7 +41,35 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
|||
return len(self.data_infos)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.test_mode:
|
||||
return self.prepare_train_data(idx)
|
||||
else:
|
||||
return self.prepare_test_data(idx)
|
||||
return self.prepare_data(idx)
|
||||
|
||||
def evaluate(self, results, metric='accuracy', logger=None):
|
||||
"""Evaluate the dataset.
|
||||
|
||||
Args:
|
||||
results (list): Testing results of the dataset.
|
||||
metric (str | list[str]): Metrics to be evaluated.
|
||||
Default value is `accuracy`.
|
||||
logger (logging.Logger | None | str): Logger used for printing
|
||||
related information during evaluation. Default: None.
|
||||
Returns:
|
||||
dict: evaluation results
|
||||
"""
|
||||
if not isinstance(metric, str):
|
||||
assert len(metric) == 1
|
||||
metric = metric[0]
|
||||
allowed_metrics = ['accuracy']
|
||||
if metric not in allowed_metrics:
|
||||
raise KeyError(f'metric {metric} is not supported')
|
||||
eval_results = {}
|
||||
if metric == 'accuracy':
|
||||
nums = []
|
||||
for result in results:
|
||||
nums.append(result['num_samples'].item())
|
||||
for topk, v in result['accuracy'].items():
|
||||
if topk not in eval_results:
|
||||
eval_results[topk] = []
|
||||
eval_results[topk].append(v.item())
|
||||
for topk, accs in eval_results.items():
|
||||
eval_results[topk] = np.average(accs, weights=nums)
|
||||
return eval_results
|
||||
|
|
|
@ -23,12 +23,16 @@ PIPELINES = Registry('pipeline')
|
|||
|
||||
|
||||
def build_dataset(cfg, default_args=None):
|
||||
from .dataset_wrappers import ConcatDataset, RepeatDataset
|
||||
from .dataset_wrappers import (ConcatDataset, RepeatDataset,
|
||||
ClassBalancedDataset)
|
||||
if isinstance(cfg, (list, tuple)):
|
||||
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
|
||||
elif cfg['type'] == 'RepeatDataset':
|
||||
dataset = RepeatDataset(
|
||||
build_dataset(cfg['dataset'], default_args), cfg['times'])
|
||||
elif cfg['type'] == 'ClassBalancedDataset':
|
||||
dataset = ClassBalancedDataset(
|
||||
build_dataset(cfg['dataset'], default_args), cfg['oversample_thr'])
|
||||
else:
|
||||
dataset = build_from_cfg(cfg, DATASETS, default_args)
|
||||
|
||||
|
@ -85,6 +89,7 @@ def build_dataloader(dataset,
|
|||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
|
||||
pin_memory=False,
|
||||
shuffle=shuffle,
|
||||
worker_init_fn=init_fn,
|
||||
**kwargs)
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
import os
|
||||
import os.path
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .base_dataset import BaseDataset
|
||||
from .builder import DATASETS
|
||||
from .utils import check_integrity, download_and_extract_archive
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class CIFAR10(BaseDataset):
|
||||
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
|
||||
|
||||
This implementation is modified from
|
||||
https://github.com/pytorch/vision/blob/master/torchvision/datasets/cifar.py # noqa: E501
|
||||
"""
|
||||
|
||||
base_folder = 'cifar-10-batches-py'
|
||||
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
|
||||
filename = "cifar-10-python.tar.gz"
|
||||
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
|
||||
train_list = [
|
||||
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
|
||||
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
|
||||
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
|
||||
['data_batch_4', '634d18415352ddfa80567beed471001a'],
|
||||
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
|
||||
]
|
||||
|
||||
test_list = [
|
||||
['test_batch', '40351d587109b95175f43aff81a1287e'],
|
||||
]
|
||||
meta = {
|
||||
'filename': 'batches.meta',
|
||||
'key': 'label_names',
|
||||
'md5': '5ff9c542aee3614f3951f8cda6e48888',
|
||||
}
|
||||
|
||||
def load_annotations(self):
|
||||
|
||||
if not self._check_integrity():
|
||||
download_and_extract_archive(
|
||||
self.url,
|
||||
self.data_prefix,
|
||||
filename=self.filename,
|
||||
md5=self.tgz_md5)
|
||||
|
||||
if not self.test_mode:
|
||||
downloaded_list = self.train_list
|
||||
else:
|
||||
downloaded_list = self.test_list
|
||||
|
||||
self.imgs = []
|
||||
self.gt_labels = []
|
||||
|
||||
# load the picked numpy arrays
|
||||
for file_name, checksum in downloaded_list:
|
||||
file_path = os.path.join(self.data_prefix, self.base_folder,
|
||||
file_name)
|
||||
with open(file_path, 'rb') as f:
|
||||
entry = pickle.load(f, encoding='latin1')
|
||||
self.imgs.append(entry['data'])
|
||||
if 'labels' in entry:
|
||||
self.gt_labels.extend(entry['labels'])
|
||||
else:
|
||||
self.gt_labels.extend(entry['fine_labels'])
|
||||
|
||||
self.imgs = np.vstack(self.imgs).reshape(-1, 3, 32, 32)
|
||||
self.imgs = self.imgs.transpose((0, 2, 3, 1)) # convert to HWC
|
||||
|
||||
self._load_meta()
|
||||
|
||||
data_infos = []
|
||||
for img, gt_label in zip(self.imgs, self.gt_labels):
|
||||
gt_label = np.array(gt_label, dtype=np.int64)
|
||||
info = {'img': img, 'gt_labels': gt_label}
|
||||
data_infos.append(info)
|
||||
return data_infos
|
||||
|
||||
def _load_meta(self):
|
||||
path = os.path.join(self.data_prefix, self.base_folder,
|
||||
self.meta['filename'])
|
||||
if not check_integrity(path, self.meta['md5']):
|
||||
raise RuntimeError(
|
||||
'Dataset metadata file not found or corrupted.' +
|
||||
' You can use download=True to download it')
|
||||
with open(path, 'rb') as infile:
|
||||
data = pickle.load(infile, encoding='latin1')
|
||||
self.classes = data[self.meta['key']]
|
||||
self.class_to_idx = {
|
||||
_class: i
|
||||
for i, _class in enumerate(self.classes)
|
||||
}
|
||||
|
||||
def _check_integrity(self):
|
||||
root = self.data_prefix
|
||||
for fentry in (self.train_list + self.test_list):
|
||||
filename, md5 = fentry[0], fentry[1]
|
||||
fpath = os.path.join(root, self.base_folder, filename)
|
||||
if not check_integrity(fpath, md5):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class CIFAR100(CIFAR10):
|
||||
"""`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
|
||||
"""
|
||||
|
||||
base_folder = 'cifar-100-python'
|
||||
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
|
||||
filename = "cifar-100-python.tar.gz"
|
||||
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
|
||||
train_list = [
|
||||
['train', '16019d7e3df5f24257cddd939b257f8d'],
|
||||
]
|
||||
|
||||
test_list = [
|
||||
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
|
||||
]
|
||||
meta = {
|
||||
'filename': 'meta',
|
||||
'key': 'fine_label_names',
|
||||
'md5': '7973b15100ade9c7d40fb424638fde48',
|
||||
}
|
|
@ -0,0 +1,159 @@
|
|||
import bisect
|
||||
import math
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
|
||||
|
||||
from .builder import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class ConcatDataset(_ConcatDataset):
|
||||
"""A wrapper of concatenated dataset.
|
||||
|
||||
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
|
||||
add `get_cat_ids` function.
|
||||
|
||||
Args:
|
||||
datasets (list[:obj:`Dataset`]): A list of datasets.
|
||||
"""
|
||||
|
||||
def __init__(self, datasets):
|
||||
super(ConcatDataset, self).__init__(datasets)
|
||||
|
||||
def get_cat_ids(self, idx):
|
||||
if idx < 0:
|
||||
if -idx > len(self):
|
||||
raise ValueError(
|
||||
'absolute value of index should not exceed dataset length')
|
||||
idx = len(self) + idx
|
||||
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
||||
if dataset_idx == 0:
|
||||
sample_idx = idx
|
||||
else:
|
||||
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
||||
return self.datasets[dataset_idx].get_cat_ids(sample_idx)
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class RepeatDataset(object):
|
||||
"""A wrapper of repeated dataset.
|
||||
|
||||
The length of repeated dataset will be `times` larger than the original
|
||||
dataset. This is useful when the data loading time is long but the dataset
|
||||
is small. Using RepeatDataset can reduce the data loading time between
|
||||
epochs.
|
||||
|
||||
Args:
|
||||
dataset (:obj:`Dataset`): The dataset to be repeated.
|
||||
times (int): Repeat times.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, times):
|
||||
self.dataset = dataset
|
||||
self.times = times
|
||||
|
||||
self._ori_len = len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.dataset[idx % self._ori_len]
|
||||
|
||||
def get_cat_ids(self, idx):
|
||||
return self.dataset.get_cat_ids(idx % self._ori_len)
|
||||
|
||||
def __len__(self):
|
||||
return self.times * self._ori_len
|
||||
|
||||
|
||||
# Modified from https://github.com/facebookresearch/detectron2/blob/41d475b75a230221e21d9cac5d69655e3415e3a4/detectron2/data/samplers/distributed_sampler.py#L57 # noqa
|
||||
@DATASETS.register_module()
|
||||
class ClassBalancedDataset(object):
|
||||
"""A wrapper of repeated dataset with repeat factor.
|
||||
|
||||
Suitable for training on class imbalanced datasets like LVIS. Following
|
||||
the sampling strategy in [1], in each epoch, an image may appear multiple
|
||||
times based on its "repeat factor".
|
||||
The repeat factor for an image is a function of the frequency the rarest
|
||||
category labeled in that image. The "frequency of category c" in [0, 1]
|
||||
is defined by the fraction of images in the training set (without repeats)
|
||||
in which category c appears.
|
||||
The dataset needs to instantiate :func:`self.get_cat_ids(idx)` to support
|
||||
ClassBalancedDataset.
|
||||
The repeat factor is computed as followed.
|
||||
1. For each category c, compute the fraction # of images
|
||||
that contain it: f(c)
|
||||
2. For each category c, compute the category-level repeat factor:
|
||||
r(c) = max(1, sqrt(t/f(c)))
|
||||
3. For each image I and its labels L(I), compute the image-level repeat
|
||||
factor:
|
||||
r(I) = max_{c in L(I)} r(c)
|
||||
|
||||
References:
|
||||
.. [1] https://arxiv.org/pdf/1908.03195.pdf
|
||||
|
||||
Args:
|
||||
dataset (:obj:`CustomDataset`): The dataset to be repeated.
|
||||
oversample_thr (float): frequency threshold below which data is
|
||||
repeated. For categories with `f_c` >= `oversample_thr`, there is
|
||||
no oversampling. For categories with `f_c` < `oversample_thr`, the
|
||||
degree of oversampling following the square-root inverse frequency
|
||||
heuristic above.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, oversample_thr):
|
||||
self.dataset = dataset
|
||||
self.oversample_thr = oversample_thr
|
||||
|
||||
repeat_factors = self._get_repeat_factors(dataset, oversample_thr)
|
||||
repeat_indices = []
|
||||
for dataset_index, repeat_factor in enumerate(repeat_factors):
|
||||
repeat_indices.extend([dataset_index] * math.ceil(repeat_factor))
|
||||
self.repeat_indices = repeat_indices
|
||||
|
||||
flags = []
|
||||
if hasattr(self.dataset, 'flag'):
|
||||
for flag, repeat_factor in zip(self.dataset.flag, repeat_factors):
|
||||
flags.extend([flag] * int(math.ceil(repeat_factor)))
|
||||
assert len(flags) == len(repeat_indices)
|
||||
self.flag = np.asarray(flags, dtype=np.uint8)
|
||||
|
||||
def _get_repeat_factors(self, dataset, repeat_thr):
|
||||
# 1. For each category c, compute the fraction # of images
|
||||
# that contain it: f(c)
|
||||
category_freq = defaultdict(int)
|
||||
num_images = len(dataset)
|
||||
for idx in range(num_images):
|
||||
cat_ids = set(self.dataset.get_cat_ids(idx))
|
||||
for cat_id in cat_ids:
|
||||
category_freq[cat_id] += 1
|
||||
for k, v in category_freq.items():
|
||||
assert v > 0, f'caterogy {k} does not contain any images'
|
||||
category_freq[k] = v / num_images
|
||||
|
||||
# 2. For each category c, compute the category-level repeat factor:
|
||||
# r(c) = max(1, sqrt(t/f(c)))
|
||||
category_repeat = {
|
||||
cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
|
||||
for cat_id, cat_freq in category_freq.items()
|
||||
}
|
||||
|
||||
# 3. For each image I and its labels L(I), compute the image-level
|
||||
# repeat factor:
|
||||
# r(I) = max_{c in L(I)} r(c)
|
||||
repeat_factors = []
|
||||
for idx in range(num_images):
|
||||
cat_ids = set(self.dataset.get_cat_ids(idx))
|
||||
repeat_factor = max(
|
||||
{category_repeat[cat_id]
|
||||
for cat_id in cat_ids})
|
||||
repeat_factors.append(repeat_factor)
|
||||
|
||||
return repeat_factors
|
||||
|
||||
def __getitem__(self, idx):
|
||||
ori_index = self.repeat_indices[idx]
|
||||
return self.dataset[ori_index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.repeat_indices)
|
|
@ -0,0 +1,104 @@
|
|||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .base_dataset import BaseDataset
|
||||
from .builder import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class ImageNet(BaseDataset):
|
||||
"""`ImageNet <http://www.image-net.org>`_ Dataset.
|
||||
|
||||
This implementation is modified from
|
||||
https://github.com/pytorch/vision/blob/master/torchvision/datasets/imagenet.py # noqa: E501
|
||||
"""
|
||||
|
||||
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
|
||||
|
||||
def load_annotations(self):
|
||||
if self.ann_file is None:
|
||||
classes, class_to_idx = find_classes(self.data_prefix)
|
||||
samples = make_dataset(
|
||||
self.data_prefix, class_to_idx, extensions=self.IMG_EXTENSIONS)
|
||||
if len(samples) == 0:
|
||||
raise (RuntimeError('Found 0 files in subfolders of: '
|
||||
f'{self.data_prefix}. '
|
||||
'Supported extensions are: '
|
||||
f'{",".join(self.IMG_EXTENSIONS)}'))
|
||||
|
||||
self.classes = classes
|
||||
self.class_to_idx = class_to_idx
|
||||
elif isinstance(self.ann_file, str):
|
||||
with open(self.ann_file) as f:
|
||||
samples = [x.strip().split(' ') for x in f.readlines()]
|
||||
else:
|
||||
raise TypeError('ann_file must be a str or None')
|
||||
self.samples = samples
|
||||
|
||||
data_infos = []
|
||||
for filename, gt_label in self.samples:
|
||||
info = {'img_prefix': self.data_prefix}
|
||||
info['img_info'] = {'filename': filename}
|
||||
info['gt_labels'] = np.array(gt_label, dtype=np.int64)
|
||||
data_infos.append(info)
|
||||
return data_infos
|
||||
|
||||
|
||||
def has_file_allowed_extension(filename, extensions):
|
||||
"""Checks if a file is an allowed extension.
|
||||
|
||||
Args:
|
||||
filename (string): path to a file
|
||||
|
||||
Returns:
|
||||
bool: True if the filename ends with a known image extension
|
||||
"""
|
||||
filename_lower = filename.lower()
|
||||
return any(filename_lower.endswith(ext) for ext in extensions)
|
||||
|
||||
|
||||
def find_classes(root):
|
||||
"""Find classes by folders under a root.
|
||||
|
||||
Args:
|
||||
root (string): root directory of folders
|
||||
|
||||
Returns:
|
||||
classes (list): a list of class names
|
||||
class_to_idx (dict): the map from class name to class idx
|
||||
"""
|
||||
classes = [
|
||||
d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))
|
||||
]
|
||||
classes.sort()
|
||||
class_to_idx = {classes[i]: i for i in range(len(classes))}
|
||||
return classes, class_to_idx
|
||||
|
||||
|
||||
def make_dataset(root, class_to_idx, extensions):
|
||||
"""Make dataset by walking all images under a root.
|
||||
|
||||
Args:
|
||||
root (string): root directory of folders
|
||||
class_to_idx (dict): the map from class name to class idx
|
||||
extensions (tuple): allowed extensions
|
||||
|
||||
Returns:
|
||||
images (list): a list of tuple where each element is (image, label)
|
||||
"""
|
||||
images = []
|
||||
root = os.path.expanduser(root)
|
||||
for class_name in sorted(os.listdir(root)):
|
||||
_dir = os.path.join(root, class_name)
|
||||
if not os.path.isdir(_dir):
|
||||
continue
|
||||
|
||||
for _, _, fns in sorted(os.walk(_dir)):
|
||||
for fn in sorted(fns):
|
||||
if has_file_allowed_extension(fn, extensions):
|
||||
path = os.path.join(class_name, fn)
|
||||
item = (path, class_to_idx[class_name])
|
||||
images.append(item)
|
||||
|
||||
return images
|
|
@ -0,0 +1,175 @@
|
|||
import codecs
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .base_dataset import BaseDataset
|
||||
from .builder import DATASETS
|
||||
from .utils import download_and_extract_archive, rm_suffix
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class MNIST(BaseDataset):
|
||||
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
|
||||
|
||||
This implementation is modified from
|
||||
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py # noqa: E501
|
||||
"""
|
||||
|
||||
resource_prefix = 'http://yann.lecun.com/exdb/mnist/'
|
||||
resources = {
|
||||
'train_image_file':
|
||||
('train-images-idx3-ubyte.gz', 'f68b3c2dcbeaaa9fbdd348bbdeb94873'),
|
||||
'train_label_file':
|
||||
('train-labels-idx1-ubyte.gz', 'd53e105ee54ea40749a09fcbcd1e9432'),
|
||||
'test_image_file':
|
||||
('t10k-images-idx3-ubyte.gz', '9fb629c4189551a2d022fa330f9573f3'),
|
||||
'test_label_file':
|
||||
('t10k-labels-idx1-ubyte.gz', 'ec29112dd5afa0611ce80d1b7f02629c')
|
||||
}
|
||||
|
||||
classes = [
|
||||
'0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five',
|
||||
'6 - six', '7 - seven', '8 - eight', '9 - nine'
|
||||
]
|
||||
|
||||
@property
|
||||
def class_to_idx(self):
|
||||
return {_class: i for i, _class in enumerate(self.classes)}
|
||||
|
||||
def load_annotations(self):
|
||||
train_image_file = osp.join(
|
||||
self.data_prefix, rm_suffix(self.resources['train_image_file'][0]))
|
||||
train_label_file = osp.join(
|
||||
self.data_prefix, rm_suffix(self.resources['train_label_file'][0]))
|
||||
test_image_file = osp.join(
|
||||
self.data_prefix, rm_suffix(self.resources['test_image_file'][0]))
|
||||
test_label_file = osp.join(
|
||||
self.data_prefix, rm_suffix(self.resources['test_label_file'][0]))
|
||||
|
||||
if not osp.exists(train_image_file) or not osp.exists(
|
||||
train_label_file) or not osp.exists(
|
||||
test_image_file) or not osp.exists(test_label_file):
|
||||
self.download()
|
||||
|
||||
train_set = (read_image_file(train_image_file),
|
||||
read_label_file(train_label_file))
|
||||
test_set = (read_image_file(test_image_file),
|
||||
read_label_file(test_label_file))
|
||||
|
||||
if not self.test_mode:
|
||||
imgs, gt_labels = train_set
|
||||
else:
|
||||
imgs, gt_labels = test_set
|
||||
|
||||
data_infos = []
|
||||
for img, gt_label in zip(imgs, gt_labels):
|
||||
gt_label = np.array(gt_label, dtype=np.int64)
|
||||
info = {'img': img.numpy(), 'gt_labels': gt_label}
|
||||
data_infos.append(info)
|
||||
return data_infos
|
||||
|
||||
def download(self):
|
||||
os.makedirs(self.data_prefix, exist_ok=True)
|
||||
|
||||
# download files
|
||||
for url, md5 in self.resources.values():
|
||||
url = osp.join(self.resource_prefix, url)
|
||||
filename = url.rpartition('/')[2]
|
||||
download_and_extract_archive(
|
||||
url,
|
||||
download_root=self.data_prefix,
|
||||
filename=filename,
|
||||
md5=md5)
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class FashionMNIST(MNIST):
|
||||
"""`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_
|
||||
Dataset.
|
||||
"""
|
||||
|
||||
resource_prefix = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/' # noqa: E501
|
||||
resources = {
|
||||
'train_image_file':
|
||||
('train-images-idx3-ubyte.gz', '8d4fb7e6c68d591d4c3dfef9ec88bf0d'),
|
||||
'train_label_file':
|
||||
('train-labels-idx1-ubyte.gz', '25c81989df183df01b3e8a0aad5dffbe'),
|
||||
'test_image_file':
|
||||
('t10k-images-idx3-ubyte.gz', 'bef4ecab320f06d8554ea6380940ec79'),
|
||||
'test_label_file':
|
||||
('t10k-labels-idx1-ubyte.gz', 'bb300cfdad3c16e7a12a480ee83cd310')
|
||||
}
|
||||
classes = [
|
||||
'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
|
||||
'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
|
||||
]
|
||||
|
||||
|
||||
def get_int(b):
|
||||
return int(codecs.encode(b, 'hex'), 16)
|
||||
|
||||
|
||||
def open_maybe_compressed_file(path):
|
||||
"""Return a file object that possibly decompresses 'path' on the fly.
|
||||
Decompression occurs when argument `path` is a string
|
||||
and ends with '.gz' or '.xz'.
|
||||
"""
|
||||
if not isinstance(path, torch._six.string_classes):
|
||||
return path
|
||||
if path.endswith('.gz'):
|
||||
import gzip
|
||||
return gzip.open(path, 'rb')
|
||||
if path.endswith('.xz'):
|
||||
import lzma
|
||||
return lzma.open(path, 'rb')
|
||||
return open(path, 'rb')
|
||||
|
||||
|
||||
def read_sn3_pascalvincent_tensor(path, strict=True):
|
||||
"""Read a SN3 file in "Pascal Vincent" format
|
||||
(Lush file 'libidx/idx-io.lsh').
|
||||
Argument may be a filename, compressed filename, or file object.
|
||||
"""
|
||||
# typemap
|
||||
if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'):
|
||||
read_sn3_pascalvincent_tensor.typemap = {
|
||||
8: (torch.uint8, np.uint8, np.uint8),
|
||||
9: (torch.int8, np.int8, np.int8),
|
||||
11: (torch.int16, np.dtype('>i2'), 'i2'),
|
||||
12: (torch.int32, np.dtype('>i4'), 'i4'),
|
||||
13: (torch.float32, np.dtype('>f4'), 'f4'),
|
||||
14: (torch.float64, np.dtype('>f8'), 'f8')
|
||||
}
|
||||
# read
|
||||
with open_maybe_compressed_file(path) as f:
|
||||
data = f.read()
|
||||
# parse
|
||||
magic = get_int(data[0:4])
|
||||
nd = magic % 256
|
||||
ty = magic // 256
|
||||
assert nd >= 1 and nd <= 3
|
||||
assert ty >= 8 and ty <= 14
|
||||
m = read_sn3_pascalvincent_tensor.typemap[ty]
|
||||
s = [get_int(data[4 * (i + 1):4 * (i + 2)]) for i in range(nd)]
|
||||
parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
|
||||
assert parsed.shape[0] == np.prod(s) or not strict
|
||||
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
|
||||
|
||||
|
||||
def read_label_file(path):
|
||||
with open(path, 'rb') as f:
|
||||
x = read_sn3_pascalvincent_tensor(f, strict=False)
|
||||
assert (x.dtype == torch.uint8)
|
||||
assert (x.ndimension() == 1)
|
||||
return x.long()
|
||||
|
||||
|
||||
def read_image_file(path):
|
||||
with open(path, 'rb') as f:
|
||||
x = read_sn3_pascalvincent_tensor(f, strict=False)
|
||||
assert (x.dtype == torch.uint8)
|
||||
assert (x.ndimension() == 3)
|
||||
return x
|
|
@ -1,3 +1,13 @@
|
|||
from .compose import Compose
|
||||
from .formating import (Collect, ImageToTensor, ToDataContainer, ToNumpy,
|
||||
ToPIL, ToTensor, Transpose, to_tensor)
|
||||
from .loading import LoadImageFromFile
|
||||
from .transforms import (CenterCrop, Normalize, RandomHorizontalFlip,
|
||||
RandomResizedCrop, Resize)
|
||||
|
||||
__all__ = ['Compose']
|
||||
__all__ = [
|
||||
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
|
||||
'ToPIL', 'ToNumpy', 'Transpose', 'Collect', 'LoadImageFromFile',
|
||||
'RandomResizedCrop', 'RandomHorizontalFlip', 'Resize', 'CenterCrop',
|
||||
'Normalize'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,198 @@
|
|||
from collections.abc import Sequence
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.parallel import DataContainer as DC
|
||||
from PIL import Image
|
||||
|
||||
from ..builder import PIPELINES
|
||||
|
||||
|
||||
def to_tensor(data):
|
||||
"""Convert objects of various python types to :obj:`torch.Tensor`.
|
||||
|
||||
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
|
||||
:class:`Sequence`, :class:`int` and :class:`float`.
|
||||
"""
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data
|
||||
elif isinstance(data, np.ndarray):
|
||||
return torch.from_numpy(data)
|
||||
elif isinstance(data, Sequence) and not mmcv.is_str(data):
|
||||
return torch.tensor(data)
|
||||
elif isinstance(data, int):
|
||||
return torch.LongTensor([data])
|
||||
elif isinstance(data, float):
|
||||
return torch.FloatTensor([data])
|
||||
else:
|
||||
raise TypeError(
|
||||
f'Type {type(data)} cannot be converted to tensor.'
|
||||
'Supported types are: `numpy.ndarray`, `torch.Tensor`, '
|
||||
'`Sequence`, `int` and `float`')
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class ToTensor(object):
|
||||
|
||||
def __init__(self, keys):
|
||||
self.keys = keys
|
||||
|
||||
def __call__(self, results):
|
||||
for key in self.keys:
|
||||
results[key] = to_tensor(results[key])
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + f'(keys={self.keys})'
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class ImageToTensor(object):
|
||||
|
||||
def __init__(self, keys):
|
||||
self.keys = keys
|
||||
|
||||
def __call__(self, results):
|
||||
for key in self.keys:
|
||||
img = results[key]
|
||||
if len(img.shape) < 3:
|
||||
img = np.expand_dims(img, -1)
|
||||
results[key] = to_tensor(img.transpose(2, 0, 1))
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + f'(keys={self.keys})'
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Transpose(object):
|
||||
|
||||
def __init__(self, keys, order):
|
||||
self.keys = keys
|
||||
self.order = order
|
||||
|
||||
def __call__(self, results):
|
||||
for key in self.keys:
|
||||
results[key] = results[key].transpose(self.order)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + \
|
||||
f'(keys={self.keys}, order={self.order})'
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class ToDataContainer(object):
|
||||
|
||||
def __init__(self,
|
||||
fields=(dict(key='img', stack=True), dict(key='gt_labels'))):
|
||||
self.fields = fields
|
||||
|
||||
def __call__(self, results):
|
||||
for field in self.fields:
|
||||
field = field.copy()
|
||||
key = field.pop('key')
|
||||
results[key] = DC(results[key], **field)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + f'(fields={self.fields})'
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class ToPIL(object):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, results):
|
||||
results['img'] = Image.fromarray(results['img'])
|
||||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class ToNumpy(object):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, results):
|
||||
results['img'] = np.array(results['img'], dtype=np.float32)
|
||||
# results['img'] = np.array(results['img'])
|
||||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Collect(object):
|
||||
"""
|
||||
Collect data from the loader relevant to the specific task.
|
||||
|
||||
This is usually the last stage of the data loader pipeline. Typically keys
|
||||
is set to some subset of "img" and "gt_labels".
|
||||
|
||||
The "img_meta" item is always populated. The contents of the "img_meta"
|
||||
dictionary depends on "meta_keys". By default this includes:
|
||||
|
||||
- "img_shape": shape of the image input to the network as a tuple
|
||||
(h, w, c). Note that images may be zero padded on the bottom/right
|
||||
if the batch tensor is larger than this shape.
|
||||
|
||||
- "filename": path to the image file
|
||||
|
||||
- "ori_shape": original shape of the image as a tuple (h, w, c)
|
||||
|
||||
- "img_norm_cfg": a dict of normalization information:
|
||||
- mean - per channel mean subtraction
|
||||
- std - per channel std divisor
|
||||
- to_rgb - bool indicating if bgr was converted to rgb
|
||||
"""
|
||||
|
||||
def __init__(self, keys, meta_keys=('filename', 'ori_shape', 'img_shape')):
|
||||
self.keys = keys
|
||||
self.meta_keys = meta_keys
|
||||
|
||||
def __call__(self, results):
|
||||
data = {}
|
||||
img_meta = {}
|
||||
for key in self.meta_keys:
|
||||
img_meta[key] = results[key]
|
||||
# data['img_metas'] = DC(img_meta, cpu_only=True)
|
||||
# data['img_metas'] = img_meta
|
||||
for key in self.keys:
|
||||
data[key] = results[key]
|
||||
return data
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + \
|
||||
f'(keys={self.keys}, meta_keys={self.meta_keys})'
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class WrapFieldsToLists(object):
|
||||
"""Wrap fields of the data dictionary into lists for evaluation.
|
||||
|
||||
This class can be used as a last step of a test or validation
|
||||
pipeline for single image evaluation or inference.
|
||||
|
||||
Example:
|
||||
>>> test_pipeline = [
|
||||
>>> dict(type='LoadImageFromFile'),
|
||||
>>> dict(type='Normalize',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
to_rgb=True),
|
||||
>>> dict(type='ImageToTensor', keys=['img']),
|
||||
>>> dict(type='Collect', keys=['img']),
|
||||
>>> dict(type='WrapIntoLists')
|
||||
>>> ]
|
||||
"""
|
||||
|
||||
def __call__(self, results):
|
||||
# Wrap dict fields into lists
|
||||
for key, val in results.items():
|
||||
results[key] = [val]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}()'
|
|
@ -0,0 +1,68 @@
|
|||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
||||
from ..builder import PIPELINES
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class LoadImageFromFile(object):
|
||||
"""Load an image from file.
|
||||
|
||||
Required keys are "img_prefix" and "img_info" (a dict that must contain the
|
||||
key "filename"). Added or updated keys are "filename", "img", "img_shape",
|
||||
"ori_shape" (same as `img_shape`) and "img_norm_cfg" (means=0 and stds=1).
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is an uint8 array.
|
||||
Defaults to False.
|
||||
color_type (str): The flag argument for :func:`mmcv.imfrombytes()`.
|
||||
Defaults to 'color'.
|
||||
file_client_args (dict): Arguments to instantiate a FileClient.
|
||||
See :class:`mmcv.fileio.FileClient` for details.
|
||||
Defaults to ``dict(backend='disk')``.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
to_float32=False,
|
||||
color_type='color',
|
||||
file_client_args=dict(backend='disk')):
|
||||
self.to_float32 = to_float32
|
||||
self.color_type = color_type
|
||||
self.file_client_args = file_client_args.copy()
|
||||
self.file_client = None
|
||||
|
||||
def __call__(self, results):
|
||||
if self.file_client is None:
|
||||
self.file_client = mmcv.FileClient(**self.file_client_args)
|
||||
|
||||
if results['img_prefix'] is not None:
|
||||
filename = osp.join(results['img_prefix'],
|
||||
results['img_info']['filename'])
|
||||
else:
|
||||
filename = results['img_info']['filename']
|
||||
|
||||
img_bytes = self.file_client.get(filename)
|
||||
img = mmcv.imfrombytes(img_bytes, flag=self.color_type)
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['filename'] = filename
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
num_channels = 1 if len(img.shape) < 3 else img.shape[2]
|
||||
results['img_norm_cfg'] = dict(
|
||||
mean=np.zeros(num_channels, dtype=np.float32),
|
||||
std=np.ones(num_channels, dtype=np.float32),
|
||||
to_rgb=False)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'to_float32={self.to_float32}, '
|
||||
f"color_type='{self.color_type}', "
|
||||
f'file_client_args={self.file_client_args})')
|
||||
return repr_str
|
|
@ -0,0 +1,115 @@
|
|||
import mmcv
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
|
||||
from ..builder import PIPELINES
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class RandomCrop(transforms.RandomCrop):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(RandomCrop, self).__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, results):
|
||||
results['img'] = super(RandomCrop, self).__call__(results['img'])
|
||||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class RandomResizedCrop(transforms.RandomResizedCrop):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(RandomResizedCrop, self).__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, results):
|
||||
results['img'] = super(RandomResizedCrop,
|
||||
self).__call__(results['img'])
|
||||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class RandomHorizontalFlip(transforms.RandomHorizontalFlip):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(RandomHorizontalFlip, self).__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, results):
|
||||
results['img'] = super(RandomHorizontalFlip,
|
||||
self).__call__(results['img'])
|
||||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Resize(transforms.Resize):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Resize, self).__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, results):
|
||||
results['img'] = super(Resize, self).__call__(results['img'])
|
||||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class CenterCrop(transforms.CenterCrop):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(CenterCrop, self).__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, results):
|
||||
results['img'] = super(CenterCrop, self).__call__(results['img'])
|
||||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class ColorJitter(transforms.ColorJitter):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ColorJitter, self).__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, results):
|
||||
results['img'] = super(ColorJitter, self).__call__(results['img'])
|
||||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Normalize(object):
|
||||
"""Normalize the image.
|
||||
|
||||
Args:
|
||||
mean (sequence): Mean values of 3 channels.
|
||||
std (sequence): Std values of 3 channels.
|
||||
to_rgb (bool): Whether to convert the image from BGR to RGB,
|
||||
default is true.
|
||||
"""
|
||||
|
||||
def __init__(self, mean, std, to_rgb=True):
|
||||
self.mean = np.array(mean, dtype=np.float32)
|
||||
self.std = np.array(std, dtype=np.float32)
|
||||
self.to_rgb = to_rgb
|
||||
|
||||
def __call__(self, results):
|
||||
for key in results.get('img_fields', ['img']):
|
||||
results[key] = mmcv.imnormalize(results[key], self.mean, self.std,
|
||||
self.to_rgb)
|
||||
results['img_norm_cfg'] = dict(
|
||||
mean=self.mean, std=self.std, to_rgb=self.to_rgb)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})'
|
||||
return repr_str
|
|
@ -0,0 +1,163 @@
|
|||
import gzip
|
||||
import hashlib
|
||||
import os
|
||||
import os.path
|
||||
import tarfile
|
||||
import zipfile
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
__all__ = ['rm_suffix', 'check_integrity', 'download_and_extract_archive']
|
||||
|
||||
|
||||
def rm_suffix(s, suffix=None):
|
||||
if suffix is None:
|
||||
return s[:s.rfind('.')]
|
||||
else:
|
||||
return s[:s.rfind(suffix)]
|
||||
|
||||
|
||||
def gen_bar_updater():
|
||||
pbar = tqdm(total=None)
|
||||
|
||||
def bar_update(count, block_size, total_size):
|
||||
if pbar.total is None and total_size:
|
||||
pbar.total = total_size
|
||||
progress_bytes = count * block_size
|
||||
pbar.update(progress_bytes - pbar.n)
|
||||
|
||||
return bar_update
|
||||
|
||||
|
||||
def calculate_md5(fpath, chunk_size=1024 * 1024):
|
||||
md5 = hashlib.md5()
|
||||
with open(fpath, 'rb') as f:
|
||||
for chunk in iter(lambda: f.read(chunk_size), b''):
|
||||
md5.update(chunk)
|
||||
return md5.hexdigest()
|
||||
|
||||
|
||||
def check_md5(fpath, md5, **kwargs):
|
||||
return md5 == calculate_md5(fpath, **kwargs)
|
||||
|
||||
|
||||
def check_integrity(fpath, md5=None):
|
||||
if not os.path.isfile(fpath):
|
||||
return False
|
||||
if md5 is None:
|
||||
return True
|
||||
return check_md5(fpath, md5)
|
||||
|
||||
|
||||
def download_url(url, root, filename=None, md5=None):
|
||||
"""Download a file from a url and place it in root.
|
||||
|
||||
Args:
|
||||
url (str): URL to download file from.
|
||||
root (str): Directory to place downloaded file in.
|
||||
filename (str | None): Name to save the file under.
|
||||
If filename is None, use the basename of the URL.
|
||||
md5 (str | None): MD5 checksum of the download.
|
||||
If md5 is None, download without md5 check.
|
||||
"""
|
||||
import urllib
|
||||
|
||||
root = os.path.expanduser(root)
|
||||
if not filename:
|
||||
filename = os.path.basename(url)
|
||||
fpath = os.path.join(root, filename)
|
||||
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
# check if file is already present locally
|
||||
if check_integrity(fpath, md5):
|
||||
print(f'Using downloaded and verified file: {fpath}')
|
||||
else: # download the file
|
||||
try:
|
||||
print('Downloading ' + url + ' to ' + fpath)
|
||||
urllib.request.urlretrieve(
|
||||
url, fpath, reporthook=gen_bar_updater())
|
||||
except (urllib.error.URLError, IOError) as e:
|
||||
if url[:5] == 'https':
|
||||
url = url.replace('https:', 'http:')
|
||||
print('Failed download. Trying https -> http instead.'
|
||||
f' Downloading {url} to {fpath}')
|
||||
urllib.request.urlretrieve(
|
||||
url, fpath, reporthook=gen_bar_updater())
|
||||
else:
|
||||
raise e
|
||||
# check integrity of downloaded file
|
||||
if not check_integrity(fpath, md5):
|
||||
raise RuntimeError('File not found or corrupted.')
|
||||
|
||||
|
||||
def _is_tarxz(filename):
|
||||
return filename.endswith('.tar.xz')
|
||||
|
||||
|
||||
def _is_tar(filename):
|
||||
return filename.endswith('.tar')
|
||||
|
||||
|
||||
def _is_targz(filename):
|
||||
return filename.endswith('.tar.gz')
|
||||
|
||||
|
||||
def _is_tgz(filename):
|
||||
return filename.endswith('.tgz')
|
||||
|
||||
|
||||
def _is_gzip(filename):
|
||||
return filename.endswith('.gz') and not filename.endswith('.tar.gz')
|
||||
|
||||
|
||||
def _is_zip(filename):
|
||||
return filename.endswith('.zip')
|
||||
|
||||
|
||||
def extract_archive(from_path, to_path=None, remove_finished=False):
|
||||
if to_path is None:
|
||||
to_path = os.path.dirname(from_path)
|
||||
|
||||
if _is_tar(from_path):
|
||||
with tarfile.open(from_path, 'r') as tar:
|
||||
tar.extractall(path=to_path)
|
||||
elif _is_targz(from_path) or _is_tgz(from_path):
|
||||
with tarfile.open(from_path, 'r:gz') as tar:
|
||||
tar.extractall(path=to_path)
|
||||
elif _is_tarxz(from_path):
|
||||
with tarfile.open(from_path, 'r:xz') as tar:
|
||||
tar.extractall(path=to_path)
|
||||
elif _is_gzip(from_path):
|
||||
to_path = os.path.join(
|
||||
to_path,
|
||||
os.path.splitext(os.path.basename(from_path))[0])
|
||||
with open(to_path, 'wb') as out_f, gzip.GzipFile(from_path) as zip_f:
|
||||
out_f.write(zip_f.read())
|
||||
elif _is_zip(from_path):
|
||||
with zipfile.ZipFile(from_path, 'r') as z:
|
||||
z.extractall(to_path)
|
||||
else:
|
||||
raise ValueError(f'Extraction of {from_path} not supported')
|
||||
|
||||
if remove_finished:
|
||||
os.remove(from_path)
|
||||
|
||||
|
||||
def download_and_extract_archive(url,
|
||||
download_root,
|
||||
extract_root=None,
|
||||
filename=None,
|
||||
md5=None,
|
||||
remove_finished=False):
|
||||
download_root = os.path.expanduser(download_root)
|
||||
if extract_root is None:
|
||||
extract_root = download_root
|
||||
if not filename:
|
||||
filename = os.path.basename(url)
|
||||
|
||||
download_url(url, download_root, filename, md5)
|
||||
|
||||
archive = os.path.join(download_root, filename)
|
||||
print(f'Extracting {archive} to {extract_root}')
|
||||
extract_archive(archive, extract_root, remove_finished)
|
Binary file not shown.
After Width: | Height: | Size: 35 KiB |
Binary file not shown.
After Width: | Height: | Size: 38 KiB |
|
@ -0,0 +1,122 @@
|
|||
import bisect
|
||||
import math
|
||||
import random
|
||||
import string
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mmcls.datasets import (DATASETS, BaseDataset, ClassBalancedDataset,
|
||||
ConcatDataset, RepeatDataset)
|
||||
from mmcls.datasets.utils import (check_integrity,
|
||||
download_and_extract_archive, rm_suffix)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'dataset_name',
|
||||
['MNIST', 'FashionMNIST', 'CIFAR10', 'CIFAR100', 'ImageNet'])
|
||||
def test_datasets_override_default(dataset_name):
|
||||
dataset_class = DATASETS.get(dataset_name)
|
||||
dataset_class.load_annotations = MagicMock()
|
||||
|
||||
# Test default behavior
|
||||
dataset = dataset_class(data_prefix='', pipeline=[])
|
||||
|
||||
assert dataset.data_prefix == ''
|
||||
assert not dataset.test_mode
|
||||
assert dataset.ann_file is None
|
||||
|
||||
|
||||
@patch.multiple(BaseDataset, __abstractmethods__=set())
|
||||
def test_dataset_wrapper():
|
||||
BaseDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
|
||||
dataset_a = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
|
||||
len_a = 10
|
||||
cat_ids_list_a = [
|
||||
np.random.randint(0, 80, num).tolist()
|
||||
for num in np.random.randint(1, 20, len_a)
|
||||
]
|
||||
dataset_a.data_infos = MagicMock()
|
||||
dataset_a.data_infos.__len__.return_value = len_a
|
||||
dataset_a.get_cat_ids = MagicMock(
|
||||
side_effect=lambda idx: cat_ids_list_a[idx])
|
||||
dataset_b = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
|
||||
len_b = 20
|
||||
cat_ids_list_b = [
|
||||
np.random.randint(0, 80, num).tolist()
|
||||
for num in np.random.randint(1, 20, len_b)
|
||||
]
|
||||
dataset_b.data_infos = MagicMock()
|
||||
dataset_b.data_infos.__len__.return_value = len_b
|
||||
dataset_b.get_cat_ids = MagicMock(
|
||||
side_effect=lambda idx: cat_ids_list_b[idx])
|
||||
|
||||
concat_dataset = ConcatDataset([dataset_a, dataset_b])
|
||||
assert concat_dataset[5] == 5
|
||||
assert concat_dataset[25] == 15
|
||||
assert concat_dataset.get_cat_ids(5) == cat_ids_list_a[5]
|
||||
assert concat_dataset.get_cat_ids(25) == cat_ids_list_b[15]
|
||||
assert len(concat_dataset) == len(dataset_a) + len(dataset_b)
|
||||
|
||||
repeat_dataset = RepeatDataset(dataset_a, 10)
|
||||
assert repeat_dataset[5] == 5
|
||||
assert repeat_dataset[15] == 5
|
||||
assert repeat_dataset[27] == 7
|
||||
assert repeat_dataset.get_cat_ids(5) == cat_ids_list_a[5]
|
||||
assert repeat_dataset.get_cat_ids(15) == cat_ids_list_a[5]
|
||||
assert repeat_dataset.get_cat_ids(27) == cat_ids_list_a[7]
|
||||
assert len(repeat_dataset) == 10 * len(dataset_a)
|
||||
|
||||
category_freq = defaultdict(int)
|
||||
for cat_ids in cat_ids_list_a:
|
||||
cat_ids = set(cat_ids)
|
||||
for cat_id in cat_ids:
|
||||
category_freq[cat_id] += 1
|
||||
for k, v in category_freq.items():
|
||||
category_freq[k] = v / len(cat_ids_list_a)
|
||||
|
||||
mean_freq = np.mean(list(category_freq.values()))
|
||||
repeat_thr = mean_freq
|
||||
|
||||
category_repeat = {
|
||||
cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
|
||||
for cat_id, cat_freq in category_freq.items()
|
||||
}
|
||||
|
||||
repeat_factors = []
|
||||
for cat_ids in cat_ids_list_a:
|
||||
cat_ids = set(cat_ids)
|
||||
repeat_factor = max({category_repeat[cat_id] for cat_id in cat_ids})
|
||||
repeat_factors.append(math.ceil(repeat_factor))
|
||||
repeat_factors_cumsum = np.cumsum(repeat_factors)
|
||||
repeat_factor_dataset = ClassBalancedDataset(dataset_a, repeat_thr)
|
||||
assert len(repeat_factor_dataset) == repeat_factors_cumsum[-1]
|
||||
for idx in np.random.randint(0, len(repeat_factor_dataset), 3):
|
||||
assert repeat_factor_dataset[idx] == bisect.bisect_right(
|
||||
repeat_factors_cumsum, idx)
|
||||
|
||||
|
||||
def test_dataset_utils():
|
||||
# test rm_suffix
|
||||
assert rm_suffix('a.jpg') == 'a'
|
||||
assert rm_suffix('a.bak.jpg') == 'a.bak'
|
||||
assert rm_suffix('a.bak.jpg', suffix='.jpg') == 'a.bak'
|
||||
assert rm_suffix('a.bak.jpg', suffix='.bak.jpg') == 'a'
|
||||
|
||||
# test check_integrity
|
||||
rand_file = ''.join(random.sample(string.ascii_letters, 10))
|
||||
assert not check_integrity(rand_file, md5=None)
|
||||
assert not check_integrity(rand_file, md5=2333)
|
||||
tmp_file = tempfile.NamedTemporaryFile()
|
||||
assert check_integrity(tmp_file.name, md5=None)
|
||||
assert not check_integrity(tmp_file.name, md5=2333)
|
||||
|
||||
# test download_and_extract_archive
|
||||
url = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'
|
||||
md5 = 'd53e105ee54ea40749a09fcbcd1e9432'
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
download_and_extract_archive(
|
||||
url, download_root=tmp_dir.name, md5=md5, remove_finished=True)
|
|
@ -0,0 +1,57 @@
|
|||
import copy
|
||||
import os.path as osp
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mmcls.datasets.pipelines import LoadImageFromFile
|
||||
|
||||
|
||||
class TestLoading(object):
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
cls.data_prefix = osp.join(osp.dirname(__file__), '../data')
|
||||
|
||||
def test_load_img(self):
|
||||
results = dict(
|
||||
img_prefix=self.data_prefix, img_info=dict(filename='color.jpg'))
|
||||
transform = LoadImageFromFile()
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['filename'] == osp.join(self.data_prefix, 'color.jpg')
|
||||
assert results['img'].shape == (288, 512, 3)
|
||||
assert results['img'].dtype == np.uint8
|
||||
assert results['img_shape'] == (288, 512, 3)
|
||||
assert results['ori_shape'] == (288, 512, 3)
|
||||
np.testing.assert_equal(results['img_norm_cfg']['mean'],
|
||||
np.zeros(3, dtype=np.float32))
|
||||
assert repr(transform) == transform.__class__.__name__ + \
|
||||
"(to_float32=False, color_type='color', " + \
|
||||
"file_client_args={'backend': 'disk'})"
|
||||
|
||||
# no img_prefix
|
||||
results = dict(
|
||||
img_prefix=None, img_info=dict(filename='tests/data/color.jpg'))
|
||||
transform = LoadImageFromFile()
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['filename'] == 'tests/data/color.jpg'
|
||||
assert results['img'].shape == (288, 512, 3)
|
||||
|
||||
# to_float32
|
||||
transform = LoadImageFromFile(to_float32=True)
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['img'].dtype == np.float32
|
||||
|
||||
# gray image
|
||||
results = dict(
|
||||
img_prefix=self.data_prefix, img_info=dict(filename='gray.jpg'))
|
||||
transform = LoadImageFromFile()
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['img'].shape == (288, 512, 3)
|
||||
assert results['img'].dtype == np.uint8
|
||||
|
||||
transform = LoadImageFromFile(color_type='unchanged')
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['img'].shape == (288, 512)
|
||||
assert results['img'].dtype == np.uint8
|
||||
np.testing.assert_equal(results['img_norm_cfg']['mean'],
|
||||
np.zeros(1, dtype=np.float32))
|
|
@ -0,0 +1,44 @@
|
|||
import copy
|
||||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.utils import build_from_cfg
|
||||
from torchvision import transforms
|
||||
|
||||
from mmcls.datasets.builder import PIPELINES
|
||||
|
||||
|
||||
def test_normalize():
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
to_rgb=True)
|
||||
transform = dict(type='Normalize', **img_norm_cfg)
|
||||
transform = build_from_cfg(transform, PIPELINES)
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
original_img = copy.deepcopy(img)
|
||||
results['img'] = img
|
||||
results['img2'] = copy.deepcopy(img)
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
results['img_fields'] = ['img', 'img2']
|
||||
|
||||
norm_results = transform(results)
|
||||
assert np.equal(norm_results['img'], norm_results['img2']).all()
|
||||
|
||||
mean = np.array(img_norm_cfg['mean'])
|
||||
std = np.array(img_norm_cfg['std'])
|
||||
normalized_img = (original_img[..., ::-1] - mean) / std
|
||||
assert np.allclose(norm_results['img'], normalized_img)
|
||||
|
||||
# compare results with torchvision
|
||||
normalize_module = transforms.Normalize(mean=mean, std=std)
|
||||
tensor_img = original_img[..., ::-1].copy()
|
||||
tensor_img = torch.Tensor(tensor_img.transpose(2, 0, 1))
|
||||
normalized_img = normalize_module(tensor_img)
|
||||
normalized_img = np.array(normalized_img).transpose(1, 2, 0)
|
||||
assert np.equal(norm_results['img'], normalized_img).all()
|
Loading…
Reference in New Issue