[Feature] Support Stanford Cars dataset. (#893)
* feat: add stanford car dataset * feat: add stanford car dataset * feat: add stanford car dataset * feat: add stanford car dataset * feat: add stanford car dataset * feat: add stanford car dataset * Update links and using cars insteam of car * place ependency scipy from runtime to optional * Fix docstring Co-authored-by: Ezra-Yu <1105212286@qq.com> Co-authored-by: mzr1996 <mzr1996@163.com>pull/976/head
parent
e54cfd6951
commit
7b16bcdd9b
|
@ -0,0 +1,46 @@
|
|||
# dataset settings
|
||||
dataset_type = 'StanfordCars'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', size=512),
|
||||
dict(type='RandomCrop', size=448),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', size=512),
|
||||
dict(type='CenterCrop', crop_size=448),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
|
||||
data_root = 'data/stanfordcars'
|
||||
data = dict(
|
||||
samples_per_gpu=8,
|
||||
workers_per_gpu=2,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_prefix=data_root,
|
||||
test_mode=False,
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_prefix=data_root,
|
||||
test_mode=True,
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_prefix=data_root,
|
||||
test_mode=True,
|
||||
pipeline=test_pipeline))
|
||||
|
||||
evaluation = dict(
|
||||
interval=1, metric='accuracy',
|
||||
save_best='auto') # save the checkpoint with highest accuracy
|
|
@ -0,0 +1,7 @@
|
|||
# optimizer
|
||||
optimizer = dict(
|
||||
type='SGD', lr=0.003, momentum=0.9, weight_decay=0.0005, nesterov=True)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[40, 70, 90])
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=100)
|
|
@ -72,6 +72,12 @@ The pre-trained models on ImageNet-21k are used to fine-tune, and therefore don'
|
|||
| :-------: | :--------------------------------------------------: | :--------: | :-------: | :------: | :-------: | :------------------------------------------------: | :---------------------------------------------------: |
|
||||
| ResNet-50 | [ImageNet-21k-mill](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_3rdparty-mill_in21k_20220331-faac000b.pth) | 448x448 | 23.92 | 16.48 | 88.45 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb8_cub.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb8_cub_20220307-57840e60.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb8_cub_20220307-57840e60.log.json) |
|
||||
|
||||
### Stanford-Cars
|
||||
|
||||
| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Config | Download |
|
||||
| :-------: | :--------------------------------------------------: | :--------: | :-------: | :------: | :-------: | :------------------------------------------------: | :---------------------------------------------------: |
|
||||
| ResNet-50 | [ImageNet-21k-mill](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_3rdparty-mill_in21k_20220331-faac000b.pth) | 448x448 | 23.92 | 16.48 | 92.82 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb8_cars.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb8_cars_20220812-9d85901a.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb8_cars_20220812-9d85901a.log.json) |
|
||||
|
||||
## Citation
|
||||
|
||||
```
|
||||
|
|
|
@ -350,3 +350,16 @@ Models:
|
|||
Pretrain: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_3rdparty-mill_in21k_20220331-faac000b.pth
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb8_cub_20220307-57840e60.pth
|
||||
Config: configs/resnet/resnet50_8xb8_cub.py
|
||||
- Name: resnet50_8xb8_cars
|
||||
Metadata:
|
||||
FLOPs: 16480000000
|
||||
Parameters: 23920000
|
||||
In Collection: ResNet
|
||||
Results:
|
||||
- Dataset: StanfordCars
|
||||
Metrics:
|
||||
Top 1 Accuracy: 92.82
|
||||
Task: Image Classification
|
||||
Pretrain: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_3rdparty-mill_in21k_20220331-faac000b.pth
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb8_cars_20220812-9d85901a.pth
|
||||
Config: configs/resnet/resnet50_8xb8_cars.py
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
_base_ = [
|
||||
'../_base_/models/resnet50.py',
|
||||
'../_base_/datasets/stanford_cars_bs8_448.py',
|
||||
'../_base_/schedules/stanford_cars_bs8.py', '../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
# use pre-train weight converted from https://github.com/Alibaba-MIIL/ImageNet21K # noqa
|
||||
checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_3rdparty-mill_in21k_20220331-faac000b.pth' # noqa
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
init_cfg=dict(
|
||||
type='Pretrained', checkpoint=checkpoint, prefix='backbone')),
|
||||
head=dict(num_classes=196, ))
|
||||
|
||||
log_config = dict(interval=50)
|
||||
checkpoint_config = dict(
|
||||
interval=1, max_keep_ckpts=3) # save last three checkpoints
|
|
@ -39,6 +39,11 @@ VOC
|
|||
|
||||
.. autoclass:: VOC
|
||||
|
||||
StanfordCars Cars
|
||||
-----------------
|
||||
|
||||
.. autoclass:: StanfordCars
|
||||
|
||||
Base classes
|
||||
------------
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ from .imagenet21k import ImageNet21k
|
|||
from .mnist import MNIST, FashionMNIST
|
||||
from .multi_label import MultiLabelDataset
|
||||
from .samplers import DistributedSampler, RepeatAugSampler
|
||||
from .stanford_cars import StanfordCars
|
||||
from .voc import VOC
|
||||
|
||||
__all__ = [
|
||||
|
@ -19,5 +20,6 @@ __all__ = [
|
|||
'VOC', 'MultiLabelDataset', 'build_dataloader', 'build_dataset',
|
||||
'DistributedSampler', 'ConcatDataset', 'RepeatDataset',
|
||||
'ClassBalancedDataset', 'DATASETS', 'PIPELINES', 'ImageNet21k', 'SAMPLERS',
|
||||
'build_sampler', 'RepeatAugSampler', 'KFoldDataset', 'CUB', 'CustomDataset'
|
||||
'build_sampler', 'RepeatAugSampler', 'KFoldDataset', 'CUB',
|
||||
'CustomDataset', 'StanfordCars'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,210 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .base_dataset import BaseDataset
|
||||
from .builder import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class StanfordCars(BaseDataset):
|
||||
"""`Stanford Cars`_ Dataset.
|
||||
|
||||
After downloading and decompression, the dataset
|
||||
directory structure is as follows.
|
||||
|
||||
Stanford Cars dataset directory::
|
||||
|
||||
Stanford Cars
|
||||
├── cars_train
|
||||
│ ├── 00001.jpg
|
||||
│ ├── 00002.jpg
|
||||
│ └── ...
|
||||
├── cars_test
|
||||
│ ├── 00001.jpg
|
||||
│ ├── 00002.jpg
|
||||
│ └── ...
|
||||
└── devkit
|
||||
├── cars_meta.mat
|
||||
├── cars_train_annos.mat
|
||||
├── cars_test_annos.mat
|
||||
├── cars_test_annoswithlabels.mat
|
||||
├── eval_train.m
|
||||
└── train_perfect_preds.txt
|
||||
|
||||
.. _Stanford Cars: https://ai.stanford.edu/~jkrause/cars/car_dataset.html
|
||||
|
||||
Args:
|
||||
data_prefix (str): the prefix of data path
|
||||
test_mode (bool): ``test_mode=True`` means in test phase. It determines
|
||||
to use the training set or test set.
|
||||
ann_file (str, optional): The annotation file. If is string, read
|
||||
samples paths from the ann_file. If is None, read samples path
|
||||
from cars_{train|test}_annos.mat file. Defaults to None.
|
||||
""" # noqa: E501
|
||||
|
||||
CLASSES = [
|
||||
'AM General Hummer SUV 2000', 'Acura RL Sedan 2012',
|
||||
'Acura TL Sedan 2012', 'Acura TL Type-S 2008', 'Acura TSX Sedan 2012',
|
||||
'Acura Integra Type R 2001', 'Acura ZDX Hatchback 2012',
|
||||
'Aston Martin V8 Vantage Convertible 2012',
|
||||
'Aston Martin V8 Vantage Coupe 2012',
|
||||
'Aston Martin Virage Convertible 2012',
|
||||
'Aston Martin Virage Coupe 2012', 'Audi RS 4 Convertible 2008',
|
||||
'Audi A5 Coupe 2012', 'Audi TTS Coupe 2012', 'Audi R8 Coupe 2012',
|
||||
'Audi V8 Sedan 1994', 'Audi 100 Sedan 1994', 'Audi 100 Wagon 1994',
|
||||
'Audi TT Hatchback 2011', 'Audi S6 Sedan 2011',
|
||||
'Audi S5 Convertible 2012', 'Audi S5 Coupe 2012', 'Audi S4 Sedan 2012',
|
||||
'Audi S4 Sedan 2007', 'Audi TT RS Coupe 2012',
|
||||
'BMW ActiveHybrid 5 Sedan 2012', 'BMW 1 Series Convertible 2012',
|
||||
'BMW 1 Series Coupe 2012', 'BMW 3 Series Sedan 2012',
|
||||
'BMW 3 Series Wagon 2012', 'BMW 6 Series Convertible 2007',
|
||||
'BMW X5 SUV 2007', 'BMW X6 SUV 2012', 'BMW M3 Coupe 2012',
|
||||
'BMW M5 Sedan 2010', 'BMW M6 Convertible 2010', 'BMW X3 SUV 2012',
|
||||
'BMW Z4 Convertible 2012',
|
||||
'Bentley Continental Supersports Conv. Convertible 2012',
|
||||
'Bentley Arnage Sedan 2009', 'Bentley Mulsanne Sedan 2011',
|
||||
'Bentley Continental GT Coupe 2012',
|
||||
'Bentley Continental GT Coupe 2007',
|
||||
'Bentley Continental Flying Spur Sedan 2007',
|
||||
'Bugatti Veyron 16.4 Convertible 2009',
|
||||
'Bugatti Veyron 16.4 Coupe 2009', 'Buick Regal GS 2012',
|
||||
'Buick Rainier SUV 2007', 'Buick Verano Sedan 2012',
|
||||
'Buick Enclave SUV 2012', 'Cadillac CTS-V Sedan 2012',
|
||||
'Cadillac SRX SUV 2012', 'Cadillac Escalade EXT Crew Cab 2007',
|
||||
'Chevrolet Silverado 1500 Hybrid Crew Cab 2012',
|
||||
'Chevrolet Corvette Convertible 2012', 'Chevrolet Corvette ZR1 2012',
|
||||
'Chevrolet Corvette Ron Fellows Edition Z06 2007',
|
||||
'Chevrolet Traverse SUV 2012', 'Chevrolet Camaro Convertible 2012',
|
||||
'Chevrolet HHR SS 2010', 'Chevrolet Impala Sedan 2007',
|
||||
'Chevrolet Tahoe Hybrid SUV 2012', 'Chevrolet Sonic Sedan 2012',
|
||||
'Chevrolet Express Cargo Van 2007',
|
||||
'Chevrolet Avalanche Crew Cab 2012', 'Chevrolet Cobalt SS 2010',
|
||||
'Chevrolet Malibu Hybrid Sedan 2010', 'Chevrolet TrailBlazer SS 2009',
|
||||
'Chevrolet Silverado 2500HD Regular Cab 2012',
|
||||
'Chevrolet Silverado 1500 Classic Extended Cab 2007',
|
||||
'Chevrolet Express Van 2007', 'Chevrolet Monte Carlo Coupe 2007',
|
||||
'Chevrolet Malibu Sedan 2007',
|
||||
'Chevrolet Silverado 1500 Extended Cab 2012',
|
||||
'Chevrolet Silverado 1500 Regular Cab 2012', 'Chrysler Aspen SUV 2009',
|
||||
'Chrysler Sebring Convertible 2010',
|
||||
'Chrysler Town and Country Minivan 2012', 'Chrysler 300 SRT-8 2010',
|
||||
'Chrysler Crossfire Convertible 2008',
|
||||
'Chrysler PT Cruiser Convertible 2008', 'Daewoo Nubira Wagon 2002',
|
||||
'Dodge Caliber Wagon 2012', 'Dodge Caliber Wagon 2007',
|
||||
'Dodge Caravan Minivan 1997', 'Dodge Ram Pickup 3500 Crew Cab 2010',
|
||||
'Dodge Ram Pickup 3500 Quad Cab 2009', 'Dodge Sprinter Cargo Van 2009',
|
||||
'Dodge Journey SUV 2012', 'Dodge Dakota Crew Cab 2010',
|
||||
'Dodge Dakota Club Cab 2007', 'Dodge Magnum Wagon 2008',
|
||||
'Dodge Challenger SRT8 2011', 'Dodge Durango SUV 2012',
|
||||
'Dodge Durango SUV 2007', 'Dodge Charger Sedan 2012',
|
||||
'Dodge Charger SRT-8 2009', 'Eagle Talon Hatchback 1998',
|
||||
'FIAT 500 Abarth 2012', 'FIAT 500 Convertible 2012',
|
||||
'Ferrari FF Coupe 2012', 'Ferrari California Convertible 2012',
|
||||
'Ferrari 458 Italia Convertible 2012', 'Ferrari 458 Italia Coupe 2012',
|
||||
'Fisker Karma Sedan 2012', 'Ford F-450 Super Duty Crew Cab 2012',
|
||||
'Ford Mustang Convertible 2007', 'Ford Freestar Minivan 2007',
|
||||
'Ford Expedition EL SUV 2009', 'Ford Edge SUV 2012',
|
||||
'Ford Ranger SuperCab 2011', 'Ford GT Coupe 2006',
|
||||
'Ford F-150 Regular Cab 2012', 'Ford F-150 Regular Cab 2007',
|
||||
'Ford Focus Sedan 2007', 'Ford E-Series Wagon Van 2012',
|
||||
'Ford Fiesta Sedan 2012', 'GMC Terrain SUV 2012',
|
||||
'GMC Savana Van 2012', 'GMC Yukon Hybrid SUV 2012',
|
||||
'GMC Acadia SUV 2012', 'GMC Canyon Extended Cab 2012',
|
||||
'Geo Metro Convertible 1993', 'HUMMER H3T Crew Cab 2010',
|
||||
'HUMMER H2 SUT Crew Cab 2009', 'Honda Odyssey Minivan 2012',
|
||||
'Honda Odyssey Minivan 2007', 'Honda Accord Coupe 2012',
|
||||
'Honda Accord Sedan 2012', 'Hyundai Veloster Hatchback 2012',
|
||||
'Hyundai Santa Fe SUV 2012', 'Hyundai Tucson SUV 2012',
|
||||
'Hyundai Veracruz SUV 2012', 'Hyundai Sonata Hybrid Sedan 2012',
|
||||
'Hyundai Elantra Sedan 2007', 'Hyundai Accent Sedan 2012',
|
||||
'Hyundai Genesis Sedan 2012', 'Hyundai Sonata Sedan 2012',
|
||||
'Hyundai Elantra Touring Hatchback 2012', 'Hyundai Azera Sedan 2012',
|
||||
'Infiniti G Coupe IPL 2012', 'Infiniti QX56 SUV 2011',
|
||||
'Isuzu Ascender SUV 2008', 'Jaguar XK XKR 2012',
|
||||
'Jeep Patriot SUV 2012', 'Jeep Wrangler SUV 2012',
|
||||
'Jeep Liberty SUV 2012', 'Jeep Grand Cherokee SUV 2012',
|
||||
'Jeep Compass SUV 2012', 'Lamborghini Reventon Coupe 2008',
|
||||
'Lamborghini Aventador Coupe 2012',
|
||||
'Lamborghini Gallardo LP 570-4 Superleggera 2012',
|
||||
'Lamborghini Diablo Coupe 2001', 'Land Rover Range Rover SUV 2012',
|
||||
'Land Rover LR2 SUV 2012', 'Lincoln Town Car Sedan 2011',
|
||||
'MINI Cooper Roadster Convertible 2012',
|
||||
'Maybach Landaulet Convertible 2012', 'Mazda Tribute SUV 2011',
|
||||
'McLaren MP4-12C Coupe 2012',
|
||||
'Mercedes-Benz 300-Class Convertible 1993',
|
||||
'Mercedes-Benz C-Class Sedan 2012',
|
||||
'Mercedes-Benz SL-Class Coupe 2009',
|
||||
'Mercedes-Benz E-Class Sedan 2012', 'Mercedes-Benz S-Class Sedan 2012',
|
||||
'Mercedes-Benz Sprinter Van 2012', 'Mitsubishi Lancer Sedan 2012',
|
||||
'Nissan Leaf Hatchback 2012', 'Nissan NV Passenger Van 2012',
|
||||
'Nissan Juke Hatchback 2012', 'Nissan 240SX Coupe 1998',
|
||||
'Plymouth Neon Coupe 1999', 'Porsche Panamera Sedan 2012',
|
||||
'Ram C/V Cargo Van Minivan 2012',
|
||||
'Rolls-Royce Phantom Drophead Coupe Convertible 2012',
|
||||
'Rolls-Royce Ghost Sedan 2012', 'Rolls-Royce Phantom Sedan 2012',
|
||||
'Scion xD Hatchback 2012', 'Spyker C8 Convertible 2009',
|
||||
'Spyker C8 Coupe 2009', 'Suzuki Aerio Sedan 2007',
|
||||
'Suzuki Kizashi Sedan 2012', 'Suzuki SX4 Hatchback 2012',
|
||||
'Suzuki SX4 Sedan 2012', 'Tesla Model S Sedan 2012',
|
||||
'Toyota Sequoia SUV 2012', 'Toyota Camry Sedan 2012',
|
||||
'Toyota Corolla Sedan 2012', 'Toyota 4Runner SUV 2012',
|
||||
'Volkswagen Golf Hatchback 2012', 'Volkswagen Golf Hatchback 1991',
|
||||
'Volkswagen Beetle Hatchback 2012', 'Volvo C30 Hatchback 2012',
|
||||
'Volvo 240 Sedan 1993', 'Volvo XC90 SUV 2007',
|
||||
'smart fortwo Convertible 2012'
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
data_prefix: str,
|
||||
test_mode: bool,
|
||||
ann_file: Optional[str] = None,
|
||||
**kwargs):
|
||||
if test_mode:
|
||||
if ann_file is not None:
|
||||
self.test_ann_file = ann_file
|
||||
else:
|
||||
self.test_ann_file = osp.join(
|
||||
data_prefix, 'devkit/cars_test_annos_withlabels.mat')
|
||||
data_prefix = osp.join(data_prefix, 'cars_test')
|
||||
else:
|
||||
if ann_file is not None:
|
||||
self.train_ann_file = ann_file
|
||||
else:
|
||||
self.train_ann_file = osp.join(data_prefix,
|
||||
'devkit/cars_train_annos.mat')
|
||||
data_prefix = osp.join(data_prefix, 'cars_train')
|
||||
super(StanfordCars, self).__init__(
|
||||
ann_file=ann_file,
|
||||
data_prefix=data_prefix,
|
||||
test_mode=test_mode,
|
||||
**kwargs)
|
||||
|
||||
def load_annotations(self):
|
||||
try:
|
||||
import scipy.io as sio
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'please run `pip install scipy` to install package `scipy`.')
|
||||
|
||||
data_infos = []
|
||||
if self.test_mode:
|
||||
data = sio.loadmat(self.test_ann_file)
|
||||
else:
|
||||
data = sio.loadmat(self.train_ann_file)
|
||||
for img in data['annotations'][0]:
|
||||
info = {'img_prefix': self.data_prefix}
|
||||
# The organization of each record is as follows,
|
||||
# 0: bbox_x1 of each image
|
||||
# 1: bbox_y1 of each image
|
||||
# 2: bbox_x2 of each image
|
||||
# 3: bbox_y2 of each image
|
||||
# 4: class_id, start from 0, so
|
||||
# here we need to '- 1' to let them start from 0
|
||||
# 5: file name of each image
|
||||
info['img_info'] = {'filename': img[5][0]}
|
||||
info['gt_label'] = np.array(img[4][0][0] - 1, dtype=np.int64)
|
||||
data_infos.append(info)
|
||||
return data_infos
|
|
@ -2,3 +2,4 @@ albumentations>=0.3.2 --no-binary qudida,albumentations
|
|||
colorama
|
||||
requests
|
||||
rich
|
||||
scipy
|
||||
|
|
|
@ -761,3 +761,151 @@ class TestCUB(TestBaseDataset):
|
|||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.tmpdir.cleanup()
|
||||
|
||||
|
||||
class TestStanfordCars(TestBaseDataset):
|
||||
DATASET_TYPE = 'StanfordCars'
|
||||
|
||||
def test_initialize(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
|
||||
with patch.object(dataset_class, 'load_annotations'):
|
||||
# Test with test_mode=False, ann_file is None
|
||||
cfg = {**self.DEFAULT_ARGS, 'test_mode': False, 'ann_file': None}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(dataset.CLASSES, dataset_class.CLASSES)
|
||||
self.assertFalse(dataset.test_mode)
|
||||
self.assertIsNone(dataset.ann_file)
|
||||
self.assertIsNotNone(dataset.train_ann_file)
|
||||
|
||||
# Test with test_mode=False, ann_file is not None
|
||||
cfg = {
|
||||
**self.DEFAULT_ARGS, 'test_mode': False,
|
||||
'ann_file': 'train_ann_file.mat'
|
||||
}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(dataset.CLASSES, dataset_class.CLASSES)
|
||||
self.assertFalse(dataset.test_mode)
|
||||
self.assertIsNotNone(dataset.ann_file)
|
||||
self.assertEqual(dataset.ann_file, 'train_ann_file.mat')
|
||||
self.assertIsNotNone(dataset.train_ann_file)
|
||||
|
||||
# Test with test_mode=True, ann_file is None
|
||||
cfg = {**self.DEFAULT_ARGS, 'test_mode': True, 'ann_file': None}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(dataset.CLASSES, dataset_class.CLASSES)
|
||||
self.assertTrue(dataset.test_mode)
|
||||
self.assertIsNone(dataset.ann_file)
|
||||
self.assertIsNotNone(dataset.test_ann_file)
|
||||
|
||||
# Test with test_mode=True, ann_file is not None
|
||||
cfg = {
|
||||
**self.DEFAULT_ARGS, 'test_mode': True,
|
||||
'ann_file': 'test_ann_file.mat'
|
||||
}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(dataset.CLASSES, dataset_class.CLASSES)
|
||||
self.assertTrue(dataset.test_mode)
|
||||
self.assertIsNotNone(dataset.ann_file)
|
||||
self.assertEqual(dataset.ann_file, 'test_ann_file.mat')
|
||||
self.assertIsNotNone(dataset.test_ann_file)
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
super().setUpClass()
|
||||
|
||||
tmpdir = tempfile.TemporaryDirectory()
|
||||
cls.tmpdir = tmpdir
|
||||
cls.data_prefix = tmpdir.name
|
||||
cls.ann_file = None
|
||||
devkit = osp.join(cls.data_prefix, 'devkit')
|
||||
if not osp.exists(devkit):
|
||||
os.mkdir(devkit)
|
||||
cls.train_ann_file = osp.join(devkit, 'cars_train_annos.mat')
|
||||
cls.test_ann_file = osp.join(devkit, 'cars_test_annos_withlabels.mat')
|
||||
cls.DEFAULT_ARGS = dict(
|
||||
data_prefix=cls.data_prefix, pipeline=[], test_mode=False)
|
||||
|
||||
try:
|
||||
import scipy.io as sio
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'please run `pip install scipy` to install package `scipy`.')
|
||||
|
||||
sio.savemat(
|
||||
cls.train_ann_file, {
|
||||
'annotations': [(
|
||||
(np.array([1]), np.array([10]), np.array(
|
||||
[20]), np.array([50]), 15, np.array(['001.jpg'])),
|
||||
(np.array([2]), np.array([15]), np.array(
|
||||
[240]), np.array([250]), 15, np.array(['002.jpg'])),
|
||||
(np.array([89]), np.array([150]), np.array(
|
||||
[278]), np.array([388]), 150, np.array(['012.jpg'])),
|
||||
)]
|
||||
})
|
||||
|
||||
sio.savemat(
|
||||
cls.test_ann_file, {
|
||||
'annotations':
|
||||
[((np.array([89]), np.array([150]), np.array(
|
||||
[278]), np.array([388]), 150, np.array(['025.jpg'])),
|
||||
(np.array([155]), np.array([10]), np.array(
|
||||
[200]), np.array([233]), 0, np.array(['111.jpg'])),
|
||||
(np.array([25]), np.array([115]), np.array(
|
||||
[240]), np.array([360]), 15, np.array(['265.jpg'])))]
|
||||
})
|
||||
|
||||
def test_load_annotations(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
|
||||
# Test with test_mode=False and ann_file=None
|
||||
dataset = dataset_class(**self.DEFAULT_ARGS)
|
||||
self.assertEqual(len(dataset), 3)
|
||||
self.assertEqual(dataset.CLASSES, dataset_class.CLASSES)
|
||||
|
||||
data_info = dataset[0]
|
||||
np.testing.assert_equal(data_info['img_prefix'],
|
||||
osp.join(self.data_prefix, 'cars_train'))
|
||||
np.testing.assert_equal(data_info['img_info'], {'filename': '001.jpg'})
|
||||
np.testing.assert_equal(data_info['gt_label'], 15 - 1)
|
||||
|
||||
# Test with test_mode=True and ann_file=None
|
||||
cfg = {**self.DEFAULT_ARGS, 'test_mode': True}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(len(dataset), 3)
|
||||
|
||||
data_info = dataset[0]
|
||||
np.testing.assert_equal(data_info['img_prefix'],
|
||||
osp.join(self.data_prefix, 'cars_test'))
|
||||
np.testing.assert_equal(data_info['img_info'], {'filename': '025.jpg'})
|
||||
np.testing.assert_equal(data_info['gt_label'], 150 - 1)
|
||||
|
||||
# Test with test_mode=False, ann_file is not None
|
||||
cfg = {
|
||||
**self.DEFAULT_ARGS, 'test_mode': False,
|
||||
'ann_file': self.train_ann_file
|
||||
}
|
||||
dataset = dataset_class(**cfg)
|
||||
data_info = dataset[0]
|
||||
np.testing.assert_equal(data_info['img_prefix'],
|
||||
osp.join(self.data_prefix, 'cars_train'))
|
||||
np.testing.assert_equal(data_info['img_info'], {'filename': '001.jpg'})
|
||||
np.testing.assert_equal(data_info['gt_label'], 15 - 1)
|
||||
|
||||
# Test with test_mode=True, ann_file is not None
|
||||
cfg = {
|
||||
**self.DEFAULT_ARGS, 'test_mode': True,
|
||||
'ann_file': self.test_ann_file
|
||||
}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(len(dataset), 3)
|
||||
|
||||
data_info = dataset[0]
|
||||
np.testing.assert_equal(data_info['img_prefix'],
|
||||
osp.join(self.data_prefix, 'cars_test'))
|
||||
np.testing.assert_equal(data_info['img_info'], {'filename': '025.jpg'})
|
||||
np.testing.assert_equal(data_info['gt_label'], 150 - 1)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.tmpdir.cleanup()
|
||||
|
|
Loading…
Reference in New Issue