[Feature] Support ImageNet21k dataset. (#461)

* add imagnet21k

* Update unit test

* update imaenet21k

* use slots

* use slots

* rename Data_item to ImageInfo

* add unit tests

* Update unit tests

* rm some print

* update unit tests

* fix lint

* remove default value of pipeline
pull/503/head
Ezra-Yu 2021-10-28 15:22:08 +08:00 committed by GitHub
parent 671414becb
commit 2ce5825ef1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 261 additions and 6 deletions

View File

@ -0,0 +1,43 @@
# dataset settings
dataset_type = 'ImageNet21k'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', size=224),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', size=(256, -1)),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
samples_per_gpu=128,
workers_per_gpu=2,
train=dict(
type=dataset_type,
data_prefix='data/imagenet21k/train',
pipeline=train_pipeline,
recursion_subdir=True),
val=dict(
type=dataset_type,
data_prefix='data/imagenet21k/val',
ann_file='data/imagenet21k/meta/val.txt',
pipeline=test_pipeline,
recursion_subdir=True),
test=dict(
# replace `data/val` with `data/test` for standard test
type=dataset_type,
data_prefix='data/imagenet21k/val',
ann_file='data/imagenet21k/meta/val.txt',
pipeline=test_pipeline,
recursion_subdir=True))
evaluation = dict(interval=1, metric='accuracy')

View File

@ -0,0 +1,12 @@
# optimizer
optimizer = dict(type='SGD', lr=0.8, momentum=0.9, weight_decay=5e-5)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(
policy='CosineAnnealing',
min_lr=0,
warmup='linear',
warmup_iters=5,
warmup_ratio=0.1,
warmup_by_epoch=True)
runner = dict(type='EpochBasedRunner', max_epochs=100)

View File

@ -0,0 +1,11 @@
_base_ = [
'../_base_/models/resnet50.py', '../_base_/datasets/imagenet21k_bs128.py',
'../_base_/schedules/imagenet_bs1024_coslr.py',
'../_base_/default_runtime.py'
]
# model settings
model = dict(head=dict(num_classes=21843))
# runtime settings
runner = dict(type='EpochBasedRunner', max_epochs=90)

View File

@ -5,6 +5,7 @@ from .cifar import CIFAR10, CIFAR100
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
RepeatDataset)
from .imagenet import ImageNet
from .imagenet21k import ImageNet21k
from .mnist import MNIST, FashionMNIST
from .multi_label import MultiLabelDataset
from .samplers import DistributedSampler
@ -14,5 +15,5 @@ __all__ = [
'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST',
'VOC', 'MultiLabelDataset', 'build_dataloader', 'build_dataset',
'DistributedSampler', 'ConcatDataset', 'RepeatDataset',
'ClassBalancedDataset', 'DATASETS', 'PIPELINES'
'ClassBalancedDataset', 'DATASETS', 'PIPELINES', 'ImageNet21k'
]

View File

@ -33,7 +33,6 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
ann_file=None,
test_mode=False):
super(BaseDataset, self).__init__()
self.ann_file = ann_file
self.data_prefix = data_prefix
self.test_mode = test_mode

View File

@ -0,0 +1,141 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import warnings
import numpy as np
from mmcv.utils import scandir
from .base_dataset import BaseDataset
from .builder import DATASETS
from .imagenet import find_folders
class ImageInfo():
"""class to store image info, using slots will save memory than using
dict."""
__slots__ = ['path', 'gt_label']
def __init__(self, path, gt_label):
self.path = path
self.gt_label = gt_label
@DATASETS.register_module()
class ImageNet21k(BaseDataset):
"""ImageNet21k Dataset.
Since the dataset ImageNet21k is extremely big, cantains 21k+ classes
and 1.4B files. This class has improved the following points on the
basis of the class `ImageNet`, in order to save memory usage and time
required :
- Delete the samples attribute
- using 'slots' create a Data_item tp replace dict
- Modify setting `info` dict from function `load_annotations` to
function `prepare_data`
- using int instead of np.array(..., np.int64)
Args:
data_prefix (str): the prefix of data path
pipeline (list): a list of dict, where each element represents
a operation defined in `mmcls.datasets.pipelines`
ann_file (str | None): the annotation file. When ann_file is str,
the subclass is expected to read from the ann_file. When ann_file
is None, the subclass is expected to read according to data_prefix
test_mode (bool): in train mode or test mode
multi_label (bool): use multi label or not.
recursion_subdir(bool): whether to use sub-directory pictures, which
are meet the conditions in the folder under category directory.
"""
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
'.JPEG', '.JPG')
CLASSES = None
def __init__(self,
data_prefix,
pipeline,
classes=None,
ann_file=None,
multi_label=False,
recursion_subdir=False,
test_mode=False):
self.recursion_subdir = recursion_subdir
if multi_label:
raise NotImplementedError('Multi_label have not be implemented.')
self.multi_lable = multi_label
super(ImageNet21k, self).__init__(data_prefix, pipeline, classes,
ann_file, test_mode)
def prepare_data(self, idx):
info = self.data_infos[idx]
results = {
'img_prefix': self.data_prefix,
'img_info': dict(filename=info.path),
'gt_label': np.array(info.gt_label, dtype=np.int64)
}
return self.pipeline(results)
def load_annotations(self):
"""load dataset annotations."""
if self.ann_file is None:
data_infos = self._load_annotations_from_dir()
elif isinstance(self.ann_file, str):
data_infos = self._load_annotations_from_file()
else:
raise TypeError('ann_file must be a str or None')
if len(data_infos) == 0:
msg = 'Found no valid file in '
msg += f'{self.ann_file}. ' if self.ann_file \
else f'{self.data_prefix}. '
msg += 'Supported extensions are: ' + \
', '.join(self.IMG_EXTENSIONS)
raise RuntimeError(msg)
return data_infos
def _find_allowed_files(self, root, folder_name):
"""find all the allowed files in a folder, including sub folder if
recursion_subdir is true."""
_dir = os.path.join(root, folder_name)
infos_pre_class = []
for path in scandir(_dir, self.IMG_EXTENSIONS, self.recursion_subdir):
path = os.path.join(folder_name, path)
item = ImageInfo(path, self.folder_to_idx[folder_name])
infos_pre_class.append(item)
return infos_pre_class
def _load_annotations_from_dir(self):
"""load annotations from self.data_prefix directory."""
data_infos, empty_classes = [], []
folder_to_idx = find_folders(self.data_prefix)
self.folder_to_idx = folder_to_idx
root = os.path.expanduser(self.data_prefix)
for folder_name in folder_to_idx.keys():
infos_pre_class = self._find_allowed_files(root, folder_name)
if len(infos_pre_class) == 0:
empty_classes.append(folder_name)
data_infos.extend(infos_pre_class)
if len(empty_classes) != 0:
msg = 'Found no valid file for the classes ' + \
f"{', '.join(sorted(empty_classes))} "
msg += 'Supported extensions are: ' + \
f"{', '.join(self.IMG_EXTENSIONS)}."
warnings.warn(msg)
return data_infos
def _load_annotations_from_file(self):
"""load annotations from self.ann_file."""
data_infos = []
with open(self.ann_file) as f:
for line in f.readlines():
if line == '':
continue
filepath, gt_label = line.strip().rsplit(' ', 1)
info = ImageInfo(filepath, int(gt_label))
data_infos.append(info)
return data_infos

View File

View File

@ -0,0 +1,3 @@
a/1.JPG 0
b/2.jpeg 1
b/subb/2.jpeg 1

View File

View File

View File

@ -6,14 +6,17 @@ import numpy as np
import pytest
import torch
from mmcls.datasets import DATASETS, BaseDataset, MultiLabelDataset
from mmcls.datasets import (DATASETS, BaseDataset, ImageNet21k,
MultiLabelDataset)
@pytest.mark.parametrize(
'dataset_name',
['MNIST', 'FashionMNIST', 'CIFAR10', 'CIFAR100', 'ImageNet', 'VOC'])
@pytest.mark.parametrize('dataset_name', [
'MNIST', 'FashionMNIST', 'CIFAR10', 'CIFAR100', 'ImageNet', 'VOC',
'ImageNet21k'
])
def test_datasets_override_default(dataset_name):
dataset_class = DATASETS.get(dataset_name)
load_annotations_f = dataset_class.load_annotations
dataset_class.load_annotations = MagicMock()
original_classes = dataset_class.CLASSES
@ -82,6 +85,8 @@ def test_datasets_override_default(dataset_name):
assert dataset.ann_file is None
assert dataset.CLASSES == original_classes
dataset_class.load_annotations = load_annotations_f
@patch.multiple(MultiLabelDataset, __abstractmethods__=set())
@patch.multiple(BaseDataset, __abstractmethods__=set())
@ -248,3 +253,43 @@ def test_dataset_evaluation():
assert 'CR' in eval_results.keys()
assert 'OF1' in eval_results.keys()
assert 'CF1' not in eval_results.keys()
def test_dataset_imagenet21k():
base_dataset_cfg = dict(
data_prefix='tests/data/dataset', pipeline=[], recursion_subdir=True)
with pytest.raises(NotImplementedError):
# multi_label have not be implemented
dataset_cfg = base_dataset_cfg.copy()
dataset_cfg.update({'multi_label': True})
dataset = ImageNet21k(**dataset_cfg)
with pytest.raises(TypeError):
# ann_file must be a string or None
dataset_cfg = base_dataset_cfg.copy()
ann_file = {'path': 'tests/data/dataset/ann.txt'}
dataset_cfg.update({'ann_file': ann_file})
dataset = ImageNet21k(**dataset_cfg)
# test with recursion_subdir is True
dataset = ImageNet21k(**base_dataset_cfg)
assert len(dataset) == 3
assert isinstance(dataset[0], dict)
assert 'img_prefix' in dataset[0]
assert 'img_info' in dataset[0]
assert 'gt_label' in dataset[0]
# test with recursion_subdir is False
dataset_cfg = base_dataset_cfg.copy()
dataset_cfg['recursion_subdir'] = False
dataset = ImageNet21k(**dataset_cfg)
assert len(dataset) == 2
assert isinstance(dataset[0], dict)
# test with load annotation from ann file
dataset_cfg = base_dataset_cfg.copy()
dataset_cfg['ann_file'] = 'tests/data/dataset/ann.txt'
dataset = ImageNet21k(**dataset_cfg)
assert len(dataset) == 3
assert isinstance(dataset[0], dict)