mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Feature] Support CUB dataset. (#703)
* support cub dataset * support cub dataset * fix train lint error * add docs * fix class label Co-authored-by: Ezra-Yu <1105212286@qq.com> * del debug code * skip docformatter problem * add unit tests * add CUB baseline configs and chpts * fix some typos * fix name style * update flops Co-authored-by: Ezra-Yu <1105212286@qq.com>
This commit is contained in:
parent
c1534f9126
commit
aa522f4309
54
configs/_base_/datasets/cub_bs8_384.py
Normal file
54
configs/_base_/datasets/cub_bs8_384.py
Normal file
@ -0,0 +1,54 @@
|
||||
# dataset settings
|
||||
dataset_type = 'CUB'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', size=510),
|
||||
dict(type='RandomCrop', size=384),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', size=510),
|
||||
dict(type='CenterCrop', crop_size=384),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
|
||||
data_root = 'data/CUB_200_2011/'
|
||||
data = dict(
|
||||
samples_per_gpu=8,
|
||||
workers_per_gpu=2,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + 'images.txt',
|
||||
image_class_labels_file=data_root + 'image_class_labels.txt',
|
||||
train_test_split_file=data_root + 'train_test_split.txt',
|
||||
data_prefix=data_root + 'images',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + 'images.txt',
|
||||
image_class_labels_file=data_root + 'image_class_labels.txt',
|
||||
train_test_split_file=data_root + 'train_test_split.txt',
|
||||
data_prefix=data_root + 'images',
|
||||
test_mode=True,
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + 'images.txt',
|
||||
image_class_labels_file=data_root + 'image_class_labels.txt',
|
||||
train_test_split_file=data_root + 'train_test_split.txt',
|
||||
data_prefix=data_root + 'images',
|
||||
test_mode=True,
|
||||
pipeline=test_pipeline))
|
||||
|
||||
evaluation = dict(
|
||||
interval=1, metric='accuracy',
|
||||
save_best='auto') # save the checkpoint with highest accuracy
|
54
configs/_base_/datasets/cub_bs8_448.py
Normal file
54
configs/_base_/datasets/cub_bs8_448.py
Normal file
@ -0,0 +1,54 @@
|
||||
# dataset settings
|
||||
dataset_type = 'CUB'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', size=600),
|
||||
dict(type='RandomCrop', size=448),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', size=600),
|
||||
dict(type='CenterCrop', crop_size=448),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
|
||||
data_root = 'data/CUB_200_2011/'
|
||||
data = dict(
|
||||
samples_per_gpu=8,
|
||||
workers_per_gpu=2,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + 'images.txt',
|
||||
image_class_labels_file=data_root + 'image_class_labels.txt',
|
||||
train_test_split_file=data_root + 'train_test_split.txt',
|
||||
data_prefix=data_root + 'images',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + 'images.txt',
|
||||
image_class_labels_file=data_root + 'image_class_labels.txt',
|
||||
train_test_split_file=data_root + 'train_test_split.txt',
|
||||
data_prefix=data_root + 'images',
|
||||
test_mode=True,
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + 'images.txt',
|
||||
image_class_labels_file=data_root + 'image_class_labels.txt',
|
||||
train_test_split_file=data_root + 'train_test_split.txt',
|
||||
data_prefix=data_root + 'images',
|
||||
test_mode=True,
|
||||
pipeline=test_pipeline))
|
||||
|
||||
evaluation = dict(
|
||||
interval=1, metric='accuracy',
|
||||
save_best='auto') # save the checkpoint with highest accuracy
|
13
configs/_base_/schedules/cub_bs64.py
Normal file
13
configs/_base_/schedules/cub_bs64.py
Normal file
@ -0,0 +1,13 @@
|
||||
# optimizer
|
||||
optimizer = dict(
|
||||
type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005, nesterov=True)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
# learning policy
|
||||
lr_config = dict(
|
||||
policy='CosineAnnealing',
|
||||
min_lr=0,
|
||||
warmup='linear',
|
||||
warmup_iters=5,
|
||||
warmup_ratio=0.01,
|
||||
warmup_by_epoch=True)
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=100)
|
@ -15,21 +15,29 @@ The depth of representations is of central importance for many visual recognitio
|
||||
|
||||
## Results and models
|
||||
|
||||
The pre-trained models on ImageNet-21k are used to fine-tune, and therefore don't have evaluation results.
|
||||
|
||||
| Model | resolution | Params(M) | Flops(G) | Download |
|
||||
|:---------------:|:-----------:|:---------:|:---------:|:--------:|
|
||||
| ResNet-50-mill | 224x224 | 86.74 | 15.14 | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_mill_3rdparty_in21k_20220307-bdb3a68b.pth)|
|
||||
|
||||
*The "mill" means using the mutil-label pretrain weight from [ImageNet-21K Pretraining for the Masses](https://github.com/Alibaba-MIIL/ImageNet21K).*
|
||||
|
||||
### Cifar10
|
||||
|
||||
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:---------:|:--------:|
|
||||
| ResNet-18-b16x8 | 11.17 | 0.56 | 94.82 | 99.87 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.log.json) |
|
||||
| ResNet-34-b16x8 | 21.28 | 1.16 | 95.34 | 99.87 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.log.json) |
|
||||
| ResNet-50-b16x8 | 23.52 | 1.31 | 95.55 | 99.91 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_b16x8_cifar10_20210528-f54bfad9.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_b16x8_cifar10_20210528-f54bfad9.log.json) |
|
||||
| ResNet-101-b16x8 | 42.51 | 2.52 | 95.58 | 99.87 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet101_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_b16x8_cifar10_20210528-2d29e936.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_b16x8_cifar10_20210528-2d29e936.log.json) |
|
||||
| ResNet-152-b16x8 | 58.16 | 3.74 | 95.76 | 99.89 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet152_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_b16x8_cifar10_20210528-3e8e9178.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_b16x8_cifar10_20210528-3e8e9178.log.json) |
|
||||
| ResNet-18 | 11.17 | 0.56 | 94.82 | 99.87 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.log.json) |
|
||||
| ResNet-34 | 21.28 | 1.16 | 95.34 | 99.87 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.log.json) |
|
||||
| ResNet-50 | 23.52 | 1.31 | 95.55 | 99.91 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_b16x8_cifar10_20210528-f54bfad9.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_b16x8_cifar10_20210528-f54bfad9.log.json) |
|
||||
| ResNet-101 | 42.51 | 2.52 | 95.58 | 99.87 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet101_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_b16x8_cifar10_20210528-2d29e936.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_b16x8_cifar10_20210528-2d29e936.log.json) |
|
||||
| ResNet-152 | 58.16 | 3.74 | 95.76 | 99.89 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet152_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_b16x8_cifar10_20210528-3e8e9178.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_b16x8_cifar10_20210528-3e8e9178.log.json) |
|
||||
|
||||
### Cifar100
|
||||
|
||||
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:---------:|:--------:|
|
||||
| ResNet-50-b16x8 | 23.71 | 1.31 | 79.90 | 95.19 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb16_cifar100.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_b16x8_cifar100_20210528-67b58a1b.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_b16x8_cifar100_20210528-67b58a1b.log.json) |
|
||||
| ResNet-50 | 23.71 | 1.31 | 79.90 | 95.19 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb16_cifar100.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_b16x8_cifar100_20210528-67b58a1b.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_b16x8_cifar100_20210528-67b58a1b.log.json) |
|
||||
|
||||
### ImageNet-1k
|
||||
|
||||
@ -57,6 +65,13 @@ The depth of representations is of central importance for many visual recognitio
|
||||
|
||||
*Models with \* are converted from the [official repo](https://github.com/pytorch/vision). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
|
||||
|
||||
### CUB-200-2011
|
||||
|
||||
| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Config | Download |
|
||||
|:---------------------:|:------------:|:---------:|:---------:|:--------:|:---------:|:---------:|:---------:|
|
||||
| ResNet-50 | [ImageNet-21k-mill](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_mill_3rdparty_in21k_20220307-bdb3a68b.pth) | 448x448 | 23.92 | 16.48 | 88.45 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb8_cub.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb8_cub_20220307-57840e60.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb8_cub_20220307-57840e60.log.json) |
|
||||
|
||||
|
||||
## Citation
|
||||
|
||||
```
|
||||
|
@ -375,3 +375,16 @@ Models:
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1c152_8xb32_in1k_20220214-c013291f.pth
|
||||
Config: configs/resnet/resnetv1c152_8xb32_in1k.py
|
||||
- Name: resnet50_8xb8_cub
|
||||
Metadata:
|
||||
FLOPs: 16480000000
|
||||
Parameters: 23920000
|
||||
In Collection: ResNet
|
||||
Results:
|
||||
- Dataset: CUB-200-2011
|
||||
Metrics:
|
||||
Top 1 Accuracy: 88.45
|
||||
Task: Image Classification
|
||||
Pretrain: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_mill_3rdparty_in21k_20220307-bdb3a68b.pth
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb8_cub_20220307-57840e60.pth
|
||||
Config: configs/resnet/resnet50_8xb8_cub.py
|
||||
|
19
configs/resnet/resnet50_8xb8_cub.py
Normal file
19
configs/resnet/resnet50_8xb8_cub.py
Normal file
@ -0,0 +1,19 @@
|
||||
_base_ = [
|
||||
'../_base_/models/resnet50.py', '../_base_/datasets/cub_bs8_448.py',
|
||||
'../_base_/schedules/cub_bs64.py', '../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
# use pre-train weight converted from https://github.com/Alibaba-MIIL/ImageNet21K # noqa
|
||||
checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_mill_3rdparty_in21k_20220307-bdb3a68b.pth' # noqa
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
init_cfg=dict(
|
||||
type='Pretrained', checkpoint=checkpoint, prefix='backbone')),
|
||||
head=dict(num_classes=200, ))
|
||||
|
||||
log_config = dict(interval=20) # log every 20 intervals
|
||||
|
||||
checkpoint_config = dict(
|
||||
interval=1, max_keep_ckpts=3) # save last three checkpoints
|
@ -41,6 +41,13 @@ The pre-trained models on ImageNet-21k are used to fine-tune, and therefore don'
|
||||
|
||||
*Models with \* are converted from the [official repo](https://github.com/microsoft/Swin-Transformer#main-results-on-imagenet-with-pretrained-models). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
|
||||
|
||||
### CUB-200-2011
|
||||
|
||||
| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Config | Download |
|
||||
|:----------------:|:------------:|:---------:|:---------:|:--------:|:---------:|:---------:|:---------:|
|
||||
| Swin-L | [ImageNet-21k](https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin-base_3rdparty_in21k-384px.pth) | 384x384 | 195.51 | 100.04 | 91.87 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin-large_8xb8_cub.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin-large_8xb8_cub_384px_20220307-1bbaee6a.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin-large_8xb8_cub_384px_20220307-1bbaee6a.log.json) |
|
||||
|
||||
|
||||
## Citation
|
||||
|
||||
```
|
||||
|
@ -186,3 +186,16 @@ Models:
|
||||
Weights: https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth
|
||||
Code: https://github.com/microsoft/Swin-Transformer/blob/777f6c66604bb5579086c4447efe3620344d95a9/models/swin_transformer.py#L458
|
||||
Config: configs/swin_transformer/swin-large_16xb64_in1k-384px.py
|
||||
- Name: swin-large_8xb8_cub_384px
|
||||
Metadata:
|
||||
FLOPs: 100040000000
|
||||
Parameters: 195510000
|
||||
In Collection: Swin-Transformer
|
||||
Results:
|
||||
- Dataset: CUB-200-2011
|
||||
Metrics:
|
||||
Top 1 Accuracy: 91.87
|
||||
Task: Image Classification
|
||||
Pretrain: https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin-large_3rdparty_in21k-384px.pth
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin-large_8xb8_cub_384px_20220307-1bbaee6a.pth
|
||||
Config: configs/swin_transformer/swin-large_8xb8_cub_384px.py
|
||||
|
37
configs/swin_transformer/swin-large_8xb8_cub_384px.py
Normal file
37
configs/swin_transformer/swin-large_8xb8_cub_384px.py
Normal file
@ -0,0 +1,37 @@
|
||||
_base_ = [
|
||||
'../_base_/models/swin_transformer/large_384.py',
|
||||
'../_base_/datasets/cub_bs8_384.py', '../_base_/schedules/cub_bs64.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
# model settings
|
||||
checkpoint = 'https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin-large_3rdparty_in21k-384px.pth' # noqa
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
init_cfg=dict(
|
||||
type='Pretrained', checkpoint=checkpoint, prefix='backbone')),
|
||||
head=dict(num_classes=200, ))
|
||||
|
||||
paramwise_cfg = dict(
|
||||
norm_decay_mult=0.0,
|
||||
bias_decay_mult=0.0,
|
||||
custom_keys={
|
||||
'.absolute_pos_embed': dict(decay_mult=0.0),
|
||||
'.relative_position_bias_table': dict(decay_mult=0.0)
|
||||
})
|
||||
|
||||
optimizer = dict(
|
||||
_delete_=True,
|
||||
type='AdamW',
|
||||
lr=5e-6,
|
||||
weight_decay=0.0005,
|
||||
eps=1e-8,
|
||||
betas=(0.9, 0.999),
|
||||
paramwise_cfg=paramwise_cfg)
|
||||
optimizer_config = dict(grad_clip=dict(max_norm=5.0), _delete_=True)
|
||||
|
||||
log_config = dict(interval=20) # log every 20 intervals
|
||||
|
||||
checkpoint_config = dict(
|
||||
interval=1, max_keep_ckpts=3) # save last three checkpoints
|
@ -3,6 +3,7 @@ from .base_dataset import BaseDataset
|
||||
from .builder import (DATASETS, PIPELINES, SAMPLERS, build_dataloader,
|
||||
build_dataset, build_sampler)
|
||||
from .cifar import CIFAR10, CIFAR100
|
||||
from .cub import CUB
|
||||
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
|
||||
KFoldDataset, RepeatDataset)
|
||||
from .imagenet import ImageNet
|
||||
@ -17,5 +18,5 @@ __all__ = [
|
||||
'VOC', 'MultiLabelDataset', 'build_dataloader', 'build_dataset',
|
||||
'DistributedSampler', 'ConcatDataset', 'RepeatDataset',
|
||||
'ClassBalancedDataset', 'DATASETS', 'PIPELINES', 'ImageNet21k', 'SAMPLERS',
|
||||
'build_sampler', 'RepeatAugSampler', 'KFoldDataset'
|
||||
'build_sampler', 'RepeatAugSampler', 'KFoldDataset', 'CUB'
|
||||
]
|
||||
|
129
mmcls/datasets/cub.py
Normal file
129
mmcls/datasets/cub.py
Normal file
@ -0,0 +1,129 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
|
||||
from .base_dataset import BaseDataset
|
||||
from .builder import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class CUB(BaseDataset):
|
||||
"""The CUB-200-2011 Dataset.
|
||||
|
||||
Support the `CUB-200-2011 <http://www.vision.caltech.edu/visipedia/CUB-200-2011.html>`_ Dataset.
|
||||
Comparing with the `CUB-200 <http://www.vision.caltech.edu/visipedia/CUB-200.html>`_ Dataset,
|
||||
there are much more pictures in `CUB-200-2011`.
|
||||
|
||||
Args:
|
||||
ann_file (str): the annotation file.
|
||||
images.txt in CUB.
|
||||
image_class_labels_file (str): the label file.
|
||||
image_class_labels.txt in CUB.
|
||||
train_test_split_file (str): the split file.
|
||||
train_test_split_file.txt in CUB.
|
||||
""" # noqa: E501
|
||||
|
||||
CLASSES = [
|
||||
'Black_footed_Albatross', 'Laysan_Albatross', 'Sooty_Albatross',
|
||||
'Groove_billed_Ani', 'Crested_Auklet', 'Least_Auklet',
|
||||
'Parakeet_Auklet', 'Rhinoceros_Auklet', 'Brewer_Blackbird',
|
||||
'Red_winged_Blackbird', 'Rusty_Blackbird', 'Yellow_headed_Blackbird',
|
||||
'Bobolink', 'Indigo_Bunting', 'Lazuli_Bunting', 'Painted_Bunting',
|
||||
'Cardinal', 'Spotted_Catbird', 'Gray_Catbird', 'Yellow_breasted_Chat',
|
||||
'Eastern_Towhee', 'Chuck_will_Widow', 'Brandt_Cormorant',
|
||||
'Red_faced_Cormorant', 'Pelagic_Cormorant', 'Bronzed_Cowbird',
|
||||
'Shiny_Cowbird', 'Brown_Creeper', 'American_Crow', 'Fish_Crow',
|
||||
'Black_billed_Cuckoo', 'Mangrove_Cuckoo', 'Yellow_billed_Cuckoo',
|
||||
'Gray_crowned_Rosy_Finch', 'Purple_Finch', 'Northern_Flicker',
|
||||
'Acadian_Flycatcher', 'Great_Crested_Flycatcher', 'Least_Flycatcher',
|
||||
'Olive_sided_Flycatcher', 'Scissor_tailed_Flycatcher',
|
||||
'Vermilion_Flycatcher', 'Yellow_bellied_Flycatcher', 'Frigatebird',
|
||||
'Northern_Fulmar', 'Gadwall', 'American_Goldfinch',
|
||||
'European_Goldfinch', 'Boat_tailed_Grackle', 'Eared_Grebe',
|
||||
'Horned_Grebe', 'Pied_billed_Grebe', 'Western_Grebe', 'Blue_Grosbeak',
|
||||
'Evening_Grosbeak', 'Pine_Grosbeak', 'Rose_breasted_Grosbeak',
|
||||
'Pigeon_Guillemot', 'California_Gull', 'Glaucous_winged_Gull',
|
||||
'Heermann_Gull', 'Herring_Gull', 'Ivory_Gull', 'Ring_billed_Gull',
|
||||
'Slaty_backed_Gull', 'Western_Gull', 'Anna_Hummingbird',
|
||||
'Ruby_throated_Hummingbird', 'Rufous_Hummingbird', 'Green_Violetear',
|
||||
'Long_tailed_Jaeger', 'Pomarine_Jaeger', 'Blue_Jay', 'Florida_Jay',
|
||||
'Green_Jay', 'Dark_eyed_Junco', 'Tropical_Kingbird', 'Gray_Kingbird',
|
||||
'Belted_Kingfisher', 'Green_Kingfisher', 'Pied_Kingfisher',
|
||||
'Ringed_Kingfisher', 'White_breasted_Kingfisher',
|
||||
'Red_legged_Kittiwake', 'Horned_Lark', 'Pacific_Loon', 'Mallard',
|
||||
'Western_Meadowlark', 'Hooded_Merganser', 'Red_breasted_Merganser',
|
||||
'Mockingbird', 'Nighthawk', 'Clark_Nutcracker',
|
||||
'White_breasted_Nuthatch', 'Baltimore_Oriole', 'Hooded_Oriole',
|
||||
'Orchard_Oriole', 'Scott_Oriole', 'Ovenbird', 'Brown_Pelican',
|
||||
'White_Pelican', 'Western_Wood_Pewee', 'Sayornis', 'American_Pipit',
|
||||
'Whip_poor_Will', 'Horned_Puffin', 'Common_Raven',
|
||||
'White_necked_Raven', 'American_Redstart', 'Geococcyx',
|
||||
'Loggerhead_Shrike', 'Great_Grey_Shrike', 'Baird_Sparrow',
|
||||
'Black_throated_Sparrow', 'Brewer_Sparrow', 'Chipping_Sparrow',
|
||||
'Clay_colored_Sparrow', 'House_Sparrow', 'Field_Sparrow',
|
||||
'Fox_Sparrow', 'Grasshopper_Sparrow', 'Harris_Sparrow',
|
||||
'Henslow_Sparrow', 'Le_Conte_Sparrow', 'Lincoln_Sparrow',
|
||||
'Nelson_Sharp_tailed_Sparrow', 'Savannah_Sparrow', 'Seaside_Sparrow',
|
||||
'Song_Sparrow', 'Tree_Sparrow', 'Vesper_Sparrow',
|
||||
'White_crowned_Sparrow', 'White_throated_Sparrow',
|
||||
'Cape_Glossy_Starling', 'Bank_Swallow', 'Barn_Swallow',
|
||||
'Cliff_Swallow', 'Tree_Swallow', 'Scarlet_Tanager', 'Summer_Tanager',
|
||||
'Artic_Tern', 'Black_Tern', 'Caspian_Tern', 'Common_Tern',
|
||||
'Elegant_Tern', 'Forsters_Tern', 'Least_Tern', 'Green_tailed_Towhee',
|
||||
'Brown_Thrasher', 'Sage_Thrasher', 'Black_capped_Vireo',
|
||||
'Blue_headed_Vireo', 'Philadelphia_Vireo', 'Red_eyed_Vireo',
|
||||
'Warbling_Vireo', 'White_eyed_Vireo', 'Yellow_throated_Vireo',
|
||||
'Bay_breasted_Warbler', 'Black_and_white_Warbler',
|
||||
'Black_throated_Blue_Warbler', 'Blue_winged_Warbler', 'Canada_Warbler',
|
||||
'Cape_May_Warbler', 'Cerulean_Warbler', 'Chestnut_sided_Warbler',
|
||||
'Golden_winged_Warbler', 'Hooded_Warbler', 'Kentucky_Warbler',
|
||||
'Magnolia_Warbler', 'Mourning_Warbler', 'Myrtle_Warbler',
|
||||
'Nashville_Warbler', 'Orange_crowned_Warbler', 'Palm_Warbler',
|
||||
'Pine_Warbler', 'Prairie_Warbler', 'Prothonotary_Warbler',
|
||||
'Swainson_Warbler', 'Tennessee_Warbler', 'Wilson_Warbler',
|
||||
'Worm_eating_Warbler', 'Yellow_Warbler', 'Northern_Waterthrush',
|
||||
'Louisiana_Waterthrush', 'Bohemian_Waxwing', 'Cedar_Waxwing',
|
||||
'American_Three_toed_Woodpecker', 'Pileated_Woodpecker',
|
||||
'Red_bellied_Woodpecker', 'Red_cockaded_Woodpecker',
|
||||
'Red_headed_Woodpecker', 'Downy_Woodpecker', 'Bewick_Wren',
|
||||
'Cactus_Wren', 'Carolina_Wren', 'House_Wren', 'Marsh_Wren',
|
||||
'Rock_Wren', 'Winter_Wren', 'Common_Yellowthroat'
|
||||
]
|
||||
|
||||
def __init__(self, *args, ann_file, image_class_labels_file,
|
||||
train_test_split_file, **kwargs):
|
||||
self.image_class_labels_file = image_class_labels_file
|
||||
self.train_test_split_file = train_test_split_file
|
||||
super(CUB, self).__init__(*args, ann_file=ann_file, **kwargs)
|
||||
|
||||
def load_annotations(self):
|
||||
with open(self.ann_file) as f:
|
||||
samples = [x.strip().split(' ')[1] for x in f.readlines()]
|
||||
|
||||
with open(self.image_class_labels_file) as f:
|
||||
gt_labels = [
|
||||
# in the official CUB-200-2011 dataset, labels in
|
||||
# image_class_labels_file are started from 1, so
|
||||
# here we need to '- 1' to let them start from 0.
|
||||
int(x.strip().split(' ')[1]) - 1 for x in f.readlines()
|
||||
]
|
||||
|
||||
with open(self.train_test_split_file) as f:
|
||||
splits = [int(x.strip().split(' ')[1]) for x in f.readlines()]
|
||||
|
||||
assert len(samples) == len(gt_labels) == len(splits),\
|
||||
f'samples({len(samples)}), gt_labels({len(gt_labels)}) and ' \
|
||||
f'splits({len(splits)}) should have same length.'
|
||||
|
||||
data_infos = []
|
||||
for filename, gt_label, split in zip(samples, gt_labels, splits):
|
||||
if split and self.test_mode:
|
||||
# skip train samples when test_mode=True
|
||||
continue
|
||||
elif not split and not self.test_mode:
|
||||
# skip test samples when test_mode=False
|
||||
continue
|
||||
info = {'img_prefix': self.data_prefix}
|
||||
info['img_info'] = {'filename': filename}
|
||||
info['gt_label'] = np.array(gt_label, dtype=np.int64)
|
||||
data_infos.append(info)
|
||||
return data_infos
|
@ -1,18 +1,18 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcls.datasets import (DATASETS, BaseDataset, ImageNet21k,
|
||||
from mmcls.datasets import (CUB, DATASETS, BaseDataset, ImageNet21k,
|
||||
MultiLabelDataset)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dataset_name', [
|
||||
'MNIST', 'FashionMNIST', 'CIFAR10', 'CIFAR100', 'ImageNet', 'VOC',
|
||||
'ImageNet21k'
|
||||
'ImageNet21k', 'CUB'
|
||||
])
|
||||
def test_datasets_override_default(dataset_name):
|
||||
dataset_class = DATASETS.get(dataset_name)
|
||||
@ -27,6 +27,15 @@ def test_datasets_override_default(dataset_name):
|
||||
|
||||
original_classes = dataset_class.CLASSES
|
||||
|
||||
# some datasets need extra argument to init
|
||||
extra_kwargs_settings = {
|
||||
'CUB':
|
||||
dict(
|
||||
ann_file=None,
|
||||
image_class_labels_file=None,
|
||||
train_test_split_file=None),
|
||||
}
|
||||
extra_kwargs = extra_kwargs_settings.get(dataset_name, dict())
|
||||
# Test VOC year
|
||||
if dataset_name == 'VOC':
|
||||
dataset = dataset_class(
|
||||
@ -47,7 +56,8 @@ def test_datasets_override_default(dataset_name):
|
||||
data_prefix='VOC2007' if dataset_name == 'VOC' else '',
|
||||
pipeline=[],
|
||||
classes=('bus', 'car'),
|
||||
test_mode=True)
|
||||
test_mode=True,
|
||||
**extra_kwargs)
|
||||
assert dataset.CLASSES == ('bus', 'car')
|
||||
|
||||
# Test get_cat_ids
|
||||
@ -61,17 +71,21 @@ def test_datasets_override_default(dataset_name):
|
||||
data_prefix='VOC2007' if dataset_name == 'VOC' else '',
|
||||
pipeline=[],
|
||||
classes=['bus', 'car'],
|
||||
test_mode=True)
|
||||
test_mode=True,
|
||||
**extra_kwargs)
|
||||
assert dataset.CLASSES == ['bus', 'car']
|
||||
|
||||
# Test setting classes through a file
|
||||
classes_file = osp.join(
|
||||
osp.dirname(__file__), '../../data/dataset/classes.txt')
|
||||
tmp_file = tempfile.NamedTemporaryFile()
|
||||
with open(tmp_file.name, 'w') as f:
|
||||
f.write('bus\ncar\n')
|
||||
dataset = dataset_class(
|
||||
data_prefix='VOC2007' if dataset_name == 'VOC' else '',
|
||||
pipeline=[],
|
||||
classes=classes_file,
|
||||
test_mode=True)
|
||||
classes=tmp_file.name,
|
||||
test_mode=True,
|
||||
**extra_kwargs)
|
||||
tmp_file.close()
|
||||
|
||||
assert dataset.CLASSES == ['bus', 'car']
|
||||
|
||||
@ -80,12 +94,15 @@ def test_datasets_override_default(dataset_name):
|
||||
data_prefix='VOC2007' if dataset_name == 'VOC' else '',
|
||||
pipeline=[],
|
||||
classes=['foo'],
|
||||
test_mode=True)
|
||||
test_mode=True,
|
||||
**extra_kwargs)
|
||||
assert dataset.CLASSES == ['foo']
|
||||
|
||||
# Test default behavior
|
||||
dataset = dataset_class(
|
||||
data_prefix='VOC2007' if dataset_name == 'VOC' else '', pipeline=[])
|
||||
data_prefix='VOC2007' if dataset_name == 'VOC' else '',
|
||||
pipeline=[],
|
||||
**extra_kwargs)
|
||||
|
||||
if dataset_name == 'VOC':
|
||||
assert dataset.data_prefix == 'VOC2007'
|
||||
@ -308,3 +325,54 @@ def test_dataset_imagenet21k():
|
||||
dataset = ImageNet21k(**dataset_cfg)
|
||||
assert len(dataset) == 3
|
||||
assert isinstance(dataset[0], dict)
|
||||
|
||||
|
||||
def test_dataset_cub():
|
||||
tmp_ann_file = tempfile.NamedTemporaryFile()
|
||||
tmp_image_class_labels_file = tempfile.NamedTemporaryFile()
|
||||
tmp_train_test_split_file = tempfile.NamedTemporaryFile()
|
||||
|
||||
with open(tmp_ann_file.name, 'w') as f:
|
||||
f.write('1 1.txt \n2 2.txt \n')
|
||||
with open(tmp_image_class_labels_file.name, 'w') as f:
|
||||
f.write('1 1 \n2 2 \n')
|
||||
with open(tmp_train_test_split_file.name, 'w') as f:
|
||||
f.write('1 0 \n2 1 \n')
|
||||
|
||||
# test in train mode
|
||||
dataset = CUB(
|
||||
data_prefix='',
|
||||
pipeline=[],
|
||||
test_mode=False,
|
||||
ann_file=tmp_ann_file.name,
|
||||
image_class_labels_file=tmp_image_class_labels_file.name,
|
||||
train_test_split_file=tmp_train_test_split_file.name)
|
||||
|
||||
assert len(dataset) == 1
|
||||
|
||||
# test in test mode
|
||||
dataset = CUB(
|
||||
data_prefix='',
|
||||
pipeline=[],
|
||||
test_mode=True,
|
||||
ann_file=tmp_ann_file.name,
|
||||
image_class_labels_file=tmp_image_class_labels_file.name,
|
||||
train_test_split_file=tmp_train_test_split_file.name)
|
||||
|
||||
assert len(dataset) == 1
|
||||
|
||||
# test with different items in three files
|
||||
with open(tmp_train_test_split_file.name, 'w') as f:
|
||||
f.write('1 0 \n')
|
||||
with pytest.raises(AssertionError, match='should have same length'):
|
||||
dataset = CUB(
|
||||
data_prefix='',
|
||||
pipeline=[],
|
||||
test_mode=True,
|
||||
ann_file=tmp_ann_file.name,
|
||||
image_class_labels_file=tmp_image_class_labels_file.name,
|
||||
train_test_split_file=tmp_train_test_split_file.name)
|
||||
|
||||
tmp_ann_file.close()
|
||||
tmp_image_class_labels_file.close()
|
||||
tmp_train_test_split_file.close()
|
||||
|
Loading…
x
Reference in New Issue
Block a user