[Refactor] Refactor MNIST and FashionMNIST dataset.

pull/913/head
yingfhu 2022-06-10 07:24:07 +00:00 committed by mzr1996
parent ee6e585e41
commit 125b74d4ca
6 changed files with 267 additions and 118 deletions

View File

@ -15,7 +15,7 @@ preprocess_cfg = dict(mean=[33.46], std=[78.87])
pipeline = [dict(type='Resize', scale=32), dict(type='PackClsInputs')]
common_data_cfg = dict(
type=dataset_type, data_root='data/mnist', pipeline=pipeline)
type=dataset_type, data_prefix='data/mnist', pipeline=pipeline)
train_dataloader = dict(
batch_size=128,
@ -53,6 +53,8 @@ val_cfg = dict(interval=1) # validate every epoch
test_cfg = dict()
# runtime settings
default_scope = 'mmcls'
default_hooks = dict(
# record the time of every iteration.
timer=dict(type='IterTimerHook'),

View File

@ -1027,3 +1027,11 @@ CIFAR100_CATEGORIES = (
'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train',
'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf',
'woman', 'worm')
MNIST_CATEGORITES = ('0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight',
'9 - nine')
FASHIONMNIST_CATEGORITES = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress',
'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag',
'Ankle boot')

View File

@ -1,16 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
import codecs
import os
import os.path as osp
from typing import List, Optional
import mmengine.dist as dist
import numpy as np
import torch
import torch.distributed as dist
from mmcv.runner import get_dist_info, master_only
from mmengine import FileClient
from mmcls.registry import DATASETS
from .base_dataset import BaseDataset
from .utils import download_and_extract_archive, rm_suffix
from .categories import FASHIONMNIST_CATEGORITES, MNIST_CATEGORITES
from .utils import (download_and_extract_archive, open_maybe_compressed_file,
rm_suffix)
@DATASETS.register_module()
@ -19,58 +20,83 @@ class MNIST(BaseDataset):
This implementation is modified from
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
Args:
data_prefix (str): Prefix for data.
test_mode (bool): ``test_mode=True`` means in test phase.
It determines to use the training set or test set.
metainfo (dict, optional): Meta information for dataset, such as
categories information. Defaults to None.
data_root (str, optional): The root directory for ``data_prefix``.
Defaults to None.
download (bool): Whether to download the dataset if not exists.
Defaults to True.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
""" # noqa: E501
resource_prefix = 'http://yann.lecun.com/exdb/mnist/'
resources = {
'train_image_file':
('train-images-idx3-ubyte.gz', 'f68b3c2dcbeaaa9fbdd348bbdeb94873'),
'train_label_file':
('train-labels-idx1-ubyte.gz', 'd53e105ee54ea40749a09fcbcd1e9432'),
'test_image_file':
('t10k-images-idx3-ubyte.gz', '9fb629c4189551a2d022fa330f9573f3'),
'test_label_file':
('t10k-labels-idx1-ubyte.gz', 'ec29112dd5afa0611ce80d1b7f02629c')
}
CLASSES = [
'0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five',
'6 - six', '7 - seven', '8 - eight', '9 - nine'
url_prefix = 'http://yann.lecun.com/exdb/mnist/'
# train images and labels
train_list = [
['train-images-idx3-ubyte.gz', 'f68b3c2dcbeaaa9fbdd348bbdeb94873'],
['train-labels-idx1-ubyte.gz', 'd53e105ee54ea40749a09fcbcd1e9432'],
]
# test images and labels
test_list = [
['t10k-images-idx3-ubyte.gz', '9fb629c4189551a2d022fa330f9573f3'],
['t10k-labels-idx1-ubyte.gz', 'ec29112dd5afa0611ce80d1b7f02629c'],
]
METAINFO = {'classes': MNIST_CATEGORITES}
def load_annotations(self):
train_image_file = osp.join(
self.data_prefix, rm_suffix(self.resources['train_image_file'][0]))
train_label_file = osp.join(
self.data_prefix, rm_suffix(self.resources['train_label_file'][0]))
test_image_file = osp.join(
self.data_prefix, rm_suffix(self.resources['test_image_file'][0]))
test_label_file = osp.join(
self.data_prefix, rm_suffix(self.resources['test_label_file'][0]))
def __init__(self,
data_prefix: str,
test_mode: bool,
metainfo: Optional[dict] = None,
data_root: Optional[str] = None,
download: bool = True,
**kwargs):
self.download = download
super().__init__(
# The MNIST dataset doesn't need specify annotation file
ann_file='',
metainfo=metainfo,
data_root=data_root,
data_prefix=dict(root=data_prefix),
test_mode=test_mode,
**kwargs)
if not osp.exists(train_image_file) or not osp.exists(
train_label_file) or not osp.exists(
test_image_file) or not osp.exists(test_label_file):
self.download()
def load_data_list(self):
"""Load images and ground truth labels."""
root = self.data_prefix['root']
file_client = FileClient.infer_client(uri=root)
_, world_size = get_dist_info()
if world_size > 1:
dist.barrier()
assert osp.exists(train_image_file) and osp.exists(
train_label_file) and osp.exists(
test_image_file) and osp.exists(test_label_file), \
'Shared storage seems unavailable. Please download dataset ' \
f'manually through {self.resource_prefix}.'
if dist.is_main_process() and not self._check_exists():
if file_client.name != 'HardDiskBackend':
raise RuntimeError(f'The dataset on {root} is not integrated, '
f'please manually handle it.')
train_set = (read_image_file(train_image_file),
read_label_file(train_label_file))
test_set = (read_image_file(test_image_file),
read_label_file(test_label_file))
if self.download:
self._download()
else:
raise RuntimeError(
f'Cannot find {self.__class__.__name__} dataset in '
f"{self.data_prefix['root']}, you can specify "
'`download=True` to download automatically.')
dist.barrier()
assert self._check_exists(), \
'Download failed or shared storage is unavailable. Please ' \
f'download the dataset manually through {self.url_prefix}.'
if not self.test_mode:
imgs, gt_labels = train_set
file_list = self.train_list
else:
imgs, gt_labels = test_set
file_list = self.test_list
# load data from SN3 files
imgs = read_image_file(
file_client.join_path(root, rm_suffix(file_list[0][0])))
gt_labels = read_label_file(
file_client.join_path(root, rm_suffix(file_list[1][0])))
data_infos = []
for img, gt_label in zip(imgs, gt_labels):
@ -79,65 +105,77 @@ class MNIST(BaseDataset):
data_infos.append(info)
return data_infos
@master_only
def download(self):
os.makedirs(self.data_prefix, exist_ok=True)
def _check_exists(self):
"""Check the exists of data files."""
root = self.data_prefix['root']
file_client = FileClient.infer_client(uri=root)
# download files
for url, md5 in self.resources.values():
url = osp.join(self.resource_prefix, url)
filename = url.rpartition('/')[2]
for filename, _ in (self.train_list + self.test_list):
# get extracted filename of data
extract_filename = rm_suffix(filename)
fpath = file_client.join_path(root, extract_filename)
if not file_client.exists(fpath):
return False
return True
def _download(self):
"""Download and extract data files."""
root = self.data_prefix['root']
file_client = FileClient.infer_client(uri=root)
for filename, md5 in (self.train_list + self.test_list):
url = file_client.join_path(self.url_prefix, filename)
download_and_extract_archive(
url,
download_root=self.data_prefix,
download_root=self.data_prefix['root'],
filename=filename,
md5=md5)
def extra_repr(self) -> List[str]:
"""The extra repr information of the dataset."""
body = [f"Prefix of data: \t{self.data_prefix['root']}"]
return body
@DATASETS.register_module()
class FashionMNIST(MNIST):
"""`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_
Dataset."""
Dataset.
resource_prefix = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/' # noqa: E501
resources = {
'train_image_file':
('train-images-idx3-ubyte.gz', '8d4fb7e6c68d591d4c3dfef9ec88bf0d'),
'train_label_file':
('train-labels-idx1-ubyte.gz', '25c81989df183df01b3e8a0aad5dffbe'),
'test_image_file':
('t10k-images-idx3-ubyte.gz', 'bef4ecab320f06d8554ea6380940ec79'),
'test_label_file':
('t10k-labels-idx1-ubyte.gz', 'bb300cfdad3c16e7a12a480ee83cd310')
}
CLASSES = [
'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
Args:
data_prefix (str): Prefix for data.
test_mode (bool): ``test_mode=True`` means in test phase.
It determines to use the training set or test set.
metainfo (dict, optional): Meta information for dataset, such as
categories information. Defaults to None.
data_root (str, optional): The root directory for ``data_prefix``.
Defaults to None.
download (bool): Whether to download the dataset if not exists.
Defaults to True.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
url_prefix = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
# train images and labels
train_list = [
['train-images-idx3-ubyte.gz', '8d4fb7e6c68d591d4c3dfef9ec88bf0d'],
['train-labels-idx1-ubyte.gz', '25c81989df183df01b3e8a0aad5dffbe'],
]
# test images and labels
test_list = [
['t10k-images-idx3-ubyte.gz', 'bef4ecab320f06d8554ea6380940ec79'],
['t10k-labels-idx1-ubyte.gz', 'bb300cfdad3c16e7a12a480ee83cd310'],
]
METAINFO = {'classes': FASHIONMNIST_CATEGORITES}
def get_int(b):
def get_int(b: bytes) -> int:
"""Convert bytes to int."""
return int(codecs.encode(b, 'hex'), 16)
def open_maybe_compressed_file(path):
"""Return a file object that possibly decompresses 'path' on the fly.
Decompression occurs when argument `path` is a string and ends with '.gz'
or '.xz'.
"""
if not isinstance(path, str):
return path
if path.endswith('.gz'):
import gzip
return gzip.open(path, 'rb')
if path.endswith('.xz'):
import lzma
return lzma.open(path, 'rb')
return open(path, 'rb')
def read_sn3_pascalvincent_tensor(path, strict=True):
def read_sn3_pascalvincent_tensor(path: str,
strict: bool = True) -> torch.Tensor:
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-
io.lsh').
@ -169,7 +207,8 @@ def read_sn3_pascalvincent_tensor(path, strict=True):
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
def read_label_file(path):
def read_label_file(path: str) -> torch.Tensor:
"""Read labels from SN3 label file."""
with open(path, 'rb') as f:
x = read_sn3_pascalvincent_tensor(f, strict=False)
assert (x.dtype == torch.uint8)
@ -177,7 +216,8 @@ def read_label_file(path):
return x.long()
def read_image_file(path):
def read_image_file(path: str) -> torch.Tensor:
"""Read images from SN3 image file."""
with open(path, 'rb') as f:
x = read_sn3_pascalvincent_tensor(f, strict=False)
assert (x.dtype == torch.uint8)

View File

@ -12,7 +12,10 @@ import zipfile
from mmengine.fileio.file_client import FileClient
__all__ = ['rm_suffix', 'check_integrity', 'download_and_extract_archive']
__all__ = [
'rm_suffix', 'check_integrity', 'download_and_extract_archive',
'open_maybe_compressed_file'
]
def rm_suffix(s, suffix=None):
@ -221,3 +224,20 @@ def download_and_extract_archive(url,
archive = os.path.join(download_root, filename)
print(f'Extracting {archive} to {extract_root}')
extract_archive(archive, extract_root, remove_finished)
def open_maybe_compressed_file(path: str):
"""Return a file object that possibly decompresses 'path' on the fly.
Decompression occurs when argument `path` is a string and ends with '.gz'
or '.xz'.
"""
if not isinstance(path, str):
return path
if path.endswith('.gz'):
import gzip
return gzip.open(path, 'rb')
if path.endswith('.xz'):
import lzma
return lzma.open(path, 'rb')
return open(path, 'rb')

View File

@ -4,7 +4,7 @@ import os.path as osp
import pickle
import tempfile
from unittest import TestCase
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, call, patch
import numpy as np
from mmengine.registry import TRANSFORMS
@ -655,8 +655,6 @@ class TestVOC(TestBaseDataset):
cls.tmpdir.cleanup()
"""Temporarily disabled.
class TestMNIST(TestBaseDataset):
DATASET_TYPE = 'MNIST'
@ -667,25 +665,22 @@ class TestMNIST(TestBaseDataset):
tmpdir = tempfile.TemporaryDirectory()
cls.tmpdir = tmpdir
data_prefix = tmpdir.name
cls.DEFAULT_ARGS = dict(data_prefix=data_prefix, pipeline=[])
cls.DEFAULT_ARGS = dict(
data_prefix=data_prefix, pipeline=[], test_mode=False)
dataset_class = DATASETS.get(cls.DATASET_TYPE)
def rm_suffix(s):
return s[:s.rfind('.')]
train_image_file = osp.join(
data_prefix,
rm_suffix(dataset_class.resources['train_image_file'][0]))
train_label_file = osp.join(
data_prefix,
rm_suffix(dataset_class.resources['train_label_file'][0]))
test_image_file = osp.join(
data_prefix,
rm_suffix(dataset_class.resources['test_image_file'][0]))
test_label_file = osp.join(
data_prefix,
rm_suffix(dataset_class.resources['test_label_file'][0]))
train_image_file = osp.join(data_prefix,
rm_suffix(dataset_class.train_list[0][0]))
train_label_file = osp.join(data_prefix,
rm_suffix(dataset_class.train_list[1][0]))
test_image_file = osp.join(data_prefix,
rm_suffix(dataset_class.test_list[0][0]))
test_label_file = osp.join(data_prefix,
rm_suffix(dataset_class.test_list[1][0]))
cls.fake_img = np.random.randint(0, 255, size=(28, 28), dtype=np.uint8)
cls.fake_label = np.random.randint(0, 10, size=(1, ), dtype=np.uint8)
@ -703,22 +698,89 @@ class TestMNIST(TestBaseDataset):
with open(file, 'wb') as f:
f.write(data)
def test_load_annotations(self):
def test_load_data_list(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
with patch.object(dataset_class, 'download'):
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
self.assertEqual(len(dataset), 1)
# Test default behavior
dataset = dataset_class(**self.DEFAULT_ARGS)
self.assertEqual(len(dataset), 1)
self.assertEqual(dataset.CLASSES, dataset_class.METAINFO['classes'])
data_info = dataset[0]
np.testing.assert_equal(data_info['img'], self.fake_img)
np.testing.assert_equal(data_info['gt_label'], self.fake_label)
data_info = dataset[0]
np.testing.assert_equal(data_info['img'], self.fake_img)
np.testing.assert_equal(data_info['gt_label'], self.fake_label)
# Test with test_mode=True
cfg = {**self.DEFAULT_ARGS, 'test_mode': True}
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 1)
data_info = dataset[0]
np.testing.assert_equal(data_info['img'], self.fake_img)
np.testing.assert_equal(data_info['gt_label'], self.fake_label)
# Test automatically download
with patch(
'mmcls.datasets.mnist.download_and_extract_archive') as mock:
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True, 'test_mode': True}
dataset = dataset_class(**cfg)
dataset.train_list = [['invalid_train_file', None]]
dataset.test_list = [['invalid_test_file', None]]
with self.assertRaisesRegex(AssertionError, 'Download failed'):
dataset.full_init()
calls = [
call(
osp.join(dataset.url_prefix, dataset.train_list[0][0]),
download_root=dataset.data_prefix['root'],
filename=dataset.train_list[0][0],
md5=None),
call(
osp.join(dataset.url_prefix, dataset.test_list[0][0]),
download_root=dataset.data_prefix['root'],
filename=dataset.test_list[0][0],
md5=None)
]
mock.assert_has_calls(calls)
with self.assertRaisesRegex(RuntimeError, '`download=True`'):
cfg = {
**self.DEFAULT_ARGS, 'lazy_init': True,
'test_mode': True,
'download': False
}
dataset = dataset_class(**cfg)
dataset._check_exists = MagicMock(return_value=False)
dataset.full_init()
# Test different backend
cfg = {
**self.DEFAULT_ARGS, 'lazy_init': True,
'data_prefix': 'http://openmmlab/mnist'
}
dataset = dataset_class(**cfg)
dataset._check_exists = MagicMock(return_value=False)
with self.assertRaisesRegex(RuntimeError, 'http://openmmlab/mnist'):
dataset.full_init()
def test_extra_repr(self):
dataset_class = DATASETS.get(self.DATASET_TYPE)
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
dataset = dataset_class(**cfg)
self.assertIn(f'Prefix of data: \t{dataset.data_prefix["root"]}',
repr(dataset))
@classmethod
def tearDownClass(cls):
cls.tmpdir.cleanup()
class FashionMNIST(TestMNIST):
DATASET_TYPE = 'FashionMNIST'
"""Temporarily disabled.
class TestCUB(TestBaseDataset):
DATASET_TYPE = 'CUB'

View File

@ -2,8 +2,12 @@
import os.path as osp
import random
import string
from unittest.mock import patch
from mmcls.datasets.utils import check_integrity, rm_suffix
import pytest
from mmcls.datasets.utils import (check_integrity, open_maybe_compressed_file,
rm_suffix)
def test_dataset_utils():
@ -20,3 +24,16 @@ def test_dataset_utils():
test_file = osp.join(osp.dirname(__file__), '../../data/color.jpg')
assert check_integrity(test_file, md5='08252e5100cb321fe74e0e12a724ce14')
assert not check_integrity(test_file, md5=2333)
@pytest.mark.parametrize('method,path', [('gzip.open', 'abc.gz'),
('lzma.open', 'abc.xz'),
('builtins.open', 'abc.txt'),
(None, 1)])
def test_open_maybe_compressed_file(method, path):
if method:
with patch(method) as mock:
open_maybe_compressed_file(path)
mock.assert_called()
else:
assert open_maybe_compressed_file(path) == path