[Feature] Support ImageNet21k dataset. (#461)
* add imagnet21k * Update unit test * update imaenet21k * use slots * use slots * rename Data_item to ImageInfo * add unit tests * Update unit tests * rm some print * update unit tests * fix lint * remove default value of pipelinepull/503/head
parent
671414becb
commit
2ce5825ef1
|
@ -0,0 +1,43 @@
|
||||||
|
# dataset settings
|
||||||
|
dataset_type = 'ImageNet21k'
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='RandomResizedCrop', size=224),
|
||||||
|
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(type='ToTensor', keys=['gt_label']),
|
||||||
|
dict(type='Collect', keys=['img', 'gt_label'])
|
||||||
|
]
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='Resize', size=(256, -1)),
|
||||||
|
dict(type='CenterCrop', crop_size=224),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(type='Collect', keys=['img'])
|
||||||
|
]
|
||||||
|
data = dict(
|
||||||
|
samples_per_gpu=128,
|
||||||
|
workers_per_gpu=2,
|
||||||
|
train=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
data_prefix='data/imagenet21k/train',
|
||||||
|
pipeline=train_pipeline,
|
||||||
|
recursion_subdir=True),
|
||||||
|
val=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
data_prefix='data/imagenet21k/val',
|
||||||
|
ann_file='data/imagenet21k/meta/val.txt',
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
recursion_subdir=True),
|
||||||
|
test=dict(
|
||||||
|
# replace `data/val` with `data/test` for standard test
|
||||||
|
type=dataset_type,
|
||||||
|
data_prefix='data/imagenet21k/val',
|
||||||
|
ann_file='data/imagenet21k/meta/val.txt',
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
recursion_subdir=True))
|
||||||
|
evaluation = dict(interval=1, metric='accuracy')
|
|
@ -0,0 +1,12 @@
|
||||||
|
# optimizer
|
||||||
|
optimizer = dict(type='SGD', lr=0.8, momentum=0.9, weight_decay=5e-5)
|
||||||
|
optimizer_config = dict(grad_clip=None)
|
||||||
|
# learning policy
|
||||||
|
lr_config = dict(
|
||||||
|
policy='CosineAnnealing',
|
||||||
|
min_lr=0,
|
||||||
|
warmup='linear',
|
||||||
|
warmup_iters=5,
|
||||||
|
warmup_ratio=0.1,
|
||||||
|
warmup_by_epoch=True)
|
||||||
|
runner = dict(type='EpochBasedRunner', max_epochs=100)
|
|
@ -0,0 +1,11 @@
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/models/resnet50.py', '../_base_/datasets/imagenet21k_bs128.py',
|
||||||
|
'../_base_/schedules/imagenet_bs1024_coslr.py',
|
||||||
|
'../_base_/default_runtime.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
model = dict(head=dict(num_classes=21843))
|
||||||
|
|
||||||
|
# runtime settings
|
||||||
|
runner = dict(type='EpochBasedRunner', max_epochs=90)
|
|
@ -5,6 +5,7 @@ from .cifar import CIFAR10, CIFAR100
|
||||||
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
|
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
|
||||||
RepeatDataset)
|
RepeatDataset)
|
||||||
from .imagenet import ImageNet
|
from .imagenet import ImageNet
|
||||||
|
from .imagenet21k import ImageNet21k
|
||||||
from .mnist import MNIST, FashionMNIST
|
from .mnist import MNIST, FashionMNIST
|
||||||
from .multi_label import MultiLabelDataset
|
from .multi_label import MultiLabelDataset
|
||||||
from .samplers import DistributedSampler
|
from .samplers import DistributedSampler
|
||||||
|
@ -14,5 +15,5 @@ __all__ = [
|
||||||
'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST',
|
'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST',
|
||||||
'VOC', 'MultiLabelDataset', 'build_dataloader', 'build_dataset',
|
'VOC', 'MultiLabelDataset', 'build_dataloader', 'build_dataset',
|
||||||
'DistributedSampler', 'ConcatDataset', 'RepeatDataset',
|
'DistributedSampler', 'ConcatDataset', 'RepeatDataset',
|
||||||
'ClassBalancedDataset', 'DATASETS', 'PIPELINES'
|
'ClassBalancedDataset', 'DATASETS', 'PIPELINES', 'ImageNet21k'
|
||||||
]
|
]
|
||||||
|
|
|
@ -33,7 +33,6 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
||||||
ann_file=None,
|
ann_file=None,
|
||||||
test_mode=False):
|
test_mode=False):
|
||||||
super(BaseDataset, self).__init__()
|
super(BaseDataset, self).__init__()
|
||||||
|
|
||||||
self.ann_file = ann_file
|
self.ann_file = ann_file
|
||||||
self.data_prefix = data_prefix
|
self.data_prefix = data_prefix
|
||||||
self.test_mode = test_mode
|
self.test_mode = test_mode
|
||||||
|
|
|
@ -0,0 +1,141 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from mmcv.utils import scandir
|
||||||
|
|
||||||
|
from .base_dataset import BaseDataset
|
||||||
|
from .builder import DATASETS
|
||||||
|
from .imagenet import find_folders
|
||||||
|
|
||||||
|
|
||||||
|
class ImageInfo():
|
||||||
|
"""class to store image info, using slots will save memory than using
|
||||||
|
dict."""
|
||||||
|
__slots__ = ['path', 'gt_label']
|
||||||
|
|
||||||
|
def __init__(self, path, gt_label):
|
||||||
|
self.path = path
|
||||||
|
self.gt_label = gt_label
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module()
|
||||||
|
class ImageNet21k(BaseDataset):
|
||||||
|
"""ImageNet21k Dataset.
|
||||||
|
|
||||||
|
Since the dataset ImageNet21k is extremely big, cantains 21k+ classes
|
||||||
|
and 1.4B files. This class has improved the following points on the
|
||||||
|
basis of the class `ImageNet`, in order to save memory usage and time
|
||||||
|
required :
|
||||||
|
|
||||||
|
- Delete the samples attribute
|
||||||
|
- using 'slots' create a Data_item tp replace dict
|
||||||
|
- Modify setting `info` dict from function `load_annotations` to
|
||||||
|
function `prepare_data`
|
||||||
|
- using int instead of np.array(..., np.int64)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_prefix (str): the prefix of data path
|
||||||
|
pipeline (list): a list of dict, where each element represents
|
||||||
|
a operation defined in `mmcls.datasets.pipelines`
|
||||||
|
ann_file (str | None): the annotation file. When ann_file is str,
|
||||||
|
the subclass is expected to read from the ann_file. When ann_file
|
||||||
|
is None, the subclass is expected to read according to data_prefix
|
||||||
|
test_mode (bool): in train mode or test mode
|
||||||
|
multi_label (bool): use multi label or not.
|
||||||
|
recursion_subdir(bool): whether to use sub-directory pictures, which
|
||||||
|
are meet the conditions in the folder under category directory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
|
||||||
|
'.JPEG', '.JPG')
|
||||||
|
CLASSES = None
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
data_prefix,
|
||||||
|
pipeline,
|
||||||
|
classes=None,
|
||||||
|
ann_file=None,
|
||||||
|
multi_label=False,
|
||||||
|
recursion_subdir=False,
|
||||||
|
test_mode=False):
|
||||||
|
self.recursion_subdir = recursion_subdir
|
||||||
|
if multi_label:
|
||||||
|
raise NotImplementedError('Multi_label have not be implemented.')
|
||||||
|
self.multi_lable = multi_label
|
||||||
|
super(ImageNet21k, self).__init__(data_prefix, pipeline, classes,
|
||||||
|
ann_file, test_mode)
|
||||||
|
|
||||||
|
def prepare_data(self, idx):
|
||||||
|
info = self.data_infos[idx]
|
||||||
|
results = {
|
||||||
|
'img_prefix': self.data_prefix,
|
||||||
|
'img_info': dict(filename=info.path),
|
||||||
|
'gt_label': np.array(info.gt_label, dtype=np.int64)
|
||||||
|
}
|
||||||
|
return self.pipeline(results)
|
||||||
|
|
||||||
|
def load_annotations(self):
|
||||||
|
"""load dataset annotations."""
|
||||||
|
if self.ann_file is None:
|
||||||
|
data_infos = self._load_annotations_from_dir()
|
||||||
|
elif isinstance(self.ann_file, str):
|
||||||
|
data_infos = self._load_annotations_from_file()
|
||||||
|
else:
|
||||||
|
raise TypeError('ann_file must be a str or None')
|
||||||
|
|
||||||
|
if len(data_infos) == 0:
|
||||||
|
msg = 'Found no valid file in '
|
||||||
|
msg += f'{self.ann_file}. ' if self.ann_file \
|
||||||
|
else f'{self.data_prefix}. '
|
||||||
|
msg += 'Supported extensions are: ' + \
|
||||||
|
', '.join(self.IMG_EXTENSIONS)
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
return data_infos
|
||||||
|
|
||||||
|
def _find_allowed_files(self, root, folder_name):
|
||||||
|
"""find all the allowed files in a folder, including sub folder if
|
||||||
|
recursion_subdir is true."""
|
||||||
|
_dir = os.path.join(root, folder_name)
|
||||||
|
infos_pre_class = []
|
||||||
|
for path in scandir(_dir, self.IMG_EXTENSIONS, self.recursion_subdir):
|
||||||
|
path = os.path.join(folder_name, path)
|
||||||
|
item = ImageInfo(path, self.folder_to_idx[folder_name])
|
||||||
|
infos_pre_class.append(item)
|
||||||
|
return infos_pre_class
|
||||||
|
|
||||||
|
def _load_annotations_from_dir(self):
|
||||||
|
"""load annotations from self.data_prefix directory."""
|
||||||
|
data_infos, empty_classes = [], []
|
||||||
|
folder_to_idx = find_folders(self.data_prefix)
|
||||||
|
self.folder_to_idx = folder_to_idx
|
||||||
|
root = os.path.expanduser(self.data_prefix)
|
||||||
|
for folder_name in folder_to_idx.keys():
|
||||||
|
infos_pre_class = self._find_allowed_files(root, folder_name)
|
||||||
|
if len(infos_pre_class) == 0:
|
||||||
|
empty_classes.append(folder_name)
|
||||||
|
data_infos.extend(infos_pre_class)
|
||||||
|
|
||||||
|
if len(empty_classes) != 0:
|
||||||
|
msg = 'Found no valid file for the classes ' + \
|
||||||
|
f"{', '.join(sorted(empty_classes))} "
|
||||||
|
msg += 'Supported extensions are: ' + \
|
||||||
|
f"{', '.join(self.IMG_EXTENSIONS)}."
|
||||||
|
warnings.warn(msg)
|
||||||
|
|
||||||
|
return data_infos
|
||||||
|
|
||||||
|
def _load_annotations_from_file(self):
|
||||||
|
"""load annotations from self.ann_file."""
|
||||||
|
data_infos = []
|
||||||
|
with open(self.ann_file) as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
if line == '':
|
||||||
|
continue
|
||||||
|
filepath, gt_label = line.strip().rsplit(' ', 1)
|
||||||
|
info = ImageInfo(filepath, int(gt_label))
|
||||||
|
data_infos.append(info)
|
||||||
|
|
||||||
|
return data_infos
|
|
@ -0,0 +1,3 @@
|
||||||
|
a/1.JPG 0
|
||||||
|
b/2.jpeg 1
|
||||||
|
b/subb/2.jpeg 1
|
|
@ -6,14 +6,17 @@ import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmcls.datasets import DATASETS, BaseDataset, MultiLabelDataset
|
from mmcls.datasets import (DATASETS, BaseDataset, ImageNet21k,
|
||||||
|
MultiLabelDataset)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize('dataset_name', [
|
||||||
'dataset_name',
|
'MNIST', 'FashionMNIST', 'CIFAR10', 'CIFAR100', 'ImageNet', 'VOC',
|
||||||
['MNIST', 'FashionMNIST', 'CIFAR10', 'CIFAR100', 'ImageNet', 'VOC'])
|
'ImageNet21k'
|
||||||
|
])
|
||||||
def test_datasets_override_default(dataset_name):
|
def test_datasets_override_default(dataset_name):
|
||||||
dataset_class = DATASETS.get(dataset_name)
|
dataset_class = DATASETS.get(dataset_name)
|
||||||
|
load_annotations_f = dataset_class.load_annotations
|
||||||
dataset_class.load_annotations = MagicMock()
|
dataset_class.load_annotations = MagicMock()
|
||||||
|
|
||||||
original_classes = dataset_class.CLASSES
|
original_classes = dataset_class.CLASSES
|
||||||
|
@ -82,6 +85,8 @@ def test_datasets_override_default(dataset_name):
|
||||||
assert dataset.ann_file is None
|
assert dataset.ann_file is None
|
||||||
assert dataset.CLASSES == original_classes
|
assert dataset.CLASSES == original_classes
|
||||||
|
|
||||||
|
dataset_class.load_annotations = load_annotations_f
|
||||||
|
|
||||||
|
|
||||||
@patch.multiple(MultiLabelDataset, __abstractmethods__=set())
|
@patch.multiple(MultiLabelDataset, __abstractmethods__=set())
|
||||||
@patch.multiple(BaseDataset, __abstractmethods__=set())
|
@patch.multiple(BaseDataset, __abstractmethods__=set())
|
||||||
|
@ -248,3 +253,43 @@ def test_dataset_evaluation():
|
||||||
assert 'CR' in eval_results.keys()
|
assert 'CR' in eval_results.keys()
|
||||||
assert 'OF1' in eval_results.keys()
|
assert 'OF1' in eval_results.keys()
|
||||||
assert 'CF1' not in eval_results.keys()
|
assert 'CF1' not in eval_results.keys()
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataset_imagenet21k():
|
||||||
|
base_dataset_cfg = dict(
|
||||||
|
data_prefix='tests/data/dataset', pipeline=[], recursion_subdir=True)
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
# multi_label have not be implemented
|
||||||
|
dataset_cfg = base_dataset_cfg.copy()
|
||||||
|
dataset_cfg.update({'multi_label': True})
|
||||||
|
dataset = ImageNet21k(**dataset_cfg)
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
# ann_file must be a string or None
|
||||||
|
dataset_cfg = base_dataset_cfg.copy()
|
||||||
|
ann_file = {'path': 'tests/data/dataset/ann.txt'}
|
||||||
|
dataset_cfg.update({'ann_file': ann_file})
|
||||||
|
dataset = ImageNet21k(**dataset_cfg)
|
||||||
|
|
||||||
|
# test with recursion_subdir is True
|
||||||
|
dataset = ImageNet21k(**base_dataset_cfg)
|
||||||
|
assert len(dataset) == 3
|
||||||
|
assert isinstance(dataset[0], dict)
|
||||||
|
assert 'img_prefix' in dataset[0]
|
||||||
|
assert 'img_info' in dataset[0]
|
||||||
|
assert 'gt_label' in dataset[0]
|
||||||
|
|
||||||
|
# test with recursion_subdir is False
|
||||||
|
dataset_cfg = base_dataset_cfg.copy()
|
||||||
|
dataset_cfg['recursion_subdir'] = False
|
||||||
|
dataset = ImageNet21k(**dataset_cfg)
|
||||||
|
assert len(dataset) == 2
|
||||||
|
assert isinstance(dataset[0], dict)
|
||||||
|
|
||||||
|
# test with load annotation from ann file
|
||||||
|
dataset_cfg = base_dataset_cfg.copy()
|
||||||
|
dataset_cfg['ann_file'] = 'tests/data/dataset/ann.txt'
|
||||||
|
dataset = ImageNet21k(**dataset_cfg)
|
||||||
|
assert len(dataset) == 3
|
||||||
|
assert isinstance(dataset[0], dict)
|
||||||
|
|
Loading…
Reference in New Issue