[Feature] Support ImageNet21k dataset. (#461)
* add imagnet21k * Update unit test * update imaenet21k * use slots * use slots * rename Data_item to ImageInfo * add unit tests * Update unit tests * rm some print * update unit tests * fix lint * remove default value of pipelinepull/503/head
parent
671414becb
commit
2ce5825ef1
|
@ -0,0 +1,43 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet21k'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='RandomResizedCrop', size=224),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', size=(256, -1)),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=128,
|
||||
workers_per_gpu=2,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet21k/train',
|
||||
pipeline=train_pipeline,
|
||||
recursion_subdir=True),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet21k/val',
|
||||
ann_file='data/imagenet21k/meta/val.txt',
|
||||
pipeline=test_pipeline,
|
||||
recursion_subdir=True),
|
||||
test=dict(
|
||||
# replace `data/val` with `data/test` for standard test
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet21k/val',
|
||||
ann_file='data/imagenet21k/meta/val.txt',
|
||||
pipeline=test_pipeline,
|
||||
recursion_subdir=True))
|
||||
evaluation = dict(interval=1, metric='accuracy')
|
|
@ -0,0 +1,12 @@
|
|||
# optimizer
|
||||
optimizer = dict(type='SGD', lr=0.8, momentum=0.9, weight_decay=5e-5)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
# learning policy
|
||||
lr_config = dict(
|
||||
policy='CosineAnnealing',
|
||||
min_lr=0,
|
||||
warmup='linear',
|
||||
warmup_iters=5,
|
||||
warmup_ratio=0.1,
|
||||
warmup_by_epoch=True)
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=100)
|
|
@ -0,0 +1,11 @@
|
|||
_base_ = [
|
||||
'../_base_/models/resnet50.py', '../_base_/datasets/imagenet21k_bs128.py',
|
||||
'../_base_/schedules/imagenet_bs1024_coslr.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
# model settings
|
||||
model = dict(head=dict(num_classes=21843))
|
||||
|
||||
# runtime settings
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=90)
|
|
@ -5,6 +5,7 @@ from .cifar import CIFAR10, CIFAR100
|
|||
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
|
||||
RepeatDataset)
|
||||
from .imagenet import ImageNet
|
||||
from .imagenet21k import ImageNet21k
|
||||
from .mnist import MNIST, FashionMNIST
|
||||
from .multi_label import MultiLabelDataset
|
||||
from .samplers import DistributedSampler
|
||||
|
@ -14,5 +15,5 @@ __all__ = [
|
|||
'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST',
|
||||
'VOC', 'MultiLabelDataset', 'build_dataloader', 'build_dataset',
|
||||
'DistributedSampler', 'ConcatDataset', 'RepeatDataset',
|
||||
'ClassBalancedDataset', 'DATASETS', 'PIPELINES'
|
||||
'ClassBalancedDataset', 'DATASETS', 'PIPELINES', 'ImageNet21k'
|
||||
]
|
||||
|
|
|
@ -33,7 +33,6 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
|||
ann_file=None,
|
||||
test_mode=False):
|
||||
super(BaseDataset, self).__init__()
|
||||
|
||||
self.ann_file = ann_file
|
||||
self.data_prefix = data_prefix
|
||||
self.test_mode = test_mode
|
||||
|
|
|
@ -0,0 +1,141 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
from mmcv.utils import scandir
|
||||
|
||||
from .base_dataset import BaseDataset
|
||||
from .builder import DATASETS
|
||||
from .imagenet import find_folders
|
||||
|
||||
|
||||
class ImageInfo():
|
||||
"""class to store image info, using slots will save memory than using
|
||||
dict."""
|
||||
__slots__ = ['path', 'gt_label']
|
||||
|
||||
def __init__(self, path, gt_label):
|
||||
self.path = path
|
||||
self.gt_label = gt_label
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class ImageNet21k(BaseDataset):
|
||||
"""ImageNet21k Dataset.
|
||||
|
||||
Since the dataset ImageNet21k is extremely big, cantains 21k+ classes
|
||||
and 1.4B files. This class has improved the following points on the
|
||||
basis of the class `ImageNet`, in order to save memory usage and time
|
||||
required :
|
||||
|
||||
- Delete the samples attribute
|
||||
- using 'slots' create a Data_item tp replace dict
|
||||
- Modify setting `info` dict from function `load_annotations` to
|
||||
function `prepare_data`
|
||||
- using int instead of np.array(..., np.int64)
|
||||
|
||||
Args:
|
||||
data_prefix (str): the prefix of data path
|
||||
pipeline (list): a list of dict, where each element represents
|
||||
a operation defined in `mmcls.datasets.pipelines`
|
||||
ann_file (str | None): the annotation file. When ann_file is str,
|
||||
the subclass is expected to read from the ann_file. When ann_file
|
||||
is None, the subclass is expected to read according to data_prefix
|
||||
test_mode (bool): in train mode or test mode
|
||||
multi_label (bool): use multi label or not.
|
||||
recursion_subdir(bool): whether to use sub-directory pictures, which
|
||||
are meet the conditions in the folder under category directory.
|
||||
"""
|
||||
|
||||
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
|
||||
'.JPEG', '.JPG')
|
||||
CLASSES = None
|
||||
|
||||
def __init__(self,
|
||||
data_prefix,
|
||||
pipeline,
|
||||
classes=None,
|
||||
ann_file=None,
|
||||
multi_label=False,
|
||||
recursion_subdir=False,
|
||||
test_mode=False):
|
||||
self.recursion_subdir = recursion_subdir
|
||||
if multi_label:
|
||||
raise NotImplementedError('Multi_label have not be implemented.')
|
||||
self.multi_lable = multi_label
|
||||
super(ImageNet21k, self).__init__(data_prefix, pipeline, classes,
|
||||
ann_file, test_mode)
|
||||
|
||||
def prepare_data(self, idx):
|
||||
info = self.data_infos[idx]
|
||||
results = {
|
||||
'img_prefix': self.data_prefix,
|
||||
'img_info': dict(filename=info.path),
|
||||
'gt_label': np.array(info.gt_label, dtype=np.int64)
|
||||
}
|
||||
return self.pipeline(results)
|
||||
|
||||
def load_annotations(self):
|
||||
"""load dataset annotations."""
|
||||
if self.ann_file is None:
|
||||
data_infos = self._load_annotations_from_dir()
|
||||
elif isinstance(self.ann_file, str):
|
||||
data_infos = self._load_annotations_from_file()
|
||||
else:
|
||||
raise TypeError('ann_file must be a str or None')
|
||||
|
||||
if len(data_infos) == 0:
|
||||
msg = 'Found no valid file in '
|
||||
msg += f'{self.ann_file}. ' if self.ann_file \
|
||||
else f'{self.data_prefix}. '
|
||||
msg += 'Supported extensions are: ' + \
|
||||
', '.join(self.IMG_EXTENSIONS)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
return data_infos
|
||||
|
||||
def _find_allowed_files(self, root, folder_name):
|
||||
"""find all the allowed files in a folder, including sub folder if
|
||||
recursion_subdir is true."""
|
||||
_dir = os.path.join(root, folder_name)
|
||||
infos_pre_class = []
|
||||
for path in scandir(_dir, self.IMG_EXTENSIONS, self.recursion_subdir):
|
||||
path = os.path.join(folder_name, path)
|
||||
item = ImageInfo(path, self.folder_to_idx[folder_name])
|
||||
infos_pre_class.append(item)
|
||||
return infos_pre_class
|
||||
|
||||
def _load_annotations_from_dir(self):
|
||||
"""load annotations from self.data_prefix directory."""
|
||||
data_infos, empty_classes = [], []
|
||||
folder_to_idx = find_folders(self.data_prefix)
|
||||
self.folder_to_idx = folder_to_idx
|
||||
root = os.path.expanduser(self.data_prefix)
|
||||
for folder_name in folder_to_idx.keys():
|
||||
infos_pre_class = self._find_allowed_files(root, folder_name)
|
||||
if len(infos_pre_class) == 0:
|
||||
empty_classes.append(folder_name)
|
||||
data_infos.extend(infos_pre_class)
|
||||
|
||||
if len(empty_classes) != 0:
|
||||
msg = 'Found no valid file for the classes ' + \
|
||||
f"{', '.join(sorted(empty_classes))} "
|
||||
msg += 'Supported extensions are: ' + \
|
||||
f"{', '.join(self.IMG_EXTENSIONS)}."
|
||||
warnings.warn(msg)
|
||||
|
||||
return data_infos
|
||||
|
||||
def _load_annotations_from_file(self):
|
||||
"""load annotations from self.ann_file."""
|
||||
data_infos = []
|
||||
with open(self.ann_file) as f:
|
||||
for line in f.readlines():
|
||||
if line == '':
|
||||
continue
|
||||
filepath, gt_label = line.strip().rsplit(' ', 1)
|
||||
info = ImageInfo(filepath, int(gt_label))
|
||||
data_infos.append(info)
|
||||
|
||||
return data_infos
|
|
@ -0,0 +1,3 @@
|
|||
a/1.JPG 0
|
||||
b/2.jpeg 1
|
||||
b/subb/2.jpeg 1
|
|
@ -6,14 +6,17 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcls.datasets import DATASETS, BaseDataset, MultiLabelDataset
|
||||
from mmcls.datasets import (DATASETS, BaseDataset, ImageNet21k,
|
||||
MultiLabelDataset)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'dataset_name',
|
||||
['MNIST', 'FashionMNIST', 'CIFAR10', 'CIFAR100', 'ImageNet', 'VOC'])
|
||||
@pytest.mark.parametrize('dataset_name', [
|
||||
'MNIST', 'FashionMNIST', 'CIFAR10', 'CIFAR100', 'ImageNet', 'VOC',
|
||||
'ImageNet21k'
|
||||
])
|
||||
def test_datasets_override_default(dataset_name):
|
||||
dataset_class = DATASETS.get(dataset_name)
|
||||
load_annotations_f = dataset_class.load_annotations
|
||||
dataset_class.load_annotations = MagicMock()
|
||||
|
||||
original_classes = dataset_class.CLASSES
|
||||
|
@ -82,6 +85,8 @@ def test_datasets_override_default(dataset_name):
|
|||
assert dataset.ann_file is None
|
||||
assert dataset.CLASSES == original_classes
|
||||
|
||||
dataset_class.load_annotations = load_annotations_f
|
||||
|
||||
|
||||
@patch.multiple(MultiLabelDataset, __abstractmethods__=set())
|
||||
@patch.multiple(BaseDataset, __abstractmethods__=set())
|
||||
|
@ -248,3 +253,43 @@ def test_dataset_evaluation():
|
|||
assert 'CR' in eval_results.keys()
|
||||
assert 'OF1' in eval_results.keys()
|
||||
assert 'CF1' not in eval_results.keys()
|
||||
|
||||
|
||||
def test_dataset_imagenet21k():
|
||||
base_dataset_cfg = dict(
|
||||
data_prefix='tests/data/dataset', pipeline=[], recursion_subdir=True)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
# multi_label have not be implemented
|
||||
dataset_cfg = base_dataset_cfg.copy()
|
||||
dataset_cfg.update({'multi_label': True})
|
||||
dataset = ImageNet21k(**dataset_cfg)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# ann_file must be a string or None
|
||||
dataset_cfg = base_dataset_cfg.copy()
|
||||
ann_file = {'path': 'tests/data/dataset/ann.txt'}
|
||||
dataset_cfg.update({'ann_file': ann_file})
|
||||
dataset = ImageNet21k(**dataset_cfg)
|
||||
|
||||
# test with recursion_subdir is True
|
||||
dataset = ImageNet21k(**base_dataset_cfg)
|
||||
assert len(dataset) == 3
|
||||
assert isinstance(dataset[0], dict)
|
||||
assert 'img_prefix' in dataset[0]
|
||||
assert 'img_info' in dataset[0]
|
||||
assert 'gt_label' in dataset[0]
|
||||
|
||||
# test with recursion_subdir is False
|
||||
dataset_cfg = base_dataset_cfg.copy()
|
||||
dataset_cfg['recursion_subdir'] = False
|
||||
dataset = ImageNet21k(**dataset_cfg)
|
||||
assert len(dataset) == 2
|
||||
assert isinstance(dataset[0], dict)
|
||||
|
||||
# test with load annotation from ann file
|
||||
dataset_cfg = base_dataset_cfg.copy()
|
||||
dataset_cfg['ann_file'] = 'tests/data/dataset/ann.txt'
|
||||
dataset = ImageNet21k(**dataset_cfg)
|
||||
assert len(dataset) == 3
|
||||
assert isinstance(dataset[0], dict)
|
||||
|
|
Loading…
Reference in New Issue