221 lines
8.0 KiB
Python
221 lines
8.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import codecs
|
|
from typing import List, Optional
|
|
from urllib.parse import urljoin
|
|
|
|
import mmengine.dist as dist
|
|
import numpy as np
|
|
import torch
|
|
from mmengine.fileio import LocalBackend, exists, get_file_backend, join_path
|
|
|
|
from mmcls.registry import DATASETS
|
|
from .base_dataset import BaseDataset
|
|
from .categories import FASHIONMNIST_CATEGORITES, MNIST_CATEGORITES
|
|
from .utils import (download_and_extract_archive, open_maybe_compressed_file,
|
|
rm_suffix)
|
|
|
|
|
|
@DATASETS.register_module()
|
|
class MNIST(BaseDataset):
|
|
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
|
|
|
|
This implementation is modified from
|
|
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
|
|
|
|
Args:
|
|
data_prefix (str): Prefix for data.
|
|
test_mode (bool): ``test_mode=True`` means in test phase.
|
|
It determines to use the training set or test set.
|
|
metainfo (dict, optional): Meta information for dataset, such as
|
|
categories information. Defaults to None.
|
|
data_root (str): The root directory for ``data_prefix``.
|
|
Defaults to ''.
|
|
download (bool): Whether to download the dataset if not exists.
|
|
Defaults to True.
|
|
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
|
""" # noqa: E501
|
|
|
|
url_prefix = 'http://yann.lecun.com/exdb/mnist/'
|
|
# train images and labels
|
|
train_list = [
|
|
['train-images-idx3-ubyte.gz', 'f68b3c2dcbeaaa9fbdd348bbdeb94873'],
|
|
['train-labels-idx1-ubyte.gz', 'd53e105ee54ea40749a09fcbcd1e9432'],
|
|
]
|
|
# test images and labels
|
|
test_list = [
|
|
['t10k-images-idx3-ubyte.gz', '9fb629c4189551a2d022fa330f9573f3'],
|
|
['t10k-labels-idx1-ubyte.gz', 'ec29112dd5afa0611ce80d1b7f02629c'],
|
|
]
|
|
METAINFO = {'classes': MNIST_CATEGORITES}
|
|
|
|
def __init__(self,
|
|
data_prefix: str,
|
|
test_mode: bool,
|
|
metainfo: Optional[dict] = None,
|
|
data_root: str = '',
|
|
download: bool = True,
|
|
**kwargs):
|
|
self.download = download
|
|
super().__init__(
|
|
# The MNIST dataset doesn't need specify annotation file
|
|
ann_file='',
|
|
metainfo=metainfo,
|
|
data_root=data_root,
|
|
data_prefix=dict(root=data_prefix),
|
|
test_mode=test_mode,
|
|
**kwargs)
|
|
|
|
def load_data_list(self):
|
|
"""Load images and ground truth labels."""
|
|
root = self.data_prefix['root']
|
|
backend = get_file_backend(root, enable_singleton=True)
|
|
|
|
if dist.is_main_process() and not self._check_exists():
|
|
if not isinstance(backend, LocalBackend):
|
|
raise RuntimeError(f'The dataset on {root} is not integrated, '
|
|
f'please manually handle it.')
|
|
|
|
if self.download:
|
|
self._download()
|
|
else:
|
|
raise RuntimeError(
|
|
f'Cannot find {self.__class__.__name__} dataset in '
|
|
f"{self.data_prefix['root']}, you can specify "
|
|
'`download=True` to download automatically.')
|
|
|
|
dist.barrier()
|
|
assert self._check_exists(), \
|
|
'Download failed or shared storage is unavailable. Please ' \
|
|
f'download the dataset manually through {self.url_prefix}.'
|
|
|
|
if not self.test_mode:
|
|
file_list = self.train_list
|
|
else:
|
|
file_list = self.test_list
|
|
|
|
# load data from SN3 files
|
|
imgs = read_image_file(join_path(root, rm_suffix(file_list[0][0])))
|
|
gt_labels = read_label_file(
|
|
join_path(root, rm_suffix(file_list[1][0])))
|
|
|
|
data_infos = []
|
|
for img, gt_label in zip(imgs, gt_labels):
|
|
gt_label = np.array(gt_label, dtype=np.int64)
|
|
info = {'img': img.numpy(), 'gt_label': gt_label}
|
|
data_infos.append(info)
|
|
return data_infos
|
|
|
|
def _check_exists(self):
|
|
"""Check the exists of data files."""
|
|
root = self.data_prefix['root']
|
|
|
|
for filename, _ in (self.train_list + self.test_list):
|
|
# get extracted filename of data
|
|
extract_filename = rm_suffix(filename)
|
|
fpath = join_path(root, extract_filename)
|
|
if not exists(fpath):
|
|
return False
|
|
return True
|
|
|
|
def _download(self):
|
|
"""Download and extract data files."""
|
|
root = self.data_prefix['root']
|
|
|
|
for filename, md5 in (self.train_list + self.test_list):
|
|
url = urljoin(self.url_prefix, filename)
|
|
download_and_extract_archive(
|
|
url, download_root=root, filename=filename, md5=md5)
|
|
|
|
def extra_repr(self) -> List[str]:
|
|
"""The extra repr information of the dataset."""
|
|
body = [f"Prefix of data: \t{self.data_prefix['root']}"]
|
|
return body
|
|
|
|
|
|
@DATASETS.register_module()
|
|
class FashionMNIST(MNIST):
|
|
"""`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_
|
|
Dataset.
|
|
|
|
Args:
|
|
data_prefix (str): Prefix for data.
|
|
test_mode (bool): ``test_mode=True`` means in test phase.
|
|
It determines to use the training set or test set.
|
|
metainfo (dict, optional): Meta information for dataset, such as
|
|
categories information. Defaults to None.
|
|
data_root (str): The root directory for ``data_prefix``.
|
|
Defaults to ''.
|
|
download (bool): Whether to download the dataset if not exists.
|
|
Defaults to True.
|
|
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
|
"""
|
|
|
|
url_prefix = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
|
|
# train images and labels
|
|
train_list = [
|
|
['train-images-idx3-ubyte.gz', '8d4fb7e6c68d591d4c3dfef9ec88bf0d'],
|
|
['train-labels-idx1-ubyte.gz', '25c81989df183df01b3e8a0aad5dffbe'],
|
|
]
|
|
# test images and labels
|
|
test_list = [
|
|
['t10k-images-idx3-ubyte.gz', 'bef4ecab320f06d8554ea6380940ec79'],
|
|
['t10k-labels-idx1-ubyte.gz', 'bb300cfdad3c16e7a12a480ee83cd310'],
|
|
]
|
|
METAINFO = {'classes': FASHIONMNIST_CATEGORITES}
|
|
|
|
|
|
def get_int(b: bytes) -> int:
|
|
"""Convert bytes to int."""
|
|
return int(codecs.encode(b, 'hex'), 16)
|
|
|
|
|
|
def read_sn3_pascalvincent_tensor(path: str,
|
|
strict: bool = True) -> torch.Tensor:
|
|
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-
|
|
io.lsh').
|
|
|
|
Argument may be a filename, compressed filename, or file object.
|
|
"""
|
|
# typemap
|
|
if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'):
|
|
read_sn3_pascalvincent_tensor.typemap = {
|
|
8: (torch.uint8, np.uint8, np.uint8),
|
|
9: (torch.int8, np.int8, np.int8),
|
|
11: (torch.int16, np.dtype('>i2'), 'i2'),
|
|
12: (torch.int32, np.dtype('>i4'), 'i4'),
|
|
13: (torch.float32, np.dtype('>f4'), 'f4'),
|
|
14: (torch.float64, np.dtype('>f8'), 'f8')
|
|
}
|
|
# read
|
|
with open_maybe_compressed_file(path) as f:
|
|
data = f.read()
|
|
# parse
|
|
magic = get_int(data[0:4])
|
|
nd = magic % 256
|
|
ty = magic // 256
|
|
assert nd >= 1 and nd <= 3
|
|
assert ty >= 8 and ty <= 14
|
|
m = read_sn3_pascalvincent_tensor.typemap[ty]
|
|
s = [get_int(data[4 * (i + 1):4 * (i + 2)]) for i in range(nd)]
|
|
parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
|
|
assert parsed.shape[0] == np.prod(s) or not strict
|
|
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
|
|
|
|
|
|
def read_label_file(path: str) -> torch.Tensor:
|
|
"""Read labels from SN3 label file."""
|
|
with open(path, 'rb') as f:
|
|
x = read_sn3_pascalvincent_tensor(f, strict=False)
|
|
assert (x.dtype == torch.uint8)
|
|
assert (x.ndimension() == 1)
|
|
return x.long()
|
|
|
|
|
|
def read_image_file(path: str) -> torch.Tensor:
|
|
"""Read images from SN3 image file."""
|
|
with open(path, 'rb') as f:
|
|
x = read_sn3_pascalvincent_tensor(f, strict=False)
|
|
assert (x.dtype == torch.uint8)
|
|
assert (x.ndimension() == 3)
|
|
return x
|