[Refactor] Refactor MNIST and FashionMNIST dataset.
parent
ee6e585e41
commit
125b74d4ca
|
@ -15,7 +15,7 @@ preprocess_cfg = dict(mean=[33.46], std=[78.87])
|
|||
pipeline = [dict(type='Resize', scale=32), dict(type='PackClsInputs')]
|
||||
|
||||
common_data_cfg = dict(
|
||||
type=dataset_type, data_root='data/mnist', pipeline=pipeline)
|
||||
type=dataset_type, data_prefix='data/mnist', pipeline=pipeline)
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=128,
|
||||
|
@ -53,6 +53,8 @@ val_cfg = dict(interval=1) # validate every epoch
|
|||
test_cfg = dict()
|
||||
|
||||
# runtime settings
|
||||
default_scope = 'mmcls'
|
||||
|
||||
default_hooks = dict(
|
||||
# record the time of every iteration.
|
||||
timer=dict(type='IterTimerHook'),
|
||||
|
|
|
@ -1027,3 +1027,11 @@ CIFAR100_CATEGORIES = (
|
|||
'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train',
|
||||
'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf',
|
||||
'woman', 'worm')
|
||||
|
||||
MNIST_CATEGORITES = ('0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
|
||||
'5 - five', '6 - six', '7 - seven', '8 - eight',
|
||||
'9 - nine')
|
||||
|
||||
FASHIONMNIST_CATEGORITES = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress',
|
||||
'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag',
|
||||
'Ankle boot')
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import codecs
|
||||
import os
|
||||
import os.path as osp
|
||||
from typing import List, Optional
|
||||
|
||||
import mmengine.dist as dist
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv.runner import get_dist_info, master_only
|
||||
from mmengine import FileClient
|
||||
|
||||
from mmcls.registry import DATASETS
|
||||
from .base_dataset import BaseDataset
|
||||
from .utils import download_and_extract_archive, rm_suffix
|
||||
from .categories import FASHIONMNIST_CATEGORITES, MNIST_CATEGORITES
|
||||
from .utils import (download_and_extract_archive, open_maybe_compressed_file,
|
||||
rm_suffix)
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
@ -19,58 +20,83 @@ class MNIST(BaseDataset):
|
|||
|
||||
This implementation is modified from
|
||||
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
|
||||
|
||||
Args:
|
||||
data_prefix (str): Prefix for data.
|
||||
test_mode (bool): ``test_mode=True`` means in test phase.
|
||||
It determines to use the training set or test set.
|
||||
metainfo (dict, optional): Meta information for dataset, such as
|
||||
categories information. Defaults to None.
|
||||
data_root (str, optional): The root directory for ``data_prefix``.
|
||||
Defaults to None.
|
||||
download (bool): Whether to download the dataset if not exists.
|
||||
Defaults to True.
|
||||
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
||||
""" # noqa: E501
|
||||
|
||||
resource_prefix = 'http://yann.lecun.com/exdb/mnist/'
|
||||
resources = {
|
||||
'train_image_file':
|
||||
('train-images-idx3-ubyte.gz', 'f68b3c2dcbeaaa9fbdd348bbdeb94873'),
|
||||
'train_label_file':
|
||||
('train-labels-idx1-ubyte.gz', 'd53e105ee54ea40749a09fcbcd1e9432'),
|
||||
'test_image_file':
|
||||
('t10k-images-idx3-ubyte.gz', '9fb629c4189551a2d022fa330f9573f3'),
|
||||
'test_label_file':
|
||||
('t10k-labels-idx1-ubyte.gz', 'ec29112dd5afa0611ce80d1b7f02629c')
|
||||
}
|
||||
|
||||
CLASSES = [
|
||||
'0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five',
|
||||
'6 - six', '7 - seven', '8 - eight', '9 - nine'
|
||||
url_prefix = 'http://yann.lecun.com/exdb/mnist/'
|
||||
# train images and labels
|
||||
train_list = [
|
||||
['train-images-idx3-ubyte.gz', 'f68b3c2dcbeaaa9fbdd348bbdeb94873'],
|
||||
['train-labels-idx1-ubyte.gz', 'd53e105ee54ea40749a09fcbcd1e9432'],
|
||||
]
|
||||
# test images and labels
|
||||
test_list = [
|
||||
['t10k-images-idx3-ubyte.gz', '9fb629c4189551a2d022fa330f9573f3'],
|
||||
['t10k-labels-idx1-ubyte.gz', 'ec29112dd5afa0611ce80d1b7f02629c'],
|
||||
]
|
||||
METAINFO = {'classes': MNIST_CATEGORITES}
|
||||
|
||||
def load_annotations(self):
|
||||
train_image_file = osp.join(
|
||||
self.data_prefix, rm_suffix(self.resources['train_image_file'][0]))
|
||||
train_label_file = osp.join(
|
||||
self.data_prefix, rm_suffix(self.resources['train_label_file'][0]))
|
||||
test_image_file = osp.join(
|
||||
self.data_prefix, rm_suffix(self.resources['test_image_file'][0]))
|
||||
test_label_file = osp.join(
|
||||
self.data_prefix, rm_suffix(self.resources['test_label_file'][0]))
|
||||
def __init__(self,
|
||||
data_prefix: str,
|
||||
test_mode: bool,
|
||||
metainfo: Optional[dict] = None,
|
||||
data_root: Optional[str] = None,
|
||||
download: bool = True,
|
||||
**kwargs):
|
||||
self.download = download
|
||||
super().__init__(
|
||||
# The MNIST dataset doesn't need specify annotation file
|
||||
ann_file='',
|
||||
metainfo=metainfo,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(root=data_prefix),
|
||||
test_mode=test_mode,
|
||||
**kwargs)
|
||||
|
||||
if not osp.exists(train_image_file) or not osp.exists(
|
||||
train_label_file) or not osp.exists(
|
||||
test_image_file) or not osp.exists(test_label_file):
|
||||
self.download()
|
||||
def load_data_list(self):
|
||||
"""Load images and ground truth labels."""
|
||||
root = self.data_prefix['root']
|
||||
file_client = FileClient.infer_client(uri=root)
|
||||
|
||||
_, world_size = get_dist_info()
|
||||
if world_size > 1:
|
||||
dist.barrier()
|
||||
assert osp.exists(train_image_file) and osp.exists(
|
||||
train_label_file) and osp.exists(
|
||||
test_image_file) and osp.exists(test_label_file), \
|
||||
'Shared storage seems unavailable. Please download dataset ' \
|
||||
f'manually through {self.resource_prefix}.'
|
||||
if dist.is_main_process() and not self._check_exists():
|
||||
if file_client.name != 'HardDiskBackend':
|
||||
raise RuntimeError(f'The dataset on {root} is not integrated, '
|
||||
f'please manually handle it.')
|
||||
|
||||
train_set = (read_image_file(train_image_file),
|
||||
read_label_file(train_label_file))
|
||||
test_set = (read_image_file(test_image_file),
|
||||
read_label_file(test_label_file))
|
||||
if self.download:
|
||||
self._download()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f'Cannot find {self.__class__.__name__} dataset in '
|
||||
f"{self.data_prefix['root']}, you can specify "
|
||||
'`download=True` to download automatically.')
|
||||
|
||||
dist.barrier()
|
||||
assert self._check_exists(), \
|
||||
'Download failed or shared storage is unavailable. Please ' \
|
||||
f'download the dataset manually through {self.url_prefix}.'
|
||||
|
||||
if not self.test_mode:
|
||||
imgs, gt_labels = train_set
|
||||
file_list = self.train_list
|
||||
else:
|
||||
imgs, gt_labels = test_set
|
||||
file_list = self.test_list
|
||||
|
||||
# load data from SN3 files
|
||||
imgs = read_image_file(
|
||||
file_client.join_path(root, rm_suffix(file_list[0][0])))
|
||||
gt_labels = read_label_file(
|
||||
file_client.join_path(root, rm_suffix(file_list[1][0])))
|
||||
|
||||
data_infos = []
|
||||
for img, gt_label in zip(imgs, gt_labels):
|
||||
|
@ -79,65 +105,77 @@ class MNIST(BaseDataset):
|
|||
data_infos.append(info)
|
||||
return data_infos
|
||||
|
||||
@master_only
|
||||
def download(self):
|
||||
os.makedirs(self.data_prefix, exist_ok=True)
|
||||
def _check_exists(self):
|
||||
"""Check the exists of data files."""
|
||||
root = self.data_prefix['root']
|
||||
file_client = FileClient.infer_client(uri=root)
|
||||
|
||||
# download files
|
||||
for url, md5 in self.resources.values():
|
||||
url = osp.join(self.resource_prefix, url)
|
||||
filename = url.rpartition('/')[2]
|
||||
for filename, _ in (self.train_list + self.test_list):
|
||||
# get extracted filename of data
|
||||
extract_filename = rm_suffix(filename)
|
||||
fpath = file_client.join_path(root, extract_filename)
|
||||
if not file_client.exists(fpath):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _download(self):
|
||||
"""Download and extract data files."""
|
||||
root = self.data_prefix['root']
|
||||
file_client = FileClient.infer_client(uri=root)
|
||||
|
||||
for filename, md5 in (self.train_list + self.test_list):
|
||||
url = file_client.join_path(self.url_prefix, filename)
|
||||
download_and_extract_archive(
|
||||
url,
|
||||
download_root=self.data_prefix,
|
||||
download_root=self.data_prefix['root'],
|
||||
filename=filename,
|
||||
md5=md5)
|
||||
|
||||
def extra_repr(self) -> List[str]:
|
||||
"""The extra repr information of the dataset."""
|
||||
body = [f"Prefix of data: \t{self.data_prefix['root']}"]
|
||||
return body
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class FashionMNIST(MNIST):
|
||||
"""`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_
|
||||
Dataset."""
|
||||
Dataset.
|
||||
|
||||
resource_prefix = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/' # noqa: E501
|
||||
resources = {
|
||||
'train_image_file':
|
||||
('train-images-idx3-ubyte.gz', '8d4fb7e6c68d591d4c3dfef9ec88bf0d'),
|
||||
'train_label_file':
|
||||
('train-labels-idx1-ubyte.gz', '25c81989df183df01b3e8a0aad5dffbe'),
|
||||
'test_image_file':
|
||||
('t10k-images-idx3-ubyte.gz', 'bef4ecab320f06d8554ea6380940ec79'),
|
||||
'test_label_file':
|
||||
('t10k-labels-idx1-ubyte.gz', 'bb300cfdad3c16e7a12a480ee83cd310')
|
||||
}
|
||||
CLASSES = [
|
||||
'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
|
||||
'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
|
||||
Args:
|
||||
data_prefix (str): Prefix for data.
|
||||
test_mode (bool): ``test_mode=True`` means in test phase.
|
||||
It determines to use the training set or test set.
|
||||
metainfo (dict, optional): Meta information for dataset, such as
|
||||
categories information. Defaults to None.
|
||||
data_root (str, optional): The root directory for ``data_prefix``.
|
||||
Defaults to None.
|
||||
download (bool): Whether to download the dataset if not exists.
|
||||
Defaults to True.
|
||||
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
||||
"""
|
||||
|
||||
url_prefix = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
|
||||
# train images and labels
|
||||
train_list = [
|
||||
['train-images-idx3-ubyte.gz', '8d4fb7e6c68d591d4c3dfef9ec88bf0d'],
|
||||
['train-labels-idx1-ubyte.gz', '25c81989df183df01b3e8a0aad5dffbe'],
|
||||
]
|
||||
# test images and labels
|
||||
test_list = [
|
||||
['t10k-images-idx3-ubyte.gz', 'bef4ecab320f06d8554ea6380940ec79'],
|
||||
['t10k-labels-idx1-ubyte.gz', 'bb300cfdad3c16e7a12a480ee83cd310'],
|
||||
]
|
||||
METAINFO = {'classes': FASHIONMNIST_CATEGORITES}
|
||||
|
||||
|
||||
def get_int(b):
|
||||
def get_int(b: bytes) -> int:
|
||||
"""Convert bytes to int."""
|
||||
return int(codecs.encode(b, 'hex'), 16)
|
||||
|
||||
|
||||
def open_maybe_compressed_file(path):
|
||||
"""Return a file object that possibly decompresses 'path' on the fly.
|
||||
|
||||
Decompression occurs when argument `path` is a string and ends with '.gz'
|
||||
or '.xz'.
|
||||
"""
|
||||
if not isinstance(path, str):
|
||||
return path
|
||||
if path.endswith('.gz'):
|
||||
import gzip
|
||||
return gzip.open(path, 'rb')
|
||||
if path.endswith('.xz'):
|
||||
import lzma
|
||||
return lzma.open(path, 'rb')
|
||||
return open(path, 'rb')
|
||||
|
||||
|
||||
def read_sn3_pascalvincent_tensor(path, strict=True):
|
||||
def read_sn3_pascalvincent_tensor(path: str,
|
||||
strict: bool = True) -> torch.Tensor:
|
||||
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-
|
||||
io.lsh').
|
||||
|
||||
|
@ -169,7 +207,8 @@ def read_sn3_pascalvincent_tensor(path, strict=True):
|
|||
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
|
||||
|
||||
|
||||
def read_label_file(path):
|
||||
def read_label_file(path: str) -> torch.Tensor:
|
||||
"""Read labels from SN3 label file."""
|
||||
with open(path, 'rb') as f:
|
||||
x = read_sn3_pascalvincent_tensor(f, strict=False)
|
||||
assert (x.dtype == torch.uint8)
|
||||
|
@ -177,7 +216,8 @@ def read_label_file(path):
|
|||
return x.long()
|
||||
|
||||
|
||||
def read_image_file(path):
|
||||
def read_image_file(path: str) -> torch.Tensor:
|
||||
"""Read images from SN3 image file."""
|
||||
with open(path, 'rb') as f:
|
||||
x = read_sn3_pascalvincent_tensor(f, strict=False)
|
||||
assert (x.dtype == torch.uint8)
|
||||
|
|
|
@ -12,7 +12,10 @@ import zipfile
|
|||
|
||||
from mmengine.fileio.file_client import FileClient
|
||||
|
||||
__all__ = ['rm_suffix', 'check_integrity', 'download_and_extract_archive']
|
||||
__all__ = [
|
||||
'rm_suffix', 'check_integrity', 'download_and_extract_archive',
|
||||
'open_maybe_compressed_file'
|
||||
]
|
||||
|
||||
|
||||
def rm_suffix(s, suffix=None):
|
||||
|
@ -221,3 +224,20 @@ def download_and_extract_archive(url,
|
|||
archive = os.path.join(download_root, filename)
|
||||
print(f'Extracting {archive} to {extract_root}')
|
||||
extract_archive(archive, extract_root, remove_finished)
|
||||
|
||||
|
||||
def open_maybe_compressed_file(path: str):
|
||||
"""Return a file object that possibly decompresses 'path' on the fly.
|
||||
|
||||
Decompression occurs when argument `path` is a string and ends with '.gz'
|
||||
or '.xz'.
|
||||
"""
|
||||
if not isinstance(path, str):
|
||||
return path
|
||||
if path.endswith('.gz'):
|
||||
import gzip
|
||||
return gzip.open(path, 'rb')
|
||||
if path.endswith('.xz'):
|
||||
import lzma
|
||||
return lzma.open(path, 'rb')
|
||||
return open(path, 'rb')
|
||||
|
|
|
@ -4,7 +4,7 @@ import os.path as osp
|
|||
import pickle
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import numpy as np
|
||||
from mmengine.registry import TRANSFORMS
|
||||
|
@ -655,8 +655,6 @@ class TestVOC(TestBaseDataset):
|
|||
cls.tmpdir.cleanup()
|
||||
|
||||
|
||||
"""Temporarily disabled.
|
||||
|
||||
class TestMNIST(TestBaseDataset):
|
||||
DATASET_TYPE = 'MNIST'
|
||||
|
||||
|
@ -667,25 +665,22 @@ class TestMNIST(TestBaseDataset):
|
|||
tmpdir = tempfile.TemporaryDirectory()
|
||||
cls.tmpdir = tmpdir
|
||||
data_prefix = tmpdir.name
|
||||
cls.DEFAULT_ARGS = dict(data_prefix=data_prefix, pipeline=[])
|
||||
cls.DEFAULT_ARGS = dict(
|
||||
data_prefix=data_prefix, pipeline=[], test_mode=False)
|
||||
|
||||
dataset_class = DATASETS.get(cls.DATASET_TYPE)
|
||||
|
||||
def rm_suffix(s):
|
||||
return s[:s.rfind('.')]
|
||||
|
||||
train_image_file = osp.join(
|
||||
data_prefix,
|
||||
rm_suffix(dataset_class.resources['train_image_file'][0]))
|
||||
train_label_file = osp.join(
|
||||
data_prefix,
|
||||
rm_suffix(dataset_class.resources['train_label_file'][0]))
|
||||
test_image_file = osp.join(
|
||||
data_prefix,
|
||||
rm_suffix(dataset_class.resources['test_image_file'][0]))
|
||||
test_label_file = osp.join(
|
||||
data_prefix,
|
||||
rm_suffix(dataset_class.resources['test_label_file'][0]))
|
||||
train_image_file = osp.join(data_prefix,
|
||||
rm_suffix(dataset_class.train_list[0][0]))
|
||||
train_label_file = osp.join(data_prefix,
|
||||
rm_suffix(dataset_class.train_list[1][0]))
|
||||
test_image_file = osp.join(data_prefix,
|
||||
rm_suffix(dataset_class.test_list[0][0]))
|
||||
test_label_file = osp.join(data_prefix,
|
||||
rm_suffix(dataset_class.test_list[1][0]))
|
||||
cls.fake_img = np.random.randint(0, 255, size=(28, 28), dtype=np.uint8)
|
||||
cls.fake_label = np.random.randint(0, 10, size=(1, ), dtype=np.uint8)
|
||||
|
||||
|
@ -703,22 +698,89 @@ class TestMNIST(TestBaseDataset):
|
|||
with open(file, 'wb') as f:
|
||||
f.write(data)
|
||||
|
||||
def test_load_annotations(self):
|
||||
def test_load_data_list(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
|
||||
with patch.object(dataset_class, 'download'):
|
||||
# Test default behavior
|
||||
dataset = dataset_class(**self.DEFAULT_ARGS)
|
||||
self.assertEqual(len(dataset), 1)
|
||||
# Test default behavior
|
||||
dataset = dataset_class(**self.DEFAULT_ARGS)
|
||||
self.assertEqual(len(dataset), 1)
|
||||
self.assertEqual(dataset.CLASSES, dataset_class.METAINFO['classes'])
|
||||
|
||||
data_info = dataset[0]
|
||||
np.testing.assert_equal(data_info['img'], self.fake_img)
|
||||
np.testing.assert_equal(data_info['gt_label'], self.fake_label)
|
||||
data_info = dataset[0]
|
||||
np.testing.assert_equal(data_info['img'], self.fake_img)
|
||||
np.testing.assert_equal(data_info['gt_label'], self.fake_label)
|
||||
|
||||
# Test with test_mode=True
|
||||
cfg = {**self.DEFAULT_ARGS, 'test_mode': True}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(len(dataset), 1)
|
||||
|
||||
data_info = dataset[0]
|
||||
np.testing.assert_equal(data_info['img'], self.fake_img)
|
||||
np.testing.assert_equal(data_info['gt_label'], self.fake_label)
|
||||
|
||||
# Test automatically download
|
||||
with patch(
|
||||
'mmcls.datasets.mnist.download_and_extract_archive') as mock:
|
||||
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True, 'test_mode': True}
|
||||
dataset = dataset_class(**cfg)
|
||||
dataset.train_list = [['invalid_train_file', None]]
|
||||
dataset.test_list = [['invalid_test_file', None]]
|
||||
with self.assertRaisesRegex(AssertionError, 'Download failed'):
|
||||
dataset.full_init()
|
||||
calls = [
|
||||
call(
|
||||
osp.join(dataset.url_prefix, dataset.train_list[0][0]),
|
||||
download_root=dataset.data_prefix['root'],
|
||||
filename=dataset.train_list[0][0],
|
||||
md5=None),
|
||||
call(
|
||||
osp.join(dataset.url_prefix, dataset.test_list[0][0]),
|
||||
download_root=dataset.data_prefix['root'],
|
||||
filename=dataset.test_list[0][0],
|
||||
md5=None)
|
||||
]
|
||||
mock.assert_has_calls(calls)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, '`download=True`'):
|
||||
cfg = {
|
||||
**self.DEFAULT_ARGS, 'lazy_init': True,
|
||||
'test_mode': True,
|
||||
'download': False
|
||||
}
|
||||
dataset = dataset_class(**cfg)
|
||||
dataset._check_exists = MagicMock(return_value=False)
|
||||
dataset.full_init()
|
||||
|
||||
# Test different backend
|
||||
cfg = {
|
||||
**self.DEFAULT_ARGS, 'lazy_init': True,
|
||||
'data_prefix': 'http://openmmlab/mnist'
|
||||
}
|
||||
dataset = dataset_class(**cfg)
|
||||
dataset._check_exists = MagicMock(return_value=False)
|
||||
with self.assertRaisesRegex(RuntimeError, 'http://openmmlab/mnist'):
|
||||
dataset.full_init()
|
||||
|
||||
def test_extra_repr(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
cfg = {**self.DEFAULT_ARGS, 'lazy_init': True}
|
||||
dataset = dataset_class(**cfg)
|
||||
|
||||
self.assertIn(f'Prefix of data: \t{dataset.data_prefix["root"]}',
|
||||
repr(dataset))
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.tmpdir.cleanup()
|
||||
|
||||
|
||||
class FashionMNIST(TestMNIST):
|
||||
DATASET_TYPE = 'FashionMNIST'
|
||||
|
||||
|
||||
"""Temporarily disabled.
|
||||
|
||||
class TestCUB(TestBaseDataset):
|
||||
DATASET_TYPE = 'CUB'
|
||||
|
||||
|
|
|
@ -2,8 +2,12 @@
|
|||
import os.path as osp
|
||||
import random
|
||||
import string
|
||||
from unittest.mock import patch
|
||||
|
||||
from mmcls.datasets.utils import check_integrity, rm_suffix
|
||||
import pytest
|
||||
|
||||
from mmcls.datasets.utils import (check_integrity, open_maybe_compressed_file,
|
||||
rm_suffix)
|
||||
|
||||
|
||||
def test_dataset_utils():
|
||||
|
@ -20,3 +24,16 @@ def test_dataset_utils():
|
|||
test_file = osp.join(osp.dirname(__file__), '../../data/color.jpg')
|
||||
assert check_integrity(test_file, md5='08252e5100cb321fe74e0e12a724ce14')
|
||||
assert not check_integrity(test_file, md5=2333)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('method,path', [('gzip.open', 'abc.gz'),
|
||||
('lzma.open', 'abc.xz'),
|
||||
('builtins.open', 'abc.txt'),
|
||||
(None, 1)])
|
||||
def test_open_maybe_compressed_file(method, path):
|
||||
if method:
|
||||
with patch(method) as mock:
|
||||
open_maybe_compressed_file(path)
|
||||
mock.assert_called()
|
||||
else:
|
||||
assert open_maybe_compressed_file(path) == path
|
||||
|
|
Loading…
Reference in New Issue