[Feature] Support CUB dataset. (#703)

* support cub dataset

* support cub dataset

* fix train lint error

* add docs

* fix class label

Co-authored-by: Ezra-Yu <1105212286@qq.com>

* del debug code

* skip docformatter problem

* add unit tests

* add CUB baseline configs and chpts

* fix some typos

* fix name style

* update flops

Co-authored-by: Ezra-Yu <1105212286@qq.com>
This commit is contained in:
takuoko 2022-03-16 17:22:28 +09:00 committed by GitHub
parent c1534f9126
commit aa522f4309
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 441 additions and 18 deletions

View File

@ -0,0 +1,54 @@
# dataset settings
dataset_type = 'CUB'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', size=510),
dict(type='RandomCrop', size=384),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', size=510),
dict(type='CenterCrop', crop_size=384),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data_root = 'data/CUB_200_2011/'
data = dict(
samples_per_gpu=8,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'images.txt',
image_class_labels_file=data_root + 'image_class_labels.txt',
train_test_split_file=data_root + 'train_test_split.txt',
data_prefix=data_root + 'images',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'images.txt',
image_class_labels_file=data_root + 'image_class_labels.txt',
train_test_split_file=data_root + 'train_test_split.txt',
data_prefix=data_root + 'images',
test_mode=True,
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'images.txt',
image_class_labels_file=data_root + 'image_class_labels.txt',
train_test_split_file=data_root + 'train_test_split.txt',
data_prefix=data_root + 'images',
test_mode=True,
pipeline=test_pipeline))
evaluation = dict(
interval=1, metric='accuracy',
save_best='auto') # save the checkpoint with highest accuracy

View File

@ -0,0 +1,54 @@
# dataset settings
dataset_type = 'CUB'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', size=600),
dict(type='RandomCrop', size=448),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', size=600),
dict(type='CenterCrop', crop_size=448),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data_root = 'data/CUB_200_2011/'
data = dict(
samples_per_gpu=8,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'images.txt',
image_class_labels_file=data_root + 'image_class_labels.txt',
train_test_split_file=data_root + 'train_test_split.txt',
data_prefix=data_root + 'images',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'images.txt',
image_class_labels_file=data_root + 'image_class_labels.txt',
train_test_split_file=data_root + 'train_test_split.txt',
data_prefix=data_root + 'images',
test_mode=True,
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'images.txt',
image_class_labels_file=data_root + 'image_class_labels.txt',
train_test_split_file=data_root + 'train_test_split.txt',
data_prefix=data_root + 'images',
test_mode=True,
pipeline=test_pipeline))
evaluation = dict(
interval=1, metric='accuracy',
save_best='auto') # save the checkpoint with highest accuracy

View File

@ -0,0 +1,13 @@
# optimizer
optimizer = dict(
type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005, nesterov=True)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(
policy='CosineAnnealing',
min_lr=0,
warmup='linear',
warmup_iters=5,
warmup_ratio=0.01,
warmup_by_epoch=True)
runner = dict(type='EpochBasedRunner', max_epochs=100)

View File

@ -15,21 +15,29 @@ The depth of representations is of central importance for many visual recognitio
## Results and models
The pre-trained models on ImageNet-21k are used to fine-tune, and therefore don't have evaluation results.
| Model | resolution | Params(M) | Flops(G) | Download |
|:---------------:|:-----------:|:---------:|:---------:|:--------:|
| ResNet-50-mill | 224x224 | 86.74 | 15.14 | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_mill_3rdparty_in21k_20220307-bdb3a68b.pth)|
*The "mill" means using the mutil-label pretrain weight from [ImageNet-21K Pretraining for the Masses](https://github.com/Alibaba-MIIL/ImageNet21K).*
### Cifar10
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:---------:|:--------:|
| ResNet-18-b16x8 | 11.17 | 0.56 | 94.82 | 99.87 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.log.json) |
| ResNet-34-b16x8 | 21.28 | 1.16 | 95.34 | 99.87 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.log.json) |
| ResNet-50-b16x8 | 23.52 | 1.31 | 95.55 | 99.91 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_b16x8_cifar10_20210528-f54bfad9.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_b16x8_cifar10_20210528-f54bfad9.log.json) |
| ResNet-101-b16x8 | 42.51 | 2.52 | 95.58 | 99.87 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet101_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_b16x8_cifar10_20210528-2d29e936.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_b16x8_cifar10_20210528-2d29e936.log.json) |
| ResNet-152-b16x8 | 58.16 | 3.74 | 95.76 | 99.89 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet152_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_b16x8_cifar10_20210528-3e8e9178.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_b16x8_cifar10_20210528-3e8e9178.log.json) |
| ResNet-18 | 11.17 | 0.56 | 94.82 | 99.87 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.log.json) |
| ResNet-34 | 21.28 | 1.16 | 95.34 | 99.87 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.log.json) |
| ResNet-50 | 23.52 | 1.31 | 95.55 | 99.91 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_b16x8_cifar10_20210528-f54bfad9.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_b16x8_cifar10_20210528-f54bfad9.log.json) |
| ResNet-101 | 42.51 | 2.52 | 95.58 | 99.87 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet101_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_b16x8_cifar10_20210528-2d29e936.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_b16x8_cifar10_20210528-2d29e936.log.json) |
| ResNet-152 | 58.16 | 3.74 | 95.76 | 99.89 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet152_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_b16x8_cifar10_20210528-3e8e9178.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_b16x8_cifar10_20210528-3e8e9178.log.json) |
### Cifar100
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:---------:|:--------:|
| ResNet-50-b16x8 | 23.71 | 1.31 | 79.90 | 95.19 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb16_cifar100.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_b16x8_cifar100_20210528-67b58a1b.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_b16x8_cifar100_20210528-67b58a1b.log.json) |
| ResNet-50 | 23.71 | 1.31 | 79.90 | 95.19 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb16_cifar100.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_b16x8_cifar100_20210528-67b58a1b.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_b16x8_cifar100_20210528-67b58a1b.log.json) |
### ImageNet-1k
@ -57,6 +65,13 @@ The depth of representations is of central importance for many visual recognitio
*Models with \* are converted from the [official repo](https://github.com/pytorch/vision). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
### CUB-200-2011
| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Config | Download |
|:---------------------:|:------------:|:---------:|:---------:|:--------:|:---------:|:---------:|:---------:|
| ResNet-50 | [ImageNet-21k-mill](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_mill_3rdparty_in21k_20220307-bdb3a68b.pth) | 448x448 | 23.92 | 16.48 | 88.45 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb8_cub.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb8_cub_20220307-57840e60.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb8_cub_20220307-57840e60.log.json) |
## Citation
```

View File

@ -375,3 +375,16 @@ Models:
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c152_8xb32_in1k_20220214-c013291f.pth
Config: configs/resnet/resnetv1c152_8xb32_in1k.py
- Name: resnet50_8xb8_cub
Metadata:
FLOPs: 16480000000
Parameters: 23920000
In Collection: ResNet
Results:
- Dataset: CUB-200-2011
Metrics:
Top 1 Accuracy: 88.45
Task: Image Classification
Pretrain: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_mill_3rdparty_in21k_20220307-bdb3a68b.pth
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb8_cub_20220307-57840e60.pth
Config: configs/resnet/resnet50_8xb8_cub.py

View File

@ -0,0 +1,19 @@
_base_ = [
'../_base_/models/resnet50.py', '../_base_/datasets/cub_bs8_448.py',
'../_base_/schedules/cub_bs64.py', '../_base_/default_runtime.py'
]
# use pre-train weight converted from https://github.com/Alibaba-MIIL/ImageNet21K # noqa
checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_mill_3rdparty_in21k_20220307-bdb3a68b.pth' # noqa
model = dict(
type='ImageClassifier',
backbone=dict(
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint, prefix='backbone')),
head=dict(num_classes=200, ))
log_config = dict(interval=20) # log every 20 intervals
checkpoint_config = dict(
interval=1, max_keep_ckpts=3) # save last three checkpoints

View File

@ -41,6 +41,13 @@ The pre-trained models on ImageNet-21k are used to fine-tune, and therefore don'
*Models with \* are converted from the [official repo](https://github.com/microsoft/Swin-Transformer#main-results-on-imagenet-with-pretrained-models). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
### CUB-200-2011
| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Config | Download |
|:----------------:|:------------:|:---------:|:---------:|:--------:|:---------:|:---------:|:---------:|
| Swin-L | [ImageNet-21k](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin-base_3rdparty_in21k-384px.pth) | 384x384 | 195.51 | 100.04 | 91.87 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin-large_8xb8_cub.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin-large_8xb8_cub_384px_20220307-1bbaee6a.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin-large_8xb8_cub_384px_20220307-1bbaee6a.log.json) |
## Citation
```

View File

@ -186,3 +186,16 @@ Models:
Weights: https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth
Code: https://github.com/microsoft/Swin-Transformer/blob/777f6c66604bb5579086c4447efe3620344d95a9/models/swin_transformer.py#L458
Config: configs/swin_transformer/swin-large_16xb64_in1k-384px.py
- Name: swin-large_8xb8_cub_384px
Metadata:
FLOPs: 100040000000
Parameters: 195510000
In Collection: Swin-Transformer
Results:
- Dataset: CUB-200-2011
Metrics:
Top 1 Accuracy: 91.87
Task: Image Classification
Pretrain: https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin-large_3rdparty_in21k-384px.pth
Weights: https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin-large_8xb8_cub_384px_20220307-1bbaee6a.pth
Config: configs/swin_transformer/swin-large_8xb8_cub_384px.py

View File

@ -0,0 +1,37 @@
_base_ = [
'../_base_/models/swin_transformer/large_384.py',
'../_base_/datasets/cub_bs8_384.py', '../_base_/schedules/cub_bs64.py',
'../_base_/default_runtime.py'
]
# model settings
checkpoint = 'https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin-large_3rdparty_in21k-384px.pth' # noqa
model = dict(
type='ImageClassifier',
backbone=dict(
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint, prefix='backbone')),
head=dict(num_classes=200, ))
paramwise_cfg = dict(
norm_decay_mult=0.0,
bias_decay_mult=0.0,
custom_keys={
'.absolute_pos_embed': dict(decay_mult=0.0),
'.relative_position_bias_table': dict(decay_mult=0.0)
})
optimizer = dict(
_delete_=True,
type='AdamW',
lr=5e-6,
weight_decay=0.0005,
eps=1e-8,
betas=(0.9, 0.999),
paramwise_cfg=paramwise_cfg)
optimizer_config = dict(grad_clip=dict(max_norm=5.0), _delete_=True)
log_config = dict(interval=20) # log every 20 intervals
checkpoint_config = dict(
interval=1, max_keep_ckpts=3) # save last three checkpoints

View File

@ -3,6 +3,7 @@ from .base_dataset import BaseDataset
from .builder import (DATASETS, PIPELINES, SAMPLERS, build_dataloader,
build_dataset, build_sampler)
from .cifar import CIFAR10, CIFAR100
from .cub import CUB
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
KFoldDataset, RepeatDataset)
from .imagenet import ImageNet
@ -17,5 +18,5 @@ __all__ = [
'VOC', 'MultiLabelDataset', 'build_dataloader', 'build_dataset',
'DistributedSampler', 'ConcatDataset', 'RepeatDataset',
'ClassBalancedDataset', 'DATASETS', 'PIPELINES', 'ImageNet21k', 'SAMPLERS',
'build_sampler', 'RepeatAugSampler', 'KFoldDataset'
'build_sampler', 'RepeatAugSampler', 'KFoldDataset', 'CUB'
]

129
mmcls/datasets/cub.py Normal file
View File

@ -0,0 +1,129 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from .base_dataset import BaseDataset
from .builder import DATASETS
@DATASETS.register_module()
class CUB(BaseDataset):
"""The CUB-200-2011 Dataset.
Support the `CUB-200-2011 <http://www.vision.caltech.edu/visipedia/CUB-200-2011.html>`_ Dataset.
Comparing with the `CUB-200 <http://www.vision.caltech.edu/visipedia/CUB-200.html>`_ Dataset,
there are much more pictures in `CUB-200-2011`.
Args:
ann_file (str): the annotation file.
images.txt in CUB.
image_class_labels_file (str): the label file.
image_class_labels.txt in CUB.
train_test_split_file (str): the split file.
train_test_split_file.txt in CUB.
""" # noqa: E501
CLASSES = [
'Black_footed_Albatross', 'Laysan_Albatross', 'Sooty_Albatross',
'Groove_billed_Ani', 'Crested_Auklet', 'Least_Auklet',
'Parakeet_Auklet', 'Rhinoceros_Auklet', 'Brewer_Blackbird',
'Red_winged_Blackbird', 'Rusty_Blackbird', 'Yellow_headed_Blackbird',
'Bobolink', 'Indigo_Bunting', 'Lazuli_Bunting', 'Painted_Bunting',
'Cardinal', 'Spotted_Catbird', 'Gray_Catbird', 'Yellow_breasted_Chat',
'Eastern_Towhee', 'Chuck_will_Widow', 'Brandt_Cormorant',
'Red_faced_Cormorant', 'Pelagic_Cormorant', 'Bronzed_Cowbird',
'Shiny_Cowbird', 'Brown_Creeper', 'American_Crow', 'Fish_Crow',
'Black_billed_Cuckoo', 'Mangrove_Cuckoo', 'Yellow_billed_Cuckoo',
'Gray_crowned_Rosy_Finch', 'Purple_Finch', 'Northern_Flicker',
'Acadian_Flycatcher', 'Great_Crested_Flycatcher', 'Least_Flycatcher',
'Olive_sided_Flycatcher', 'Scissor_tailed_Flycatcher',
'Vermilion_Flycatcher', 'Yellow_bellied_Flycatcher', 'Frigatebird',
'Northern_Fulmar', 'Gadwall', 'American_Goldfinch',
'European_Goldfinch', 'Boat_tailed_Grackle', 'Eared_Grebe',
'Horned_Grebe', 'Pied_billed_Grebe', 'Western_Grebe', 'Blue_Grosbeak',
'Evening_Grosbeak', 'Pine_Grosbeak', 'Rose_breasted_Grosbeak',
'Pigeon_Guillemot', 'California_Gull', 'Glaucous_winged_Gull',
'Heermann_Gull', 'Herring_Gull', 'Ivory_Gull', 'Ring_billed_Gull',
'Slaty_backed_Gull', 'Western_Gull', 'Anna_Hummingbird',
'Ruby_throated_Hummingbird', 'Rufous_Hummingbird', 'Green_Violetear',
'Long_tailed_Jaeger', 'Pomarine_Jaeger', 'Blue_Jay', 'Florida_Jay',
'Green_Jay', 'Dark_eyed_Junco', 'Tropical_Kingbird', 'Gray_Kingbird',
'Belted_Kingfisher', 'Green_Kingfisher', 'Pied_Kingfisher',
'Ringed_Kingfisher', 'White_breasted_Kingfisher',
'Red_legged_Kittiwake', 'Horned_Lark', 'Pacific_Loon', 'Mallard',
'Western_Meadowlark', 'Hooded_Merganser', 'Red_breasted_Merganser',
'Mockingbird', 'Nighthawk', 'Clark_Nutcracker',
'White_breasted_Nuthatch', 'Baltimore_Oriole', 'Hooded_Oriole',
'Orchard_Oriole', 'Scott_Oriole', 'Ovenbird', 'Brown_Pelican',
'White_Pelican', 'Western_Wood_Pewee', 'Sayornis', 'American_Pipit',
'Whip_poor_Will', 'Horned_Puffin', 'Common_Raven',
'White_necked_Raven', 'American_Redstart', 'Geococcyx',
'Loggerhead_Shrike', 'Great_Grey_Shrike', 'Baird_Sparrow',
'Black_throated_Sparrow', 'Brewer_Sparrow', 'Chipping_Sparrow',
'Clay_colored_Sparrow', 'House_Sparrow', 'Field_Sparrow',
'Fox_Sparrow', 'Grasshopper_Sparrow', 'Harris_Sparrow',
'Henslow_Sparrow', 'Le_Conte_Sparrow', 'Lincoln_Sparrow',
'Nelson_Sharp_tailed_Sparrow', 'Savannah_Sparrow', 'Seaside_Sparrow',
'Song_Sparrow', 'Tree_Sparrow', 'Vesper_Sparrow',
'White_crowned_Sparrow', 'White_throated_Sparrow',
'Cape_Glossy_Starling', 'Bank_Swallow', 'Barn_Swallow',
'Cliff_Swallow', 'Tree_Swallow', 'Scarlet_Tanager', 'Summer_Tanager',
'Artic_Tern', 'Black_Tern', 'Caspian_Tern', 'Common_Tern',
'Elegant_Tern', 'Forsters_Tern', 'Least_Tern', 'Green_tailed_Towhee',
'Brown_Thrasher', 'Sage_Thrasher', 'Black_capped_Vireo',
'Blue_headed_Vireo', 'Philadelphia_Vireo', 'Red_eyed_Vireo',
'Warbling_Vireo', 'White_eyed_Vireo', 'Yellow_throated_Vireo',
'Bay_breasted_Warbler', 'Black_and_white_Warbler',
'Black_throated_Blue_Warbler', 'Blue_winged_Warbler', 'Canada_Warbler',
'Cape_May_Warbler', 'Cerulean_Warbler', 'Chestnut_sided_Warbler',
'Golden_winged_Warbler', 'Hooded_Warbler', 'Kentucky_Warbler',
'Magnolia_Warbler', 'Mourning_Warbler', 'Myrtle_Warbler',
'Nashville_Warbler', 'Orange_crowned_Warbler', 'Palm_Warbler',
'Pine_Warbler', 'Prairie_Warbler', 'Prothonotary_Warbler',
'Swainson_Warbler', 'Tennessee_Warbler', 'Wilson_Warbler',
'Worm_eating_Warbler', 'Yellow_Warbler', 'Northern_Waterthrush',
'Louisiana_Waterthrush', 'Bohemian_Waxwing', 'Cedar_Waxwing',
'American_Three_toed_Woodpecker', 'Pileated_Woodpecker',
'Red_bellied_Woodpecker', 'Red_cockaded_Woodpecker',
'Red_headed_Woodpecker', 'Downy_Woodpecker', 'Bewick_Wren',
'Cactus_Wren', 'Carolina_Wren', 'House_Wren', 'Marsh_Wren',
'Rock_Wren', 'Winter_Wren', 'Common_Yellowthroat'
]
def __init__(self, *args, ann_file, image_class_labels_file,
train_test_split_file, **kwargs):
self.image_class_labels_file = image_class_labels_file
self.train_test_split_file = train_test_split_file
super(CUB, self).__init__(*args, ann_file=ann_file, **kwargs)
def load_annotations(self):
with open(self.ann_file) as f:
samples = [x.strip().split(' ')[1] for x in f.readlines()]
with open(self.image_class_labels_file) as f:
gt_labels = [
# in the official CUB-200-2011 dataset, labels in
# image_class_labels_file are started from 1, so
# here we need to '- 1' to let them start from 0.
int(x.strip().split(' ')[1]) - 1 for x in f.readlines()
]
with open(self.train_test_split_file) as f:
splits = [int(x.strip().split(' ')[1]) for x in f.readlines()]
assert len(samples) == len(gt_labels) == len(splits),\
f'samples({len(samples)}), gt_labels({len(gt_labels)}) and ' \
f'splits({len(splits)}) should have same length.'
data_infos = []
for filename, gt_label, split in zip(samples, gt_labels, splits):
if split and self.test_mode:
# skip train samples when test_mode=True
continue
elif not split and not self.test_mode:
# skip test samples when test_mode=False
continue
info = {'img_prefix': self.data_prefix}
info['img_info'] = {'filename': filename}
info['gt_label'] = np.array(gt_label, dtype=np.int64)
data_infos.append(info)
return data_infos

View File

@ -1,18 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
import torch
from mmcls.datasets import (DATASETS, BaseDataset, ImageNet21k,
from mmcls.datasets import (CUB, DATASETS, BaseDataset, ImageNet21k,
MultiLabelDataset)
@pytest.mark.parametrize('dataset_name', [
'MNIST', 'FashionMNIST', 'CIFAR10', 'CIFAR100', 'ImageNet', 'VOC',
'ImageNet21k'
'ImageNet21k', 'CUB'
])
def test_datasets_override_default(dataset_name):
dataset_class = DATASETS.get(dataset_name)
@ -27,6 +27,15 @@ def test_datasets_override_default(dataset_name):
original_classes = dataset_class.CLASSES
# some datasets need extra argument to init
extra_kwargs_settings = {
'CUB':
dict(
ann_file=None,
image_class_labels_file=None,
train_test_split_file=None),
}
extra_kwargs = extra_kwargs_settings.get(dataset_name, dict())
# Test VOC year
if dataset_name == 'VOC':
dataset = dataset_class(
@ -47,7 +56,8 @@ def test_datasets_override_default(dataset_name):
data_prefix='VOC2007' if dataset_name == 'VOC' else '',
pipeline=[],
classes=('bus', 'car'),
test_mode=True)
test_mode=True,
**extra_kwargs)
assert dataset.CLASSES == ('bus', 'car')
# Test get_cat_ids
@ -61,17 +71,21 @@ def test_datasets_override_default(dataset_name):
data_prefix='VOC2007' if dataset_name == 'VOC' else '',
pipeline=[],
classes=['bus', 'car'],
test_mode=True)
test_mode=True,
**extra_kwargs)
assert dataset.CLASSES == ['bus', 'car']
# Test setting classes through a file
classes_file = osp.join(
osp.dirname(__file__), '../../data/dataset/classes.txt')
tmp_file = tempfile.NamedTemporaryFile()
with open(tmp_file.name, 'w') as f:
f.write('bus\ncar\n')
dataset = dataset_class(
data_prefix='VOC2007' if dataset_name == 'VOC' else '',
pipeline=[],
classes=classes_file,
test_mode=True)
classes=tmp_file.name,
test_mode=True,
**extra_kwargs)
tmp_file.close()
assert dataset.CLASSES == ['bus', 'car']
@ -80,12 +94,15 @@ def test_datasets_override_default(dataset_name):
data_prefix='VOC2007' if dataset_name == 'VOC' else '',
pipeline=[],
classes=['foo'],
test_mode=True)
test_mode=True,
**extra_kwargs)
assert dataset.CLASSES == ['foo']
# Test default behavior
dataset = dataset_class(
data_prefix='VOC2007' if dataset_name == 'VOC' else '', pipeline=[])
data_prefix='VOC2007' if dataset_name == 'VOC' else '',
pipeline=[],
**extra_kwargs)
if dataset_name == 'VOC':
assert dataset.data_prefix == 'VOC2007'
@ -308,3 +325,54 @@ def test_dataset_imagenet21k():
dataset = ImageNet21k(**dataset_cfg)
assert len(dataset) == 3
assert isinstance(dataset[0], dict)
def test_dataset_cub():
tmp_ann_file = tempfile.NamedTemporaryFile()
tmp_image_class_labels_file = tempfile.NamedTemporaryFile()
tmp_train_test_split_file = tempfile.NamedTemporaryFile()
with open(tmp_ann_file.name, 'w') as f:
f.write('1 1.txt \n2 2.txt \n')
with open(tmp_image_class_labels_file.name, 'w') as f:
f.write('1 1 \n2 2 \n')
with open(tmp_train_test_split_file.name, 'w') as f:
f.write('1 0 \n2 1 \n')
# test in train mode
dataset = CUB(
data_prefix='',
pipeline=[],
test_mode=False,
ann_file=tmp_ann_file.name,
image_class_labels_file=tmp_image_class_labels_file.name,
train_test_split_file=tmp_train_test_split_file.name)
assert len(dataset) == 1
# test in test mode
dataset = CUB(
data_prefix='',
pipeline=[],
test_mode=True,
ann_file=tmp_ann_file.name,
image_class_labels_file=tmp_image_class_labels_file.name,
train_test_split_file=tmp_train_test_split_file.name)
assert len(dataset) == 1
# test with different items in three files
with open(tmp_train_test_split_file.name, 'w') as f:
f.write('1 0 \n')
with pytest.raises(AssertionError, match='should have same length'):
dataset = CUB(
data_prefix='',
pipeline=[],
test_mode=True,
ann_file=tmp_ann_file.name,
image_class_labels_file=tmp_image_class_labels_file.name,
train_test_split_file=tmp_train_test_split_file.name)
tmp_ann_file.close()
tmp_image_class_labels_file.close()
tmp_train_test_split_file.close()