[Refactor] Refactor MNIST and FashionMNIST dataset.

pull/913/head
yingfhu 2022-06-10 07:24:07 +00:00 committed by mzr1996
parent ee6e585e41
commit 125b74d4ca
6 changed files with 267 additions and 118 deletions

View File

@ -15,7 +15,7 @@ preprocess_cfg = dict(mean=[33.46], std=[78.87])
pipeline = [dict(type='Resize', scale=32), dict(type='PackClsInputs')] pipeline = [dict(type='Resize', scale=32), dict(type='PackClsInputs')]
common_data_cfg = dict( common_data_cfg = dict(
type=dataset_type, data_root='data/mnist', pipeline=pipeline) type=dataset_type, data_prefix='data/mnist', pipeline=pipeline)
train_dataloader = dict( train_dataloader = dict(
batch_size=128, batch_size=128,
@ -53,6 +53,8 @@ val_cfg = dict(interval=1) # validate every epoch
test_cfg = dict() test_cfg = dict()
# runtime settings # runtime settings
default_scope = 'mmcls'
default_hooks = dict( default_hooks = dict(
# record the time of every iteration. # record the time of every iteration.
timer=dict(type='IterTimerHook'), timer=dict(type='IterTimerHook'),

View File

@ -1027,3 +1027,11 @@ CIFAR100_CATEGORIES = (
'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train',
'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf',
'woman', 'worm') 'woman', 'worm')
MNIST_CATEGORITES = ('0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight',
'9 - nine')
FASHIONMNIST_CATEGORITES = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress',
'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag',
'Ankle boot')

View File

@ -1,16 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import codecs import codecs
import os from typing import List, Optional
import os.path as osp
import mmengine.dist as dist
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist from mmengine import FileClient
from mmcv.runner import get_dist_info, master_only
from mmcls.registry import DATASETS from mmcls.registry import DATASETS
from .base_dataset import BaseDataset from .base_dataset import BaseDataset
from .utils import download_and_extract_archive, rm_suffix from .categories import FASHIONMNIST_CATEGORITES, MNIST_CATEGORITES
from .utils import (download_and_extract_archive, open_maybe_compressed_file,
rm_suffix)
@DATASETS.register_module() @DATASETS.register_module()
@ -19,58 +20,83 @@ class MNIST(BaseDataset):
This implementation is modified from This implementation is modified from
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
Args:
data_prefix (str): Prefix for data.
test_mode (bool): ``test_mode=True`` means in test phase.
It determines to use the training set or test set.
metainfo (dict, optional): Meta information for dataset, such as
categories information. Defaults to None.
data_root (str, optional): The root directory for ``data_prefix``.
Defaults to None.
download (bool): Whether to download the dataset if not exists.
Defaults to True.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
""" # noqa: E501 """ # noqa: E501
resource_prefix = 'http://yann.lecun.com/exdb/mnist/' url_prefix = 'http://yann.lecun.com/exdb/mnist/'
resources = { # train images and labels
'train_image_file': train_list = [
('train-images-idx3-ubyte.gz', 'f68b3c2dcbeaaa9fbdd348bbdeb94873'), ['train-images-idx3-ubyte.gz', 'f68b3c2dcbeaaa9fbdd348bbdeb94873'],
'train_label_file': ['train-labels-idx1-ubyte.gz', 'd53e105ee54ea40749a09fcbcd1e9432'],
('train-labels-idx1-ubyte.gz', 'd53e105ee54ea40749a09fcbcd1e9432'),
'test_image_file':
('t10k-images-idx3-ubyte.gz', '9fb629c4189551a2d022fa330f9573f3'),
'test_label_file':
('t10k-labels-idx1-ubyte.gz', 'ec29112dd5afa0611ce80d1b7f02629c')
}
CLASSES = [
'0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five',
'6 - six', '7 - seven', '8 - eight', '9 - nine'
] ]
# test images and labels
test_list = [
['t10k-images-idx3-ubyte.gz', '9fb629c4189551a2d022fa330f9573f3'],
['t10k-labels-idx1-ubyte.gz', 'ec29112dd5afa0611ce80d1b7f02629c'],
]
METAINFO = {'classes': MNIST_CATEGORITES}
def load_annotations(self): def __init__(self,
train_image_file = osp.join( data_prefix: str,
self.data_prefix, rm_suffix(self.resources['train_image_file'][0])) test_mode: bool,
train_label_file = osp.join( metainfo: Optional[dict] = None,
self.data_prefix, rm_suffix(self.resources['train_label_file'][0])) data_root: Optional[str] = None,
test_image_file = osp.join( download: bool = True,
self.data_prefix, rm_suffix(self.resources['test_image_file'][0])) **kwargs):
test_label_file = osp.join( self.download = download
self.data_prefix, rm_suffix(self.resources['test_label_file'][0])) super().__init__(
# The MNIST dataset doesn't need specify annotation file
ann_file='',
metainfo=metainfo,
data_root=data_root,
data_prefix=dict(root=data_prefix),
test_mode=test_mode,
**kwargs)
if not osp.exists(train_image_file) or not osp.exists( def load_data_list(self):
train_label_file) or not osp.exists( """Load images and ground truth labels."""
test_image_file) or not osp.exists(test_label_file): root = self.data_prefix['root']
self.download() file_client = FileClient.infer_client(uri=root)
_, world_size = get_dist_info() if dist.is_main_process() and not self._check_exists():
if world_size > 1: if file_client.name != 'HardDiskBackend':
dist.barrier() raise RuntimeError(f'The dataset on {root} is not integrated, '
assert osp.exists(train_image_file) and osp.exists( f'please manually handle it.')
train_label_file) and osp.exists(
test_image_file) and osp.exists(test_label_file), \
'Shared storage seems unavailable. Please download dataset ' \
f'manually through {self.resource_prefix}.'
train_set = (read_image_file(train_image_file), if self.download:
read_label_file(train_label_file)) self._download()
test_set = (read_image_file(test_image_file), else:
read_label_file(test_label_file)) raise RuntimeError(
f'Cannot find {self.__class__.__name__} dataset in '
f"{self.data_prefix['root']}, you can specify "
'`download=True` to download automatically.')
dist.barrier()
assert self._check_exists(), \
'Download failed or shared storage is unavailable. Please ' \
f'download the dataset manually through {self.url_prefix}.'
if not self.test_mode: if not self.test_mode:
imgs, gt_labels = train_set file_list = self.train_list
else: else:
imgs, gt_labels = test_set file_list = self.test_list
# load data from SN3 files
imgs = read_image_file(
file_client.join_path(root, rm_suffix(file_list[0][0])))
gt_labels = read_label_file(
file_client.join_path(root, rm_suffix(file_list[1][0])))
data_infos = [] data_infos = []
for img, gt_label in zip(imgs, gt_labels): for img, gt_label in zip(imgs, gt_labels):
@ -79,65 +105,77 @@ class MNIST(BaseDataset):
data_infos.append(info) data_infos.append(info)
return data_infos return data_infos
@master_only def _check_exists(self):
def download(self): """Check the exists of data files."""
os.makedirs(self.data_prefix, exist_ok=True) root = self.data_prefix['root']
file_client = FileClient.infer_client(uri=root)
# download files for filename, _ in (self.train_list + self.test_list):
for url, md5 in self.resources.values(): # get extracted filename of data
url = osp.join(self.resource_prefix, url) extract_filename = rm_suffix(filename)
filename = url.rpartition('/')[2] fpath = file_client.join_path(root, extract_filename)
if not file_client.exists(fpath):
return False
return True
def _download(self):
"""Download and extract data files."""
root = self.data_prefix['root']
file_client = FileClient.infer_client(uri=root)
for filename, md5 in (self.train_list + self.test_list):
url = file_client.join_path(self.url_prefix, filename)
download_and_extract_archive( download_and_extract_archive(
url, url,
download_root=self.data_prefix, download_root=self.data_prefix['root'],
filename=filename, filename=filename,
md5=md5) md5=md5)
def extra_repr(self) -> List[str]:
"""The extra repr information of the dataset."""
body = [f"Prefix of data: \t{self.data_prefix['root']}"]
return body
@DATASETS.register_module() @DATASETS.register_module()
class FashionMNIST(MNIST): class FashionMNIST(MNIST):
"""`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ """`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_
Dataset.""" Dataset.
resource_prefix = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/' # noqa: E501 Args:
resources = { data_prefix (str): Prefix for data.
'train_image_file': test_mode (bool): ``test_mode=True`` means in test phase.
('train-images-idx3-ubyte.gz', '8d4fb7e6c68d591d4c3dfef9ec88bf0d'), It determines to use the training set or test set.
'train_label_file': metainfo (dict, optional): Meta information for dataset, such as
('train-labels-idx1-ubyte.gz', '25c81989df183df01b3e8a0aad5dffbe'), categories information. Defaults to None.
'test_image_file': data_root (str, optional): The root directory for ``data_prefix``.
('t10k-images-idx3-ubyte.gz', 'bef4ecab320f06d8554ea6380940ec79'), Defaults to None.
'test_label_file': download (bool): Whether to download the dataset if not exists.
('t10k-labels-idx1-ubyte.gz', 'bb300cfdad3c16e7a12a480ee83cd310') Defaults to True.
} **kwargs: Other keyword arguments in :class:`BaseDataset`.
CLASSES = [ """
'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
'Shirt', 'Sneaker', 'Bag', 'Ankle boot' url_prefix = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
# train images and labels
train_list = [
['train-images-idx3-ubyte.gz', '8d4fb7e6c68d591d4c3dfef9ec88bf0d'],
['train-labels-idx1-ubyte.gz', '25c81989df183df01b3e8a0aad5dffbe'],
] ]
# test images and labels
test_list = [
['t10k-images-idx3-ubyte.gz', 'bef4ecab320f06d8554ea6380940ec79'],
['t10k-labels-idx1-ubyte.gz', 'bb300cfdad3c16e7a12a480ee83cd310'],
]
METAINFO = {'classes': FASHIONMNIST_CATEGORITES}
def get_int(b): def get_int(b: bytes) -> int:
"""Convert bytes to int."""
return int(codecs.encode(b, 'hex'), 16) return int(codecs.encode(b, 'hex'), 16)
def open_maybe_compressed_file(path): def read_sn3_pascalvincent_tensor(path: str,
"""Return a file object that possibly decompresses 'path' on the fly. strict: bool = True) -> torch.Tensor:
Decompression occurs when argument `path` is a string and ends with '.gz'
or '.xz'.
"""
if not isinstance(path, str):
return path
if path.endswith('.gz'):
import gzip
return gzip.open(path, 'rb')
if path.endswith('.xz'):
import lzma
return lzma.open(path, 'rb')
return open(path, 'rb')
def read_sn3_pascalvincent_tensor(path, strict=True):
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx- """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-
io.lsh'). io.lsh').
@ -169,7 +207,8 @@ def read_sn3_pascalvincent_tensor(path, strict=True):
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s) return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
def read_label_file(path): def read_label_file(path: str) -> torch.Tensor:
"""Read labels from SN3 label file."""
with open(path, 'rb') as f: with open(path, 'rb') as f:
x = read_sn3_pascalvincent_tensor(f, strict=False) x = read_sn3_pascalvincent_tensor(f, strict=False)
assert (x.dtype == torch.uint8) assert (x.dtype == torch.uint8)
@ -177,7 +216,8 @@ def read_label_file(path):
return x.long() return x.long()
def read_image_file(path): def read_image_file(path: str) -> torch.Tensor:
"""Read images from SN3 image file."""
with open(path, 'rb') as f: with open(path, 'rb') as f:
x = read_sn3_pascalvincent_tensor(f, strict=False) x = read_sn3_pascalvincent_tensor(f, strict=False)
assert (x.dtype == torch.uint8) assert (x.dtype == torch.uint8)

View File

@ -12,7 +12,10 @@ import zipfile
from mmengine.fileio.file_client import FileClient from mmengine.fileio.file_client import FileClient
__all__ = ['rm_suffix', 'check_integrity', 'download_and_extract_archive'] __all__ = [
'rm_suffix', 'check_integrity', 'download_and_extract_archive',
'open_maybe_compressed_file'
]
def rm_suffix(s, suffix=None): def rm_suffix(s, suffix=None):
@ -221,3 +224,20 @@ def download_and_extract_archive(url,
archive = os.path.join(download_root, filename) archive = os.path.join(download_root, filename)
print(f'Extracting {archive} to {extract_root}') print(f'Extracting {archive} to {extract_root}')
extract_archive(archive, extract_root, remove_finished) extract_archive(archive, extract_root, remove_finished)
def open_maybe_compressed_file(path: str):
"""Return a file object that possibly decompresses 'path' on the fly.
Decompression occurs when argument `path` is a string and ends with '.gz'
or '.xz'.
"""
if not isinstance(path, str):
return path
if path.endswith('.gz'):
import gzip
return gzip.open(path, 'rb')
if path.endswith('.xz'):
import lzma
return lzma.open(path, 'rb')
return open(path, 'rb')

View File

@ -4,7 +4,7 @@ import os.path as osp
import pickle import pickle
import tempfile import tempfile
from unittest import TestCase from unittest import TestCase
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, call, patch
import numpy as np import numpy as np
from mmengine.registry import TRANSFORMS from mmengine.registry import TRANSFORMS
@ -655,8 +655,6 @@ class TestVOC(TestBaseDataset):
cls.tmpdir.cleanup() cls.tmpdir.cleanup()
"""Temporarily disabled.
class TestMNIST(TestBaseDataset): class TestMNIST(TestBaseDataset):
DATASET_TYPE = 'MNIST' DATASET_TYPE = 'MNIST'
@ -667,25 +665,22 @@ class TestMNIST(TestBaseDataset):
tmpdir = tempfile.TemporaryDirectory() tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir cls.tmpdir = tmpdir
data_prefix = tmpdir.name data_prefix = tmpdir.name
cls.DEFAULT_ARGS = dict(data_prefix=data_prefix, pipeline=[]) cls.DEFAULT_ARGS = dict(
data_prefix=data_prefix, pipeline=[], test_mode=False)
dataset_class = DATASETS.get(cls.DATASET_TYPE) dataset_class = DATASETS.get(cls.DATASET_TYPE)
def rm_suffix(s): def rm_suffix(s):
return s[:s.rfind('.')] return s[:s.rfind('.')]
train_image_file = osp.join( train_image_file = osp.join(data_prefix,
data_prefix, rm_suffix(dataset_class.train_list[0][0]))
rm_suffix(dataset_class.resources['train_image_file'][0])) train_label_file = osp.join(data_prefix,
train_label_file = osp.join( rm_suffix(dataset_class.train_list[1][0]))
data_prefix, test_image_file = osp.join(data_prefix,
rm_suffix(dataset_class.resources['train_label_file'][0])) rm_suffix(dataset_class.test_list[0][0]))
test_image_file = osp.join( test_label_file = osp.join(data_prefix,
data_prefix, rm_suffix(dataset_class.test_list[1][0]))
rm_suffix(dataset_class.resources['test_image_file'][0]))
test_label_file = osp.join(
data_prefix,
rm_suffix(dataset_class.resources['test_label_file'][0]))
cls.fake_img = np.random.randint(0, 255, size=(28, 28), dtype=np.uint8) cls.fake_img = np.random.randint(0, 255, size=(28, 28), dtype=np.uint8)
cls.fake_label = np.random.randint(0, 10, size=(1, ), dtype=np.uint8) cls.fake_label = np.random.randint(0, 10, size=(1, ), dtype=np.uint8)
@ -703,22 +698,89 @@ class TestMNIST(TestBaseDataset):
with open(file, 'wb') as f: with open(file, 'wb') as f:
f.write(data) f.write(data)
def test_load_annotations(self): def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE) dataset_class = DATASETS.get(self.DATASET_TYPE)
with patch.object(dataset_class, 'download'): # Test default behavior
# Test default behavior dataset = dataset_class(**self.DEFAULT_ARGS)
dataset = dataset_class(**self.DEFAULT_ARGS) self.assertEqual(len(dataset), 1)
self.assertEqual(len(dataset), 1) self.assertEqual(dataset.CLASSES, dataset_class.METAINFO['classes'])
data_info = dataset[0] data_info = dataset[0]
np.testing.assert_equal(data_info['img'], self.fake_img) np.testing.assert_equal(data_info['img'], self.fake_img)
np.testing.assert_equal(data_info['gt_label'], self.fake_label) np.testing.assert_equal(data_info['gt_label'], self.fake_label)
# Test with test_mode=True
cfg = {**self.DEFAULT_ARGS, 'test_mode': True}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 1)
data_info = dataset[0]
np.testing.assert_equal(data_info['img'], self.fake_img)
np.testing.assert_equal(data_info['gt_label'], self.fake_label)
# Test automatically download
with patch(
'mmcls.datasets.mnist.download_and_extract_archive') as mock:
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True, 'test_mode': True}
dataset = dataset_class(**cfg)
dataset.train_list = [['invalid_train_file', None]]
dataset.test_list = [['invalid_test_file', None]]
with self.assertRaisesRegex(AssertionError, 'Download failed'):
dataset.full_init()
calls = [
call(
osp.join(dataset.url_prefix, dataset.train_list[0][0]),
download_root=dataset.data_prefix['root'],
filename=dataset.train_list[0][0],
md5=None),
call(
osp.join(dataset.url_prefix, dataset.test_list[0][0]),
download_root=dataset.data_prefix['root'],
filename=dataset.test_list[0][0],
md5=None)
]
mock.assert_has_calls(calls)
with self.assertRaisesRegex(RuntimeError, '`download=True`'):
cfg = {
**self.DEFAULT_ARGS, 'lazy_init': True,
'test_mode': True,
'download': False
}
dataset = dataset_class(**cfg)
dataset._check_exists = MagicMock(return_value=False)
dataset.full_init()
# Test different backend
cfg = {
**self.DEFAULT_ARGS, 'lazy_init': True,
'data_prefix': 'http://openmmlab/mnist'
}
dataset = dataset_class(**cfg)
dataset._check_exists = MagicMock(return_value=False)
with self.assertRaisesRegex(RuntimeError, 'http://openmmlab/mnist'):
dataset.full_init()
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
dataset = dataset_class(**cfg)
self.assertIn(f'Prefix of data: \t{dataset.data_prefix["root"]}',
repr(dataset))
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
cls.tmpdir.cleanup() cls.tmpdir.cleanup()
class FashionMNIST(TestMNIST):
DATASET_TYPE = 'FashionMNIST'
"""Temporarily disabled.
class TestCUB(TestBaseDataset): class TestCUB(TestBaseDataset):
DATASET_TYPE = 'CUB' DATASET_TYPE = 'CUB'

View File

@ -2,8 +2,12 @@
import os.path as osp import os.path as osp
import random import random
import string import string
from unittest.mock import patch
from mmcls.datasets.utils import check_integrity, rm_suffix import pytest
from mmcls.datasets.utils import (check_integrity, open_maybe_compressed_file,
rm_suffix)
def test_dataset_utils(): def test_dataset_utils():
@ -20,3 +24,16 @@ def test_dataset_utils():
test_file = osp.join(osp.dirname(__file__), '../../data/color.jpg') test_file = osp.join(osp.dirname(__file__), '../../data/color.jpg')
assert check_integrity(test_file, md5='08252e5100cb321fe74e0e12a724ce14') assert check_integrity(test_file, md5='08252e5100cb321fe74e0e12a724ce14')
assert not check_integrity(test_file, md5=2333) assert not check_integrity(test_file, md5=2333)
@pytest.mark.parametrize('method,path', [('gzip.open', 'abc.gz'),
('lzma.open', 'abc.xz'),
('builtins.open', 'abc.txt'),
(None, 1)])
def test_open_maybe_compressed_file(method, path):
if method:
with patch(method) as mock:
open_maybe_compressed_file(path)
mock.assert_called()
else:
assert open_maybe_compressed_file(path) == path