mirror of
https://github.com/open-mmlab/mmfewshot.git
synced 2025-06-03 14:49:43 +08:00
* fix init * fix test api fix test api bug * add metarcnn fsdetview config * add pr * add metatestparallel comments * add test code and fix typos * add test code of model frozen * update test det forward code * update pr * update doc str
63 lines
2.0 KiB
Python
63 lines
2.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest.mock import MagicMock
|
|
|
|
from mmfewshot.classification.datasets import (CUBDataset, MiniImageNetDataset,
|
|
TieredImageNetDataset)
|
|
|
|
|
|
def test_cub_dataset():
|
|
CUBDataset.load_annotations = MagicMock(return_value=[])
|
|
dataset = CUBDataset(
|
|
data_prefix='',
|
|
subset='train',
|
|
pipeline=[dict(type='LoadImageFromFile')])
|
|
assert len(dataset.CLASSES) == 100
|
|
dataset = CUBDataset(
|
|
data_prefix='',
|
|
subset='val',
|
|
pipeline=[dict(type='LoadImageFromFile')])
|
|
assert len(dataset.CLASSES) == 50
|
|
dataset = CUBDataset(
|
|
data_prefix='',
|
|
subset='test',
|
|
pipeline=[dict(type='LoadImageFromFile')])
|
|
assert len(dataset.CLASSES) == 50
|
|
|
|
|
|
def test_mini_imagenet_dataset():
|
|
MiniImageNetDataset.load_annotations = MagicMock(return_value=[])
|
|
dataset = MiniImageNetDataset(
|
|
data_prefix='',
|
|
subset='train',
|
|
pipeline=[dict(type='LoadImageFromFile')])
|
|
assert len(dataset.CLASSES) == 64
|
|
dataset = MiniImageNetDataset(
|
|
data_prefix='',
|
|
subset='val',
|
|
pipeline=[dict(type='LoadImageFromFile')])
|
|
assert len(dataset.CLASSES) == 16
|
|
dataset = MiniImageNetDataset(
|
|
data_prefix='',
|
|
subset='test',
|
|
pipeline=[dict(type='LoadImageFromFile')])
|
|
assert len(dataset.CLASSES) == 20
|
|
|
|
|
|
def test_tiered_imagenet_dataset():
|
|
TieredImageNetDataset.load_annotations = MagicMock(return_value=[])
|
|
dataset = TieredImageNetDataset(
|
|
data_prefix='',
|
|
subset='train',
|
|
pipeline=[dict(type='LoadImageFromFile')])
|
|
assert len(dataset.CLASSES) == 351
|
|
dataset = TieredImageNetDataset(
|
|
data_prefix='',
|
|
subset='val',
|
|
pipeline=[dict(type='LoadImageFromFile')])
|
|
assert len(dataset.CLASSES) == 97
|
|
dataset = TieredImageNetDataset(
|
|
data_prefix='',
|
|
subset='test',
|
|
pipeline=[dict(type='LoadImageFromFile')])
|
|
assert len(dataset.CLASSES) == 160
|