implementation of datasets and dataloader (#2)

* implementation of datasets and dataloader

* add testing file

* fix MergeDataset flag bug and QueryAwareDataset one shot setting bug

* Set all flag to 0 in three datasetwrapper

* Fix NwayKshotDataloader sampler bug and fix some review comments

* add pytest file

* add voc test data for pytest

* finish test file for few shot custom dataset

* finish test file for few shot custom dataset

* finish test file for few shot custom dataset

* finish test file for merge dataset

* finish test file nwaykshot dataset

* cover more coner case in both datasets and add test file for query aware dataset and nwaykshot dataset

* finish test file of dataloader and fix all random seed in test file

* remove config

* avoid ann info change

* fix voc comments

* fix voc comments

* Lyq dataset dataloader (#4)

* fix voc comments

* fix voc comments

* fix voc comments

* fix comment and refactoring FewShotCustomDataset FewShotVOCDataset

* add coco dataset and test file

* Lyq dataset dataloader (#6)

* fix comments

* fix comments

* Lyq dataset dataloader (#7)

* fix comments

* fix comments

* fix comments
pull/1/head
Linyiqi 2021-05-24 17:07:43 +08:00 committed by GitHub
parent 6f15e33ab9
commit 85933cb556
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
67 changed files with 4348 additions and 118 deletions

View File

@ -1,49 +0,0 @@
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
evaluation = dict(interval=1, metric='bbox')

View File

@ -0,0 +1,116 @@
ALL_CLASSES = {
1: ('aeroplane', 'bicycle', 'boat', 'bottle', 'car', 'cat', 'chair',
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'sheep',
'train', 'tvmonitor', 'bird', 'bus', 'cow', 'motorbike', 'sofa'),
2: ('bicycle', 'bird', 'boat', 'bus', 'car', 'cat', 'chair', 'diningtable',
'dog', 'motorbike', 'person', 'pottedplant', 'sheep', 'train',
'tvmonitor', 'aeroplane', 'bottle', 'cow', 'horse', 'sofa'),
3: ('aeroplane', 'bicycle', 'bird', 'bottle', 'bus', 'car', 'chair', 'cow',
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'train',
'tvmonitor', 'boat', 'cat', 'motorbike', 'sheep', 'sofa')
}
NOVEL_CLASSES = {
1: ('bird', 'bus', 'cow', 'motorbike', 'sofa'),
2: ('aeroplane', 'bottle', 'cow', 'horse', 'sofa'),
3: ('boat', 'cat', 'motorbike', 'sheep', 'sofa'),
}
BASE_CLASSES = {
1: ('aeroplane', 'bicycle', 'boat', 'bottle', 'car', 'cat', 'chair',
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'sheep',
'train', 'tvmonitor'),
2: ('bicycle', 'bird', 'boat', 'bus', 'car', 'cat', 'chair', 'diningtable',
'dog', 'motorbike', 'person', 'pottedplant', 'sheep', 'train',
'tvmonitor'),
3: ('aeroplane', 'bicycle', 'bird', 'bottle', 'bus', 'car', 'chair', 'cow',
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'train',
'tvmonitor')
}
# dataset settings
data_root = 'data/VOCdevkit/'
split = 1
all_classes = ALL_CLASSES[split]
base_classes = BASE_CLASSES[split]
novel_classes = NOVEL_CLASSES[split]
num_base_shot = 1
num_novel_shot = 1
# load few shot data :
# each ann file corresponding to one class
# all file should use same image prefix
ann_file_root = 'data/few_shot_voc_split/'
ann_file_per_class = [] # file path
img_prefix_per_class = [] # image prefix
ann_shot_filter_per_class = [] # ann filter for each ann file
for class_name in base_classes:
ann_file_per_class.append(
ann_file_root +
f'{num_base_shot}shot/box_{num_base_shot}shot_{class_name}_train.txt')
img_prefix_per_class.append(data_root)
ann_shot_filter_per_class.append({class_name: num_base_shot})
for class_name in novel_classes:
ann_file_per_class.append(
ann_file_root +
f'{num_novel_shot}shot/box_{num_novel_shot}shot_{class_name}_train.txt'
)
img_prefix_per_class.append(data_root)
ann_shot_filter_per_class.append({class_name: num_novel_shot})
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=[(1333, 480), (1333, 800)], keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type='RepeatDataset',
times=3,
dataset=dict(
type='FewShotVOCDataset',
ann_file=ann_file_per_class,
img_prefix=img_prefix_per_class,
ann_masks=ann_shot_filter_per_class,
pipeline=train_pipeline,
classes=all_classes,
merge_dataset=True)),
val=dict(
type='VOCDataset',
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
img_prefix=data_root + 'VOC2007/',
pipeline=test_pipeline,
classes=novel_classes),
test=dict(
type='VOCDataset',
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
img_prefix=data_root + 'VOC2007/',
pipeline=test_pipeline,
classes=novel_classes))
evaluation = dict(interval=1, metric='mAP')

View File

@ -0,0 +1,92 @@
ALL_CLASSES = {
1: ('aeroplane', 'bicycle', 'boat', 'bottle', 'car', 'cat', 'chair',
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'sheep',
'train', 'tvmonitor', 'bird', 'bus', 'cow', 'motorbike', 'sofa'),
2: ('bicycle', 'bird', 'boat', 'bus', 'car', 'cat', 'chair', 'diningtable',
'dog', 'motorbike', 'person', 'pottedplant', 'sheep', 'train',
'tvmonitor', 'aeroplane', 'bottle', 'cow', 'horse', 'sofa'),
3: ('aeroplane', 'bicycle', 'bird', 'bottle', 'bus', 'car', 'chair', 'cow',
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'train',
'tvmonitor', 'boat', 'cat', 'motorbike', 'sheep', 'sofa')
}
NOVEL_CLASSES = {
1: ('bird', 'bus', 'cow', 'motorbike', 'sofa'),
2: ('aeroplane', 'bottle', 'cow', 'horse', 'sofa'),
3: ('boat', 'cat', 'motorbike', 'sheep', 'sofa'),
}
BASE_CLASSES = {
1: ('aeroplane', 'bicycle', 'boat', 'bottle', 'car', 'cat', 'chair',
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'sheep',
'train', 'tvmonitor'),
2: ('bicycle', 'bird', 'boat', 'bus', 'car', 'cat', 'chair', 'diningtable',
'dog', 'motorbike', 'person', 'pottedplant', 'sheep', 'train',
'tvmonitor'),
3: ('aeroplane', 'bicycle', 'bird', 'bottle', 'bus', 'car', 'chair', 'cow',
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'train',
'tvmonitor')
}
# dataset settings
data_root = 'data/VOCdevkit/'
# few shot setting
split = 1
base_classes = BASE_CLASSES[split]
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=[(1333, 480), (1333, 800)], keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type='RepeatDataset',
times=3,
dataset=dict(
type='VOCDataset',
ann_file=[
data_root + 'VOC2007/ImageSets/Main/trainval.txt',
data_root + 'VOC2012/ImageSets/Main/trainval.txt'
],
img_prefix=[data_root + 'VOC2007/', data_root + 'VOC2012/'],
pipeline=train_pipeline,
classes=base_classes)),
val=dict(
type='VOCDataset',
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
img_prefix=data_root + 'VOC2007/',
pipeline=test_pipeline,
classes=base_classes),
test=dict(
type='VOCDataset',
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
img_prefix=data_root + 'VOC2007/',
pipeline=test_pipeline,
classes=base_classes))
evaluation = dict(interval=1, metric='mAP')

View File

@ -0,0 +1,128 @@
ALL_CLASSES = {
1: ('aeroplane', 'bicycle', 'boat', 'bottle', 'car', 'cat', 'chair',
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'sheep',
'train', 'tvmonitor', 'bird', 'bus', 'cow', 'motorbike', 'sofa'),
2: ('bicycle', 'bird', 'boat', 'bus', 'car', 'cat', 'chair', 'diningtable',
'dog', 'motorbike', 'person', 'pottedplant', 'sheep', 'train',
'tvmonitor', 'aeroplane', 'bottle', 'cow', 'horse', 'sofa'),
3: ('aeroplane', 'bicycle', 'bird', 'bottle', 'bus', 'car', 'chair', 'cow',
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'train',
'tvmonitor', 'boat', 'cat', 'motorbike', 'sheep', 'sofa')
}
NOVEL_CLASSES = {
1: ('bird', 'bus', 'cow', 'motorbike', 'sofa'),
2: ('aeroplane', 'bottle', 'cow', 'horse', 'sofa'),
3: ('boat', 'cat', 'motorbike', 'sheep', 'sofa'),
}
BASE_CLASSES = {
1: ('aeroplane', 'bicycle', 'boat', 'bottle', 'car', 'cat', 'chair',
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'sheep',
'train', 'tvmonitor'),
2: ('bicycle', 'bird', 'boat', 'bus', 'car', 'cat', 'chair', 'diningtable',
'dog', 'motorbike', 'person', 'pottedplant', 'sheep', 'train',
'tvmonitor'),
3: ('aeroplane', 'bicycle', 'bird', 'bottle', 'bus', 'car', 'chair', 'cow',
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'train',
'tvmonitor')
}
# dataset settings
data_root = 'data/VOCdevkit/'
split = 1
all_classes = ALL_CLASSES[split]
base_classes = BASE_CLASSES[split]
novel_classes = NOVEL_CLASSES[split]
num_base_shot = 1
num_novel_shot = 1
# load few shot data :
# each ann file corresponding to one class
# all file should use same image prefix
ann_file_root = 'data/few_shot_voc_split/'
ann_file_per_class = [] # file path
img_prefix_per_class = [] # image prefix
ann_shot_filter_per_class = [] # ann filter for each ann file
for class_name in base_classes:
ann_file_per_class.append(
ann_file_root +
f'{num_base_shot}shot/box_{num_base_shot}shot_{class_name}_train.txt')
img_prefix_per_class.append(data_root)
ann_shot_filter_per_class.append({class_name: num_base_shot})
for class_name in novel_classes:
ann_file_per_class.append(
ann_file_root +
f'{num_novel_shot}shot/box_{num_novel_shot}shot_{class_name}_train.txt'
)
img_prefix_per_class.append(data_root)
ann_shot_filter_per_class.append({class_name: num_novel_shot})
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = dict(
query=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1000, 600), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
],
support=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1000, 600), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
])
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1000, 600),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type='NwayKshotDataset',
support_way=20,
support_shot=1,
dataset=dict(
type='FewShotVOCDataset',
ann_file=ann_file_per_class,
img_prefix=img_prefix_per_class,
ann_masks=ann_shot_filter_per_class,
pipeline=train_pipeline,
classes=all_classes,
merge_dataset=True)),
val=dict(
type='VOCDataset',
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
img_prefix=data_root + 'VOC2007/',
pipeline=test_pipeline,
classes=novel_classes),
test=dict(
type='VOCDataset',
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
img_prefix=data_root + 'VOC2007/',
pipeline=test_pipeline,
classes=novel_classes))
evaluation = dict(interval=10, metric='mAP')

View File

@ -0,0 +1,106 @@
ALL_CLASSES = {
1: ('aeroplane', 'bicycle', 'boat', 'bottle', 'car', 'cat', 'chair',
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'sheep',
'train', 'tvmonitor', 'bird', 'bus', 'cow', 'motorbike', 'sofa'),
2: ('bicycle', 'bird', 'boat', 'bus', 'car', 'cat', 'chair', 'diningtable',
'dog', 'motorbike', 'person', 'pottedplant', 'sheep', 'train',
'tvmonitor', 'aeroplane', 'bottle', 'cow', 'horse', 'sofa'),
3: ('aeroplane', 'bicycle', 'bird', 'bottle', 'bus', 'car', 'chair', 'cow',
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'train',
'tvmonitor', 'boat', 'cat', 'motorbike', 'sheep', 'sofa')
}
NOVEL_CLASSES = {
1: ('bird', 'bus', 'cow', 'motorbike', 'sofa'),
2: ('aeroplane', 'bottle', 'cow', 'horse', 'sofa'),
3: ('boat', 'cat', 'motorbike', 'sheep', 'sofa'),
}
BASE_CLASSES = {
1: ('aeroplane', 'bicycle', 'boat', 'bottle', 'car', 'cat', 'chair',
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'sheep',
'train', 'tvmonitor'),
2: ('bicycle', 'bird', 'boat', 'bus', 'car', 'cat', 'chair', 'diningtable',
'dog', 'motorbike', 'person', 'pottedplant', 'sheep', 'train',
'tvmonitor'),
3: ('aeroplane', 'bicycle', 'bird', 'bottle', 'bus', 'car', 'chair', 'cow',
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'train',
'tvmonitor')
}
# dataset settings
data_root = 'data/VOCdevkit/'
# few shot setting
split = 1
base_classes = BASE_CLASSES[split]
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = dict(
query=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1000, 600), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
],
support=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1000, 600), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
])
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1000, 600),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type='NwayKshotDataset',
support_way=15,
support_shot=1,
dataset=dict(
type='FewShotVOCDataset',
ann_file=[
data_root + 'VOC2007/ImageSets/Main/trainval.txt',
data_root + 'VOC2012/ImageSets/Main/trainval.txt'
],
img_prefix=[data_root, data_root],
pipeline=train_pipeline,
classes=base_classes,
merge_dataset=True,
)),
val=dict(
type='VOCDataset',
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
img_prefix=data_root + 'VOC2007/',
pipeline=test_pipeline,
classes=base_classes),
test=dict(
type='VOCDataset',
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
img_prefix=data_root + 'VOC2007/',
pipeline=test_pipeline,
classes=base_classes))
evaluation = dict(interval=1, metric='mAP')

View File

@ -0,0 +1,128 @@
ALL_CLASSES = {
1: ('aeroplane', 'bicycle', 'boat', 'bottle', 'car', 'cat', 'chair',
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'sheep',
'train', 'tvmonitor', 'bird', 'bus', 'cow', 'motorbike', 'sofa'),
2: ('bicycle', 'bird', 'boat', 'bus', 'car', 'cat', 'chair', 'diningtable',
'dog', 'motorbike', 'person', 'pottedplant', 'sheep', 'train',
'tvmonitor', 'aeroplane', 'bottle', 'cow', 'horse', 'sofa'),
3: ('aeroplane', 'bicycle', 'bird', 'bottle', 'bus', 'car', 'chair', 'cow',
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'train',
'tvmonitor', 'boat', 'cat', 'motorbike', 'sheep', 'sofa')
}
NOVEL_CLASSES = {
1: ('bird', 'bus', 'cow', 'motorbike', 'sofa'),
2: ('aeroplane', 'bottle', 'cow', 'horse', 'sofa'),
3: ('boat', 'cat', 'motorbike', 'sheep', 'sofa'),
}
BASE_CLASSES = {
1: ('aeroplane', 'bicycle', 'boat', 'bottle', 'car', 'cat', 'chair',
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'sheep',
'train', 'tvmonitor'),
2: ('bicycle', 'bird', 'boat', 'bus', 'car', 'cat', 'chair', 'diningtable',
'dog', 'motorbike', 'person', 'pottedplant', 'sheep', 'train',
'tvmonitor'),
3: ('aeroplane', 'bicycle', 'bird', 'bottle', 'bus', 'car', 'chair', 'cow',
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'train',
'tvmonitor')
}
# dataset settings
data_root = 'data/VOCdevkit/'
split = 1
all_classes = ALL_CLASSES[split]
base_classes = BASE_CLASSES[split]
novel_classes = NOVEL_CLASSES[split]
num_base_shot = 1
num_novel_shot = 1
# load few shot data :
# each ann file corresponding to one class
# all file should use same image prefix
ann_file_root = 'data/few_shot_voc_split/'
ann_file_per_class = [] # file path
img_prefix_per_class = [] # image prefix
ann_shot_filter_per_class = [] # ann filter for each ann file
for class_name in base_classes:
ann_file_per_class.append(
ann_file_root +
f'{num_base_shot}shot/box_{num_base_shot}shot_{class_name}_train.txt')
img_prefix_per_class.append(data_root)
ann_shot_filter_per_class.append({class_name: num_base_shot})
for class_name in novel_classes:
ann_file_per_class.append(
ann_file_root +
f'{num_novel_shot}shot/box_{num_novel_shot}shot_{class_name}_train.txt'
)
img_prefix_per_class.append(data_root)
ann_shot_filter_per_class.append({class_name: num_novel_shot})
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = dict(
query=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1000, 600), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
],
support=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1000, 600), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
])
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1000, 600),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type='QueryAwareDataset',
support_way=2,
support_shot=1,
dataset=dict(
type='FewShotVOCDataset',
ann_file=ann_file_per_class,
img_prefix=img_prefix_per_class,
ann_masks=ann_shot_filter_per_class,
pipeline=train_pipeline,
classes=all_classes,
merge_dataset=True)),
val=dict(
type='VOCDataset',
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
img_prefix=data_root + 'VOC2007/',
pipeline=test_pipeline,
classes=novel_classes),
test=dict(
type='VOCDataset',
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
img_prefix=data_root + 'VOC2007/',
pipeline=test_pipeline,
classes=novel_classes))
evaluation = dict(interval=1, metric='mAP')

View File

@ -0,0 +1,164 @@
VOC_FEW_SHOT_SPLIT_ALL_CLASSES = {
1: ('aeroplane', 'bicycle', 'boat', 'bottle', 'car', 'cat', 'chair',
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'sheep',
'train', 'tvmonitor', 'bird', 'bus', 'cow', 'motorbike', 'sofa'),
2: ('bicycle', 'bird', 'boat', 'bus', 'car', 'cat', 'chair', 'diningtable',
'dog', 'motorbike', 'person', 'pottedplant', 'sheep', 'train',
'tvmonitor', 'aeroplane', 'bottle', 'cow', 'horse', 'sofa'),
3: (
'aeroplane',
'bicycle',
'bird',
'bottle',
'bus',
'car',
'chair',
'cow',
'diningtable',
'dog',
'horse',
'person',
'pottedplant',
'train',
'tvmonitor',
'boat',
'cat',
'motorbike',
'sheep',
'sofa',
),
}
VOC_FEW_SHOT_SPLIT_NOVEL_CLASSES = {
1: ('bird', 'bus', 'cow', 'motorbike', 'sofa'),
2: ('aeroplane', 'bottle', 'cow', 'horse', 'sofa'),
3: ('boat', 'cat', 'motorbike', 'sheep', 'sofa'),
}
VOC_FEW_SHOT_SPLIT_BASE_CLASSES = {
1: (
'aeroplane',
'bicycle',
'boat',
'bottle',
'car',
'cat',
'chair',
'diningtable',
'dog',
'horse',
'person',
'pottedplant',
'sheep',
'train',
'tvmonitor',
),
2: (
'bicycle',
'bird',
'boat',
'bus',
'car',
'cat',
'chair',
'diningtable',
'dog',
'motorbike',
'person',
'pottedplant',
'sheep',
'train',
'tvmonitor',
),
3: (
'aeroplane',
'bicycle',
'bird',
'bottle',
'bus',
'car',
'chair',
'cow',
'diningtable',
'dog',
'horse',
'person',
'pottedplant',
'train',
'tvmonitor',
),
}
# dataset settings
data_root = 'data/VOCdevkit/'
# few shot setting
split = 1
base_classes = VOC_FEW_SHOT_SPLIT_BASE_CLASSES[split]
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = dict(
query=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1000, 600), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
],
support=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1000, 600), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
])
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1000, 600),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type='QueryAwareDataset',
support_way=2,
support_shot=5,
dataset=dict(
type='FewShotVOCDataset',
ann_file=[
data_root + 'VOC2007/ImageSets/Main/trainval.txt',
data_root + 'VOC2012/ImageSets/Main/trainval.txt'
],
img_prefix=[data_root, data_root],
pipeline=train_pipeline,
classes=base_classes,
merge_dataset=True)),
val=dict(
type='VOCDataset',
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
img_prefix=data_root + 'VOC2007/',
pipeline=test_pipeline,
classes=base_classes),
test=dict(
type='VOCDataset',
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
img_prefix=data_root + 'VOC2007/',
pipeline=test_pipeline,
classes=base_classes))
evaluation = dict(interval=1, metric='mAP')

View File

@ -1,55 +0,0 @@
# dataset settings
dataset_type = 'VOCDataset'
data_root = 'data/VOCdevkit/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1000, 600), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1000, 600),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type='RepeatDataset',
times=3,
dataset=dict(
type=dataset_type,
ann_file=[
data_root + 'VOC2007/ImageSets/Main/trainval.txt',
data_root + 'VOC2012/ImageSets/Main/trainval.txt'
],
img_prefix=[data_root + 'VOC2007/', data_root + 'VOC2012/'],
pipeline=train_pipeline)),
val=dict(
type=dataset_type,
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
img_prefix=data_root + 'VOC2007/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
img_prefix=data_root + 'VOC2007/',
pipeline=test_pipeline))
evaluation = dict(interval=1, metric='mAP')

View File

@ -3,7 +3,8 @@ import random
import numpy as np
import torch
from mmcls.apis.train import train_model as train_classifier
from mmdet.apis.train import train_detector
from mmfewshot.detection.apis.train import train_detector
def set_random_seed(seed, deterministic=False):

View File

@ -1,7 +1,11 @@
# this file only for unittests
from mmcls.datasets.builder import build_dataloader as build_cls_dataloader
from mmcls.datasets.builder import build_dataset as build_cls_dataset
from mmdet.datasets.builder import build_dataloader as build_det_dataloader
from mmdet.datasets.builder import build_dataset as build_det_dataset
from mmfewshot.detection.datasets.builder import \
build_dataloader as build_det_dataloader
from mmfewshot.detection.datasets.builder import \
build_dataset as build_det_dataset
def build_dataloader(dataset=None, task_type='mmdet', round_up=True, **kwargs):

View File

@ -1,3 +1,4 @@
# this file only for unittests
from mmcls.models.builder import build_classifier as build_cls_model
from mmdet.models.builder import build_detector as build_det_model

View File

@ -0,0 +1,3 @@
from .train import get_root_logger, set_random_seed, train_detector
__all__ = ['get_root_logger', 'set_random_seed', 'train_detector']

View File

@ -0,0 +1,170 @@
import random
import warnings
import numpy as np
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
Fp16OptimizerHook, OptimizerHook, build_optimizer,
build_runner)
from mmcv.utils import build_from_cfg
from mmdet.core import DistEvalHook, EvalHook
from mmdet.datasets import replace_ImageToTensor
from mmdet.utils import get_root_logger
from mmfewshot.detection.datasets import build_dataloader, build_dataset
def set_random_seed(seed, deterministic=False):
"""Set random seed.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def train_detector(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
logger = get_root_logger(cfg.log_level)
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
if 'imgs_per_gpu' in cfg.data:
logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. '
'Please use "samples_per_gpu" instead')
if 'samples_per_gpu' in cfg.data:
logger.warning(
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
f'={cfg.data.imgs_per_gpu} is used in this experiments')
else:
logger.warning(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{cfg.data.imgs_per_gpu} in this experiments')
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
# cfg.gpus will be ignored if distributed
len(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed) for ds in dataset
]
# put model on gpus
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = MMDataParallel(
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
if 'runner' not in cfg:
cfg.runner = {
'type': 'EpochBasedRunner',
'max_epochs': cfg.total_epochs
}
warnings.warn(
'config is now expected to have a `runner` section, '
'please set `runner` in your config.', UserWarning)
else:
if 'total_epochs' in cfg:
assert cfg.total_epochs == cfg.runner.max_epochs
runner = build_runner(
cfg.runner,
default_args=dict(
model=model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta))
# an ugly workaround to make .log and .log.json filenames the same
runner.timestamp = timestamp
# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
optimizer_config = Fp16OptimizerHook(
**cfg.optimizer_config, **fp16_cfg, distributed=distributed)
elif distributed and 'type' not in cfg.optimizer_config:
optimizer_config = OptimizerHook(**cfg.optimizer_config)
else:
optimizer_config = cfg.optimizer_config
# register hooks
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))
if distributed:
if isinstance(runner, EpochBasedRunner):
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
if validate:
# Support batch_size > 1 in validation
val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1)
if val_samples_per_gpu > 1:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg.data.val.pipeline = replace_ImageToTensor(
cfg.data.val.pipeline)
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
val_dataloader = build_dataloader(
val_dataset,
samples_per_gpu=val_samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
eval_cfg = cfg.get('evaluation', {})
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
# user-defined hooks
if cfg.get('custom_hooks', None):
custom_hooks = cfg.custom_hooks
assert isinstance(custom_hooks, list), \
f'custom_hooks expect list type, but got {type(custom_hooks)}'
for hook_cfg in cfg.custom_hooks:
assert isinstance(hook_cfg, dict), \
'Each item in custom_hooks expects dict type, but got ' \
f'{type(hook_cfg)}'
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
hook = build_from_cfg(hook_cfg, HOOKS)
runner.register_hook(hook, priority=priority)
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow)

View File

@ -1,3 +1,18 @@
from .base_meta_learning_dataset import BaseMetaLearingDataset
from .builder import build_dataloader, build_dataset
from .dataloader_wrappers import NwayKshotDataloader
from .dataset_wrappers import MergeDataset, NwayKshotDataset, QueryAwareDataset
from .few_shot_custom import FewShotCustomDataset
from .utils import query_support_collate_fn
from .voc import FewShotVOCDataset
__all__ = ['BaseMetaLearingDataset']
__all__ = [
'build_dataloader',
'build_dataset',
'MergeDataset',
'QueryAwareDataset',
'NwayKshotDataset',
'NwayKshotDataloader',
'query_support_collate_fn',
'FewShotCustomDataset',
'FewShotVOCDataset',
]

View File

@ -1,8 +0,0 @@
# jsut an example
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.custom import CustomDataset
@DATASETS.register_module()
class BaseMetaLearingDataset(CustomDataset):
pass

View File

@ -0,0 +1,226 @@
import copy
from functools import partial
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import build_from_cfg
from mmdet.datasets.builder import DATASETS, worker_init_fn
from mmdet.datasets.dataset_wrappers import (ClassBalancedDataset,
ConcatDataset, RepeatDataset)
from mmdet.datasets.samplers import (DistributedGroupSampler,
DistributedSampler, GroupSampler)
from torch.utils.data import DataLoader
from .dataset_wrappers import MergeDataset, NwayKshotDataset, QueryAwareDataset
def _concat_dataset(cfg, default_args=None):
ann_files = cfg['ann_file']
img_prefixes = cfg.get('img_prefix', None)
seg_prefixes = cfg.get('seg_prefix', None)
proposal_files = cfg.get('proposal_file', None)
separate_eval = cfg.get('separate_eval', True)
merge_dataset = cfg.get('merge_dataset', False)
ann_shot_filter = cfg.get('ann_shot_filter', None)
if ann_shot_filter is not None:
assert merge_dataset, 'using ann shot filter to load ann file ' \
'in FewShotDataset, merge_dataset should be set to True.'
datasets = []
num_dset = len(ann_files)
for i in range(num_dset):
data_cfg = copy.deepcopy(cfg)
# pop 'separate_eval' since it is not a valid key for common datasets.
if 'separate_eval' in data_cfg:
data_cfg.pop('separate_eval')
if 'merge_dataset' in data_cfg:
data_cfg.pop('merge_dataset')
data_cfg['ann_file'] = ann_files[i]
if isinstance(img_prefixes, (list, tuple)):
data_cfg['img_prefix'] = img_prefixes[i]
if isinstance(seg_prefixes, (list, tuple)):
data_cfg['seg_prefix'] = seg_prefixes[i]
if isinstance(proposal_files, (list, tuple)):
data_cfg['proposal_file'] = proposal_files[i]
if isinstance(ann_shot_filter, (list, tuple)):
data_cfg['ann_shot_filter'] = ann_shot_filter[i]
datasets.append(build_dataset(data_cfg, default_args))
if merge_dataset:
return MergeDataset(datasets)
else:
return ConcatDataset(datasets, separate_eval)
def build_dataset(cfg, default_args=None):
if isinstance(cfg, (list, tuple)):
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
elif cfg['type'] == 'ConcatDataset':
dataset = ConcatDataset(
[build_dataset(c, default_args) for c in cfg['datasets']],
cfg.get('separate_eval', True))
elif cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(
build_dataset(cfg['dataset'], default_args), cfg['times'])
elif cfg['type'] == 'ClassBalancedDataset':
dataset = ClassBalancedDataset(
build_dataset(cfg['dataset'], default_args), cfg['oversample_thr'])
elif cfg['type'] == 'QueryAwareDataset':
dataset = QueryAwareDataset(
build_dataset(cfg['dataset'], default_args), cfg['support_way'],
cfg['support_shot'])
elif cfg['type'] == 'NwayKshotDataset':
dataset = NwayKshotDataset(
build_dataset(cfg['dataset'], default_args), cfg['support_way'],
cfg['support_shot'])
elif isinstance(cfg.get('ann_file'), (list, tuple)):
dataset = _concat_dataset(cfg, default_args)
else:
dataset = build_from_cfg(cfg, DATASETS, default_args)
return dataset
def build_dataloader(dataset,
samples_per_gpu,
workers_per_gpu,
num_gpus=1,
dist=True,
shuffle=True,
seed=None,
**kwargs):
"""Build PyTorch DataLoader.
In distributed training, each GPU/process has a dataloader.
In non-distributed training, there is only one dataloader for all GPUs.
Args:
dataset (Dataset): A PyTorch dataset.
samples_per_gpu (int): Number of training samples on each GPU, i.e.,
batch size of each GPU.
workers_per_gpu (int): How many subprocesses to use for data loading
for each GPU.
num_gpus (int): Number of GPUs. Only used in non-distributed training.
dist (bool): Distributed training/test or not. Default: True.
shuffle (bool): Whether to shuffle the data at every epoch.
Default: True.
seed (int): Random seed.
kwargs: any keyword argument to be used to initialize DataLoader
Returns:
DataLoader: A PyTorch dataloader.
"""
rank, world_size = get_dist_info()
(sampler, batch_size, num_workers) \
= build_sampler(dist=dist,
shuffle=shuffle,
dataset=dataset,
num_gpus=num_gpus,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=workers_per_gpu,
seed=seed, )
init_fn = partial(
worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None
if isinstance(dataset, QueryAwareDataset):
from .utils import query_support_collate_fn
data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(
query_support_collate_fn, samples_per_gpu=samples_per_gpu),
pin_memory=False,
worker_init_fn=init_fn,
**kwargs)
elif isinstance(dataset, NwayKshotDataset):
from .dataloader_wrappers import NwayKshotDataloader
from .utils import query_support_collate_fn
# init query dataloader
query_data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(
query_support_collate_fn, samples_per_gpu=samples_per_gpu),
pin_memory=False,
worker_init_fn=init_fn,
**kwargs)
# creat support dataset from query dataset and
# pre sample batch index with same length as query dataloader
support_dataset = copy.deepcopy(dataset)
support_dataset.convert_query_to_support(len(query_data_loader))
(support_sampler, _, _) \
= build_sampler(dist=dist,
shuffle=shuffle,
dataset=support_dataset,
num_gpus=num_gpus,
samples_per_gpu=1,
workers_per_gpu=workers_per_gpu,
seed=seed,
)
data_loader = NwayKshotDataloader(
query_data_loader=query_data_loader,
support_dataset=support_dataset,
support_sampler=support_sampler,
num_workers=num_workers,
support_collate_fn=partial(
query_support_collate_fn, samples_per_gpu=1),
pin_memory=False,
worker_init_fn=init_fn,
**kwargs)
else:
data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
pin_memory=False,
worker_init_fn=init_fn,
**kwargs)
return data_loader
def build_sampler(dist, shuffle, dataset, num_gpus, samples_per_gpu,
workers_per_gpu, seed):
"""Build pytorch sampler for dataLoader.
Args:
dist (bool): Distributed training/test or not.
shuffle (bool): Whether to shuffle the data at every epoch.
dataset (Dataset): A PyTorch dataset.
num_gpus (int): Number of GPUs. Only used in non-distributed training.
samples_per_gpu (int): Number of training samples on each GPU, i.e.,
batch size of each GPU.
workers_per_gpu (int): How many subprocesses to use for data loading
for each GPU.
seed (int): Random seed.
Returns:
"""
rank, world_size = get_dist_info()
if dist:
# DistributedGroupSampler will definitely shuffle the data to satisfy
# that images on each GPU are in the same group
if shuffle:
sampler = DistributedGroupSampler(
dataset, samples_per_gpu, world_size, rank, seed=seed)
else:
sampler = DistributedSampler(
dataset, world_size, rank, shuffle=False, seed=seed)
batch_size = samples_per_gpu
num_workers = workers_per_gpu
else:
sampler = GroupSampler(dataset, samples_per_gpu) if shuffle else None
batch_size = num_gpus * samples_per_gpu
num_workers = num_gpus * workers_per_gpu
return sampler, batch_size, num_workers

View File

@ -0,0 +1,536 @@
import itertools
import logging
import os.path as osp
import tempfile
import warnings
from collections import OrderedDict
import mmcv
import numpy as np
from mmcv.utils import print_log
from mmdet.core import eval_recalls
from mmdet.datasets.api_wrappers import COCO, COCOeval
from mmdet.datasets.builder import DATASETS
from terminaltables import AsciiTable
from .few_shot_custom import FewShotCustomDataset
@DATASETS.register_module()
class FewShotCocoDataset(FewShotCustomDataset):
def __init__(self, **kwargs):
assert self.CLASSES or kwargs.get('classes', None),\
'CLASSES in `FewShotCocoDataset` can not be None.'
super(FewShotCocoDataset, self).__init__(**kwargs)
def load_annotations(self, ann_file):
"""Load annotation from COCO style annotation file.
Args:
ann_file (str): Path of annotation file.
Returns:
list[dict]: Annotation info from COCO api.
"""
self.coco = COCO(ann_file)
self.cat_ids = []
self.cat2label = {}
# to keep the label order equal to the order in CLASSES
for i, class_name in enumerate(self.CLASSES):
cat_id = self.coco.get_cat_ids(cat_names=[class_name])[0]
self.cat_ids.append(cat_id)
self.cat2label[cat_id] = i
self.img_ids = self.coco.get_img_ids()
data_infos = []
total_ann_ids = []
for i in self.img_ids:
info = self.coco.load_imgs([i])[0]
info['filename'] = info['file_name']
info['ann'] = self._get_ann_info(info)
data_infos.append(info)
ann_ids = self.coco.get_ann_ids(img_ids=[i])
total_ann_ids.extend(ann_ids)
assert len(set(total_ann_ids)) == len(
total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!"
return data_infos
def _get_ann_info(self, data_info):
"""Get COCO annotation by index.
Args:
data_info dict: Data info.
Returns:
dict: Annotation info of specified index.
"""
img_id = data_info['id']
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
ann_info = self.coco.load_anns(ann_ids)
return self._parse_ann_info(data_info, ann_info)
def _filter_imgs(self, min_size=32):
"""Filter images too small or without ground truths."""
valid_inds = []
# obtain images that contain annotation
ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
# obtain images that contain annotations of the required categories
ids_in_cat = set()
for i, class_id in enumerate(self.cat_ids):
ids_in_cat |= set(self.coco.cat_img_map[class_id])
# merge the image id sets of the two conditions and use the merged set
# to filter out images if self.filter_empty_gt=True
ids_in_cat &= ids_with_ann
valid_img_ids = []
for i, img_info in enumerate(self.data_infos):
img_id = self.img_ids[i]
if self.filter_empty_gt and img_id not in ids_in_cat:
continue
if min(img_info['width'], img_info['height']) >= min_size:
valid_inds.append(i)
valid_img_ids.append(img_id)
self.img_ids = valid_img_ids
return valid_inds
def _parse_ann_info(self, img_info, ann_info):
"""Parse bbox and mask annotation.
Args:
img_info (dict): Image info.
ann_info (list[dict]): Annotation info of an image.
Returns:
dict: A dict containing the following keys: bboxes, bboxes_ignore,\
labels, masks, seg_map. "masks" are raw annotations and not \
decoded into binary masks.
"""
gt_bboxes = []
gt_labels = []
gt_bboxes_ignore = []
gt_masks_ann = []
for i, ann in enumerate(ann_info):
if ann.get('ignore', False):
continue
x1, y1, w, h = ann['bbox']
inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
if inter_w * inter_h == 0:
continue
if ann['area'] <= 0 or w < 1 or h < 1:
continue
if ann['category_id'] not in self.cat_ids:
continue
bbox = [x1, y1, x1 + w, y1 + h]
if ann.get('iscrowd', False):
gt_bboxes_ignore.append(bbox)
else:
gt_bboxes.append(bbox)
gt_labels.append(self.cat2label[ann['category_id']])
gt_masks_ann.append(ann.get('segmentation', None))
if gt_bboxes:
gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
gt_labels = np.array(gt_labels, dtype=np.int64)
else:
gt_bboxes = np.zeros((0, 4), dtype=np.float32)
gt_labels = np.array([], dtype=np.int64)
if gt_bboxes_ignore:
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
else:
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
seg_map = img_info['filename'].replace('jpg', 'png')
ann = dict(
bboxes=gt_bboxes,
labels=gt_labels,
bboxes_ignore=gt_bboxes_ignore,
masks=gt_masks_ann,
seg_map=seg_map)
return ann
def xyxy2xywh(self, bbox):
"""Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO
evaluation.
Args:
bbox (numpy.ndarray): The bounding boxes, shape (4, ), in
``xyxy`` order.
Returns:
list[float]: The converted bounding boxes, in ``xywh`` order.
"""
_bbox = bbox.tolist()
return [
_bbox[0],
_bbox[1],
_bbox[2] - _bbox[0],
_bbox[3] - _bbox[1],
]
def _proposal2json(self, results):
"""Convert proposal results to COCO json style."""
json_results = []
for idx in range(len(self)):
img_id = self.img_ids[idx]
bboxes = results[idx]
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['bbox'] = self.xyxy2xywh(bboxes[i])
data['score'] = float(bboxes[i][4])
data['category_id'] = 1
json_results.append(data)
return json_results
def _det2json(self, results):
"""Convert detection results to COCO json style."""
json_results = []
for idx in range(len(self)):
img_id = self.img_ids[idx]
result = results[idx]
for label in range(len(result)):
bboxes = result[label]
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['bbox'] = self.xyxy2xywh(bboxes[i])
data['score'] = float(bboxes[i][4])
data['category_id'] = self.cat_ids[label]
json_results.append(data)
return json_results
def _segm2json(self, results):
"""Convert instance segmentation results to COCO json style."""
bbox_json_results = []
segm_json_results = []
for idx in range(len(self)):
img_id = self.img_ids[idx]
det, seg = results[idx]
for label in range(len(det)):
# bbox results
bboxes = det[label]
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['bbox'] = self.xyxy2xywh(bboxes[i])
data['score'] = float(bboxes[i][4])
data['category_id'] = self.cat_ids[label]
bbox_json_results.append(data)
# segm results
# some detectors use different scores for bbox and mask
if isinstance(seg, tuple):
segms = seg[0][label]
mask_score = seg[1][label]
else:
segms = seg[label]
mask_score = [bbox[4] for bbox in bboxes]
for i in range(bboxes.shape[0]):
data = dict()
data['image_id'] = img_id
data['bbox'] = self.xyxy2xywh(bboxes[i])
data['score'] = float(mask_score[i])
data['category_id'] = self.cat_ids[label]
if isinstance(segms[i]['counts'], bytes):
segms[i]['counts'] = segms[i]['counts'].decode()
data['segmentation'] = segms[i]
segm_json_results.append(data)
return bbox_json_results, segm_json_results
def results2json(self, results, outfile_prefix):
"""Dump the detection results to a COCO style json file.
There are 3 types of results: proposals, bbox predictions, mask
predictions, and they have different data types. This method will
automatically recognize the type, and dump them to json files.
Args:
results (list[list | tuple | ndarray]): Testing results of the
dataset.
outfile_prefix (str): The filename prefix of the json files. If the
prefix is "somepath/xxx", the json files will be named
"somepath/xxx.bbox.json", "somepath/xxx.segm.json",
"somepath/xxx.proposal.json".
Returns:
dict[str: str]: Possible keys are "bbox", "segm", "proposal", and \
values are corresponding filenames.
"""
result_files = dict()
if isinstance(results[0], list):
json_results = self._det2json(results)
result_files['bbox'] = f'{outfile_prefix}.bbox.json'
result_files['proposal'] = f'{outfile_prefix}.bbox.json'
mmcv.dump(json_results, result_files['bbox'])
elif isinstance(results[0], tuple):
json_results = self._segm2json(results)
result_files['bbox'] = f'{outfile_prefix}.bbox.json'
result_files['proposal'] = f'{outfile_prefix}.bbox.json'
result_files['segm'] = f'{outfile_prefix}.segm.json'
mmcv.dump(json_results[0], result_files['bbox'])
mmcv.dump(json_results[1], result_files['segm'])
elif isinstance(results[0], np.ndarray):
json_results = self._proposal2json(results)
result_files['proposal'] = f'{outfile_prefix}.proposal.json'
mmcv.dump(json_results, result_files['proposal'])
else:
raise TypeError('invalid type of results')
return result_files
def fast_eval_recall(self, results, proposal_nums, iou_thrs, logger=None):
gt_bboxes = []
for i in range(len(self.img_ids)):
ann_ids = self.coco.get_ann_ids(img_ids=self.img_ids[i])
ann_info = self.coco.load_anns(ann_ids)
if len(ann_info) == 0:
gt_bboxes.append(np.zeros((0, 4)))
continue
bboxes = []
for ann in ann_info:
if ann.get('ignore', False) or ann['iscrowd']:
continue
x1, y1, w, h = ann['bbox']
bboxes.append([x1, y1, x1 + w, y1 + h])
bboxes = np.array(bboxes, dtype=np.float32)
if bboxes.shape[0] == 0:
bboxes = np.zeros((0, 4))
gt_bboxes.append(bboxes)
recalls = eval_recalls(
gt_bboxes, results, proposal_nums, iou_thrs, logger=logger)
ar = recalls.mean(axis=1)
return ar
def format_results(self, results, jsonfile_prefix=None, **kwargs):
"""Format the results to json (standard format for COCO evaluation).
Args:
results (list[tuple | numpy.ndarray]): Testing results of the
dataset.
jsonfile_prefix (str | None): The prefix of json files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None.
Returns:
tuple: (result_files, tmp_dir), result_files is a dict containing \
the json filepaths, tmp_dir is the temporal directory created \
for saving json files when jsonfile_prefix is not specified.
"""
assert isinstance(results, list), 'results must be a list'
assert len(results) == len(self), (
'The length of results is not equal to the dataset len: {} != {}'.
format(len(results), len(self)))
if jsonfile_prefix is None:
tmp_dir = tempfile.TemporaryDirectory()
jsonfile_prefix = osp.join(tmp_dir.name, 'results')
else:
tmp_dir = None
result_files = self.results2json(results, jsonfile_prefix)
return result_files, tmp_dir
def evaluate(self,
results,
metric='bbox',
logger=None,
jsonfile_prefix=None,
classwise=False,
proposal_nums=(100, 300, 1000),
iou_thrs=None,
metric_items=None):
"""Evaluation in COCO protocol.
Args:
results (list[list | tuple]): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated. Options are
'bbox', 'segm', 'proposal', 'proposal_fast'.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
jsonfile_prefix (str | None): The prefix of json files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None.
classwise (bool): Whether to evaluating the AP for each class.
proposal_nums (Sequence[int]): Proposal number used for evaluating
recalls, such as recall@100, recall@1000.
Default: (100, 300, 1000).
iou_thrs (Sequence[float], optional): IoU threshold used for
evaluating recalls/mAPs. If set to a list, the average of all
IoUs will also be computed. If not specified, [0.50, 0.55,
0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used.
Default: None.
metric_items (list[str] | str, optional): Metric items that will
be returned. If not specified, ``['AR@100', 'AR@300',
'AR@1000', 'AR_s@1000', 'AR_m@1000', 'AR_l@1000' ]`` will be
used when ``metric=='proposal'``, ``['mAP', 'mAP_50', 'mAP_75',
'mAP_s', 'mAP_m', 'mAP_l']`` will be used when
``metric=='bbox' or metric=='segm'``.
Returns:
dict[str, float]: COCO style evaluation metric.
"""
metrics = metric if isinstance(metric, list) else [metric]
allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast']
for metric in metrics:
if metric not in allowed_metrics:
raise KeyError(f'metric {metric} is not supported')
if iou_thrs is None:
iou_thrs = np.linspace(
.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
if metric_items is not None:
if not isinstance(metric_items, list):
metric_items = [metric_items]
result_files, tmp_dir = self.format_results(results, jsonfile_prefix)
eval_results = OrderedDict()
cocoGt = self.coco
for metric in metrics:
msg = f'Evaluating {metric}...'
if logger is None:
msg = '\n' + msg
print_log(msg, logger=logger)
if metric == 'proposal_fast':
ar = self.fast_eval_recall(
results, proposal_nums, iou_thrs, logger='silent')
log_msg = []
for i, num in enumerate(proposal_nums):
eval_results[f'AR@{num}'] = ar[i]
log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}')
log_msg = ''.join(log_msg)
print_log(log_msg, logger=logger)
continue
iou_type = 'bbox' if metric == 'proposal' else metric
if metric not in result_files:
raise KeyError(f'{metric} is not in results')
try:
predictions = mmcv.load(result_files[metric])
if iou_type == 'segm':
# Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa
# When evaluating mask AP, if the results contain bbox,
# cocoapi will use the box area instead of the mask area
# for calculating the instance area. Though the overall AP
# is not affected, this leads to different
# small/medium/large mask AP results.
for x in predictions:
x.pop('bbox')
warnings.simplefilter('once')
warnings.warn(
'The key "bbox" is deleted for more accurate mask AP '
'of small/medium/large instances since v2.12.0. This '
'does not change the overall mAP calculation.',
UserWarning)
cocoDt = cocoGt.loadRes(predictions)
except IndexError:
print_log(
'The testing results of the whole dataset is empty.',
logger=logger,
level=logging.ERROR)
break
cocoEval = COCOeval(cocoGt, cocoDt, iou_type)
cocoEval.params.catIds = self.cat_ids
cocoEval.params.imgIds = self.img_ids
cocoEval.params.maxDets = list(proposal_nums)
cocoEval.params.iouThrs = iou_thrs
# mapping of cocoEval.stats
coco_metric_names = {
'mAP': 0,
'mAP_50': 1,
'mAP_75': 2,
'mAP_s': 3,
'mAP_m': 4,
'mAP_l': 5,
'AR@100': 6,
'AR@300': 7,
'AR@1000': 8,
'AR_s@1000': 9,
'AR_m@1000': 10,
'AR_l@1000': 11
}
if metric_items is not None:
for metric_item in metric_items:
if metric_item not in coco_metric_names:
raise KeyError(
f'metric item {metric_item} is not supported')
if metric == 'proposal':
cocoEval.params.useCats = 0
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
if metric_items is None:
metric_items = [
'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000',
'AR_m@1000', 'AR_l@1000'
]
for item in metric_items:
val = float(
f'{cocoEval.stats[coco_metric_names[item]]:.3f}')
eval_results[item] = val
else:
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
if classwise: # Compute per-category AP
# Compute per-category AP
# from https://github.com/facebookresearch/detectron2/
precisions = cocoEval.eval['precision']
# precision: (iou, recall, cls, area range, max dets)
assert len(self.cat_ids) == precisions.shape[2]
results_per_category = []
for idx, catId in enumerate(self.cat_ids):
# area range index 0: all area ranges
# max dets index -1: typically 100 per image
nm = self.coco.loadCats(catId)[0]
precision = precisions[:, :, idx, 0, -1]
precision = precision[precision > -1]
if precision.size:
ap = np.mean(precision)
else:
ap = float('nan')
results_per_category.append(
(f'{nm["name"]}', f'{float(ap):0.3f}'))
num_columns = min(6, len(results_per_category) * 2)
results_flatten = list(
itertools.chain(*results_per_category))
headers = ['category', 'AP'] * (num_columns // 2)
results_2d = itertools.zip_longest(*[
results_flatten[i::num_columns]
for i in range(num_columns)
])
table_data = [headers]
table_data += [result for result in results_2d]
table = AsciiTable(table_data)
print_log('\n' + table.table, logger=logger)
if metric_items is None:
metric_items = [
'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'
]
for metric_item in metric_items:
key = f'{metric}_{metric_item}'
val = float(
f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}'
)
eval_results[key] = val
ap = cocoEval.stats[:6]
eval_results[f'{metric}_mAP_copypaste'] = (
f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '
f'{ap[4]:.3f} {ap[5]:.3f}')
if tmp_dir is not None:
tmp_dir.cleanup()
return eval_results

View File

@ -0,0 +1,64 @@
from torch.utils.data import DataLoader
class NwayKshotDataloader(object):
"""A dataloader wrapper of NwayKshotDataset dataset. Create a iterator to
generate query and support batch simultaneously. Each batch return a batch
of query data (batch_size) and support data.
(support_way * support_shot).
Args:
datasets (list[:obj:`NwayKshotDataset`]): A list of datasets.
batch_size (int): How many query samples per batch to load.
sampler (Sampler): Sampler for query dataloader only.
num_workers (int): Num workers for both support and query dataloader.
collate_fn (callable): Collate function for query dataloader.
pin_memory (bool): Pin memory for both support and query dataloader.
worker_init_fn (callable): Worker init function for both
support and query dataloader.
kwargs: any keyword argument to be used to initialize DataLoader.
"""
def __init__(self, query_data_loader, support_dataset, support_sampler,
num_workers, support_collate_fn, pin_memory, worker_init_fn,
**kwargs):
self.dataset = query_data_loader.dataset
self.query_data_loader = query_data_loader
self.support_dataset = support_dataset
self.support_sampler = support_sampler
self.num_workers = num_workers
self.support_collate_fn = support_collate_fn
self.pin_memory = pin_memory
self.worker_init_fn = worker_init_fn
self.kwargs = kwargs
def __iter__(self):
# generate different support batch index for each epoch
self.support_dataset.shuffle_support()
# init support dataloader with batch_size 1
# each batch are pre-sampler in dataset and use collate
# function to generate a batch with support_way*support_shot
self.support_data_loader = DataLoader(
self.support_dataset,
batch_size=1,
sampler=self.support_sampler,
num_workers=self.num_workers,
collate_fn=self.support_collate_fn,
pin_memory=self.pin_memory,
worker_init_fn=self.worker_init_fn,
**self.kwargs)
# init iterator for query and support
self.query_iter = iter(self.query_data_loader)
self.support_iter = iter(self.support_data_loader)
return self
def __next__(self):
# call query and support iterator
query_data = self.query_iter.next()
support_data = self.support_iter.next()
return {'query_data': query_data, 'support_data': support_data}
def __len__(self):
return len(self.query_data_loader)

View File

@ -0,0 +1,453 @@
import copy
import warnings
import numpy as np
from mmdet.datasets.builder import DATASETS
@DATASETS.register_module()
class MergeDataset(object):
"""A wrapper of merge dataset.
This dataset wrapper would be called when using multiple annotation
files for NwayKshotDataset, QueryAwareDataset, and FewShotCustomDataset.
It would merge the data info of input datasets, because different
annotations of same image will cross different datasets.
Args:
datasets (list[:obj:`Dataset`]): A list of datasets.
"""
def __init__(self, datasets):
self.dataset = copy.deepcopy(datasets[0])
self.CLASSES = self.dataset.CLASSES
for dataset in datasets:
assert dataset.img_prefix == self.dataset.img_prefix, \
'when using MergeDataset all img_prefix should be the same'
self.img_prefix = self.dataset.img_prefix
# merge datainfos for all datasets
concat_data_infos = sum([dataset.data_infos for dataset in datasets],
[])
merge_data_dict = {}
for i, data_info in enumerate(concat_data_infos):
if merge_data_dict.get(data_info['id'], None) is None:
merge_data_dict[data_info['id']] = data_info
else:
merge_data_dict[data_info['id']]['ann'] = \
self.merge_ann(merge_data_dict[data_info['id']]['ann'],
data_info['ann'])
self.dataset.data_infos = [
merge_data_dict[key] for key in merge_data_dict.keys()
]
# Disable the groupsampler, because in few shot setting,
# one group may only has two or three images.
if hasattr(datasets[0], 'flag'):
self.flag = np.zeros(len(self.dataset), dtype=np.uint8)
def get_cat_ids(self, idx):
"""Get category ids of merge dataset by index.
Args:
idx (int): Index of data.
Returns:
list[int]: All categories in the image of specified index.
"""
return self.dataset.get_cat_ids(idx)
def prepare_train_img(self, idx, pipeline_key=None, gt_idx=None):
"""Get training data and annotations after pipeline.
Args:
idx (int): Index of data.
pipeline_key (str): Name of pipeline
gt_idx (list[int]): Index of used annotation.
Returns:
dict: Training data and annotation after pipeline with new keys \
introduced by pipeline.
"""
return self.dataset.prepare_train_img(idx, pipeline_key, gt_idx)
def get_ann_info(self, idx):
"""Get annotation by index.
Args:
idx (int): Index of data.
Returns:
dict: Annotation info of specified index.
"""
return self.dataset.get_ann_info(idx)
def __getitem__(self, idx):
return self.dataset[idx]
def __len__(self):
"""Dataset length after merge."""
return len(self.dataset)
def __repr__(self):
return self.dataset.__repr__()
def evaluate(self, results, logger=None, **kwargs):
"""Evaluate the results.
Args:
results (list[list | tuple]): Testing results of the dataset.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
Returns:
dict[str: float]: AP results of the total dataset or each separate
dataset if `self.separate_eval=True`.
"""
eval_results = self.dataset.evaluate(results, logger=logger, **kwargs)
return eval_results
@staticmethod
def merge_ann(ann_a, ann_b):
"""Merge two annotations.
Args:
ann_a (dict): Dict of annotation.
ann_b (dict): Dict of annotation.
Returns:
dict: Merged annotation.
"""
assert sorted(ann_a.keys()) == sorted(ann_b.keys()), \
'can not merge different type of annotations'
return {
'bboxes': np.concatenate((ann_a['bboxes'], ann_b['bboxes'])),
'labels': np.concatenate((ann_a['labels'], ann_b['labels'])),
'bboxes_ignore': ann_a['bboxes_ignore'],
'labels_ignore': ann_a['labels_ignore']
}
@DATASETS.register_module()
class QueryAwareDataset(object):
"""A wrapper of query aware dataset.
For each item in dataset, there will be one query image and
(num_support_way * num_support_shot) support images.
The support images are sampled according to the selected
query image and include positive class (random one class
in query image) and negative class (any classes not appear in
query image).
Args:
datasets (obj:`FewShotDataset`, `MergeDataset`):
The dataset to be wrapped.
num_support_way (int): The number of classes for support data,
the first one always be the positive class.
num_support_shot (int): The number of shot for each support class.
"""
def __init__(self, dataset, num_support_way, num_support_shot):
self.dataset = dataset
self.num_support_way = num_support_way
self.num_support_shot = num_support_shot
self.CLASSES = dataset.CLASSES
assert self.num_support_way <= len(self.CLASSES), \
'Please set the num_support_way smaller than the ' \
'number of classes.'
# build data index (idx, gt_idx) by class.
self.data_infos_by_class = {i: [] for i in range(len(self.CLASSES))}
# count max number of anns in one image for each class, which will
# decide whether sample repeated instance or not.
self.max_anns_per_image_by_class = [
0 for _ in range(len(self.CLASSES))
]
# count image for each class annotation when novel class only
# has one image, the positive support is allowed sampled from itself.
self.num_image_by_class = [0 for _ in range(len(self.CLASSES))]
for idx in range(len(self.dataset)):
labels = self.dataset.get_ann_info(idx)['labels']
class_count = [0 for _ in range(len(self.CLASSES))]
for gt_idx, gt in enumerate(labels):
self.data_infos_by_class[gt].append((idx, gt_idx))
class_count[gt] += 1
for i in range(len(self.CLASSES)):
# number of images for each class
if class_count[i] > 0:
self.num_image_by_class[i] += 1
# max number of one class annotations in one image
if class_count[i] > self.max_anns_per_image_by_class[i]:
self.max_anns_per_image_by_class[i] = class_count[i]
for i in range(len(self.CLASSES)):
assert len(self.data_infos_by_class[i]) > 0, \
f'Class {self.CLASSES[i]} has zero annotation'
if len(self.data_infos_by_class[i]) <= self.num_support_shot - \
self.max_anns_per_image_by_class[i]:
warnings.warn(
f'During training, instances of class {self.CLASSES[i]} '
f'may smaller than the number of support shots which '
f'causes some instance will be sampled multiple times')
if self.num_image_by_class[i] == 1:
warnings.warn(f'Class {self.CLASSES[i]} only have one '
f'image, query and support will sample '
f'from instance of same image')
# Disable the groupsampler, because in few shot setting,
# one group may only has two or three images.
if hasattr(dataset, 'flag'):
self.flag = np.zeros(len(self.dataset), dtype=np.uint8)
def __getitem__(self, idx):
# sample query data
try_time = 0
while True:
try_time += 1
cat_ids = self.dataset.get_cat_ids(idx)
# query image have too many classes, can not find enough
# negative support classes.
if len(self.CLASSES) - len(cat_ids) >= self.num_support_way - 1:
break
else:
idx = self._rand_another(idx)
assert try_time < 100, \
'Not enough negative support classes for query image,' \
' please try a smaller support way.'
query_class = np.random.choice(cat_ids)
query_gt_idx = [
i for i in range(len(cat_ids)) if cat_ids[i] == query_class
]
query_data = self.dataset.prepare_train_img(idx, 'query', query_gt_idx)
query_data['query_class'] = [query_class]
# sample negative support classes, which not appear in query image
support_class = [
i for i in range(len(self.CLASSES)) if i not in cat_ids
]
support_class = np.random.choice(
support_class,
min(self.num_support_way - 1, len(support_class)),
replace=False)
support_idxes = self.generate_support(idx, query_class, support_class)
support_data = [
self.dataset.prepare_train_img(idx, 'support', [gt_idx])
for (idx, gt_idx) in support_idxes
]
return {'query_data': query_data, 'support_data': support_data}
def __len__(self):
"""Length after repetition."""
return len(self.dataset)
def _rand_another(self, idx):
"""Get another random index from the same group as the given index."""
pool = np.where(self.flag == self.flag[idx])[0]
return np.random.choice(pool)
def generate_support(self, idx, query_class, support_classes):
"""Generate support indexes of query images.
Args:
idx (int): Index of query data.
query_class (int): Query class.
support_classes (list[int]): Classes of support data.
Returns:
list[(int, int)]: A batch (num_support_way * num_support_shot)
of support data (idx, gt_idx).
"""
support_idxes = []
if self.num_image_by_class[query_class] == 1:
# only have one image, instance will sample from same image
pos_support_idxes = self.sample_support_shots(
idx, query_class, allow_same_image=True)
else:
# instance will sample from different image from query image
pos_support_idxes = self.sample_support_shots(idx, query_class)
support_idxes.extend(pos_support_idxes)
for support_class in support_classes:
neg_support_idxes = self.sample_support_shots(idx, support_class)
support_idxes.extend(neg_support_idxes)
return support_idxes
def sample_support_shots(self, idx, class_id, allow_same_image=False):
"""Generate positive support indexes by class id.
Args:
idx (int): Index of query data.
class_id (int): Support class.
allow_same_image: Allow instance sampled from same image
as query image. Default: False.
Returns:
list[(int, int)]: Support data (num_support_shot)
of specific class.
"""
support_idxes = []
num_total_shot = len(self.data_infos_by_class[class_id])
num_ignore_shot = self.count_class_id(idx, class_id)
# set num_sample_shots for each time of sampling
if num_total_shot - num_ignore_shot < self.num_support_shot:
# if not have enough support data allow repeated data
num_sample_shots = num_total_shot
allow_repeat = True
else:
# if have enough support data not allow repeated data
num_sample_shots = self.num_support_shot
allow_repeat = False
while len(support_idxes) < self.num_support_shot:
selected_gt_idxes = np.random.choice(
num_total_shot, num_sample_shots, replace=False)
selected_gts = [
self.data_infos_by_class[class_id][selected_gt_idx]
for selected_gt_idx in selected_gt_idxes
]
for selected_gt in selected_gts:
# filter out query annotations
if selected_gt[0] == idx:
if not allow_same_image:
continue
if allow_repeat:
support_idxes.append(selected_gt)
elif selected_gt not in support_idxes:
support_idxes.append(selected_gt)
if len(support_idxes) == self.num_support_shot:
break
# update the number of data for next time sample
num_sample_shots = min(self.num_support_shot - len(support_idxes),
num_sample_shots)
return support_idxes
def count_class_id(self, idx, class_id):
"""Count number of instance of specific."""
cat_ids = self.dataset.get_cat_ids(idx)
cat_ids_of_class = [
i for i in range(len(cat_ids)) if cat_ids[i] == class_id
]
return len(cat_ids_of_class)
@DATASETS.register_module()
class NwayKshotDataset(object):
"""A dataset wrapper of NwayKshotDataset.
Based on incoming dataset, query dataset will sample batch data as
regular dataset, while support dataset will pre sample batch data
indexes. Each batch index contain (num_support_way * num_support_shot)
samples. The default format of NwayKshotDataset is query dataset and
the query dataset will convert into support dataset by using convert
function.
Args:
datasets (obj:`FewShotDataset`, `MergeDataset`):
The dataset to be wrapped.
num_support_way (int):
The number of classes in support data batch.
num_support_shot (int):
The number of shots for each class in support data batch.
"""
def __init__(self, dataset, num_support_way, num_support_shot):
self.dataset = dataset
self.CLASSES = dataset.CLASSES
# The data_type determinate the behavior of fetching data,
# the default data_type is 'query', which is the same as regular
# dataset. To convert the dataset into 'support' dataset, simply
# call the function convert_query_to_support().
self.data_type = 'query'
self.num_support_way = num_support_way
assert num_support_way <= len(self.CLASSES), \
'support way can not larger than the number of classes'
self.num_support_shot = num_support_shot
self.batch_index = []
self.data_infos_by_class = {i: [] for i in range(len(self.CLASSES))}
# Disable the groupsampler, because in few shot setting,
# one group may only has two or three images.
if hasattr(dataset, 'flag'):
self.flag = np.zeros(len(self.dataset), dtype=np.uint8)
def __getitem__(self, idx):
if self.data_type == 'query':
# loads one data in query pipeline
return self.dataset.prepare_train_img(idx, 'query')
elif self.data_type == 'support':
# loads one batch of data in support pipeline
b_idx = self.batch_index[idx]
batch_data = [
self.dataset.prepare_train_img(idx, 'support', [gt_idx])
for (idx, gt_idx) in b_idx
]
return batch_data
else:
raise ValueError('not support data type')
def __len__(self):
"""Length of dataset."""
if self.data_type == 'query':
return len(self.dataset)
elif self.data_type == 'support':
return len(self.batch_index)
else:
raise ValueError('not support data type')
def shuffle_support(self):
"""Generate new batch indexes."""
if self.data_type == 'query':
raise ValueError('not support data type')
self.batch_index = self.generate_batch_index(len(self.batch_index))
def convert_query_to_support(self, support_dataset_len):
"""Convert query dataset to support dataset.
Args:
support_dataset_len (int): Length of pre sample batch indexes.
"""
# create lookup table for annotations in same class
for idx in range(len(self.dataset)):
labels = self.dataset.get_ann_info(idx)['labels']
for gt_idx, gt in enumerate(labels):
self.data_infos_by_class[gt].append((idx, gt_idx))
# make sure all class index lists have enough
# instances (length > num_support_shot)
for i in range(len(self.CLASSES)):
num_gts = len(self.data_infos_by_class[i])
if num_gts < self.num_support_shot:
self.data_infos_by_class[i] = self.data_infos_by_class[i] * \
(self.num_support_shot // num_gts + 1)
self.batch_index = self.generate_batch_index(support_dataset_len)
self.data_type = 'support'
if hasattr(self, 'flag'):
self.flag = np.zeros(support_dataset_len, dtype=np.uint8)
def generate_batch_index(self, dataset_len):
"""Generate batch index [length of datasets * [support way * support shots]].
Args:
dataset_len: Length of pre sample batch indexes.
Returns:
List[List[(data_idx, gt_idx)]]: Pre sample batch indexes.
"""
total_batch_index = []
for _ in range(dataset_len):
batch_index = []
selected_classes = np.random.choice(
len(self.CLASSES), self.num_support_way, replace=False)
for cls in selected_classes:
num_gts = len(self.data_infos_by_class[cls])
selected_gts_idx = np.random.choice(
num_gts, self.num_support_shot, replace=False)
selected_gts = [
self.data_infos_by_class[cls][gt_idx]
for gt_idx in selected_gts_idx
]
batch_index.extend(selected_gts)
total_batch_index.append(batch_index)
return total_batch_index

View File

@ -0,0 +1,264 @@
import copy
import os.path as osp
import warnings
import numpy as np
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.custom import CustomDataset
from mmdet.datasets.pipelines import Compose
@DATASETS.register_module()
class FewShotCustomDataset(CustomDataset):
"""Custom dataset for few shot detection.
It allow single (normal dataset of fully supervised setting) or
two (query-support fashion) pipelines for data processing.
When annotation shots filter is used, it make sure accessible
annotations meet the few shot setting in exact number of instances.
The annotation format is shown as follows. The `ann` field
is optional for testing.
.. code-block:: none
[
{
'id': '0000001'
'filename': 'a.jpg',
'width': 1280,
'height': 720,
'ann': {
'bboxes': <np.ndarray> (n, 4) in (x1, y1, x2, y2) order.
'labels': <np.ndarray> (n, ),
'bboxes_ignore': <np.ndarray> (k, 4), (optional field)
'labels_ignore': <np.ndarray> (k, 4) (optional field)
}
},
...
]
Args:
ann_file (str): Annotation file path.
pipeline (list[dict] | dict): Processing pipeline
If is list[dict] all data will pass through this pipeline,
If is dict, query and support data will be processed with
two different pipelines and the dict should contain two keys:
- 'query': list[dict]
- 'support': list[dict]
classes (str | Sequence[str]): Classes for model training and
provide fixed label for each class.
data_root (str, optional): Data root for ``ann_file``,
``img_prefix``, ``seg_prefix``, ``proposal_file`` if specified.
test_mode (bool, optional): If set True, annotation will not be loaded.
filter_empty_gt (bool, optional): If set true, images without bounding
boxes of the dataset's classes will be filtered out. This option
only works when `test_mode=False`, i.e., we never filter images
during tests.
ann_shot_filter (dict, optional): If set None, all annotation from
ann file will be loaded. If not None, annotation shot filter will
specific which class and the maximum number of instances to load
from annotation file. For example: {'dog': 10, 'person': 5}.
Default: None.
"""
CLASSES = None
def __init__(
self,
ann_file,
pipeline,
classes,
data_root=None,
img_prefix='',
seg_prefix=None,
proposal_file=None,
test_mode=False,
filter_empty_gt=True,
ann_shot_filter=None,
):
self.ann_file = ann_file
self.data_root = data_root
self.img_prefix = img_prefix
self.seg_prefix = seg_prefix
self.proposal_file = proposal_file
self.test_mode = test_mode
self.filter_empty_gt = filter_empty_gt
self.CLASSES = self.get_classes(classes)
self.ann_shot_filter = ann_shot_filter
if self.ann_shot_filter is not None:
for class_name in list(self.ann_shot_filter.keys()):
assert class_name in self.CLASSES, \
f'class {class_name} from ' \
f'ann_shot_filter not in CLASSES, '
# join paths if data_root is specified
if self.data_root is not None:
if not osp.isabs(self.ann_file):
self.ann_file = osp.join(self.data_root, self.ann_file)
if not (self.img_prefix is None or osp.isabs(self.img_prefix)):
self.img_prefix = osp.join(self.data_root, self.img_prefix)
if not (self.seg_prefix is None or osp.isabs(self.seg_prefix)):
self.seg_prefix = osp.join(self.data_root, self.seg_prefix)
if not (self.proposal_file is None
or osp.isabs(self.proposal_file)):
self.proposal_file = osp.join(self.data_root,
self.proposal_file)
# load annotations (and proposals)
self.data_infos = self.load_annotations(self.ann_file)
# filter annotations according to ann_shot_filter
if self.ann_shot_filter is not None:
self.data_infos = self._filter_annotations(self.data_infos,
self.ann_shot_filter)
if self.proposal_file is not None:
self.proposals = self.load_proposals(self.proposal_file)
else:
self.proposals = None
# filter images too small and containing no annotations
if not test_mode:
valid_inds = self._filter_imgs()
self.data_infos = [self.data_infos[i] for i in valid_inds]
if self.proposals is not None:
self.proposals = [self.proposals[i] for i in valid_inds]
# set group flag for the sampler
self._set_group_flag()
# processing pipeline if there are two pipeline the
# pipeline will be determined by key name of query or support
if isinstance(pipeline, dict):
self.pipeline = {}
for key in pipeline.keys():
self.pipeline[key] = Compose(pipeline[key])
else:
self.pipeline = Compose(pipeline)
def get_ann_info(self, idx):
"""Get annotation by index.
When override this function please make sure same annotations are used
during the whole training.
Args:
idx (int): Index of data.
Returns:
dict: Annotation info of specified index.
"""
return copy.deepcopy(self.data_infos[idx]['ann'])
def prepare_train_img(self, idx, pipeline_key=None, gt_idx=None):
"""Get training data and annotations after pipeline.
Args:
idx (int): Index of data.
pipeline_key (str): Name of pipeline
gt_idx (list[int]): Index of used annotation.
Returns:
dict: Training data and annotation after pipeline with new keys \
introduced by pipeline.
"""
img_info = self.data_infos[idx]
ann_info = self.get_ann_info(idx)
# annotation filter
if gt_idx is not None:
selected_ann_info = {
'bboxes': ann_info['bboxes'][gt_idx],
'labels': ann_info['labels'][gt_idx]
}
# keep pace with new annotations
new_img_info = copy.deepcopy(img_info)
new_img_info['ann'] = selected_ann_info
results = dict(img_info=new_img_info, ann_info=selected_ann_info)
else:
results = dict(img_info=copy.deepcopy(img_info), ann_info=ann_info)
if self.proposals is not None:
results['proposals'] = self.proposals[idx]
self.pre_pipeline(results)
if pipeline_key is None:
return self.pipeline(results)
else:
return self.pipeline[pipeline_key](results)
def _filter_annotations(self, data_infos, ann_shot_filter):
"""Filter out annotations not in class_masks and excess annotations of
specific class, while annotations of other classes in class_masks
remain unchanged.
Args:
data_infos (list[dict]): Annotation infos.
ann_shot_filter (dict): Specific which class and how many
instances of each class to load from annotation file.
For example: {'dog': 10, 'cat': 10, 'person': 5} Default: None.
Returns:
list[dict]: Annotation infos where number of specified class
shots less than or equal to predefined number.
"""
# build instance indexes of (img_id, gt_idx)
total_instance_dict = {key: [] for key in ann_shot_filter.keys()}
for data_info in data_infos:
img_id = data_info['id']
ann = data_info['ann']
for i in range(ann['labels'].shape[0]):
instance_class_name = self.CLASSES[ann['labels'][i]]
if instance_class_name in ann_shot_filter.keys():
total_instance_dict[instance_class_name].append(
(img_id, i))
total_instance_indexes = []
for class_name in ann_shot_filter.keys():
num_shot = ann_shot_filter[class_name]
instance_indexes = total_instance_dict[class_name]
# we will random sample from all instances to get exact
# number of instance
if len(instance_indexes) > num_shot:
random_select = np.random.choice(
len(instance_indexes), num_shot, replace=False)
total_instance_indexes += \
[instance_indexes[i] for i in random_select]
# number of shot less than the predefined number,
# which may cause the performance degradation
elif len(instance_indexes) < num_shot:
warning = f'number of {class_name} instance ' \
f'is {len(instance_indexes)} which is ' \
f'less than predefined shots {num_shot}.'
warnings.warn(warning)
total_instance_indexes += instance_indexes
else:
total_instance_indexes += instance_indexes
new_data_infos = []
for data_info in data_infos:
img_id = data_info['id']
selected_instance_index = \
sorted([instance[1] for instance in total_instance_indexes
if instance[0] == img_id])
ann = data_info['ann']
if len(selected_instance_index) == 0:
continue
selected_ann = dict(
bboxes=ann['bboxes'][selected_instance_index],
labels=ann['labels'][selected_instance_index],
)
if ann.get('bboxes_ignore') is not None:
selected_ann['bboxes_ignore'] = ann['bboxes_ignore']
if ann.get('labels_ignore') is not None:
selected_ann['labels_ignore'] = ann['labels_ignore']
new_data_infos.append(
dict(
id=img_id,
filename=data_info['filename'],
width=data_info['width'],
height=data_info['height'],
ann=selected_ann))
return new_data_infos

View File

@ -0,0 +1,92 @@
# Copyright (c) Open-MMLab. All rights reserved.
from collections.abc import Mapping, Sequence
import torch
import torch.nn.functional as F
from mmcv.parallel.data_container import DataContainer
from torch.utils.data.dataloader import default_collate
def query_support_collate_fn(batch, samples_per_gpu=1):
"""Puts each data field into a tensor/DataContainer with outer dimension
batch size.
Extend default_collate to add support for
:type:`~mmcv.parallel.DataContainer`. There are 3 cases.
1. cpu_only = True, e.g., meta data
2. cpu_only = False, stack = True, e.g., images tensors
3. cpu_only = False, stack = False, e.g., gt bboxes
"""
if not isinstance(batch, Sequence):
raise TypeError(f'{batch.dtype} is not supported.')
# process the support batch data in type of List: [ List: [ DataContainer]]
if isinstance(batch[0], Sequence):
samples_per_gpu = len(batch[0]) * samples_per_gpu
batch = sum(batch, [])
if isinstance(batch[0], DataContainer):
stacked = []
if batch[0].cpu_only:
for i in range(0, len(batch), samples_per_gpu):
stacked.append(
[sample.data for sample in batch[i:i + samples_per_gpu]])
return DataContainer(
stacked, batch[0].stack, batch[0].padding_value, cpu_only=True)
elif batch[0].stack:
for i in range(0, len(batch), samples_per_gpu):
assert isinstance(batch[i].data, torch.Tensor)
if batch[i].pad_dims is not None:
ndim = batch[i].dim()
assert ndim > batch[i].pad_dims
max_shape = [0 for _ in range(batch[i].pad_dims)]
for dim in range(1, batch[i].pad_dims + 1):
max_shape[dim - 1] = batch[i].size(-dim)
for sample in batch[i:i + samples_per_gpu]:
for dim in range(0, ndim - batch[i].pad_dims):
assert batch[i].size(dim) == sample.size(dim)
for dim in range(1, batch[i].pad_dims + 1):
max_shape[dim - 1] = max(max_shape[dim - 1],
sample.size(-dim))
padded_samples = []
for sample in batch[i:i + samples_per_gpu]:
pad = [0 for _ in range(batch[i].pad_dims * 2)]
for dim in range(1, batch[i].pad_dims + 1):
pad[2 * dim -
1] = max_shape[dim - 1] - sample.size(-dim)
padded_samples.append(
F.pad(
sample.data, pad, value=sample.padding_value))
stacked.append(default_collate(padded_samples))
elif batch[i].pad_dims is None:
stacked.append(
default_collate([
sample.data
for sample in batch[i:i + samples_per_gpu]
]))
else:
raise ValueError(
'pad_dims should be either None or integers (1-3)')
else:
for i in range(0, len(batch), samples_per_gpu):
stacked.append(
[sample.data for sample in batch[i:i + samples_per_gpu]])
return DataContainer(stacked, batch[0].stack, batch[0].padding_value)
elif isinstance(batch[0], Sequence):
transposed = zip(*batch)
return [
query_support_collate_fn(samples, samples_per_gpu)
for samples in transposed
]
elif isinstance(batch[0], Mapping):
return {
key: query_support_collate_fn([d[key] for d in batch],
samples_per_gpu)
for key in batch[0]
}
else:
return default_collate(batch)

View File

@ -0,0 +1,183 @@
import os.path as osp
import xml.etree.ElementTree as ET
import mmcv
import numpy as np
from mmdet.datasets.builder import DATASETS
from .few_shot_custom import FewShotCustomDataset
@DATASETS.register_module()
class FewShotVOCDataset(FewShotCustomDataset):
"""VOC dataset for few shot detection.
FewShotVOCDataset allow annotation mask during loading annotation.
The annotation can be loaded from image id or image path. For example:
.. code-block:: none
ann_image_id.txt:
000001
000002
ann_image_path.txt:
VOC2007/JPEGImages/000001.jpg
VOC2007/JPEGImages/000002.jpg
Args:
min_size (int | float, optional): The minimum size of bounding
boxes in the images. If the size of a bounding box is less than
``min_size``, it would be add to ignored field. Default: None.
"""
def __init__(self, min_size=None, **kwargs):
assert self.CLASSES or kwargs.get(
'classes', None), 'CLASSES in `XMLDataset` can not be None.'
self.min_size = min_size
super(FewShotVOCDataset, self).__init__(**kwargs)
def load_annotations(self, ann_file):
"""Load annotation from XML style ann_file.
Args:
ann_file (str): Path of XML file.
Returns:
list[dict]: Annotation info from XML file.
"""
self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)}
data_infos = []
img_names = mmcv.list_from_file(ann_file)
for img_name in img_names:
# ann file in image path format
if 'VOC2007' in img_name:
dataset_year = 'VOC2007'
img_id = img_name.split('/')[-1].split('.')[0]
filename = img_name
# ann file in image path format
elif 'VOC2012' in img_name:
dataset_year = 'VOC2012'
img_id = img_name.split('/')[-1].split('.')[0]
filename = img_name
# ann file in image id format
elif 'VOC2007' in ann_file:
dataset_year = 'VOC2007'
img_id = img_name
filename = f'VOC2007/JPEGImages/{img_name}.jpg'
# ann file in image id format
elif 'VOC2012' in ann_file:
dataset_year = 'VOC2012'
img_id = img_name
filename = f'VOC2012/JPEGImages/{img_name}.jpg'
else:
raise ValueError('Cannot infer dataset year from img_prefix')
xml_path = osp.join(self.img_prefix, dataset_year, 'Annotations',
f'{img_id}.xml')
tree = ET.parse(xml_path)
root = tree.getroot()
size = root.find('size')
if size is not None:
width = int(size.find('width').text)
height = int(size.find('height').text)
else:
img_path = osp.join(self.img_prefix, dataset_year,
'JPEGImages', '{}.jpg'.format(img_id))
img = mmcv.imread(img_path)
width, height = img.size
# save annotation infos into data infos, because not all the
# annotations will be used for training and the used annotations
# should stay the same anytime during training.
ann_info = self._get_ann_info(dataset_year, img_id)
data_infos.append(
dict(
id=img_id,
filename=filename,
width=width,
height=height,
ann=ann_info))
return data_infos
def _get_ann_info(self, dataset_year, img_id):
"""Get annotation from XML file by img_id.
Args:
dataset_year (str): Year of voc dataset. Options are
'VOC2007', 'VOC2012'
img_id (str): Id of image.
Returns:
dict: Annotation info of specified id with specified class.
"""
bboxes = []
labels = []
bboxes_ignore = []
labels_ignore = []
xml_path = osp.join(self.img_prefix, dataset_year, 'Annotations',
f'{img_id}.xml')
tree = ET.parse(xml_path)
root = tree.getroot()
for obj in root.findall('object'):
name = obj.find('name').text
if name not in self.CLASSES:
continue
label = self.cat2label[name]
difficult = obj.find('difficult')
difficult = 0 if difficult is None else int(difficult.text)
bnd_box = obj.find('bndbox')
# TODO: check whether it is necessary to use int
# Coordinates may be float type
bbox = [
int(float(bnd_box.find('xmin').text)),
int(float(bnd_box.find('ymin').text)),
int(float(bnd_box.find('xmax').text)),
int(float(bnd_box.find('ymax').text))
]
ignore = False
if self.min_size:
assert not self.test_mode
w = bbox[2] - bbox[0]
h = bbox[3] - bbox[1]
if w < self.min_size or h < self.min_size:
ignore = True
if difficult or ignore:
bboxes_ignore.append(bbox)
labels_ignore.append(label)
else:
bboxes.append(bbox)
labels.append(label)
if not bboxes:
bboxes = np.zeros((0, 4))
labels = np.zeros((0, ))
else:
bboxes = np.array(bboxes, ndmin=2) - 1
labels = np.array(labels)
if not bboxes_ignore:
bboxes_ignore = np.zeros((0, 4))
labels_ignore = np.zeros((0, ))
else:
bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1
labels_ignore = np.array(labels_ignore)
ann_info = dict(
bboxes=bboxes.astype(np.float32),
labels=labels.astype(np.int64),
bboxes_ignore=bboxes_ignore.astype(np.float32),
labels_ignore=labels_ignore.astype(np.int64))
return ann_info
def _filter_imgs(self, min_size=32):
"""Filter images too small or without annotation."""
valid_inds = []
for i, img_info in enumerate(self.data_infos):
if min(img_info['width'], img_info['height']) < min_size:
continue
if self.filter_empty_gt:
cat_ids = img_info['ann']['labels'].astype(np.int).tolist()
if len(cat_ids) > 0:
valid_inds.append(i)
else:
valid_inds.append(i)
return valid_inds

View File

@ -3,7 +3,8 @@ line_length = 79
multi_line_output = 0
known_standard_library = setuptools
known_first_party = mmfewshot
known_third_party = mmcls,mmcv,mmdet,numpy,pytest,torch
known_third_party = mmcls,mmcv,mmdet,numpy,pytest,terminaltables,torch
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY

View File

@ -0,0 +1,44 @@
<annotation>
<folder>VOC2007</folder>
<filename>000001.jpg</filename>
<source>
<database>The VOC2007 Database</database>
<annotation>PASCAL VOC2007</annotation>
<image>flickr</image>
<flickrid>341012865</flickrid>
</source>
<owner>
<flickrid>Fried Camels</flickrid>
<name>Jinky the Fruit Bat</name>
</owner>
<size>
<width>353</width>
<height>500</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<object>
<name>dog</name>
<pose>Left</pose>
<truncated>1</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>48</xmin>
<ymin>240</ymin>
<xmax>195</xmax>
<ymax>371</ymax>
</bndbox>
</object>
<object>
<name>person</name>
<pose>Left</pose>
<truncated>1</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>8</xmin>
<ymin>12</ymin>
<xmax>352</xmax>
<ymax>498</ymax>
</bndbox>
</object>
</annotation>

View File

@ -0,0 +1,32 @@
<annotation>
<folder>VOC2007</folder>
<filename>000002.jpg</filename>
<source>
<database>The VOC2007 Database</database>
<annotation>PASCAL VOC2007</annotation>
<image>flickr</image>
<flickrid>329145082</flickrid>
</source>
<owner>
<flickrid>hiromori2</flickrid>
<name>Hiroyuki Mori</name>
</owner>
<size>
<width>335</width>
<height>500</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<object>
<name>train</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>139</xmin>
<ymin>200</ymin>
<xmax>207</xmax>
<ymax>301</ymax>
</bndbox>
</object>
</annotation>

View File

@ -0,0 +1,44 @@
<annotation>
<folder>VOC2007</folder>
<filename>000003.jpg</filename>
<source>
<database>The VOC2007 Database</database>
<annotation>PASCAL VOC2007</annotation>
<image>flickr</image>
<flickrid>138563409</flickrid>
</source>
<owner>
<flickrid>RandomEvent101</flickrid>
<name>?</name>
</owner>
<size>
<width>500</width>
<height>375</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<object>
<name>sofa</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>123</xmin>
<ymin>155</ymin>
<xmax>215</xmax>
<ymax>195</ymax>
</bndbox>
</object>
<object>
<name>chair</name>
<pose>Left</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>239</xmin>
<ymin>156</ymin>
<xmax>307</xmax>
<ymax>205</ymax>
</bndbox>
</object>
</annotation>

View File

@ -0,0 +1,104 @@
<annotation>
<folder>VOC2007</folder>
<filename>000004.jpg</filename>
<source>
<database>The VOC2007 Database</database>
<annotation>PASCAL VOC2007</annotation>
<image>flickr</image>
<flickrid>322032655</flickrid>
</source>
<owner>
<flickrid>paytonc</flickrid>
<name>Payton Chung</name>
</owner>
<size>
<width>500</width>
<height>406</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<object>
<name>car</name>
<pose>Frontal</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>13</xmin>
<ymin>311</ymin>
<xmax>84</xmax>
<ymax>362</ymax>
</bndbox>
</object>
<object>
<name>car</name>
<pose>Unspecified</pose>
<truncated>1</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>362</xmin>
<ymin>330</ymin>
<xmax>500</xmax>
<ymax>389</ymax>
</bndbox>
</object>
<object>
<name>car</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>235</xmin>
<ymin>328</ymin>
<xmax>334</xmax>
<ymax>375</ymax>
</bndbox>
</object>
<object>
<name>car</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>175</xmin>
<ymin>327</ymin>
<xmax>252</xmax>
<ymax>364</ymax>
</bndbox>
</object>
<object>
<name>car</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>139</xmin>
<ymin>320</ymin>
<xmax>189</xmax>
<ymax>359</ymax>
</bndbox>
</object>
<object>
<name>car</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>108</xmin>
<ymin>325</ymin>
<xmax>150</xmax>
<ymax>353</ymax>
</bndbox>
</object>
<object>
<name>car</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>84</xmin>
<ymin>323</ymin>
<xmax>121</xmax>
<ymax>350</ymax>
</bndbox>
</object>
</annotation>

View File

@ -0,0 +1,80 @@
<annotation>
<folder>VOC2007</folder>
<filename>000005.jpg</filename>
<source>
<database>The VOC2007 Database</database>
<annotation>PASCAL VOC2007</annotation>
<image>flickr</image>
<flickrid>325991873</flickrid>
</source>
<owner>
<flickrid>archintent louisville</flickrid>
<name>?</name>
</owner>
<size>
<width>500</width>
<height>375</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<object>
<name>chair</name>
<pose>Rear</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>263</xmin>
<ymin>211</ymin>
<xmax>324</xmax>
<ymax>339</ymax>
</bndbox>
</object>
<object>
<name>chair</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>165</xmin>
<ymin>264</ymin>
<xmax>253</xmax>
<ymax>372</ymax>
</bndbox>
</object>
<object>
<name>chair</name>
<pose>Unspecified</pose>
<truncated>1</truncated>
<difficult>1</difficult>
<bndbox>
<xmin>5</xmin>
<ymin>244</ymin>
<xmax>67</xmax>
<ymax>374</ymax>
</bndbox>
</object>
<object>
<name>chair</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>241</xmin>
<ymin>194</ymin>
<xmax>295</xmax>
<ymax>299</ymax>
</bndbox>
</object>
<object>
<name>chair</name>
<pose>Unspecified</pose>
<truncated>1</truncated>
<difficult>1</difficult>
<bndbox>
<xmin>277</xmin>
<ymin>186</ymin>
<xmax>312</xmax>
<ymax>220</ymax>
</bndbox>
</object>
</annotation>

View File

@ -0,0 +1,5 @@
000001
000002
000003
000004
000005

View File

@ -0,0 +1,5 @@
000001
000002
000003
000004
000005

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 111 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 120 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 100 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 83 KiB

View File

@ -0,0 +1,63 @@
<annotation>
<folder>VOC2012</folder>
<filename>2007_000027.jpg</filename>
<source>
<database>The VOC2007 Database</database>
<annotation>PASCAL VOC2007</annotation>
<image>flickr</image>
</source>
<size>
<width>486</width>
<height>500</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<object>
<name>person</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>174</xmin>
<ymin>101</ymin>
<xmax>349</xmax>
<ymax>351</ymax>
</bndbox>
<part>
<name>head</name>
<bndbox>
<xmin>169</xmin>
<ymin>104</ymin>
<xmax>209</xmax>
<ymax>146</ymax>
</bndbox>
</part>
<part>
<name>hand</name>
<bndbox>
<xmin>278</xmin>
<ymin>210</ymin>
<xmax>297</xmax>
<ymax>233</ymax>
</bndbox>
</part>
<part>
<name>foot</name>
<bndbox>
<xmin>273</xmin>
<ymin>333</ymin>
<xmax>297</xmax>
<ymax>354</ymax>
</bndbox>
</part>
<part>
<name>foot</name>
<bndbox>
<xmin>319</xmin>
<ymin>307</ymin>
<xmax>340</xmax>
<ymax>326</ymax>
</bndbox>
</part>
</object>
</annotation>

View File

@ -0,0 +1,63 @@
<annotation>
<folder>VOC2012</folder>
<filename>2007_000032.jpg</filename>
<source>
<database>The VOC2007 Database</database>
<annotation>PASCAL VOC2007</annotation>
<image>flickr</image>
</source>
<size>
<width>500</width>
<height>281</height>
<depth>3</depth>
</size>
<segmented>1</segmented>
<object>
<name>aeroplane</name>
<pose>Frontal</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>104</xmin>
<ymin>78</ymin>
<xmax>375</xmax>
<ymax>183</ymax>
</bndbox>
</object>
<object>
<name>aeroplane</name>
<pose>Left</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>133</xmin>
<ymin>88</ymin>
<xmax>197</xmax>
<ymax>123</ymax>
</bndbox>
</object>
<object>
<name>person</name>
<pose>Rear</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>195</xmin>
<ymin>180</ymin>
<xmax>213</xmax>
<ymax>229</ymax>
</bndbox>
</object>
<object>
<name>person</name>
<pose>Rear</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>26</xmin>
<ymin>189</ymin>
<xmax>44</xmax>
<ymax>238</ymax>
</bndbox>
</object>
</annotation>

View File

@ -0,0 +1,51 @@
<annotation>
<folder>VOC2012</folder>
<filename>2007_000033.jpg</filename>
<source>
<database>The VOC2007 Database</database>
<annotation>PASCAL VOC2007</annotation>
<image>flickr</image>
</source>
<size>
<width>500</width>
<height>366</height>
<depth>3</depth>
</size>
<segmented>1</segmented>
<object>
<name>aeroplane</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>9</xmin>
<ymin>107</ymin>
<xmax>499</xmax>
<ymax>263</ymax>
</bndbox>
</object>
<object>
<name>aeroplane</name>
<pose>Left</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>421</xmin>
<ymin>200</ymin>
<xmax>482</xmax>
<ymax>226</ymax>
</bndbox>
</object>
<object>
<name>aeroplane</name>
<pose>Left</pose>
<truncated>1</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>325</xmin>
<ymin>188</ymin>
<xmax>411</xmax>
<ymax>223</ymax>
</bndbox>
</object>
</annotation>

View File

@ -0,0 +1,27 @@
<annotation>
<folder>VOC2012</folder>
<filename>2007_000039.jpg</filename>
<source>
<database>The VOC2007 Database</database>
<annotation>PASCAL VOC2007</annotation>
<image>flickr</image>
</source>
<size>
<width>500</width>
<height>375</height>
<depth>3</depth>
</size>
<segmented>1</segmented>
<object>
<name>tvmonitor</name>
<pose>Frontal</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>156</xmin>
<ymin>89</ymin>
<xmax>344</xmax>
<ymax>279</ymax>
</bndbox>
</object>
</annotation>

View File

@ -0,0 +1,39 @@
<annotation>
<folder>VOC2012</folder>
<filename>2007_000042.jpg</filename>
<source>
<database>The VOC2007 Database</database>
<annotation>PASCAL VOC2007</annotation>
<image>flickr</image>
</source>
<size>
<width>500</width>
<height>335</height>
<depth>3</depth>
</size>
<segmented>1</segmented>
<object>
<name>train</name>
<pose>Unspecified</pose>
<truncated>1</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>263</xmin>
<ymin>32</ymin>
<xmax>500</xmax>
<ymax>295</ymax>
</bndbox>
</object>
<object>
<name>train</name>
<pose>Unspecified</pose>
<truncated>1</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>1</xmin>
<ymin>36</ymin>
<xmax>235</xmax>
<ymax>299</ymax>
</bndbox>
</object>
</annotation>

View File

@ -0,0 +1,5 @@
2007_000027
2007_000032
2007_000033
2007_000039
2007_000042

View File

@ -0,0 +1,5 @@
2007_000027
2007_000032
2007_000033
2007_000039
2007_000042

Binary file not shown.

After

Width:  |  Height:  |  Size: 142 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 81 KiB

View File

@ -0,0 +1,77 @@
{
"images": [
{
"file_name": "fake1.jpg",
"height": 800,
"width": 800,
"id": 0
},
{
"file_name": "fake2.jpg",
"height": 800,
"width": 800,
"id": 1
},
{
"file_name": "fake3.jpg",
"height": 800,
"width": 800,
"id": 2
}
],
"annotations": [
{
"bbox": [
0,
0,
20,
20
],
"area": 400.00,
"score": 1.0,
"category_id": 1,
"id": 1,
"image_id": 0
},
{
"bbox": [
0,
0,
20,
20
],
"area": 400.00,
"score": 1.0,
"category_id": 2,
"id": 2,
"image_id": 0
},
{
"bbox": [
0,
0,
20,
20
],
"area": 400.00,
"score": 1.0,
"category_id": 1,
"id": 3,
"image_id": 1
}
],
"categories": [
{
"id": 1,
"name": "bus",
"supercategory": "none"
},
{
"id": 2,
"name": "car",
"supercategory": "none"
}
],
"licenses": [],
"info": null
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,10 @@
VOC2007/JPEGImages/000001.jpg
VOC2007/JPEGImages/000002.jpg
VOC2007/JPEGImages/000003.jpg
VOC2007/JPEGImages/000004.jpg
VOC2007/JPEGImages/000005.jpg
VOC2012/JPEGImages/2007_000027.jpg
VOC2012/JPEGImages/2007_000032.jpg
VOC2012/JPEGImages/2007_000033.jpg
VOC2012/JPEGImages/2007_000039.jpg
VOC2012/JPEGImages/2007_000042.jpg

View File

@ -0,0 +1,10 @@
VOC2007/JPEGImages/000001.jpg
VOC2007/JPEGImages/000002.jpg
VOC2007/JPEGImages/000003.jpg
VOC2007/JPEGImages/000004.jpg
VOC2007/JPEGImages/000005.jpg
VOC2012/JPEGImages/2007_000027.jpg
VOC2012/JPEGImages/2007_000032.jpg
VOC2012/JPEGImages/2007_000033.jpg
VOC2012/JPEGImages/2007_000039.jpg
VOC2012/JPEGImages/2007_000042.jpg

View File

@ -0,0 +1,10 @@
VOC2007/JPEGImages/000001.jpg
VOC2007/JPEGImages/000002.jpg
VOC2007/JPEGImages/000003.jpg
VOC2007/JPEGImages/000004.jpg
VOC2007/JPEGImages/000005.jpg
VOC2012/JPEGImages/2007_000027.jpg
VOC2012/JPEGImages/2007_000032.jpg
VOC2012/JPEGImages/2007_000033.jpg
VOC2012/JPEGImages/2007_000039.jpg
VOC2012/JPEGImages/2007_000042.jpg

View File

@ -0,0 +1,10 @@
VOC2007/JPEGImages/000001.jpg
VOC2007/JPEGImages/000002.jpg
VOC2007/JPEGImages/000003.jpg
VOC2007/JPEGImages/000004.jpg
VOC2007/JPEGImages/000005.jpg
VOC2012/JPEGImages/2007_000027.jpg
VOC2012/JPEGImages/2007_000032.jpg
VOC2012/JPEGImages/2007_000033.jpg
VOC2012/JPEGImages/2007_000039.jpg
VOC2012/JPEGImages/2007_000042.jpg

View File

@ -0,0 +1,10 @@
VOC2007/JPEGImages/000001.jpg
VOC2007/JPEGImages/000002.jpg
VOC2007/JPEGImages/000003.jpg
VOC2007/JPEGImages/000004.jpg
VOC2007/JPEGImages/000005.jpg
VOC2012/JPEGImages/2007_000027.jpg
VOC2012/JPEGImages/2007_000032.jpg
VOC2012/JPEGImages/2007_000033.jpg
VOC2012/JPEGImages/2007_000039.jpg
VOC2012/JPEGImages/2007_000042.jpg

BIN
tests/data/gray.jpg 100644

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

View File

@ -0,0 +1,247 @@
import torch
from mmfewshot.apis.train import set_random_seed
from mmfewshot.detection.datasets.builder import (build_dataloader,
build_dataset)
def test_dataloader():
set_random_seed(2021)
# test regular and few shot annotations
dataconfigs = [{
'type': 'NwayKshotDataset',
'support_way': 5,
'support_shot': 1,
'dataset': {
'type':
'FewShotVOCDataset',
'ann_file': [
'tests/data/VOCdevkit/VOC2007/ImageSets/Main/trainval.txt',
'tests/data/VOCdevkit/VOC2012/ImageSets/Main/trainval.txt'
],
'img_prefix': [
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
],
'pipeline': {
'query': [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(
type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
],
'support': [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(
type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
},
'classes': ('person', 'dog', 'chair', 'car', 'aeroplane', 'train'),
'merge_dataset':
True
}
}, {
'type': 'NwayKshotDataset',
'support_way': 5,
'support_shot': 1,
'dataset': {
'type':
'FewShotVOCDataset',
'ann_file': [
'tests/data/few_shot_voc_split/1.txt',
'tests/data/few_shot_voc_split/2.txt',
'tests/data/few_shot_voc_split/3.txt',
'tests/data/few_shot_voc_split/4.txt',
'tests/data/few_shot_voc_split/5.txt'
],
'img_prefix': [
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
],
'ann_shot_filter': [{
'person': 2
}, {
'dog': 2
}, {
'chair': 3
}, {
'car': 3
}, {
'aeroplane': 3
}],
'pipeline': {
'query': [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(
type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
],
'support': [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(
type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
},
'classes': ('person', 'dog', 'chair', 'car', 'aeroplane'),
'merge_dataset':
True
}
}]
for dataconfig in dataconfigs:
nway_kshot_dataset = build_dataset(cfg=dataconfig)
nway_kshot_dataloader = build_dataloader(
nway_kshot_dataset,
samples_per_gpu=2,
workers_per_gpu=0,
num_gpus=1,
dist=False,
shuffle=True,
seed=2021)
for i, data_batch in enumerate(nway_kshot_dataloader):
assert len(data_batch['query_data']['img_metas'].data[0]) == 2
assert len(nway_kshot_dataloader.query_data_loader) == \
len(nway_kshot_dataloader.support_data_loader)
support_labels = data_batch['support_data']['gt_labels'].data[0]
assert len(set(torch.cat(
support_labels).tolist())) == dataconfig['support_way']
assert len(torch.cat(support_labels).tolist()) == \
dataconfig['support_way'] * dataconfig['support_shot']
dataconfigs = [{
'type': 'QueryAwareDataset',
'support_way': 3,
'support_shot': 5,
'dataset': {
'type':
'FewShotVOCDataset',
'ann_file': [
'tests/data/VOCdevkit/VOC2007/ImageSets/Main/trainval.txt',
'tests/data/VOCdevkit/VOC2012/ImageSets/Main/trainval.txt'
],
'img_prefix': [
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
],
'pipeline': {
'query': [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(
type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
],
'support': [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(
type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
},
'classes': ('dog', 'chair', 'car'),
'merge_dataset':
True
}
}, {
'type': 'QueryAwareDataset',
'support_way': 3,
'support_shot': 2,
'dataset': {
'type':
'FewShotVOCDataset',
'ann_file': [
'tests/data/few_shot_voc_split/1.txt',
'tests/data/few_shot_voc_split/2.txt',
'tests/data/few_shot_voc_split/3.txt',
'tests/data/few_shot_voc_split/4.txt',
'tests/data/few_shot_voc_split/5.txt'
],
'img_prefix': [
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
],
'ann_shot_filter': [{
'person': 1
}, {
'dog': 1
}, {
'chair': 2
}, {
'car': 2
}, {
'aeroplane': 2
}],
'pipeline': {
'query': [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(
type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
],
'support': [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(
type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
},
'classes': ('person', 'dog', 'chair', 'car', 'aeroplane'),
'merge_dataset':
True
}
}]
for dataconfig in dataconfigs:
query_aware_dataset = build_dataset(cfg=dataconfig)
query_aware_dataloader = build_dataloader(
query_aware_dataset,
samples_per_gpu=2,
workers_per_gpu=0,
num_gpus=1,
dist=False,
shuffle=True,
seed=2021)
for i, data_batch in enumerate(query_aware_dataloader):
assert len(data_batch['query_data']['img_metas'].data[0]) == 2
assert len(data_batch['query_data']['query_class'].tolist()) == 2
support_labels = data_batch['support_data']['gt_labels'].data[0]
half_batch = len(support_labels) // 2
assert len(set(torch.cat(support_labels[:half_batch]).tolist())) \
== dataconfig['support_way']
assert len(set(torch.cat(support_labels[half_batch:]).tolist())) \
== dataconfig['support_way']

View File

@ -0,0 +1,49 @@
from mmfewshot.apis.train import set_random_seed
from mmfewshot.detection.datasets.coco import FewShotCocoDataset
def test_few_shot_voc_dataset():
set_random_seed(2021)
# test regular annotation loading
dataconfig = {
'ann_file': 'tests/data/coco_sample.json',
'img_prefix': '',
'pipeline': {
'query': [{
'type': 'LoadImageFromFile'
}],
'support': [{
'type': 'LoadImageFromFile'
}]
},
'classes': ('bus', 'car')
}
few_shot_custom_dataset = FewShotCocoDataset(**dataconfig)
# filter image without labels
assert len(few_shot_custom_dataset.data_infos) == 2
assert few_shot_custom_dataset.CLASSES == ('bus', 'car')
# test loading annotation with specific class
dataconfig = {
'ann_file': 'tests/data/few_shot_coco_split/bus.json',
'img_prefix': '',
'ann_shot_filter': {
'bus': 5
},
'pipeline': {
'query': [{
'type': 'LoadImageFromFile'
}],
'support': [{
'type': 'LoadImageFromFile'
}]
},
'classes': ('bus', 'dog', 'car'),
}
few_shot_custom_dataset = FewShotCocoDataset(**dataconfig)
count = 0
for datainfo in few_shot_custom_dataset.data_infos:
count += len(datainfo['ann']['labels'])
for i in range(len(datainfo['ann']['labels'])):
assert datainfo['ann']['labels'][i] == 0
assert count == 5

View File

@ -0,0 +1,112 @@
import copy
from unittest.mock import MagicMock, patch
import numpy as np
from mmfewshot.detection.datasets import FewShotCustomDataset
data_infos = [
{
'id': '1',
'filename': 'tests/data/VOCdevkit/VOC2007/JPEGImages/000001.jpg',
'width': 800,
'height': 720,
'ann': {
'bboxes': np.array([[10, 10, 100, 100], [20, 20, 200, 200]]),
'labels': np.array([0, 1])
}
},
{
'id': '2',
'filename': 'tests/data/VOCdevkit/VOC2007/JPEGImages/000002.jpg',
'width': 800,
'height': 720,
'ann': {
'bboxes': np.array([[11, 11, 100, 100], [20, 20, 200, 200]]),
'labels': np.array([1, 1])
}
},
{
'id': '3',
'filename': 'tests/data/VOCdevkit/VOC2007/JPEGImages/000003.jpg',
'width': 800,
'height': 720,
'ann': {
'bboxes':
np.array([[11, 11, 100, 100], [20, 20, 200, 200],
[20, 20, 200, 200]]),
'labels':
np.array([2, 3, 3, 4])
}
},
{
'id': '4',
'filename': 'tests/data/VOCdevkit/VOC2007/JPEGImages/000004.jpg',
'width': 800,
'height': 720,
'ann': {
'bboxes':
np.array([[11, 11, 100, 100], [20, 20, 200, 200],
[20, 20, 200, 200], [20, 20, 200, 200]]),
'labels':
np.array([2, 2, 4, 4])
}
},
]
@patch('mmfewshot.detection.datasets.FewShotCustomDataset.load_annotations',
MagicMock(return_value=data_infos))
def test_few_shot_custom_dataset():
dataconfig = {
'ann_file': '',
'img_prefix': '',
'ann_shot_filter': {
'cat': 10,
'dog': 10,
'person': 2,
'car': 2,
},
'pipeline': {
'query': [{
'type': 'LoadImageFromFile'
}],
'support': [{
'type': 'LoadImageFromFile'
}]
},
'classes': ('cat', 'dog', 'person', 'car', 'bird')
}
few_shot_custom_dataset = FewShotCustomDataset(**dataconfig)
original_data_infos = copy.deepcopy(few_shot_custom_dataset.data_infos)
# test prepare_train_img()
data = few_shot_custom_dataset.prepare_train_img(0, 'query')
assert (data['img_info']['ann']['bboxes'] == np.array([[10, 10, 100, 100],
[20, 20, 200,
200]])).all()
assert (data['img_info']['ann']['labels'] == np.array([0, 1])).all()
data = few_shot_custom_dataset.prepare_train_img(1, 'support')
assert (data['img_info']['ann']['bboxes'] == np.array([[11, 11, 100, 100],
[20, 20, 200,
200]])).all()
assert (data['img_info']['ann']['labels'] == np.array([1, 1])).all()
data = few_shot_custom_dataset.prepare_train_img(0, 'query', [0])
assert (data['img_info']['ann']['bboxes'] == np.array([[10, 10, 100,
100]])).all()
assert (data['img_info']['ann']['labels'] == np.array([0])).all()
data = few_shot_custom_dataset.prepare_train_img(0, 'support', [1])
assert (data['img_info']['ann']['bboxes'] == np.array([[20, 20, 200,
200]])).all()
assert (data['img_info']['ann']['labels'] == np.array([1])).all()
# test whether data_infos have been accidentally changed or not
for i in range(len(few_shot_custom_dataset)):
assert (original_data_infos[i]['ann']['bboxes'] ==
few_shot_custom_dataset.data_infos[i]['ann']['bboxes']).all()
assert (original_data_infos[i]['ann']['labels'] ==
few_shot_custom_dataset.data_infos[i]['ann']['labels']).all()

View File

@ -0,0 +1,70 @@
from mmfewshot.apis.train import set_random_seed
from mmfewshot.detection.datasets.voc import FewShotVOCDataset
def test_few_shot_voc_dataset():
set_random_seed(2021)
# test regular annotation loading
dataconfig = {
'ann_file': 'tests/data/VOCdevkit/VOC2007/ImageSets/Main/trainval.txt',
'img_prefix': 'tests/data/VOCdevkit/',
'pipeline': {
'query': [{
'type': 'LoadImageFromFile'
}],
'support': [{
'type': 'LoadImageFromFile'
}]
},
'classes': ('car', 'dog', 'chair')
}
few_shot_custom_dataset = FewShotVOCDataset(**dataconfig)
# filter image without labels
assert len(few_shot_custom_dataset.data_infos) == 4
assert few_shot_custom_dataset.CLASSES == ('car', 'dog', 'chair')
# test loading annotation with specific class
dataconfig = {
'ann_file': 'tests/data/few_shot_voc_split/1.txt',
'img_prefix': 'tests/data/VOCdevkit/',
'ann_shot_filter': {
'aeroplane': 10
},
'pipeline': {
'query': [{
'type': 'LoadImageFromFile'
}],
'support': [{
'type': 'LoadImageFromFile'
}]
},
'classes': ('car', 'dog', 'chair', 'aeroplane'),
}
few_shot_custom_dataset = FewShotVOCDataset(**dataconfig)
count = 0
for datainfo in few_shot_custom_dataset.data_infos:
count += len(datainfo['ann']['bboxes'])
assert count == 5
# test loading annotation with specific class with specific shot
dataconfig = {
'ann_file': 'tests/data/few_shot_voc_split/1.txt',
'img_prefix': 'tests/data/VOCdevkit/',
'ann_shot_filter': {
'aeroplane': 2
},
'pipeline': {
'query': [{
'type': 'LoadImageFromFile'
}],
'support': [{
'type': 'LoadImageFromFile'
}]
},
'classes': ('car', 'dog', 'chair', 'aeroplane'),
}
few_shot_custom_dataset = FewShotVOCDataset(**dataconfig)
count = 0
for datainfo in few_shot_custom_dataset.data_infos:
count += len(datainfo['ann']['bboxes'])
assert count == 2

View File

@ -0,0 +1,137 @@
import numpy as np
from mmfewshot.apis.train import set_random_seed
from mmfewshot.detection.datasets.builder import build_dataset
def test_merge_dataset():
set_random_seed(2023)
# test merge dataset load regular annotation
dataconfig = {
'type':
'FewShotVOCDataset',
'ann_file': [
'tests/data/VOCdevkit/VOC2007/ImageSets/Main/trainval.txt',
'tests/data/VOCdevkit/VOC2012/ImageSets/Main/trainval.txt'
],
'img_prefix': [
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
],
'pipeline': {
'query': [{
'type': 'LoadImageFromFile'
}],
'support': [{
'type': 'LoadImageFromFile'
}]
},
'classes': ('person', 'dog', 'chair', 'car', 'aeroplane'),
'merge_dataset':
True
}
merge_dataset = build_dataset(cfg=dataconfig)
count = [0 for _ in range(5)]
for data_info in merge_dataset.dataset.data_infos:
# test label merge
if data_info['id'] == '000001':
assert (np.sort(data_info['ann']['labels']) == np.array([0, 1
])).all()
for label in data_info['ann']['labels']:
count[label] += 1
assert count == [4, 1, 4, 7, 5]
# test merge dataset load annotation by class
dataconfig = {
'type':
'FewShotVOCDataset',
'ann_file': [
'tests/data/few_shot_voc_split/1.txt',
'tests/data/few_shot_voc_split/2.txt',
'tests/data/few_shot_voc_split/3.txt',
'tests/data/few_shot_voc_split/4.txt',
'tests/data/few_shot_voc_split/5.txt'
],
'img_prefix': [
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
],
'ann_shot_filter': [{
'person': 2
}, {
'dog': 2
}, {
'chair': 3
}, {
'car': 3
}, {
'aeroplane': 3
}],
'pipeline': {
'query': [{
'type': 'LoadImageFromFile'
}],
'support': [{
'type': 'LoadImageFromFile'
}]
},
'classes': ('person', 'dog', 'chair', 'car', 'aeroplane'),
'merge_dataset':
True
}
merge_dataset = build_dataset(cfg=dataconfig)
count = [0 for _ in range(5)]
for data_info in merge_dataset.dataset.data_infos:
# test label merge
if data_info['id'] == '000001':
assert (np.sort(data_info['ann']['labels']) == np.array([0, 1
])).all()
for label in data_info['ann']['labels']:
count[label] += 1
assert count == [2, 1, 3, 3, 3]
# test loading annotation with specific class with specific shot
dataconfig = {
'type':
'FewShotCocoDataset',
'ann_file': [
'tests/data/few_shot_coco_split/bus.json',
'tests/data/few_shot_coco_split/car.json',
'tests/data/few_shot_coco_split/cat.json',
'tests/data/few_shot_coco_split/dog.json',
'tests/data/few_shot_coco_split/person.json',
],
'img_prefix': ['', '', '', '', ''],
'ann_shot_filter': [{
'bus': 2
}, {
'car': 2
}, {
'cat': 3
}, {
'dog': 3
}, {
'person': 3
}],
'pipeline': {
'query': [{
'type': 'LoadImageFromFile'
}],
'support': [{
'type': 'LoadImageFromFile'
}]
},
'classes': ('bus', 'car', 'cat', 'dog', 'person'),
'merge_dataset':
True
}
merge_dataset = build_dataset(cfg=dataconfig)
count = [0 for _ in range(5)]
for data_info in merge_dataset.dataset.data_infos:
# test label merge
for label in data_info['ann']['labels']:
count[label] += 1
assert count == [2, 2, 3, 3, 3]

View File

@ -0,0 +1,145 @@
import numpy as np
from mmfewshot.apis.train import set_random_seed
from mmfewshot.detection.datasets.builder import build_dataset
def test_nway_kshot_dataset():
set_random_seed(2021)
# test regular and few shot annotations
dataconfigs = [{
'type': 'NwayKshotDataset',
'support_way': 5,
'support_shot': 1,
'dataset': {
'type':
'FewShotVOCDataset',
'ann_file': [
'tests/data/VOCdevkit/VOC2007/ImageSets/Main/trainval.txt',
'tests/data/VOCdevkit/VOC2012/ImageSets/Main/trainval.txt'
],
'img_prefix': [
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
],
'pipeline': {
'query': [{
'type': 'LoadImageFromFile'
}],
'support': [{
'type': 'LoadImageFromFile'
}]
},
'classes': ('person', 'dog', 'chair', 'car', 'aeroplane'),
'merge_dataset':
True
}
}, {
'type': 'NwayKshotDataset',
'support_way': 5,
'support_shot': 1,
'dataset': {
'type':
'FewShotVOCDataset',
'ann_file': [
'tests/data/few_shot_voc_split/1.txt',
'tests/data/few_shot_voc_split/2.txt',
'tests/data/few_shot_voc_split/3.txt',
'tests/data/few_shot_voc_split/4.txt',
'tests/data/few_shot_voc_split/5.txt'
],
'img_prefix': [
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
],
'ann_shot_filter': [{
'person': 2
}, {
'dog': 2
}, {
'chair': 3
}, {
'car': 3
}, {
'aeroplane': 3
}],
'pipeline': {
'query': [{
'type': 'LoadImageFromFile'
}],
'support': [{
'type': 'LoadImageFromFile'
}]
},
'classes': ('person', 'dog', 'chair', 'car', 'aeroplane'),
'merge_dataset':
True
}
}]
for dataconfig in dataconfigs:
# test query dataset with 5 way 1 shot
nway_kshot_dataset = build_dataset(cfg=dataconfig)
assert nway_kshot_dataset.data_type == 'query'
assert np.sum(nway_kshot_dataset.flag) == 0
assert isinstance(nway_kshot_dataset[0], dict)
# test support dataset with 5 way 1 shot
nway_kshot_dataset.convert_query_to_support(support_dataset_len=2)
batch_index = nway_kshot_dataset.batch_index
assert nway_kshot_dataset.data_type == 'support'
assert nway_kshot_dataset.flag.shape[0] == 2
assert len(batch_index) == 2
assert len(batch_index[0]) == 5
assert len(batch_index[0][0]) == 2
# test batch of support dataset with 5 way 1 shot
support_batch = nway_kshot_dataset[0]
assert isinstance(support_batch, list)
count_classes = [0 for _ in range(5)]
for item in support_batch:
count_classes[item['ann_info']['labels'][0]] += 1
for count in count_classes:
assert count == 1
# test support dataset with 4 way 2 shot
dataconfig['support_way'] = 4
dataconfig['support_shot'] = 2
nway_kshot_dataset = build_dataset(cfg=dataconfig)
assert nway_kshot_dataset.data_type == 'query'
assert np.sum(nway_kshot_dataset.flag) == 0
assert isinstance(nway_kshot_dataset[0], dict)
# test support dataset with 4 way 2 shot
nway_kshot_dataset.convert_query_to_support(support_dataset_len=3)
batch_index = nway_kshot_dataset.batch_index
assert nway_kshot_dataset.data_type == 'support'
assert nway_kshot_dataset.flag.shape[0] == 3
assert len(batch_index) == 3
assert len(batch_index[0]) == 4 * 2
assert len(batch_index[0][0]) == 2
for i in range(len(nway_kshot_dataset.CLASSES)):
assert len(nway_kshot_dataset.data_infos_by_class[i]) >= 2
# test batch of support dataset with 4 way 2 shot
for idx in range(3):
support_batch = nway_kshot_dataset[idx]
assert isinstance(support_batch, list)
count_classes = [0 for _ in range(5)]
dog_ann = None
for item in support_batch:
label = item['ann_info']['labels'][0]
count_classes[label] += 1
# test whether dog label is repeat or not
# (only one dog instance)
if label == 1:
if dog_ann is None:
dog_ann = item['ann_info']['bboxes']
else:
assert (dog_ann == item['ann_info']['bboxes']).all()
# test number of classes sampled
# 4 class have 2 shots 1 class has 0 shot
is_skip = False
for count in count_classes:
if count == 0:
assert not is_skip
is_skip = True
else:
assert count == 2

View File

@ -0,0 +1,136 @@
import numpy as np
from mmfewshot.apis.train import set_random_seed
from mmfewshot.detection.datasets.builder import build_dataset
def test_query_aware_dataset():
set_random_seed(2023)
# test regular annotations
dataconfig = {
'type': 'QueryAwareDataset',
'support_way': 3,
'support_shot': 5,
'dataset': {
'type':
'FewShotVOCDataset',
'ann_file': [
'tests/data/VOCdevkit/VOC2007/ImageSets/Main/trainval.txt',
'tests/data/VOCdevkit/VOC2012/ImageSets/Main/trainval.txt'
],
'img_prefix': [
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
],
'pipeline': {
'query': [{
'type': 'LoadImageFromFile'
}],
'support': [{
'type': 'LoadImageFromFile'
}]
},
'classes': ('dog', 'chair', 'car'),
'merge_dataset':
True
}
}
# test query dataset with 5 way 2 shot
query_aware_dataset = build_dataset(cfg=dataconfig)
assert np.sum(query_aware_dataset.flag) == 0
# print(query_aware_dataset.data_infos_by_class)
# self.data_infos_by_class = {
# 0: [(0, 0)],
# 1: [(1, 0), (3, 0), (3, 1), (3, 2)],
# 2: [(2, 0), (2, 1), (2, 2), (2, 3), (2, 4), (2, 5), (2, 6)]
# }
assert query_aware_dataset.sample_support_shots(0, 0, True) == \
[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0)]
support = query_aware_dataset.sample_support_shots(0, 1, False)
assert len(set(support)) == 4
support = query_aware_dataset.sample_support_shots(1, 1, False)
assert len(set(support)) == 3
support = query_aware_dataset.sample_support_shots(3, 1, False)
assert len(set(support)) == 1
support = query_aware_dataset.sample_support_shots(3, 2)
assert len(set(support)) == 5
support = query_aware_dataset.sample_support_shots(3, 0)
assert len(set(support)) == 1
dataconfig = {
'type': 'QueryAwareDataset',
'support_way': 3,
'support_shot': 2,
'dataset': {
'type':
'FewShotVOCDataset',
'ann_file': [
'tests/data/few_shot_voc_split/1.txt',
'tests/data/few_shot_voc_split/2.txt',
'tests/data/few_shot_voc_split/3.txt',
'tests/data/few_shot_voc_split/4.txt',
'tests/data/few_shot_voc_split/5.txt'
],
'img_prefix': [
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
'tests/data/VOCdevkit/',
],
'ann_shot_filter': [{
'person': 1
}, {
'dog': 1
}, {
'chair': 2
}, {
'car': 2
}, {
'aeroplane': 2
}],
'pipeline': {
'query': [{
'type': 'LoadImageFromFile'
}],
'support': [{
'type': 'LoadImageFromFile'
}]
},
'classes': ('person', 'dog', 'chair', 'car', 'aeroplane'),
'merge_dataset':
True
}
}
query_aware_dataset = build_dataset(cfg=dataconfig)
assert np.sum(query_aware_dataset.flag) == 0
# print(query_aware_dataset.data_infos_by_class)
# self.data_infos_by_class = {
# 0: [(0, 0)],
# 1: [(1, 0)],
# 2: [(2, 0), (2, 1)],
# 3: [(3, 0), (3, 1)],
# 4: [(4, 0), (5, 0)]}
assert query_aware_dataset.sample_support_shots(0, 0, True) == \
[(0, 0), (0, 0)]
support = query_aware_dataset.sample_support_shots(0, 1, False)
assert len(set(support)) == 1
support = query_aware_dataset.sample_support_shots(3, 0)
assert len(set(support)) == 1
assert len(support) == 2
support = query_aware_dataset.sample_support_shots(3, 2)
assert len(set(support)) == 2
batch = query_aware_dataset[0]
assert len(batch['support_data']) == 6
assert batch['query_data']['ann_info']['labels'][0] == \
batch['support_data'][0]['ann_info']['labels'][0]
assert batch['query_data']['ann_info']['labels'][0] == \
batch['support_data'][1]['ann_info']['labels'][0]
assert batch['support_data'][2]['ann_info']['labels'][0] == \
batch['support_data'][3]['ann_info']['labels'][0]
assert batch['support_data'][4]['ann_info']['labels'][0] == \
batch['support_data'][5]['ann_info']['labels'][0]