implementation of datasets and dataloader (#2)
* implementation of datasets and dataloader * add testing file * fix MergeDataset flag bug and QueryAwareDataset one shot setting bug * Set all flag to 0 in three datasetwrapper * Fix NwayKshotDataloader sampler bug and fix some review comments * add pytest file * add voc test data for pytest * finish test file for few shot custom dataset * finish test file for few shot custom dataset * finish test file for few shot custom dataset * finish test file for merge dataset * finish test file nwaykshot dataset * cover more coner case in both datasets and add test file for query aware dataset and nwaykshot dataset * finish test file of dataloader and fix all random seed in test file * remove config * avoid ann info change * fix voc comments * fix voc comments * Lyq dataset dataloader (#4) * fix voc comments * fix voc comments * fix voc comments * fix comment and refactoring FewShotCustomDataset FewShotVOCDataset * add coco dataset and test file * Lyq dataset dataloader (#6) * fix comments * fix comments * Lyq dataset dataloader (#7) * fix comments * fix comments * fix commentspull/1/head
|
@ -1,49 +0,0 @@
|
||||||
# dataset settings
|
|
||||||
dataset_type = 'CocoDataset'
|
|
||||||
data_root = 'data/coco/'
|
|
||||||
img_norm_cfg = dict(
|
|
||||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
|
||||||
train_pipeline = [
|
|
||||||
dict(type='LoadImageFromFile'),
|
|
||||||
dict(type='LoadAnnotations', with_bbox=True),
|
|
||||||
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
|
|
||||||
dict(type='RandomFlip', flip_ratio=0.5),
|
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
|
||||||
dict(type='Pad', size_divisor=32),
|
|
||||||
dict(type='DefaultFormatBundle'),
|
|
||||||
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
|
|
||||||
]
|
|
||||||
test_pipeline = [
|
|
||||||
dict(type='LoadImageFromFile'),
|
|
||||||
dict(
|
|
||||||
type='MultiScaleFlipAug',
|
|
||||||
img_scale=(1333, 800),
|
|
||||||
flip=False,
|
|
||||||
transforms=[
|
|
||||||
dict(type='Resize', keep_ratio=True),
|
|
||||||
dict(type='RandomFlip'),
|
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
|
||||||
dict(type='Pad', size_divisor=32),
|
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
|
||||||
dict(type='Collect', keys=['img']),
|
|
||||||
])
|
|
||||||
]
|
|
||||||
data = dict(
|
|
||||||
samples_per_gpu=2,
|
|
||||||
workers_per_gpu=2,
|
|
||||||
train=dict(
|
|
||||||
type=dataset_type,
|
|
||||||
ann_file=data_root + 'annotations/instances_train2017.json',
|
|
||||||
img_prefix=data_root + 'train2017/',
|
|
||||||
pipeline=train_pipeline),
|
|
||||||
val=dict(
|
|
||||||
type=dataset_type,
|
|
||||||
ann_file=data_root + 'annotations/instances_val2017.json',
|
|
||||||
img_prefix=data_root + 'val2017/',
|
|
||||||
pipeline=test_pipeline),
|
|
||||||
test=dict(
|
|
||||||
type=dataset_type,
|
|
||||||
ann_file=data_root + 'annotations/instances_val2017.json',
|
|
||||||
img_prefix=data_root + 'val2017/',
|
|
||||||
pipeline=test_pipeline))
|
|
||||||
evaluation = dict(interval=1, metric='bbox')
|
|
|
@ -0,0 +1,116 @@
|
||||||
|
ALL_CLASSES = {
|
||||||
|
1: ('aeroplane', 'bicycle', 'boat', 'bottle', 'car', 'cat', 'chair',
|
||||||
|
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'sheep',
|
||||||
|
'train', 'tvmonitor', 'bird', 'bus', 'cow', 'motorbike', 'sofa'),
|
||||||
|
2: ('bicycle', 'bird', 'boat', 'bus', 'car', 'cat', 'chair', 'diningtable',
|
||||||
|
'dog', 'motorbike', 'person', 'pottedplant', 'sheep', 'train',
|
||||||
|
'tvmonitor', 'aeroplane', 'bottle', 'cow', 'horse', 'sofa'),
|
||||||
|
3: ('aeroplane', 'bicycle', 'bird', 'bottle', 'bus', 'car', 'chair', 'cow',
|
||||||
|
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'train',
|
||||||
|
'tvmonitor', 'boat', 'cat', 'motorbike', 'sheep', 'sofa')
|
||||||
|
}
|
||||||
|
|
||||||
|
NOVEL_CLASSES = {
|
||||||
|
1: ('bird', 'bus', 'cow', 'motorbike', 'sofa'),
|
||||||
|
2: ('aeroplane', 'bottle', 'cow', 'horse', 'sofa'),
|
||||||
|
3: ('boat', 'cat', 'motorbike', 'sheep', 'sofa'),
|
||||||
|
}
|
||||||
|
|
||||||
|
BASE_CLASSES = {
|
||||||
|
1: ('aeroplane', 'bicycle', 'boat', 'bottle', 'car', 'cat', 'chair',
|
||||||
|
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'sheep',
|
||||||
|
'train', 'tvmonitor'),
|
||||||
|
2: ('bicycle', 'bird', 'boat', 'bus', 'car', 'cat', 'chair', 'diningtable',
|
||||||
|
'dog', 'motorbike', 'person', 'pottedplant', 'sheep', 'train',
|
||||||
|
'tvmonitor'),
|
||||||
|
3: ('aeroplane', 'bicycle', 'bird', 'bottle', 'bus', 'car', 'chair', 'cow',
|
||||||
|
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'train',
|
||||||
|
'tvmonitor')
|
||||||
|
}
|
||||||
|
|
||||||
|
# dataset settings
|
||||||
|
data_root = 'data/VOCdevkit/'
|
||||||
|
|
||||||
|
split = 1
|
||||||
|
all_classes = ALL_CLASSES[split]
|
||||||
|
base_classes = BASE_CLASSES[split]
|
||||||
|
novel_classes = NOVEL_CLASSES[split]
|
||||||
|
num_base_shot = 1
|
||||||
|
num_novel_shot = 1
|
||||||
|
# load few shot data :
|
||||||
|
# each ann file corresponding to one class
|
||||||
|
# all file should use same image prefix
|
||||||
|
ann_file_root = 'data/few_shot_voc_split/'
|
||||||
|
ann_file_per_class = [] # file path
|
||||||
|
img_prefix_per_class = [] # image prefix
|
||||||
|
ann_shot_filter_per_class = [] # ann filter for each ann file
|
||||||
|
|
||||||
|
for class_name in base_classes:
|
||||||
|
ann_file_per_class.append(
|
||||||
|
ann_file_root +
|
||||||
|
f'{num_base_shot}shot/box_{num_base_shot}shot_{class_name}_train.txt')
|
||||||
|
img_prefix_per_class.append(data_root)
|
||||||
|
ann_shot_filter_per_class.append({class_name: num_base_shot})
|
||||||
|
|
||||||
|
for class_name in novel_classes:
|
||||||
|
ann_file_per_class.append(
|
||||||
|
ann_file_root +
|
||||||
|
f'{num_novel_shot}shot/box_{num_novel_shot}shot_{class_name}_train.txt'
|
||||||
|
)
|
||||||
|
img_prefix_per_class.append(data_root)
|
||||||
|
ann_shot_filter_per_class.append({class_name: num_novel_shot})
|
||||||
|
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations', with_bbox=True),
|
||||||
|
dict(type='Resize', img_scale=[(1333, 480), (1333, 800)], keep_ratio=True),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.5),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='DefaultFormatBundle'),
|
||||||
|
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
|
||||||
|
]
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='MultiScaleFlipAug',
|
||||||
|
img_scale=(1333, 800),
|
||||||
|
flip=False,
|
||||||
|
transforms=[
|
||||||
|
dict(type='Resize', keep_ratio=True),
|
||||||
|
dict(type='RandomFlip'),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(type='Collect', keys=['img']),
|
||||||
|
])
|
||||||
|
]
|
||||||
|
data = dict(
|
||||||
|
samples_per_gpu=2,
|
||||||
|
workers_per_gpu=2,
|
||||||
|
train=dict(
|
||||||
|
type='RepeatDataset',
|
||||||
|
times=3,
|
||||||
|
dataset=dict(
|
||||||
|
type='FewShotVOCDataset',
|
||||||
|
ann_file=ann_file_per_class,
|
||||||
|
img_prefix=img_prefix_per_class,
|
||||||
|
ann_masks=ann_shot_filter_per_class,
|
||||||
|
pipeline=train_pipeline,
|
||||||
|
classes=all_classes,
|
||||||
|
merge_dataset=True)),
|
||||||
|
val=dict(
|
||||||
|
type='VOCDataset',
|
||||||
|
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
|
||||||
|
img_prefix=data_root + 'VOC2007/',
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
classes=novel_classes),
|
||||||
|
test=dict(
|
||||||
|
type='VOCDataset',
|
||||||
|
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
|
||||||
|
img_prefix=data_root + 'VOC2007/',
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
classes=novel_classes))
|
||||||
|
evaluation = dict(interval=1, metric='mAP')
|
|
@ -0,0 +1,92 @@
|
||||||
|
ALL_CLASSES = {
|
||||||
|
1: ('aeroplane', 'bicycle', 'boat', 'bottle', 'car', 'cat', 'chair',
|
||||||
|
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'sheep',
|
||||||
|
'train', 'tvmonitor', 'bird', 'bus', 'cow', 'motorbike', 'sofa'),
|
||||||
|
2: ('bicycle', 'bird', 'boat', 'bus', 'car', 'cat', 'chair', 'diningtable',
|
||||||
|
'dog', 'motorbike', 'person', 'pottedplant', 'sheep', 'train',
|
||||||
|
'tvmonitor', 'aeroplane', 'bottle', 'cow', 'horse', 'sofa'),
|
||||||
|
3: ('aeroplane', 'bicycle', 'bird', 'bottle', 'bus', 'car', 'chair', 'cow',
|
||||||
|
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'train',
|
||||||
|
'tvmonitor', 'boat', 'cat', 'motorbike', 'sheep', 'sofa')
|
||||||
|
}
|
||||||
|
|
||||||
|
NOVEL_CLASSES = {
|
||||||
|
1: ('bird', 'bus', 'cow', 'motorbike', 'sofa'),
|
||||||
|
2: ('aeroplane', 'bottle', 'cow', 'horse', 'sofa'),
|
||||||
|
3: ('boat', 'cat', 'motorbike', 'sheep', 'sofa'),
|
||||||
|
}
|
||||||
|
|
||||||
|
BASE_CLASSES = {
|
||||||
|
1: ('aeroplane', 'bicycle', 'boat', 'bottle', 'car', 'cat', 'chair',
|
||||||
|
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'sheep',
|
||||||
|
'train', 'tvmonitor'),
|
||||||
|
2: ('bicycle', 'bird', 'boat', 'bus', 'car', 'cat', 'chair', 'diningtable',
|
||||||
|
'dog', 'motorbike', 'person', 'pottedplant', 'sheep', 'train',
|
||||||
|
'tvmonitor'),
|
||||||
|
3: ('aeroplane', 'bicycle', 'bird', 'bottle', 'bus', 'car', 'chair', 'cow',
|
||||||
|
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'train',
|
||||||
|
'tvmonitor')
|
||||||
|
}
|
||||||
|
|
||||||
|
# dataset settings
|
||||||
|
data_root = 'data/VOCdevkit/'
|
||||||
|
|
||||||
|
# few shot setting
|
||||||
|
split = 1
|
||||||
|
base_classes = BASE_CLASSES[split]
|
||||||
|
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations', with_bbox=True),
|
||||||
|
dict(type='Resize', img_scale=[(1333, 480), (1333, 800)], keep_ratio=True),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.5),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='DefaultFormatBundle'),
|
||||||
|
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
|
||||||
|
]
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='MultiScaleFlipAug',
|
||||||
|
img_scale=(1333, 800),
|
||||||
|
flip=False,
|
||||||
|
transforms=[
|
||||||
|
dict(type='Resize', keep_ratio=True),
|
||||||
|
dict(type='RandomFlip'),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(type='Collect', keys=['img']),
|
||||||
|
])
|
||||||
|
]
|
||||||
|
data = dict(
|
||||||
|
samples_per_gpu=2,
|
||||||
|
workers_per_gpu=2,
|
||||||
|
train=dict(
|
||||||
|
type='RepeatDataset',
|
||||||
|
times=3,
|
||||||
|
dataset=dict(
|
||||||
|
type='VOCDataset',
|
||||||
|
ann_file=[
|
||||||
|
data_root + 'VOC2007/ImageSets/Main/trainval.txt',
|
||||||
|
data_root + 'VOC2012/ImageSets/Main/trainval.txt'
|
||||||
|
],
|
||||||
|
img_prefix=[data_root + 'VOC2007/', data_root + 'VOC2012/'],
|
||||||
|
pipeline=train_pipeline,
|
||||||
|
classes=base_classes)),
|
||||||
|
val=dict(
|
||||||
|
type='VOCDataset',
|
||||||
|
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
|
||||||
|
img_prefix=data_root + 'VOC2007/',
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
classes=base_classes),
|
||||||
|
test=dict(
|
||||||
|
type='VOCDataset',
|
||||||
|
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
|
||||||
|
img_prefix=data_root + 'VOC2007/',
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
classes=base_classes))
|
||||||
|
evaluation = dict(interval=1, metric='mAP')
|
|
@ -0,0 +1,128 @@
|
||||||
|
ALL_CLASSES = {
|
||||||
|
1: ('aeroplane', 'bicycle', 'boat', 'bottle', 'car', 'cat', 'chair',
|
||||||
|
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'sheep',
|
||||||
|
'train', 'tvmonitor', 'bird', 'bus', 'cow', 'motorbike', 'sofa'),
|
||||||
|
2: ('bicycle', 'bird', 'boat', 'bus', 'car', 'cat', 'chair', 'diningtable',
|
||||||
|
'dog', 'motorbike', 'person', 'pottedplant', 'sheep', 'train',
|
||||||
|
'tvmonitor', 'aeroplane', 'bottle', 'cow', 'horse', 'sofa'),
|
||||||
|
3: ('aeroplane', 'bicycle', 'bird', 'bottle', 'bus', 'car', 'chair', 'cow',
|
||||||
|
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'train',
|
||||||
|
'tvmonitor', 'boat', 'cat', 'motorbike', 'sheep', 'sofa')
|
||||||
|
}
|
||||||
|
|
||||||
|
NOVEL_CLASSES = {
|
||||||
|
1: ('bird', 'bus', 'cow', 'motorbike', 'sofa'),
|
||||||
|
2: ('aeroplane', 'bottle', 'cow', 'horse', 'sofa'),
|
||||||
|
3: ('boat', 'cat', 'motorbike', 'sheep', 'sofa'),
|
||||||
|
}
|
||||||
|
|
||||||
|
BASE_CLASSES = {
|
||||||
|
1: ('aeroplane', 'bicycle', 'boat', 'bottle', 'car', 'cat', 'chair',
|
||||||
|
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'sheep',
|
||||||
|
'train', 'tvmonitor'),
|
||||||
|
2: ('bicycle', 'bird', 'boat', 'bus', 'car', 'cat', 'chair', 'diningtable',
|
||||||
|
'dog', 'motorbike', 'person', 'pottedplant', 'sheep', 'train',
|
||||||
|
'tvmonitor'),
|
||||||
|
3: ('aeroplane', 'bicycle', 'bird', 'bottle', 'bus', 'car', 'chair', 'cow',
|
||||||
|
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'train',
|
||||||
|
'tvmonitor')
|
||||||
|
}
|
||||||
|
|
||||||
|
# dataset settings
|
||||||
|
data_root = 'data/VOCdevkit/'
|
||||||
|
|
||||||
|
split = 1
|
||||||
|
all_classes = ALL_CLASSES[split]
|
||||||
|
base_classes = BASE_CLASSES[split]
|
||||||
|
novel_classes = NOVEL_CLASSES[split]
|
||||||
|
num_base_shot = 1
|
||||||
|
num_novel_shot = 1
|
||||||
|
# load few shot data :
|
||||||
|
# each ann file corresponding to one class
|
||||||
|
# all file should use same image prefix
|
||||||
|
ann_file_root = 'data/few_shot_voc_split/'
|
||||||
|
ann_file_per_class = [] # file path
|
||||||
|
img_prefix_per_class = [] # image prefix
|
||||||
|
ann_shot_filter_per_class = [] # ann filter for each ann file
|
||||||
|
|
||||||
|
for class_name in base_classes:
|
||||||
|
ann_file_per_class.append(
|
||||||
|
ann_file_root +
|
||||||
|
f'{num_base_shot}shot/box_{num_base_shot}shot_{class_name}_train.txt')
|
||||||
|
img_prefix_per_class.append(data_root)
|
||||||
|
ann_shot_filter_per_class.append({class_name: num_base_shot})
|
||||||
|
|
||||||
|
for class_name in novel_classes:
|
||||||
|
ann_file_per_class.append(
|
||||||
|
ann_file_root +
|
||||||
|
f'{num_novel_shot}shot/box_{num_novel_shot}shot_{class_name}_train.txt'
|
||||||
|
)
|
||||||
|
img_prefix_per_class.append(data_root)
|
||||||
|
ann_shot_filter_per_class.append({class_name: num_novel_shot})
|
||||||
|
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||||
|
train_pipeline = dict(
|
||||||
|
query=[
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations', with_bbox=True),
|
||||||
|
dict(type='Resize', img_scale=(1000, 600), keep_ratio=True),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.5),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='DefaultFormatBundle'),
|
||||||
|
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
|
||||||
|
],
|
||||||
|
support=[
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations', with_bbox=True),
|
||||||
|
dict(type='Resize', img_scale=(1000, 600), keep_ratio=True),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.5),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='DefaultFormatBundle'),
|
||||||
|
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
|
||||||
|
])
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='MultiScaleFlipAug',
|
||||||
|
img_scale=(1000, 600),
|
||||||
|
flip=False,
|
||||||
|
transforms=[
|
||||||
|
dict(type='Resize', keep_ratio=True),
|
||||||
|
dict(type='RandomFlip'),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(type='Collect', keys=['img']),
|
||||||
|
])
|
||||||
|
]
|
||||||
|
data = dict(
|
||||||
|
samples_per_gpu=2,
|
||||||
|
workers_per_gpu=2,
|
||||||
|
train=dict(
|
||||||
|
type='NwayKshotDataset',
|
||||||
|
support_way=20,
|
||||||
|
support_shot=1,
|
||||||
|
dataset=dict(
|
||||||
|
type='FewShotVOCDataset',
|
||||||
|
ann_file=ann_file_per_class,
|
||||||
|
img_prefix=img_prefix_per_class,
|
||||||
|
ann_masks=ann_shot_filter_per_class,
|
||||||
|
pipeline=train_pipeline,
|
||||||
|
classes=all_classes,
|
||||||
|
merge_dataset=True)),
|
||||||
|
val=dict(
|
||||||
|
type='VOCDataset',
|
||||||
|
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
|
||||||
|
img_prefix=data_root + 'VOC2007/',
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
classes=novel_classes),
|
||||||
|
test=dict(
|
||||||
|
type='VOCDataset',
|
||||||
|
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
|
||||||
|
img_prefix=data_root + 'VOC2007/',
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
classes=novel_classes))
|
||||||
|
evaluation = dict(interval=10, metric='mAP')
|
|
@ -0,0 +1,106 @@
|
||||||
|
ALL_CLASSES = {
|
||||||
|
1: ('aeroplane', 'bicycle', 'boat', 'bottle', 'car', 'cat', 'chair',
|
||||||
|
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'sheep',
|
||||||
|
'train', 'tvmonitor', 'bird', 'bus', 'cow', 'motorbike', 'sofa'),
|
||||||
|
2: ('bicycle', 'bird', 'boat', 'bus', 'car', 'cat', 'chair', 'diningtable',
|
||||||
|
'dog', 'motorbike', 'person', 'pottedplant', 'sheep', 'train',
|
||||||
|
'tvmonitor', 'aeroplane', 'bottle', 'cow', 'horse', 'sofa'),
|
||||||
|
3: ('aeroplane', 'bicycle', 'bird', 'bottle', 'bus', 'car', 'chair', 'cow',
|
||||||
|
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'train',
|
||||||
|
'tvmonitor', 'boat', 'cat', 'motorbike', 'sheep', 'sofa')
|
||||||
|
}
|
||||||
|
|
||||||
|
NOVEL_CLASSES = {
|
||||||
|
1: ('bird', 'bus', 'cow', 'motorbike', 'sofa'),
|
||||||
|
2: ('aeroplane', 'bottle', 'cow', 'horse', 'sofa'),
|
||||||
|
3: ('boat', 'cat', 'motorbike', 'sheep', 'sofa'),
|
||||||
|
}
|
||||||
|
|
||||||
|
BASE_CLASSES = {
|
||||||
|
1: ('aeroplane', 'bicycle', 'boat', 'bottle', 'car', 'cat', 'chair',
|
||||||
|
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'sheep',
|
||||||
|
'train', 'tvmonitor'),
|
||||||
|
2: ('bicycle', 'bird', 'boat', 'bus', 'car', 'cat', 'chair', 'diningtable',
|
||||||
|
'dog', 'motorbike', 'person', 'pottedplant', 'sheep', 'train',
|
||||||
|
'tvmonitor'),
|
||||||
|
3: ('aeroplane', 'bicycle', 'bird', 'bottle', 'bus', 'car', 'chair', 'cow',
|
||||||
|
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'train',
|
||||||
|
'tvmonitor')
|
||||||
|
}
|
||||||
|
|
||||||
|
# dataset settings
|
||||||
|
data_root = 'data/VOCdevkit/'
|
||||||
|
|
||||||
|
# few shot setting
|
||||||
|
split = 1
|
||||||
|
base_classes = BASE_CLASSES[split]
|
||||||
|
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||||
|
train_pipeline = dict(
|
||||||
|
query=[
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations', with_bbox=True),
|
||||||
|
dict(type='Resize', img_scale=(1000, 600), keep_ratio=True),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.5),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='DefaultFormatBundle'),
|
||||||
|
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
|
||||||
|
],
|
||||||
|
support=[
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations', with_bbox=True),
|
||||||
|
dict(type='Resize', img_scale=(1000, 600), keep_ratio=True),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.5),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='DefaultFormatBundle'),
|
||||||
|
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
|
||||||
|
])
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='MultiScaleFlipAug',
|
||||||
|
img_scale=(1000, 600),
|
||||||
|
flip=False,
|
||||||
|
transforms=[
|
||||||
|
dict(type='Resize', keep_ratio=True),
|
||||||
|
dict(type='RandomFlip'),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(type='Collect', keys=['img']),
|
||||||
|
])
|
||||||
|
]
|
||||||
|
data = dict(
|
||||||
|
samples_per_gpu=2,
|
||||||
|
workers_per_gpu=2,
|
||||||
|
train=dict(
|
||||||
|
type='NwayKshotDataset',
|
||||||
|
support_way=15,
|
||||||
|
support_shot=1,
|
||||||
|
dataset=dict(
|
||||||
|
type='FewShotVOCDataset',
|
||||||
|
ann_file=[
|
||||||
|
data_root + 'VOC2007/ImageSets/Main/trainval.txt',
|
||||||
|
data_root + 'VOC2012/ImageSets/Main/trainval.txt'
|
||||||
|
],
|
||||||
|
img_prefix=[data_root, data_root],
|
||||||
|
pipeline=train_pipeline,
|
||||||
|
classes=base_classes,
|
||||||
|
merge_dataset=True,
|
||||||
|
)),
|
||||||
|
val=dict(
|
||||||
|
type='VOCDataset',
|
||||||
|
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
|
||||||
|
img_prefix=data_root + 'VOC2007/',
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
classes=base_classes),
|
||||||
|
test=dict(
|
||||||
|
type='VOCDataset',
|
||||||
|
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
|
||||||
|
img_prefix=data_root + 'VOC2007/',
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
classes=base_classes))
|
||||||
|
evaluation = dict(interval=1, metric='mAP')
|
|
@ -0,0 +1,128 @@
|
||||||
|
ALL_CLASSES = {
|
||||||
|
1: ('aeroplane', 'bicycle', 'boat', 'bottle', 'car', 'cat', 'chair',
|
||||||
|
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'sheep',
|
||||||
|
'train', 'tvmonitor', 'bird', 'bus', 'cow', 'motorbike', 'sofa'),
|
||||||
|
2: ('bicycle', 'bird', 'boat', 'bus', 'car', 'cat', 'chair', 'diningtable',
|
||||||
|
'dog', 'motorbike', 'person', 'pottedplant', 'sheep', 'train',
|
||||||
|
'tvmonitor', 'aeroplane', 'bottle', 'cow', 'horse', 'sofa'),
|
||||||
|
3: ('aeroplane', 'bicycle', 'bird', 'bottle', 'bus', 'car', 'chair', 'cow',
|
||||||
|
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'train',
|
||||||
|
'tvmonitor', 'boat', 'cat', 'motorbike', 'sheep', 'sofa')
|
||||||
|
}
|
||||||
|
|
||||||
|
NOVEL_CLASSES = {
|
||||||
|
1: ('bird', 'bus', 'cow', 'motorbike', 'sofa'),
|
||||||
|
2: ('aeroplane', 'bottle', 'cow', 'horse', 'sofa'),
|
||||||
|
3: ('boat', 'cat', 'motorbike', 'sheep', 'sofa'),
|
||||||
|
}
|
||||||
|
|
||||||
|
BASE_CLASSES = {
|
||||||
|
1: ('aeroplane', 'bicycle', 'boat', 'bottle', 'car', 'cat', 'chair',
|
||||||
|
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'sheep',
|
||||||
|
'train', 'tvmonitor'),
|
||||||
|
2: ('bicycle', 'bird', 'boat', 'bus', 'car', 'cat', 'chair', 'diningtable',
|
||||||
|
'dog', 'motorbike', 'person', 'pottedplant', 'sheep', 'train',
|
||||||
|
'tvmonitor'),
|
||||||
|
3: ('aeroplane', 'bicycle', 'bird', 'bottle', 'bus', 'car', 'chair', 'cow',
|
||||||
|
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'train',
|
||||||
|
'tvmonitor')
|
||||||
|
}
|
||||||
|
|
||||||
|
# dataset settings
|
||||||
|
data_root = 'data/VOCdevkit/'
|
||||||
|
|
||||||
|
split = 1
|
||||||
|
all_classes = ALL_CLASSES[split]
|
||||||
|
base_classes = BASE_CLASSES[split]
|
||||||
|
novel_classes = NOVEL_CLASSES[split]
|
||||||
|
num_base_shot = 1
|
||||||
|
num_novel_shot = 1
|
||||||
|
# load few shot data :
|
||||||
|
# each ann file corresponding to one class
|
||||||
|
# all file should use same image prefix
|
||||||
|
ann_file_root = 'data/few_shot_voc_split/'
|
||||||
|
ann_file_per_class = [] # file path
|
||||||
|
img_prefix_per_class = [] # image prefix
|
||||||
|
ann_shot_filter_per_class = [] # ann filter for each ann file
|
||||||
|
|
||||||
|
for class_name in base_classes:
|
||||||
|
ann_file_per_class.append(
|
||||||
|
ann_file_root +
|
||||||
|
f'{num_base_shot}shot/box_{num_base_shot}shot_{class_name}_train.txt')
|
||||||
|
img_prefix_per_class.append(data_root)
|
||||||
|
ann_shot_filter_per_class.append({class_name: num_base_shot})
|
||||||
|
|
||||||
|
for class_name in novel_classes:
|
||||||
|
ann_file_per_class.append(
|
||||||
|
ann_file_root +
|
||||||
|
f'{num_novel_shot}shot/box_{num_novel_shot}shot_{class_name}_train.txt'
|
||||||
|
)
|
||||||
|
img_prefix_per_class.append(data_root)
|
||||||
|
ann_shot_filter_per_class.append({class_name: num_novel_shot})
|
||||||
|
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||||
|
train_pipeline = dict(
|
||||||
|
query=[
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations', with_bbox=True),
|
||||||
|
dict(type='Resize', img_scale=(1000, 600), keep_ratio=True),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.5),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='DefaultFormatBundle'),
|
||||||
|
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
|
||||||
|
],
|
||||||
|
support=[
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations', with_bbox=True),
|
||||||
|
dict(type='Resize', img_scale=(1000, 600), keep_ratio=True),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.5),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='DefaultFormatBundle'),
|
||||||
|
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
|
||||||
|
])
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='MultiScaleFlipAug',
|
||||||
|
img_scale=(1000, 600),
|
||||||
|
flip=False,
|
||||||
|
transforms=[
|
||||||
|
dict(type='Resize', keep_ratio=True),
|
||||||
|
dict(type='RandomFlip'),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(type='Collect', keys=['img']),
|
||||||
|
])
|
||||||
|
]
|
||||||
|
data = dict(
|
||||||
|
samples_per_gpu=2,
|
||||||
|
workers_per_gpu=2,
|
||||||
|
train=dict(
|
||||||
|
type='QueryAwareDataset',
|
||||||
|
support_way=2,
|
||||||
|
support_shot=1,
|
||||||
|
dataset=dict(
|
||||||
|
type='FewShotVOCDataset',
|
||||||
|
ann_file=ann_file_per_class,
|
||||||
|
img_prefix=img_prefix_per_class,
|
||||||
|
ann_masks=ann_shot_filter_per_class,
|
||||||
|
pipeline=train_pipeline,
|
||||||
|
classes=all_classes,
|
||||||
|
merge_dataset=True)),
|
||||||
|
val=dict(
|
||||||
|
type='VOCDataset',
|
||||||
|
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
|
||||||
|
img_prefix=data_root + 'VOC2007/',
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
classes=novel_classes),
|
||||||
|
test=dict(
|
||||||
|
type='VOCDataset',
|
||||||
|
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
|
||||||
|
img_prefix=data_root + 'VOC2007/',
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
classes=novel_classes))
|
||||||
|
evaluation = dict(interval=1, metric='mAP')
|
|
@ -0,0 +1,164 @@
|
||||||
|
VOC_FEW_SHOT_SPLIT_ALL_CLASSES = {
|
||||||
|
1: ('aeroplane', 'bicycle', 'boat', 'bottle', 'car', 'cat', 'chair',
|
||||||
|
'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'sheep',
|
||||||
|
'train', 'tvmonitor', 'bird', 'bus', 'cow', 'motorbike', 'sofa'),
|
||||||
|
2: ('bicycle', 'bird', 'boat', 'bus', 'car', 'cat', 'chair', 'diningtable',
|
||||||
|
'dog', 'motorbike', 'person', 'pottedplant', 'sheep', 'train',
|
||||||
|
'tvmonitor', 'aeroplane', 'bottle', 'cow', 'horse', 'sofa'),
|
||||||
|
3: (
|
||||||
|
'aeroplane',
|
||||||
|
'bicycle',
|
||||||
|
'bird',
|
||||||
|
'bottle',
|
||||||
|
'bus',
|
||||||
|
'car',
|
||||||
|
'chair',
|
||||||
|
'cow',
|
||||||
|
'diningtable',
|
||||||
|
'dog',
|
||||||
|
'horse',
|
||||||
|
'person',
|
||||||
|
'pottedplant',
|
||||||
|
'train',
|
||||||
|
'tvmonitor',
|
||||||
|
'boat',
|
||||||
|
'cat',
|
||||||
|
'motorbike',
|
||||||
|
'sheep',
|
||||||
|
'sofa',
|
||||||
|
),
|
||||||
|
}
|
||||||
|
VOC_FEW_SHOT_SPLIT_NOVEL_CLASSES = {
|
||||||
|
1: ('bird', 'bus', 'cow', 'motorbike', 'sofa'),
|
||||||
|
2: ('aeroplane', 'bottle', 'cow', 'horse', 'sofa'),
|
||||||
|
3: ('boat', 'cat', 'motorbike', 'sheep', 'sofa'),
|
||||||
|
}
|
||||||
|
VOC_FEW_SHOT_SPLIT_BASE_CLASSES = {
|
||||||
|
1: (
|
||||||
|
'aeroplane',
|
||||||
|
'bicycle',
|
||||||
|
'boat',
|
||||||
|
'bottle',
|
||||||
|
'car',
|
||||||
|
'cat',
|
||||||
|
'chair',
|
||||||
|
'diningtable',
|
||||||
|
'dog',
|
||||||
|
'horse',
|
||||||
|
'person',
|
||||||
|
'pottedplant',
|
||||||
|
'sheep',
|
||||||
|
'train',
|
||||||
|
'tvmonitor',
|
||||||
|
),
|
||||||
|
2: (
|
||||||
|
'bicycle',
|
||||||
|
'bird',
|
||||||
|
'boat',
|
||||||
|
'bus',
|
||||||
|
'car',
|
||||||
|
'cat',
|
||||||
|
'chair',
|
||||||
|
'diningtable',
|
||||||
|
'dog',
|
||||||
|
'motorbike',
|
||||||
|
'person',
|
||||||
|
'pottedplant',
|
||||||
|
'sheep',
|
||||||
|
'train',
|
||||||
|
'tvmonitor',
|
||||||
|
),
|
||||||
|
3: (
|
||||||
|
'aeroplane',
|
||||||
|
'bicycle',
|
||||||
|
'bird',
|
||||||
|
'bottle',
|
||||||
|
'bus',
|
||||||
|
'car',
|
||||||
|
'chair',
|
||||||
|
'cow',
|
||||||
|
'diningtable',
|
||||||
|
'dog',
|
||||||
|
'horse',
|
||||||
|
'person',
|
||||||
|
'pottedplant',
|
||||||
|
'train',
|
||||||
|
'tvmonitor',
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# dataset settings
|
||||||
|
data_root = 'data/VOCdevkit/'
|
||||||
|
|
||||||
|
# few shot setting
|
||||||
|
split = 1
|
||||||
|
base_classes = VOC_FEW_SHOT_SPLIT_BASE_CLASSES[split]
|
||||||
|
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||||
|
train_pipeline = dict(
|
||||||
|
query=[
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations', with_bbox=True),
|
||||||
|
dict(type='Resize', img_scale=(1000, 600), keep_ratio=True),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.5),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='DefaultFormatBundle'),
|
||||||
|
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
|
||||||
|
],
|
||||||
|
support=[
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations', with_bbox=True),
|
||||||
|
dict(type='Resize', img_scale=(1000, 600), keep_ratio=True),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.5),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='DefaultFormatBundle'),
|
||||||
|
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
|
||||||
|
])
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='MultiScaleFlipAug',
|
||||||
|
img_scale=(1000, 600),
|
||||||
|
flip=False,
|
||||||
|
transforms=[
|
||||||
|
dict(type='Resize', keep_ratio=True),
|
||||||
|
dict(type='RandomFlip'),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(type='Collect', keys=['img']),
|
||||||
|
])
|
||||||
|
]
|
||||||
|
data = dict(
|
||||||
|
samples_per_gpu=2,
|
||||||
|
workers_per_gpu=2,
|
||||||
|
train=dict(
|
||||||
|
type='QueryAwareDataset',
|
||||||
|
support_way=2,
|
||||||
|
support_shot=5,
|
||||||
|
dataset=dict(
|
||||||
|
type='FewShotVOCDataset',
|
||||||
|
ann_file=[
|
||||||
|
data_root + 'VOC2007/ImageSets/Main/trainval.txt',
|
||||||
|
data_root + 'VOC2012/ImageSets/Main/trainval.txt'
|
||||||
|
],
|
||||||
|
img_prefix=[data_root, data_root],
|
||||||
|
pipeline=train_pipeline,
|
||||||
|
classes=base_classes,
|
||||||
|
merge_dataset=True)),
|
||||||
|
val=dict(
|
||||||
|
type='VOCDataset',
|
||||||
|
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
|
||||||
|
img_prefix=data_root + 'VOC2007/',
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
classes=base_classes),
|
||||||
|
test=dict(
|
||||||
|
type='VOCDataset',
|
||||||
|
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
|
||||||
|
img_prefix=data_root + 'VOC2007/',
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
classes=base_classes))
|
||||||
|
evaluation = dict(interval=1, metric='mAP')
|
|
@ -1,55 +0,0 @@
|
||||||
# dataset settings
|
|
||||||
dataset_type = 'VOCDataset'
|
|
||||||
data_root = 'data/VOCdevkit/'
|
|
||||||
img_norm_cfg = dict(
|
|
||||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
|
||||||
train_pipeline = [
|
|
||||||
dict(type='LoadImageFromFile'),
|
|
||||||
dict(type='LoadAnnotations', with_bbox=True),
|
|
||||||
dict(type='Resize', img_scale=(1000, 600), keep_ratio=True),
|
|
||||||
dict(type='RandomFlip', flip_ratio=0.5),
|
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
|
||||||
dict(type='Pad', size_divisor=32),
|
|
||||||
dict(type='DefaultFormatBundle'),
|
|
||||||
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
|
|
||||||
]
|
|
||||||
test_pipeline = [
|
|
||||||
dict(type='LoadImageFromFile'),
|
|
||||||
dict(
|
|
||||||
type='MultiScaleFlipAug',
|
|
||||||
img_scale=(1000, 600),
|
|
||||||
flip=False,
|
|
||||||
transforms=[
|
|
||||||
dict(type='Resize', keep_ratio=True),
|
|
||||||
dict(type='RandomFlip'),
|
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
|
||||||
dict(type='Pad', size_divisor=32),
|
|
||||||
dict(type='ImageToTensor', keys=['img']),
|
|
||||||
dict(type='Collect', keys=['img']),
|
|
||||||
])
|
|
||||||
]
|
|
||||||
data = dict(
|
|
||||||
samples_per_gpu=2,
|
|
||||||
workers_per_gpu=2,
|
|
||||||
train=dict(
|
|
||||||
type='RepeatDataset',
|
|
||||||
times=3,
|
|
||||||
dataset=dict(
|
|
||||||
type=dataset_type,
|
|
||||||
ann_file=[
|
|
||||||
data_root + 'VOC2007/ImageSets/Main/trainval.txt',
|
|
||||||
data_root + 'VOC2012/ImageSets/Main/trainval.txt'
|
|
||||||
],
|
|
||||||
img_prefix=[data_root + 'VOC2007/', data_root + 'VOC2012/'],
|
|
||||||
pipeline=train_pipeline)),
|
|
||||||
val=dict(
|
|
||||||
type=dataset_type,
|
|
||||||
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
|
|
||||||
img_prefix=data_root + 'VOC2007/',
|
|
||||||
pipeline=test_pipeline),
|
|
||||||
test=dict(
|
|
||||||
type=dataset_type,
|
|
||||||
ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
|
|
||||||
img_prefix=data_root + 'VOC2007/',
|
|
||||||
pipeline=test_pipeline))
|
|
||||||
evaluation = dict(interval=1, metric='mAP')
|
|
|
@ -3,7 +3,8 @@ import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from mmcls.apis.train import train_model as train_classifier
|
from mmcls.apis.train import train_model as train_classifier
|
||||||
from mmdet.apis.train import train_detector
|
|
||||||
|
from mmfewshot.detection.apis.train import train_detector
|
||||||
|
|
||||||
|
|
||||||
def set_random_seed(seed, deterministic=False):
|
def set_random_seed(seed, deterministic=False):
|
||||||
|
|
|
@ -1,7 +1,11 @@
|
||||||
|
# this file only for unittests
|
||||||
from mmcls.datasets.builder import build_dataloader as build_cls_dataloader
|
from mmcls.datasets.builder import build_dataloader as build_cls_dataloader
|
||||||
from mmcls.datasets.builder import build_dataset as build_cls_dataset
|
from mmcls.datasets.builder import build_dataset as build_cls_dataset
|
||||||
from mmdet.datasets.builder import build_dataloader as build_det_dataloader
|
|
||||||
from mmdet.datasets.builder import build_dataset as build_det_dataset
|
from mmfewshot.detection.datasets.builder import \
|
||||||
|
build_dataloader as build_det_dataloader
|
||||||
|
from mmfewshot.detection.datasets.builder import \
|
||||||
|
build_dataset as build_det_dataset
|
||||||
|
|
||||||
|
|
||||||
def build_dataloader(dataset=None, task_type='mmdet', round_up=True, **kwargs):
|
def build_dataloader(dataset=None, task_type='mmdet', round_up=True, **kwargs):
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
# this file only for unittests
|
||||||
from mmcls.models.builder import build_classifier as build_cls_model
|
from mmcls.models.builder import build_classifier as build_cls_model
|
||||||
from mmdet.models.builder import build_detector as build_det_model
|
from mmdet.models.builder import build_detector as build_det_model
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .train import get_root_logger, set_random_seed, train_detector
|
||||||
|
|
||||||
|
__all__ = ['get_root_logger', 'set_random_seed', 'train_detector']
|
|
@ -0,0 +1,170 @@
|
||||||
|
import random
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||||
|
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
|
||||||
|
Fp16OptimizerHook, OptimizerHook, build_optimizer,
|
||||||
|
build_runner)
|
||||||
|
from mmcv.utils import build_from_cfg
|
||||||
|
from mmdet.core import DistEvalHook, EvalHook
|
||||||
|
from mmdet.datasets import replace_ImageToTensor
|
||||||
|
from mmdet.utils import get_root_logger
|
||||||
|
|
||||||
|
from mmfewshot.detection.datasets import build_dataloader, build_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def set_random_seed(seed, deterministic=False):
|
||||||
|
"""Set random seed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed (int): Seed to be used.
|
||||||
|
deterministic (bool): Whether to set the deterministic option for
|
||||||
|
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
|
||||||
|
to True and `torch.backends.cudnn.benchmark` to False.
|
||||||
|
Default: False.
|
||||||
|
"""
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
if deterministic:
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
|
||||||
|
def train_detector(model,
|
||||||
|
dataset,
|
||||||
|
cfg,
|
||||||
|
distributed=False,
|
||||||
|
validate=False,
|
||||||
|
timestamp=None,
|
||||||
|
meta=None):
|
||||||
|
logger = get_root_logger(cfg.log_level)
|
||||||
|
|
||||||
|
# prepare data loaders
|
||||||
|
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
|
||||||
|
if 'imgs_per_gpu' in cfg.data:
|
||||||
|
logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. '
|
||||||
|
'Please use "samples_per_gpu" instead')
|
||||||
|
if 'samples_per_gpu' in cfg.data:
|
||||||
|
logger.warning(
|
||||||
|
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
|
||||||
|
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
|
||||||
|
f'={cfg.data.imgs_per_gpu} is used in this experiments')
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
|
||||||
|
f'{cfg.data.imgs_per_gpu} in this experiments')
|
||||||
|
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
|
||||||
|
|
||||||
|
data_loaders = [
|
||||||
|
build_dataloader(
|
||||||
|
ds,
|
||||||
|
cfg.data.samples_per_gpu,
|
||||||
|
cfg.data.workers_per_gpu,
|
||||||
|
# cfg.gpus will be ignored if distributed
|
||||||
|
len(cfg.gpu_ids),
|
||||||
|
dist=distributed,
|
||||||
|
seed=cfg.seed) for ds in dataset
|
||||||
|
]
|
||||||
|
|
||||||
|
# put model on gpus
|
||||||
|
if distributed:
|
||||||
|
find_unused_parameters = cfg.get('find_unused_parameters', False)
|
||||||
|
# Sets the `find_unused_parameters` parameter in
|
||||||
|
# torch.nn.parallel.DistributedDataParallel
|
||||||
|
model = MMDistributedDataParallel(
|
||||||
|
model.cuda(),
|
||||||
|
device_ids=[torch.cuda.current_device()],
|
||||||
|
broadcast_buffers=False,
|
||||||
|
find_unused_parameters=find_unused_parameters)
|
||||||
|
else:
|
||||||
|
model = MMDataParallel(
|
||||||
|
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
|
||||||
|
|
||||||
|
# build runner
|
||||||
|
optimizer = build_optimizer(model, cfg.optimizer)
|
||||||
|
|
||||||
|
if 'runner' not in cfg:
|
||||||
|
cfg.runner = {
|
||||||
|
'type': 'EpochBasedRunner',
|
||||||
|
'max_epochs': cfg.total_epochs
|
||||||
|
}
|
||||||
|
warnings.warn(
|
||||||
|
'config is now expected to have a `runner` section, '
|
||||||
|
'please set `runner` in your config.', UserWarning)
|
||||||
|
else:
|
||||||
|
if 'total_epochs' in cfg:
|
||||||
|
assert cfg.total_epochs == cfg.runner.max_epochs
|
||||||
|
|
||||||
|
runner = build_runner(
|
||||||
|
cfg.runner,
|
||||||
|
default_args=dict(
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
work_dir=cfg.work_dir,
|
||||||
|
logger=logger,
|
||||||
|
meta=meta))
|
||||||
|
|
||||||
|
# an ugly workaround to make .log and .log.json filenames the same
|
||||||
|
runner.timestamp = timestamp
|
||||||
|
|
||||||
|
# fp16 setting
|
||||||
|
fp16_cfg = cfg.get('fp16', None)
|
||||||
|
if fp16_cfg is not None:
|
||||||
|
optimizer_config = Fp16OptimizerHook(
|
||||||
|
**cfg.optimizer_config, **fp16_cfg, distributed=distributed)
|
||||||
|
elif distributed and 'type' not in cfg.optimizer_config:
|
||||||
|
optimizer_config = OptimizerHook(**cfg.optimizer_config)
|
||||||
|
else:
|
||||||
|
optimizer_config = cfg.optimizer_config
|
||||||
|
|
||||||
|
# register hooks
|
||||||
|
runner.register_training_hooks(cfg.lr_config, optimizer_config,
|
||||||
|
cfg.checkpoint_config, cfg.log_config,
|
||||||
|
cfg.get('momentum_config', None))
|
||||||
|
if distributed:
|
||||||
|
if isinstance(runner, EpochBasedRunner):
|
||||||
|
runner.register_hook(DistSamplerSeedHook())
|
||||||
|
|
||||||
|
# register eval hooks
|
||||||
|
if validate:
|
||||||
|
# Support batch_size > 1 in validation
|
||||||
|
val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1)
|
||||||
|
if val_samples_per_gpu > 1:
|
||||||
|
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
|
||||||
|
cfg.data.val.pipeline = replace_ImageToTensor(
|
||||||
|
cfg.data.val.pipeline)
|
||||||
|
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
|
||||||
|
val_dataloader = build_dataloader(
|
||||||
|
val_dataset,
|
||||||
|
samples_per_gpu=val_samples_per_gpu,
|
||||||
|
workers_per_gpu=cfg.data.workers_per_gpu,
|
||||||
|
dist=distributed,
|
||||||
|
shuffle=False)
|
||||||
|
eval_cfg = cfg.get('evaluation', {})
|
||||||
|
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
|
||||||
|
eval_hook = DistEvalHook if distributed else EvalHook
|
||||||
|
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
|
||||||
|
|
||||||
|
# user-defined hooks
|
||||||
|
if cfg.get('custom_hooks', None):
|
||||||
|
custom_hooks = cfg.custom_hooks
|
||||||
|
assert isinstance(custom_hooks, list), \
|
||||||
|
f'custom_hooks expect list type, but got {type(custom_hooks)}'
|
||||||
|
for hook_cfg in cfg.custom_hooks:
|
||||||
|
assert isinstance(hook_cfg, dict), \
|
||||||
|
'Each item in custom_hooks expects dict type, but got ' \
|
||||||
|
f'{type(hook_cfg)}'
|
||||||
|
hook_cfg = hook_cfg.copy()
|
||||||
|
priority = hook_cfg.pop('priority', 'NORMAL')
|
||||||
|
hook = build_from_cfg(hook_cfg, HOOKS)
|
||||||
|
runner.register_hook(hook, priority=priority)
|
||||||
|
|
||||||
|
if cfg.resume_from:
|
||||||
|
runner.resume(cfg.resume_from)
|
||||||
|
elif cfg.load_from:
|
||||||
|
runner.load_checkpoint(cfg.load_from)
|
||||||
|
runner.run(data_loaders, cfg.workflow)
|
|
@ -1,3 +1,18 @@
|
||||||
from .base_meta_learning_dataset import BaseMetaLearingDataset
|
from .builder import build_dataloader, build_dataset
|
||||||
|
from .dataloader_wrappers import NwayKshotDataloader
|
||||||
|
from .dataset_wrappers import MergeDataset, NwayKshotDataset, QueryAwareDataset
|
||||||
|
from .few_shot_custom import FewShotCustomDataset
|
||||||
|
from .utils import query_support_collate_fn
|
||||||
|
from .voc import FewShotVOCDataset
|
||||||
|
|
||||||
__all__ = ['BaseMetaLearingDataset']
|
__all__ = [
|
||||||
|
'build_dataloader',
|
||||||
|
'build_dataset',
|
||||||
|
'MergeDataset',
|
||||||
|
'QueryAwareDataset',
|
||||||
|
'NwayKshotDataset',
|
||||||
|
'NwayKshotDataloader',
|
||||||
|
'query_support_collate_fn',
|
||||||
|
'FewShotCustomDataset',
|
||||||
|
'FewShotVOCDataset',
|
||||||
|
]
|
||||||
|
|
|
@ -1,8 +0,0 @@
|
||||||
# jsut an example
|
|
||||||
from mmdet.datasets.builder import DATASETS
|
|
||||||
from mmdet.datasets.custom import CustomDataset
|
|
||||||
|
|
||||||
|
|
||||||
@DATASETS.register_module()
|
|
||||||
class BaseMetaLearingDataset(CustomDataset):
|
|
||||||
pass
|
|
|
@ -0,0 +1,226 @@
|
||||||
|
import copy
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from mmcv.parallel import collate
|
||||||
|
from mmcv.runner import get_dist_info
|
||||||
|
from mmcv.utils import build_from_cfg
|
||||||
|
from mmdet.datasets.builder import DATASETS, worker_init_fn
|
||||||
|
from mmdet.datasets.dataset_wrappers import (ClassBalancedDataset,
|
||||||
|
ConcatDataset, RepeatDataset)
|
||||||
|
from mmdet.datasets.samplers import (DistributedGroupSampler,
|
||||||
|
DistributedSampler, GroupSampler)
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from .dataset_wrappers import MergeDataset, NwayKshotDataset, QueryAwareDataset
|
||||||
|
|
||||||
|
|
||||||
|
def _concat_dataset(cfg, default_args=None):
|
||||||
|
ann_files = cfg['ann_file']
|
||||||
|
img_prefixes = cfg.get('img_prefix', None)
|
||||||
|
seg_prefixes = cfg.get('seg_prefix', None)
|
||||||
|
proposal_files = cfg.get('proposal_file', None)
|
||||||
|
separate_eval = cfg.get('separate_eval', True)
|
||||||
|
merge_dataset = cfg.get('merge_dataset', False)
|
||||||
|
ann_shot_filter = cfg.get('ann_shot_filter', None)
|
||||||
|
|
||||||
|
if ann_shot_filter is not None:
|
||||||
|
assert merge_dataset, 'using ann shot filter to load ann file ' \
|
||||||
|
'in FewShotDataset, merge_dataset should be set to True.'
|
||||||
|
|
||||||
|
datasets = []
|
||||||
|
num_dset = len(ann_files)
|
||||||
|
for i in range(num_dset):
|
||||||
|
data_cfg = copy.deepcopy(cfg)
|
||||||
|
# pop 'separate_eval' since it is not a valid key for common datasets.
|
||||||
|
if 'separate_eval' in data_cfg:
|
||||||
|
data_cfg.pop('separate_eval')
|
||||||
|
if 'merge_dataset' in data_cfg:
|
||||||
|
data_cfg.pop('merge_dataset')
|
||||||
|
data_cfg['ann_file'] = ann_files[i]
|
||||||
|
if isinstance(img_prefixes, (list, tuple)):
|
||||||
|
data_cfg['img_prefix'] = img_prefixes[i]
|
||||||
|
if isinstance(seg_prefixes, (list, tuple)):
|
||||||
|
data_cfg['seg_prefix'] = seg_prefixes[i]
|
||||||
|
if isinstance(proposal_files, (list, tuple)):
|
||||||
|
data_cfg['proposal_file'] = proposal_files[i]
|
||||||
|
if isinstance(ann_shot_filter, (list, tuple)):
|
||||||
|
data_cfg['ann_shot_filter'] = ann_shot_filter[i]
|
||||||
|
|
||||||
|
datasets.append(build_dataset(data_cfg, default_args))
|
||||||
|
if merge_dataset:
|
||||||
|
return MergeDataset(datasets)
|
||||||
|
else:
|
||||||
|
return ConcatDataset(datasets, separate_eval)
|
||||||
|
|
||||||
|
|
||||||
|
def build_dataset(cfg, default_args=None):
|
||||||
|
if isinstance(cfg, (list, tuple)):
|
||||||
|
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
|
||||||
|
elif cfg['type'] == 'ConcatDataset':
|
||||||
|
dataset = ConcatDataset(
|
||||||
|
[build_dataset(c, default_args) for c in cfg['datasets']],
|
||||||
|
cfg.get('separate_eval', True))
|
||||||
|
elif cfg['type'] == 'RepeatDataset':
|
||||||
|
dataset = RepeatDataset(
|
||||||
|
build_dataset(cfg['dataset'], default_args), cfg['times'])
|
||||||
|
elif cfg['type'] == 'ClassBalancedDataset':
|
||||||
|
dataset = ClassBalancedDataset(
|
||||||
|
build_dataset(cfg['dataset'], default_args), cfg['oversample_thr'])
|
||||||
|
elif cfg['type'] == 'QueryAwareDataset':
|
||||||
|
dataset = QueryAwareDataset(
|
||||||
|
build_dataset(cfg['dataset'], default_args), cfg['support_way'],
|
||||||
|
cfg['support_shot'])
|
||||||
|
elif cfg['type'] == 'NwayKshotDataset':
|
||||||
|
dataset = NwayKshotDataset(
|
||||||
|
build_dataset(cfg['dataset'], default_args), cfg['support_way'],
|
||||||
|
cfg['support_shot'])
|
||||||
|
elif isinstance(cfg.get('ann_file'), (list, tuple)):
|
||||||
|
dataset = _concat_dataset(cfg, default_args)
|
||||||
|
else:
|
||||||
|
dataset = build_from_cfg(cfg, DATASETS, default_args)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def build_dataloader(dataset,
|
||||||
|
samples_per_gpu,
|
||||||
|
workers_per_gpu,
|
||||||
|
num_gpus=1,
|
||||||
|
dist=True,
|
||||||
|
shuffle=True,
|
||||||
|
seed=None,
|
||||||
|
**kwargs):
|
||||||
|
"""Build PyTorch DataLoader.
|
||||||
|
|
||||||
|
In distributed training, each GPU/process has a dataloader.
|
||||||
|
In non-distributed training, there is only one dataloader for all GPUs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset (Dataset): A PyTorch dataset.
|
||||||
|
samples_per_gpu (int): Number of training samples on each GPU, i.e.,
|
||||||
|
batch size of each GPU.
|
||||||
|
workers_per_gpu (int): How many subprocesses to use for data loading
|
||||||
|
for each GPU.
|
||||||
|
num_gpus (int): Number of GPUs. Only used in non-distributed training.
|
||||||
|
dist (bool): Distributed training/test or not. Default: True.
|
||||||
|
shuffle (bool): Whether to shuffle the data at every epoch.
|
||||||
|
Default: True.
|
||||||
|
seed (int): Random seed.
|
||||||
|
kwargs: any keyword argument to be used to initialize DataLoader
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataLoader: A PyTorch dataloader.
|
||||||
|
"""
|
||||||
|
rank, world_size = get_dist_info()
|
||||||
|
(sampler, batch_size, num_workers) \
|
||||||
|
= build_sampler(dist=dist,
|
||||||
|
shuffle=shuffle,
|
||||||
|
dataset=dataset,
|
||||||
|
num_gpus=num_gpus,
|
||||||
|
samples_per_gpu=samples_per_gpu,
|
||||||
|
workers_per_gpu=workers_per_gpu,
|
||||||
|
seed=seed, )
|
||||||
|
init_fn = partial(
|
||||||
|
worker_init_fn, num_workers=num_workers, rank=rank,
|
||||||
|
seed=seed) if seed is not None else None
|
||||||
|
if isinstance(dataset, QueryAwareDataset):
|
||||||
|
from .utils import query_support_collate_fn
|
||||||
|
data_loader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
sampler=sampler,
|
||||||
|
num_workers=num_workers,
|
||||||
|
collate_fn=partial(
|
||||||
|
query_support_collate_fn, samples_per_gpu=samples_per_gpu),
|
||||||
|
pin_memory=False,
|
||||||
|
worker_init_fn=init_fn,
|
||||||
|
**kwargs)
|
||||||
|
elif isinstance(dataset, NwayKshotDataset):
|
||||||
|
from .dataloader_wrappers import NwayKshotDataloader
|
||||||
|
from .utils import query_support_collate_fn
|
||||||
|
|
||||||
|
# init query dataloader
|
||||||
|
query_data_loader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
sampler=sampler,
|
||||||
|
num_workers=num_workers,
|
||||||
|
collate_fn=partial(
|
||||||
|
query_support_collate_fn, samples_per_gpu=samples_per_gpu),
|
||||||
|
pin_memory=False,
|
||||||
|
worker_init_fn=init_fn,
|
||||||
|
**kwargs)
|
||||||
|
# creat support dataset from query dataset and
|
||||||
|
# pre sample batch index with same length as query dataloader
|
||||||
|
support_dataset = copy.deepcopy(dataset)
|
||||||
|
support_dataset.convert_query_to_support(len(query_data_loader))
|
||||||
|
|
||||||
|
(support_sampler, _, _) \
|
||||||
|
= build_sampler(dist=dist,
|
||||||
|
shuffle=shuffle,
|
||||||
|
dataset=support_dataset,
|
||||||
|
num_gpus=num_gpus,
|
||||||
|
samples_per_gpu=1,
|
||||||
|
workers_per_gpu=workers_per_gpu,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
data_loader = NwayKshotDataloader(
|
||||||
|
query_data_loader=query_data_loader,
|
||||||
|
support_dataset=support_dataset,
|
||||||
|
support_sampler=support_sampler,
|
||||||
|
num_workers=num_workers,
|
||||||
|
support_collate_fn=partial(
|
||||||
|
query_support_collate_fn, samples_per_gpu=1),
|
||||||
|
pin_memory=False,
|
||||||
|
worker_init_fn=init_fn,
|
||||||
|
**kwargs)
|
||||||
|
else:
|
||||||
|
data_loader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
sampler=sampler,
|
||||||
|
num_workers=num_workers,
|
||||||
|
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
|
||||||
|
pin_memory=False,
|
||||||
|
worker_init_fn=init_fn,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
return data_loader
|
||||||
|
|
||||||
|
|
||||||
|
def build_sampler(dist, shuffle, dataset, num_gpus, samples_per_gpu,
|
||||||
|
workers_per_gpu, seed):
|
||||||
|
"""Build pytorch sampler for dataLoader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dist (bool): Distributed training/test or not.
|
||||||
|
shuffle (bool): Whether to shuffle the data at every epoch.
|
||||||
|
dataset (Dataset): A PyTorch dataset.
|
||||||
|
num_gpus (int): Number of GPUs. Only used in non-distributed training.
|
||||||
|
samples_per_gpu (int): Number of training samples on each GPU, i.e.,
|
||||||
|
batch size of each GPU.
|
||||||
|
workers_per_gpu (int): How many subprocesses to use for data loading
|
||||||
|
for each GPU.
|
||||||
|
seed (int): Random seed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"""
|
||||||
|
rank, world_size = get_dist_info()
|
||||||
|
if dist:
|
||||||
|
# DistributedGroupSampler will definitely shuffle the data to satisfy
|
||||||
|
# that images on each GPU are in the same group
|
||||||
|
if shuffle:
|
||||||
|
sampler = DistributedGroupSampler(
|
||||||
|
dataset, samples_per_gpu, world_size, rank, seed=seed)
|
||||||
|
else:
|
||||||
|
sampler = DistributedSampler(
|
||||||
|
dataset, world_size, rank, shuffle=False, seed=seed)
|
||||||
|
batch_size = samples_per_gpu
|
||||||
|
num_workers = workers_per_gpu
|
||||||
|
else:
|
||||||
|
sampler = GroupSampler(dataset, samples_per_gpu) if shuffle else None
|
||||||
|
batch_size = num_gpus * samples_per_gpu
|
||||||
|
num_workers = num_gpus * workers_per_gpu
|
||||||
|
|
||||||
|
return sampler, batch_size, num_workers
|
|
@ -0,0 +1,536 @@
|
||||||
|
import itertools
|
||||||
|
import logging
|
||||||
|
import os.path as osp
|
||||||
|
import tempfile
|
||||||
|
import warnings
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
from mmcv.utils import print_log
|
||||||
|
from mmdet.core import eval_recalls
|
||||||
|
from mmdet.datasets.api_wrappers import COCO, COCOeval
|
||||||
|
from mmdet.datasets.builder import DATASETS
|
||||||
|
from terminaltables import AsciiTable
|
||||||
|
|
||||||
|
from .few_shot_custom import FewShotCustomDataset
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module()
|
||||||
|
class FewShotCocoDataset(FewShotCustomDataset):
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
assert self.CLASSES or kwargs.get('classes', None),\
|
||||||
|
'CLASSES in `FewShotCocoDataset` can not be None.'
|
||||||
|
super(FewShotCocoDataset, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
def load_annotations(self, ann_file):
|
||||||
|
"""Load annotation from COCO style annotation file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ann_file (str): Path of annotation file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[dict]: Annotation info from COCO api.
|
||||||
|
"""
|
||||||
|
self.coco = COCO(ann_file)
|
||||||
|
self.cat_ids = []
|
||||||
|
self.cat2label = {}
|
||||||
|
# to keep the label order equal to the order in CLASSES
|
||||||
|
for i, class_name in enumerate(self.CLASSES):
|
||||||
|
cat_id = self.coco.get_cat_ids(cat_names=[class_name])[0]
|
||||||
|
self.cat_ids.append(cat_id)
|
||||||
|
self.cat2label[cat_id] = i
|
||||||
|
self.img_ids = self.coco.get_img_ids()
|
||||||
|
|
||||||
|
data_infos = []
|
||||||
|
total_ann_ids = []
|
||||||
|
for i in self.img_ids:
|
||||||
|
info = self.coco.load_imgs([i])[0]
|
||||||
|
info['filename'] = info['file_name']
|
||||||
|
info['ann'] = self._get_ann_info(info)
|
||||||
|
data_infos.append(info)
|
||||||
|
ann_ids = self.coco.get_ann_ids(img_ids=[i])
|
||||||
|
total_ann_ids.extend(ann_ids)
|
||||||
|
assert len(set(total_ann_ids)) == len(
|
||||||
|
total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!"
|
||||||
|
return data_infos
|
||||||
|
|
||||||
|
def _get_ann_info(self, data_info):
|
||||||
|
"""Get COCO annotation by index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_info dict: Data info.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Annotation info of specified index.
|
||||||
|
"""
|
||||||
|
|
||||||
|
img_id = data_info['id']
|
||||||
|
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
|
||||||
|
ann_info = self.coco.load_anns(ann_ids)
|
||||||
|
return self._parse_ann_info(data_info, ann_info)
|
||||||
|
|
||||||
|
def _filter_imgs(self, min_size=32):
|
||||||
|
"""Filter images too small or without ground truths."""
|
||||||
|
valid_inds = []
|
||||||
|
# obtain images that contain annotation
|
||||||
|
ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
|
||||||
|
# obtain images that contain annotations of the required categories
|
||||||
|
ids_in_cat = set()
|
||||||
|
for i, class_id in enumerate(self.cat_ids):
|
||||||
|
ids_in_cat |= set(self.coco.cat_img_map[class_id])
|
||||||
|
# merge the image id sets of the two conditions and use the merged set
|
||||||
|
# to filter out images if self.filter_empty_gt=True
|
||||||
|
ids_in_cat &= ids_with_ann
|
||||||
|
|
||||||
|
valid_img_ids = []
|
||||||
|
for i, img_info in enumerate(self.data_infos):
|
||||||
|
img_id = self.img_ids[i]
|
||||||
|
if self.filter_empty_gt and img_id not in ids_in_cat:
|
||||||
|
continue
|
||||||
|
if min(img_info['width'], img_info['height']) >= min_size:
|
||||||
|
valid_inds.append(i)
|
||||||
|
valid_img_ids.append(img_id)
|
||||||
|
self.img_ids = valid_img_ids
|
||||||
|
return valid_inds
|
||||||
|
|
||||||
|
def _parse_ann_info(self, img_info, ann_info):
|
||||||
|
"""Parse bbox and mask annotation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_info (dict): Image info.
|
||||||
|
ann_info (list[dict]): Annotation info of an image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dict containing the following keys: bboxes, bboxes_ignore,\
|
||||||
|
labels, masks, seg_map. "masks" are raw annotations and not \
|
||||||
|
decoded into binary masks.
|
||||||
|
"""
|
||||||
|
gt_bboxes = []
|
||||||
|
gt_labels = []
|
||||||
|
gt_bboxes_ignore = []
|
||||||
|
gt_masks_ann = []
|
||||||
|
for i, ann in enumerate(ann_info):
|
||||||
|
if ann.get('ignore', False):
|
||||||
|
continue
|
||||||
|
x1, y1, w, h = ann['bbox']
|
||||||
|
inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
|
||||||
|
inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
|
||||||
|
if inter_w * inter_h == 0:
|
||||||
|
continue
|
||||||
|
if ann['area'] <= 0 or w < 1 or h < 1:
|
||||||
|
continue
|
||||||
|
if ann['category_id'] not in self.cat_ids:
|
||||||
|
continue
|
||||||
|
bbox = [x1, y1, x1 + w, y1 + h]
|
||||||
|
if ann.get('iscrowd', False):
|
||||||
|
gt_bboxes_ignore.append(bbox)
|
||||||
|
else:
|
||||||
|
gt_bboxes.append(bbox)
|
||||||
|
gt_labels.append(self.cat2label[ann['category_id']])
|
||||||
|
gt_masks_ann.append(ann.get('segmentation', None))
|
||||||
|
|
||||||
|
if gt_bboxes:
|
||||||
|
gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
|
||||||
|
gt_labels = np.array(gt_labels, dtype=np.int64)
|
||||||
|
else:
|
||||||
|
gt_bboxes = np.zeros((0, 4), dtype=np.float32)
|
||||||
|
gt_labels = np.array([], dtype=np.int64)
|
||||||
|
|
||||||
|
if gt_bboxes_ignore:
|
||||||
|
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
|
||||||
|
else:
|
||||||
|
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
|
||||||
|
|
||||||
|
seg_map = img_info['filename'].replace('jpg', 'png')
|
||||||
|
|
||||||
|
ann = dict(
|
||||||
|
bboxes=gt_bboxes,
|
||||||
|
labels=gt_labels,
|
||||||
|
bboxes_ignore=gt_bboxes_ignore,
|
||||||
|
masks=gt_masks_ann,
|
||||||
|
seg_map=seg_map)
|
||||||
|
|
||||||
|
return ann
|
||||||
|
|
||||||
|
def xyxy2xywh(self, bbox):
|
||||||
|
"""Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO
|
||||||
|
evaluation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bbox (numpy.ndarray): The bounding boxes, shape (4, ), in
|
||||||
|
``xyxy`` order.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float]: The converted bounding boxes, in ``xywh`` order.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_bbox = bbox.tolist()
|
||||||
|
return [
|
||||||
|
_bbox[0],
|
||||||
|
_bbox[1],
|
||||||
|
_bbox[2] - _bbox[0],
|
||||||
|
_bbox[3] - _bbox[1],
|
||||||
|
]
|
||||||
|
|
||||||
|
def _proposal2json(self, results):
|
||||||
|
"""Convert proposal results to COCO json style."""
|
||||||
|
json_results = []
|
||||||
|
for idx in range(len(self)):
|
||||||
|
img_id = self.img_ids[idx]
|
||||||
|
bboxes = results[idx]
|
||||||
|
for i in range(bboxes.shape[0]):
|
||||||
|
data = dict()
|
||||||
|
data['image_id'] = img_id
|
||||||
|
data['bbox'] = self.xyxy2xywh(bboxes[i])
|
||||||
|
data['score'] = float(bboxes[i][4])
|
||||||
|
data['category_id'] = 1
|
||||||
|
json_results.append(data)
|
||||||
|
return json_results
|
||||||
|
|
||||||
|
def _det2json(self, results):
|
||||||
|
"""Convert detection results to COCO json style."""
|
||||||
|
json_results = []
|
||||||
|
for idx in range(len(self)):
|
||||||
|
img_id = self.img_ids[idx]
|
||||||
|
result = results[idx]
|
||||||
|
for label in range(len(result)):
|
||||||
|
bboxes = result[label]
|
||||||
|
for i in range(bboxes.shape[0]):
|
||||||
|
data = dict()
|
||||||
|
data['image_id'] = img_id
|
||||||
|
data['bbox'] = self.xyxy2xywh(bboxes[i])
|
||||||
|
data['score'] = float(bboxes[i][4])
|
||||||
|
data['category_id'] = self.cat_ids[label]
|
||||||
|
json_results.append(data)
|
||||||
|
return json_results
|
||||||
|
|
||||||
|
def _segm2json(self, results):
|
||||||
|
"""Convert instance segmentation results to COCO json style."""
|
||||||
|
bbox_json_results = []
|
||||||
|
segm_json_results = []
|
||||||
|
for idx in range(len(self)):
|
||||||
|
img_id = self.img_ids[idx]
|
||||||
|
det, seg = results[idx]
|
||||||
|
for label in range(len(det)):
|
||||||
|
# bbox results
|
||||||
|
bboxes = det[label]
|
||||||
|
for i in range(bboxes.shape[0]):
|
||||||
|
data = dict()
|
||||||
|
data['image_id'] = img_id
|
||||||
|
data['bbox'] = self.xyxy2xywh(bboxes[i])
|
||||||
|
data['score'] = float(bboxes[i][4])
|
||||||
|
data['category_id'] = self.cat_ids[label]
|
||||||
|
bbox_json_results.append(data)
|
||||||
|
|
||||||
|
# segm results
|
||||||
|
# some detectors use different scores for bbox and mask
|
||||||
|
if isinstance(seg, tuple):
|
||||||
|
segms = seg[0][label]
|
||||||
|
mask_score = seg[1][label]
|
||||||
|
else:
|
||||||
|
segms = seg[label]
|
||||||
|
mask_score = [bbox[4] for bbox in bboxes]
|
||||||
|
for i in range(bboxes.shape[0]):
|
||||||
|
data = dict()
|
||||||
|
data['image_id'] = img_id
|
||||||
|
data['bbox'] = self.xyxy2xywh(bboxes[i])
|
||||||
|
data['score'] = float(mask_score[i])
|
||||||
|
data['category_id'] = self.cat_ids[label]
|
||||||
|
if isinstance(segms[i]['counts'], bytes):
|
||||||
|
segms[i]['counts'] = segms[i]['counts'].decode()
|
||||||
|
data['segmentation'] = segms[i]
|
||||||
|
segm_json_results.append(data)
|
||||||
|
return bbox_json_results, segm_json_results
|
||||||
|
|
||||||
|
def results2json(self, results, outfile_prefix):
|
||||||
|
"""Dump the detection results to a COCO style json file.
|
||||||
|
|
||||||
|
There are 3 types of results: proposals, bbox predictions, mask
|
||||||
|
predictions, and they have different data types. This method will
|
||||||
|
automatically recognize the type, and dump them to json files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (list[list | tuple | ndarray]): Testing results of the
|
||||||
|
dataset.
|
||||||
|
outfile_prefix (str): The filename prefix of the json files. If the
|
||||||
|
prefix is "somepath/xxx", the json files will be named
|
||||||
|
"somepath/xxx.bbox.json", "somepath/xxx.segm.json",
|
||||||
|
"somepath/xxx.proposal.json".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str: str]: Possible keys are "bbox", "segm", "proposal", and \
|
||||||
|
values are corresponding filenames.
|
||||||
|
"""
|
||||||
|
result_files = dict()
|
||||||
|
if isinstance(results[0], list):
|
||||||
|
json_results = self._det2json(results)
|
||||||
|
result_files['bbox'] = f'{outfile_prefix}.bbox.json'
|
||||||
|
result_files['proposal'] = f'{outfile_prefix}.bbox.json'
|
||||||
|
mmcv.dump(json_results, result_files['bbox'])
|
||||||
|
elif isinstance(results[0], tuple):
|
||||||
|
json_results = self._segm2json(results)
|
||||||
|
result_files['bbox'] = f'{outfile_prefix}.bbox.json'
|
||||||
|
result_files['proposal'] = f'{outfile_prefix}.bbox.json'
|
||||||
|
result_files['segm'] = f'{outfile_prefix}.segm.json'
|
||||||
|
mmcv.dump(json_results[0], result_files['bbox'])
|
||||||
|
mmcv.dump(json_results[1], result_files['segm'])
|
||||||
|
elif isinstance(results[0], np.ndarray):
|
||||||
|
json_results = self._proposal2json(results)
|
||||||
|
result_files['proposal'] = f'{outfile_prefix}.proposal.json'
|
||||||
|
mmcv.dump(json_results, result_files['proposal'])
|
||||||
|
else:
|
||||||
|
raise TypeError('invalid type of results')
|
||||||
|
return result_files
|
||||||
|
|
||||||
|
def fast_eval_recall(self, results, proposal_nums, iou_thrs, logger=None):
|
||||||
|
gt_bboxes = []
|
||||||
|
for i in range(len(self.img_ids)):
|
||||||
|
ann_ids = self.coco.get_ann_ids(img_ids=self.img_ids[i])
|
||||||
|
ann_info = self.coco.load_anns(ann_ids)
|
||||||
|
if len(ann_info) == 0:
|
||||||
|
gt_bboxes.append(np.zeros((0, 4)))
|
||||||
|
continue
|
||||||
|
bboxes = []
|
||||||
|
for ann in ann_info:
|
||||||
|
if ann.get('ignore', False) or ann['iscrowd']:
|
||||||
|
continue
|
||||||
|
x1, y1, w, h = ann['bbox']
|
||||||
|
bboxes.append([x1, y1, x1 + w, y1 + h])
|
||||||
|
bboxes = np.array(bboxes, dtype=np.float32)
|
||||||
|
if bboxes.shape[0] == 0:
|
||||||
|
bboxes = np.zeros((0, 4))
|
||||||
|
gt_bboxes.append(bboxes)
|
||||||
|
|
||||||
|
recalls = eval_recalls(
|
||||||
|
gt_bboxes, results, proposal_nums, iou_thrs, logger=logger)
|
||||||
|
ar = recalls.mean(axis=1)
|
||||||
|
return ar
|
||||||
|
|
||||||
|
def format_results(self, results, jsonfile_prefix=None, **kwargs):
|
||||||
|
"""Format the results to json (standard format for COCO evaluation).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (list[tuple | numpy.ndarray]): Testing results of the
|
||||||
|
dataset.
|
||||||
|
jsonfile_prefix (str | None): The prefix of json files. It includes
|
||||||
|
the file path and the prefix of filename, e.g., "a/b/prefix".
|
||||||
|
If not specified, a temp file will be created. Default: None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (result_files, tmp_dir), result_files is a dict containing \
|
||||||
|
the json filepaths, tmp_dir is the temporal directory created \
|
||||||
|
for saving json files when jsonfile_prefix is not specified.
|
||||||
|
"""
|
||||||
|
assert isinstance(results, list), 'results must be a list'
|
||||||
|
assert len(results) == len(self), (
|
||||||
|
'The length of results is not equal to the dataset len: {} != {}'.
|
||||||
|
format(len(results), len(self)))
|
||||||
|
|
||||||
|
if jsonfile_prefix is None:
|
||||||
|
tmp_dir = tempfile.TemporaryDirectory()
|
||||||
|
jsonfile_prefix = osp.join(tmp_dir.name, 'results')
|
||||||
|
else:
|
||||||
|
tmp_dir = None
|
||||||
|
result_files = self.results2json(results, jsonfile_prefix)
|
||||||
|
return result_files, tmp_dir
|
||||||
|
|
||||||
|
def evaluate(self,
|
||||||
|
results,
|
||||||
|
metric='bbox',
|
||||||
|
logger=None,
|
||||||
|
jsonfile_prefix=None,
|
||||||
|
classwise=False,
|
||||||
|
proposal_nums=(100, 300, 1000),
|
||||||
|
iou_thrs=None,
|
||||||
|
metric_items=None):
|
||||||
|
"""Evaluation in COCO protocol.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (list[list | tuple]): Testing results of the dataset.
|
||||||
|
metric (str | list[str]): Metrics to be evaluated. Options are
|
||||||
|
'bbox', 'segm', 'proposal', 'proposal_fast'.
|
||||||
|
logger (logging.Logger | str | None): Logger used for printing
|
||||||
|
related information during evaluation. Default: None.
|
||||||
|
jsonfile_prefix (str | None): The prefix of json files. It includes
|
||||||
|
the file path and the prefix of filename, e.g., "a/b/prefix".
|
||||||
|
If not specified, a temp file will be created. Default: None.
|
||||||
|
classwise (bool): Whether to evaluating the AP for each class.
|
||||||
|
proposal_nums (Sequence[int]): Proposal number used for evaluating
|
||||||
|
recalls, such as recall@100, recall@1000.
|
||||||
|
Default: (100, 300, 1000).
|
||||||
|
iou_thrs (Sequence[float], optional): IoU threshold used for
|
||||||
|
evaluating recalls/mAPs. If set to a list, the average of all
|
||||||
|
IoUs will also be computed. If not specified, [0.50, 0.55,
|
||||||
|
0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used.
|
||||||
|
Default: None.
|
||||||
|
metric_items (list[str] | str, optional): Metric items that will
|
||||||
|
be returned. If not specified, ``['AR@100', 'AR@300',
|
||||||
|
'AR@1000', 'AR_s@1000', 'AR_m@1000', 'AR_l@1000' ]`` will be
|
||||||
|
used when ``metric=='proposal'``, ``['mAP', 'mAP_50', 'mAP_75',
|
||||||
|
'mAP_s', 'mAP_m', 'mAP_l']`` will be used when
|
||||||
|
``metric=='bbox' or metric=='segm'``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, float]: COCO style evaluation metric.
|
||||||
|
"""
|
||||||
|
|
||||||
|
metrics = metric if isinstance(metric, list) else [metric]
|
||||||
|
allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast']
|
||||||
|
for metric in metrics:
|
||||||
|
if metric not in allowed_metrics:
|
||||||
|
raise KeyError(f'metric {metric} is not supported')
|
||||||
|
if iou_thrs is None:
|
||||||
|
iou_thrs = np.linspace(
|
||||||
|
.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
|
||||||
|
if metric_items is not None:
|
||||||
|
if not isinstance(metric_items, list):
|
||||||
|
metric_items = [metric_items]
|
||||||
|
|
||||||
|
result_files, tmp_dir = self.format_results(results, jsonfile_prefix)
|
||||||
|
|
||||||
|
eval_results = OrderedDict()
|
||||||
|
cocoGt = self.coco
|
||||||
|
for metric in metrics:
|
||||||
|
msg = f'Evaluating {metric}...'
|
||||||
|
if logger is None:
|
||||||
|
msg = '\n' + msg
|
||||||
|
print_log(msg, logger=logger)
|
||||||
|
|
||||||
|
if metric == 'proposal_fast':
|
||||||
|
ar = self.fast_eval_recall(
|
||||||
|
results, proposal_nums, iou_thrs, logger='silent')
|
||||||
|
log_msg = []
|
||||||
|
for i, num in enumerate(proposal_nums):
|
||||||
|
eval_results[f'AR@{num}'] = ar[i]
|
||||||
|
log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}')
|
||||||
|
log_msg = ''.join(log_msg)
|
||||||
|
print_log(log_msg, logger=logger)
|
||||||
|
continue
|
||||||
|
|
||||||
|
iou_type = 'bbox' if metric == 'proposal' else metric
|
||||||
|
if metric not in result_files:
|
||||||
|
raise KeyError(f'{metric} is not in results')
|
||||||
|
try:
|
||||||
|
predictions = mmcv.load(result_files[metric])
|
||||||
|
if iou_type == 'segm':
|
||||||
|
# Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa
|
||||||
|
# When evaluating mask AP, if the results contain bbox,
|
||||||
|
# cocoapi will use the box area instead of the mask area
|
||||||
|
# for calculating the instance area. Though the overall AP
|
||||||
|
# is not affected, this leads to different
|
||||||
|
# small/medium/large mask AP results.
|
||||||
|
for x in predictions:
|
||||||
|
x.pop('bbox')
|
||||||
|
warnings.simplefilter('once')
|
||||||
|
warnings.warn(
|
||||||
|
'The key "bbox" is deleted for more accurate mask AP '
|
||||||
|
'of small/medium/large instances since v2.12.0. This '
|
||||||
|
'does not change the overall mAP calculation.',
|
||||||
|
UserWarning)
|
||||||
|
cocoDt = cocoGt.loadRes(predictions)
|
||||||
|
except IndexError:
|
||||||
|
print_log(
|
||||||
|
'The testing results of the whole dataset is empty.',
|
||||||
|
logger=logger,
|
||||||
|
level=logging.ERROR)
|
||||||
|
break
|
||||||
|
|
||||||
|
cocoEval = COCOeval(cocoGt, cocoDt, iou_type)
|
||||||
|
cocoEval.params.catIds = self.cat_ids
|
||||||
|
cocoEval.params.imgIds = self.img_ids
|
||||||
|
cocoEval.params.maxDets = list(proposal_nums)
|
||||||
|
cocoEval.params.iouThrs = iou_thrs
|
||||||
|
# mapping of cocoEval.stats
|
||||||
|
coco_metric_names = {
|
||||||
|
'mAP': 0,
|
||||||
|
'mAP_50': 1,
|
||||||
|
'mAP_75': 2,
|
||||||
|
'mAP_s': 3,
|
||||||
|
'mAP_m': 4,
|
||||||
|
'mAP_l': 5,
|
||||||
|
'AR@100': 6,
|
||||||
|
'AR@300': 7,
|
||||||
|
'AR@1000': 8,
|
||||||
|
'AR_s@1000': 9,
|
||||||
|
'AR_m@1000': 10,
|
||||||
|
'AR_l@1000': 11
|
||||||
|
}
|
||||||
|
if metric_items is not None:
|
||||||
|
for metric_item in metric_items:
|
||||||
|
if metric_item not in coco_metric_names:
|
||||||
|
raise KeyError(
|
||||||
|
f'metric item {metric_item} is not supported')
|
||||||
|
|
||||||
|
if metric == 'proposal':
|
||||||
|
cocoEval.params.useCats = 0
|
||||||
|
cocoEval.evaluate()
|
||||||
|
cocoEval.accumulate()
|
||||||
|
cocoEval.summarize()
|
||||||
|
if metric_items is None:
|
||||||
|
metric_items = [
|
||||||
|
'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000',
|
||||||
|
'AR_m@1000', 'AR_l@1000'
|
||||||
|
]
|
||||||
|
|
||||||
|
for item in metric_items:
|
||||||
|
val = float(
|
||||||
|
f'{cocoEval.stats[coco_metric_names[item]]:.3f}')
|
||||||
|
eval_results[item] = val
|
||||||
|
else:
|
||||||
|
cocoEval.evaluate()
|
||||||
|
cocoEval.accumulate()
|
||||||
|
cocoEval.summarize()
|
||||||
|
if classwise: # Compute per-category AP
|
||||||
|
# Compute per-category AP
|
||||||
|
# from https://github.com/facebookresearch/detectron2/
|
||||||
|
precisions = cocoEval.eval['precision']
|
||||||
|
# precision: (iou, recall, cls, area range, max dets)
|
||||||
|
assert len(self.cat_ids) == precisions.shape[2]
|
||||||
|
|
||||||
|
results_per_category = []
|
||||||
|
for idx, catId in enumerate(self.cat_ids):
|
||||||
|
# area range index 0: all area ranges
|
||||||
|
# max dets index -1: typically 100 per image
|
||||||
|
nm = self.coco.loadCats(catId)[0]
|
||||||
|
precision = precisions[:, :, idx, 0, -1]
|
||||||
|
precision = precision[precision > -1]
|
||||||
|
if precision.size:
|
||||||
|
ap = np.mean(precision)
|
||||||
|
else:
|
||||||
|
ap = float('nan')
|
||||||
|
results_per_category.append(
|
||||||
|
(f'{nm["name"]}', f'{float(ap):0.3f}'))
|
||||||
|
|
||||||
|
num_columns = min(6, len(results_per_category) * 2)
|
||||||
|
results_flatten = list(
|
||||||
|
itertools.chain(*results_per_category))
|
||||||
|
headers = ['category', 'AP'] * (num_columns // 2)
|
||||||
|
results_2d = itertools.zip_longest(*[
|
||||||
|
results_flatten[i::num_columns]
|
||||||
|
for i in range(num_columns)
|
||||||
|
])
|
||||||
|
table_data = [headers]
|
||||||
|
table_data += [result for result in results_2d]
|
||||||
|
table = AsciiTable(table_data)
|
||||||
|
print_log('\n' + table.table, logger=logger)
|
||||||
|
|
||||||
|
if metric_items is None:
|
||||||
|
metric_items = [
|
||||||
|
'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'
|
||||||
|
]
|
||||||
|
|
||||||
|
for metric_item in metric_items:
|
||||||
|
key = f'{metric}_{metric_item}'
|
||||||
|
val = float(
|
||||||
|
f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}'
|
||||||
|
)
|
||||||
|
eval_results[key] = val
|
||||||
|
ap = cocoEval.stats[:6]
|
||||||
|
eval_results[f'{metric}_mAP_copypaste'] = (
|
||||||
|
f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '
|
||||||
|
f'{ap[4]:.3f} {ap[5]:.3f}')
|
||||||
|
if tmp_dir is not None:
|
||||||
|
tmp_dir.cleanup()
|
||||||
|
return eval_results
|
|
@ -0,0 +1,64 @@
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
|
||||||
|
class NwayKshotDataloader(object):
|
||||||
|
"""A dataloader wrapper of NwayKshotDataset dataset. Create a iterator to
|
||||||
|
generate query and support batch simultaneously. Each batch return a batch
|
||||||
|
of query data (batch_size) and support data.
|
||||||
|
|
||||||
|
(support_way * support_shot).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
datasets (list[:obj:`NwayKshotDataset`]): A list of datasets.
|
||||||
|
batch_size (int): How many query samples per batch to load.
|
||||||
|
sampler (Sampler): Sampler for query dataloader only.
|
||||||
|
num_workers (int): Num workers for both support and query dataloader.
|
||||||
|
collate_fn (callable): Collate function for query dataloader.
|
||||||
|
pin_memory (bool): Pin memory for both support and query dataloader.
|
||||||
|
worker_init_fn (callable): Worker init function for both
|
||||||
|
support and query dataloader.
|
||||||
|
kwargs: any keyword argument to be used to initialize DataLoader.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, query_data_loader, support_dataset, support_sampler,
|
||||||
|
num_workers, support_collate_fn, pin_memory, worker_init_fn,
|
||||||
|
**kwargs):
|
||||||
|
self.dataset = query_data_loader.dataset
|
||||||
|
self.query_data_loader = query_data_loader
|
||||||
|
self.support_dataset = support_dataset
|
||||||
|
self.support_sampler = support_sampler
|
||||||
|
self.num_workers = num_workers
|
||||||
|
self.support_collate_fn = support_collate_fn
|
||||||
|
self.pin_memory = pin_memory
|
||||||
|
self.worker_init_fn = worker_init_fn
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
# generate different support batch index for each epoch
|
||||||
|
self.support_dataset.shuffle_support()
|
||||||
|
# init support dataloader with batch_size 1
|
||||||
|
# each batch are pre-sampler in dataset and use collate
|
||||||
|
# function to generate a batch with support_way*support_shot
|
||||||
|
self.support_data_loader = DataLoader(
|
||||||
|
self.support_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
sampler=self.support_sampler,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
collate_fn=self.support_collate_fn,
|
||||||
|
pin_memory=self.pin_memory,
|
||||||
|
worker_init_fn=self.worker_init_fn,
|
||||||
|
**self.kwargs)
|
||||||
|
|
||||||
|
# init iterator for query and support
|
||||||
|
self.query_iter = iter(self.query_data_loader)
|
||||||
|
self.support_iter = iter(self.support_data_loader)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
# call query and support iterator
|
||||||
|
query_data = self.query_iter.next()
|
||||||
|
support_data = self.support_iter.next()
|
||||||
|
return {'query_data': query_data, 'support_data': support_data}
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.query_data_loader)
|
|
@ -0,0 +1,453 @@
|
||||||
|
import copy
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from mmdet.datasets.builder import DATASETS
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module()
|
||||||
|
class MergeDataset(object):
|
||||||
|
"""A wrapper of merge dataset.
|
||||||
|
|
||||||
|
This dataset wrapper would be called when using multiple annotation
|
||||||
|
files for NwayKshotDataset, QueryAwareDataset, and FewShotCustomDataset.
|
||||||
|
It would merge the data info of input datasets, because different
|
||||||
|
annotations of same image will cross different datasets.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
datasets (list[:obj:`Dataset`]): A list of datasets.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, datasets):
|
||||||
|
self.dataset = copy.deepcopy(datasets[0])
|
||||||
|
self.CLASSES = self.dataset.CLASSES
|
||||||
|
for dataset in datasets:
|
||||||
|
assert dataset.img_prefix == self.dataset.img_prefix, \
|
||||||
|
'when using MergeDataset all img_prefix should be the same'
|
||||||
|
|
||||||
|
self.img_prefix = self.dataset.img_prefix
|
||||||
|
|
||||||
|
# merge datainfos for all datasets
|
||||||
|
concat_data_infos = sum([dataset.data_infos for dataset in datasets],
|
||||||
|
[])
|
||||||
|
merge_data_dict = {}
|
||||||
|
for i, data_info in enumerate(concat_data_infos):
|
||||||
|
|
||||||
|
if merge_data_dict.get(data_info['id'], None) is None:
|
||||||
|
merge_data_dict[data_info['id']] = data_info
|
||||||
|
else:
|
||||||
|
merge_data_dict[data_info['id']]['ann'] = \
|
||||||
|
self.merge_ann(merge_data_dict[data_info['id']]['ann'],
|
||||||
|
data_info['ann'])
|
||||||
|
|
||||||
|
self.dataset.data_infos = [
|
||||||
|
merge_data_dict[key] for key in merge_data_dict.keys()
|
||||||
|
]
|
||||||
|
|
||||||
|
# Disable the groupsampler, because in few shot setting,
|
||||||
|
# one group may only has two or three images.
|
||||||
|
if hasattr(datasets[0], 'flag'):
|
||||||
|
self.flag = np.zeros(len(self.dataset), dtype=np.uint8)
|
||||||
|
|
||||||
|
def get_cat_ids(self, idx):
|
||||||
|
"""Get category ids of merge dataset by index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx (int): Index of data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[int]: All categories in the image of specified index.
|
||||||
|
"""
|
||||||
|
return self.dataset.get_cat_ids(idx)
|
||||||
|
|
||||||
|
def prepare_train_img(self, idx, pipeline_key=None, gt_idx=None):
|
||||||
|
"""Get training data and annotations after pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx (int): Index of data.
|
||||||
|
pipeline_key (str): Name of pipeline
|
||||||
|
gt_idx (list[int]): Index of used annotation.
|
||||||
|
Returns:
|
||||||
|
dict: Training data and annotation after pipeline with new keys \
|
||||||
|
introduced by pipeline.
|
||||||
|
"""
|
||||||
|
return self.dataset.prepare_train_img(idx, pipeline_key, gt_idx)
|
||||||
|
|
||||||
|
def get_ann_info(self, idx):
|
||||||
|
"""Get annotation by index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx (int): Index of data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Annotation info of specified index.
|
||||||
|
"""
|
||||||
|
return self.dataset.get_ann_info(idx)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.dataset[idx]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""Dataset length after merge."""
|
||||||
|
return len(self.dataset)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.dataset.__repr__()
|
||||||
|
|
||||||
|
def evaluate(self, results, logger=None, **kwargs):
|
||||||
|
"""Evaluate the results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (list[list | tuple]): Testing results of the dataset.
|
||||||
|
logger (logging.Logger | str | None): Logger used for printing
|
||||||
|
related information during evaluation. Default: None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str: float]: AP results of the total dataset or each separate
|
||||||
|
dataset if `self.separate_eval=True`.
|
||||||
|
"""
|
||||||
|
eval_results = self.dataset.evaluate(results, logger=logger, **kwargs)
|
||||||
|
return eval_results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def merge_ann(ann_a, ann_b):
|
||||||
|
"""Merge two annotations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ann_a (dict): Dict of annotation.
|
||||||
|
ann_b (dict): Dict of annotation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Merged annotation.
|
||||||
|
"""
|
||||||
|
assert sorted(ann_a.keys()) == sorted(ann_b.keys()), \
|
||||||
|
'can not merge different type of annotations'
|
||||||
|
return {
|
||||||
|
'bboxes': np.concatenate((ann_a['bboxes'], ann_b['bboxes'])),
|
||||||
|
'labels': np.concatenate((ann_a['labels'], ann_b['labels'])),
|
||||||
|
'bboxes_ignore': ann_a['bboxes_ignore'],
|
||||||
|
'labels_ignore': ann_a['labels_ignore']
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module()
|
||||||
|
class QueryAwareDataset(object):
|
||||||
|
"""A wrapper of query aware dataset.
|
||||||
|
|
||||||
|
For each item in dataset, there will be one query image and
|
||||||
|
(num_support_way * num_support_shot) support images.
|
||||||
|
The support images are sampled according to the selected
|
||||||
|
query image and include positive class (random one class
|
||||||
|
in query image) and negative class (any classes not appear in
|
||||||
|
query image).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
datasets (obj:`FewShotDataset`, `MergeDataset`):
|
||||||
|
The dataset to be wrapped.
|
||||||
|
num_support_way (int): The number of classes for support data,
|
||||||
|
the first one always be the positive class.
|
||||||
|
num_support_shot (int): The number of shot for each support class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dataset, num_support_way, num_support_shot):
|
||||||
|
self.dataset = dataset
|
||||||
|
self.num_support_way = num_support_way
|
||||||
|
self.num_support_shot = num_support_shot
|
||||||
|
self.CLASSES = dataset.CLASSES
|
||||||
|
assert self.num_support_way <= len(self.CLASSES), \
|
||||||
|
'Please set the num_support_way smaller than the ' \
|
||||||
|
'number of classes.'
|
||||||
|
# build data index (idx, gt_idx) by class.
|
||||||
|
self.data_infos_by_class = {i: [] for i in range(len(self.CLASSES))}
|
||||||
|
# count max number of anns in one image for each class, which will
|
||||||
|
# decide whether sample repeated instance or not.
|
||||||
|
self.max_anns_per_image_by_class = [
|
||||||
|
0 for _ in range(len(self.CLASSES))
|
||||||
|
]
|
||||||
|
# count image for each class annotation when novel class only
|
||||||
|
# has one image, the positive support is allowed sampled from itself.
|
||||||
|
self.num_image_by_class = [0 for _ in range(len(self.CLASSES))]
|
||||||
|
|
||||||
|
for idx in range(len(self.dataset)):
|
||||||
|
labels = self.dataset.get_ann_info(idx)['labels']
|
||||||
|
class_count = [0 for _ in range(len(self.CLASSES))]
|
||||||
|
for gt_idx, gt in enumerate(labels):
|
||||||
|
self.data_infos_by_class[gt].append((idx, gt_idx))
|
||||||
|
class_count[gt] += 1
|
||||||
|
for i in range(len(self.CLASSES)):
|
||||||
|
# number of images for each class
|
||||||
|
if class_count[i] > 0:
|
||||||
|
self.num_image_by_class[i] += 1
|
||||||
|
# max number of one class annotations in one image
|
||||||
|
if class_count[i] > self.max_anns_per_image_by_class[i]:
|
||||||
|
self.max_anns_per_image_by_class[i] = class_count[i]
|
||||||
|
|
||||||
|
for i in range(len(self.CLASSES)):
|
||||||
|
assert len(self.data_infos_by_class[i]) > 0, \
|
||||||
|
f'Class {self.CLASSES[i]} has zero annotation'
|
||||||
|
if len(self.data_infos_by_class[i]) <= self.num_support_shot - \
|
||||||
|
self.max_anns_per_image_by_class[i]:
|
||||||
|
warnings.warn(
|
||||||
|
f'During training, instances of class {self.CLASSES[i]} '
|
||||||
|
f'may smaller than the number of support shots which '
|
||||||
|
f'causes some instance will be sampled multiple times')
|
||||||
|
if self.num_image_by_class[i] == 1:
|
||||||
|
warnings.warn(f'Class {self.CLASSES[i]} only have one '
|
||||||
|
f'image, query and support will sample '
|
||||||
|
f'from instance of same image')
|
||||||
|
|
||||||
|
# Disable the groupsampler, because in few shot setting,
|
||||||
|
# one group may only has two or three images.
|
||||||
|
if hasattr(dataset, 'flag'):
|
||||||
|
self.flag = np.zeros(len(self.dataset), dtype=np.uint8)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# sample query data
|
||||||
|
try_time = 0
|
||||||
|
while True:
|
||||||
|
try_time += 1
|
||||||
|
cat_ids = self.dataset.get_cat_ids(idx)
|
||||||
|
# query image have too many classes, can not find enough
|
||||||
|
# negative support classes.
|
||||||
|
if len(self.CLASSES) - len(cat_ids) >= self.num_support_way - 1:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
idx = self._rand_another(idx)
|
||||||
|
assert try_time < 100, \
|
||||||
|
'Not enough negative support classes for query image,' \
|
||||||
|
' please try a smaller support way.'
|
||||||
|
|
||||||
|
query_class = np.random.choice(cat_ids)
|
||||||
|
query_gt_idx = [
|
||||||
|
i for i in range(len(cat_ids)) if cat_ids[i] == query_class
|
||||||
|
]
|
||||||
|
query_data = self.dataset.prepare_train_img(idx, 'query', query_gt_idx)
|
||||||
|
query_data['query_class'] = [query_class]
|
||||||
|
|
||||||
|
# sample negative support classes, which not appear in query image
|
||||||
|
support_class = [
|
||||||
|
i for i in range(len(self.CLASSES)) if i not in cat_ids
|
||||||
|
]
|
||||||
|
support_class = np.random.choice(
|
||||||
|
support_class,
|
||||||
|
min(self.num_support_way - 1, len(support_class)),
|
||||||
|
replace=False)
|
||||||
|
support_idxes = self.generate_support(idx, query_class, support_class)
|
||||||
|
support_data = [
|
||||||
|
self.dataset.prepare_train_img(idx, 'support', [gt_idx])
|
||||||
|
for (idx, gt_idx) in support_idxes
|
||||||
|
]
|
||||||
|
return {'query_data': query_data, 'support_data': support_data}
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""Length after repetition."""
|
||||||
|
return len(self.dataset)
|
||||||
|
|
||||||
|
def _rand_another(self, idx):
|
||||||
|
"""Get another random index from the same group as the given index."""
|
||||||
|
pool = np.where(self.flag == self.flag[idx])[0]
|
||||||
|
return np.random.choice(pool)
|
||||||
|
|
||||||
|
def generate_support(self, idx, query_class, support_classes):
|
||||||
|
"""Generate support indexes of query images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx (int): Index of query data.
|
||||||
|
query_class (int): Query class.
|
||||||
|
support_classes (list[int]): Classes of support data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[(int, int)]: A batch (num_support_way * num_support_shot)
|
||||||
|
of support data (idx, gt_idx).
|
||||||
|
"""
|
||||||
|
support_idxes = []
|
||||||
|
if self.num_image_by_class[query_class] == 1:
|
||||||
|
# only have one image, instance will sample from same image
|
||||||
|
pos_support_idxes = self.sample_support_shots(
|
||||||
|
idx, query_class, allow_same_image=True)
|
||||||
|
else:
|
||||||
|
# instance will sample from different image from query image
|
||||||
|
pos_support_idxes = self.sample_support_shots(idx, query_class)
|
||||||
|
support_idxes.extend(pos_support_idxes)
|
||||||
|
for support_class in support_classes:
|
||||||
|
neg_support_idxes = self.sample_support_shots(idx, support_class)
|
||||||
|
support_idxes.extend(neg_support_idxes)
|
||||||
|
return support_idxes
|
||||||
|
|
||||||
|
def sample_support_shots(self, idx, class_id, allow_same_image=False):
|
||||||
|
"""Generate positive support indexes by class id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx (int): Index of query data.
|
||||||
|
class_id (int): Support class.
|
||||||
|
allow_same_image: Allow instance sampled from same image
|
||||||
|
as query image. Default: False.
|
||||||
|
Returns:
|
||||||
|
list[(int, int)]: Support data (num_support_shot)
|
||||||
|
of specific class.
|
||||||
|
"""
|
||||||
|
support_idxes = []
|
||||||
|
num_total_shot = len(self.data_infos_by_class[class_id])
|
||||||
|
num_ignore_shot = self.count_class_id(idx, class_id)
|
||||||
|
# set num_sample_shots for each time of sampling
|
||||||
|
|
||||||
|
if num_total_shot - num_ignore_shot < self.num_support_shot:
|
||||||
|
# if not have enough support data allow repeated data
|
||||||
|
num_sample_shots = num_total_shot
|
||||||
|
allow_repeat = True
|
||||||
|
else:
|
||||||
|
# if have enough support data not allow repeated data
|
||||||
|
num_sample_shots = self.num_support_shot
|
||||||
|
allow_repeat = False
|
||||||
|
while len(support_idxes) < self.num_support_shot:
|
||||||
|
selected_gt_idxes = np.random.choice(
|
||||||
|
num_total_shot, num_sample_shots, replace=False)
|
||||||
|
|
||||||
|
selected_gts = [
|
||||||
|
self.data_infos_by_class[class_id][selected_gt_idx]
|
||||||
|
for selected_gt_idx in selected_gt_idxes
|
||||||
|
]
|
||||||
|
for selected_gt in selected_gts:
|
||||||
|
# filter out query annotations
|
||||||
|
if selected_gt[0] == idx:
|
||||||
|
if not allow_same_image:
|
||||||
|
continue
|
||||||
|
if allow_repeat:
|
||||||
|
support_idxes.append(selected_gt)
|
||||||
|
elif selected_gt not in support_idxes:
|
||||||
|
support_idxes.append(selected_gt)
|
||||||
|
if len(support_idxes) == self.num_support_shot:
|
||||||
|
break
|
||||||
|
# update the number of data for next time sample
|
||||||
|
num_sample_shots = min(self.num_support_shot - len(support_idxes),
|
||||||
|
num_sample_shots)
|
||||||
|
return support_idxes
|
||||||
|
|
||||||
|
def count_class_id(self, idx, class_id):
|
||||||
|
"""Count number of instance of specific."""
|
||||||
|
cat_ids = self.dataset.get_cat_ids(idx)
|
||||||
|
cat_ids_of_class = [
|
||||||
|
i for i in range(len(cat_ids)) if cat_ids[i] == class_id
|
||||||
|
]
|
||||||
|
return len(cat_ids_of_class)
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module()
|
||||||
|
class NwayKshotDataset(object):
|
||||||
|
"""A dataset wrapper of NwayKshotDataset.
|
||||||
|
|
||||||
|
Based on incoming dataset, query dataset will sample batch data as
|
||||||
|
regular dataset, while support dataset will pre sample batch data
|
||||||
|
indexes. Each batch index contain (num_support_way * num_support_shot)
|
||||||
|
samples. The default format of NwayKshotDataset is query dataset and
|
||||||
|
the query dataset will convert into support dataset by using convert
|
||||||
|
function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
datasets (obj:`FewShotDataset`, `MergeDataset`):
|
||||||
|
The dataset to be wrapped.
|
||||||
|
num_support_way (int):
|
||||||
|
The number of classes in support data batch.
|
||||||
|
num_support_shot (int):
|
||||||
|
The number of shots for each class in support data batch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dataset, num_support_way, num_support_shot):
|
||||||
|
self.dataset = dataset
|
||||||
|
self.CLASSES = dataset.CLASSES
|
||||||
|
# The data_type determinate the behavior of fetching data,
|
||||||
|
# the default data_type is 'query', which is the same as regular
|
||||||
|
# dataset. To convert the dataset into 'support' dataset, simply
|
||||||
|
# call the function convert_query_to_support().
|
||||||
|
self.data_type = 'query'
|
||||||
|
self.num_support_way = num_support_way
|
||||||
|
assert num_support_way <= len(self.CLASSES), \
|
||||||
|
'support way can not larger than the number of classes'
|
||||||
|
self.num_support_shot = num_support_shot
|
||||||
|
self.batch_index = []
|
||||||
|
self.data_infos_by_class = {i: [] for i in range(len(self.CLASSES))}
|
||||||
|
|
||||||
|
# Disable the groupsampler, because in few shot setting,
|
||||||
|
# one group may only has two or three images.
|
||||||
|
if hasattr(dataset, 'flag'):
|
||||||
|
self.flag = np.zeros(len(self.dataset), dtype=np.uint8)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
if self.data_type == 'query':
|
||||||
|
# loads one data in query pipeline
|
||||||
|
return self.dataset.prepare_train_img(idx, 'query')
|
||||||
|
elif self.data_type == 'support':
|
||||||
|
# loads one batch of data in support pipeline
|
||||||
|
b_idx = self.batch_index[idx]
|
||||||
|
batch_data = [
|
||||||
|
self.dataset.prepare_train_img(idx, 'support', [gt_idx])
|
||||||
|
for (idx, gt_idx) in b_idx
|
||||||
|
]
|
||||||
|
return batch_data
|
||||||
|
else:
|
||||||
|
raise ValueError('not support data type')
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""Length of dataset."""
|
||||||
|
if self.data_type == 'query':
|
||||||
|
return len(self.dataset)
|
||||||
|
elif self.data_type == 'support':
|
||||||
|
return len(self.batch_index)
|
||||||
|
else:
|
||||||
|
raise ValueError('not support data type')
|
||||||
|
|
||||||
|
def shuffle_support(self):
|
||||||
|
"""Generate new batch indexes."""
|
||||||
|
if self.data_type == 'query':
|
||||||
|
raise ValueError('not support data type')
|
||||||
|
self.batch_index = self.generate_batch_index(len(self.batch_index))
|
||||||
|
|
||||||
|
def convert_query_to_support(self, support_dataset_len):
|
||||||
|
"""Convert query dataset to support dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
support_dataset_len (int): Length of pre sample batch indexes.
|
||||||
|
"""
|
||||||
|
# create lookup table for annotations in same class
|
||||||
|
for idx in range(len(self.dataset)):
|
||||||
|
labels = self.dataset.get_ann_info(idx)['labels']
|
||||||
|
for gt_idx, gt in enumerate(labels):
|
||||||
|
self.data_infos_by_class[gt].append((idx, gt_idx))
|
||||||
|
# make sure all class index lists have enough
|
||||||
|
# instances (length > num_support_shot)
|
||||||
|
for i in range(len(self.CLASSES)):
|
||||||
|
num_gts = len(self.data_infos_by_class[i])
|
||||||
|
if num_gts < self.num_support_shot:
|
||||||
|
self.data_infos_by_class[i] = self.data_infos_by_class[i] * \
|
||||||
|
(self.num_support_shot // num_gts + 1)
|
||||||
|
self.batch_index = self.generate_batch_index(support_dataset_len)
|
||||||
|
self.data_type = 'support'
|
||||||
|
if hasattr(self, 'flag'):
|
||||||
|
self.flag = np.zeros(support_dataset_len, dtype=np.uint8)
|
||||||
|
|
||||||
|
def generate_batch_index(self, dataset_len):
|
||||||
|
"""Generate batch index [length of datasets * [support way * support shots]].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_len: Length of pre sample batch indexes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[List[(data_idx, gt_idx)]]: Pre sample batch indexes.
|
||||||
|
"""
|
||||||
|
total_batch_index = []
|
||||||
|
for _ in range(dataset_len):
|
||||||
|
batch_index = []
|
||||||
|
selected_classes = np.random.choice(
|
||||||
|
len(self.CLASSES), self.num_support_way, replace=False)
|
||||||
|
for cls in selected_classes:
|
||||||
|
num_gts = len(self.data_infos_by_class[cls])
|
||||||
|
selected_gts_idx = np.random.choice(
|
||||||
|
num_gts, self.num_support_shot, replace=False)
|
||||||
|
selected_gts = [
|
||||||
|
self.data_infos_by_class[cls][gt_idx]
|
||||||
|
for gt_idx in selected_gts_idx
|
||||||
|
]
|
||||||
|
batch_index.extend(selected_gts)
|
||||||
|
total_batch_index.append(batch_index)
|
||||||
|
return total_batch_index
|
|
@ -0,0 +1,264 @@
|
||||||
|
import copy
|
||||||
|
import os.path as osp
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from mmdet.datasets.builder import DATASETS
|
||||||
|
from mmdet.datasets.custom import CustomDataset
|
||||||
|
from mmdet.datasets.pipelines import Compose
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module()
|
||||||
|
class FewShotCustomDataset(CustomDataset):
|
||||||
|
"""Custom dataset for few shot detection.
|
||||||
|
|
||||||
|
It allow single (normal dataset of fully supervised setting) or
|
||||||
|
two (query-support fashion) pipelines for data processing.
|
||||||
|
When annotation shots filter is used, it make sure accessible
|
||||||
|
annotations meet the few shot setting in exact number of instances.
|
||||||
|
|
||||||
|
The annotation format is shown as follows. The `ann` field
|
||||||
|
is optional for testing.
|
||||||
|
|
||||||
|
.. code-block:: none
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
'id': '0000001'
|
||||||
|
'filename': 'a.jpg',
|
||||||
|
'width': 1280,
|
||||||
|
'height': 720,
|
||||||
|
'ann': {
|
||||||
|
'bboxes': <np.ndarray> (n, 4) in (x1, y1, x2, y2) order.
|
||||||
|
'labels': <np.ndarray> (n, ),
|
||||||
|
'bboxes_ignore': <np.ndarray> (k, 4), (optional field)
|
||||||
|
'labels_ignore': <np.ndarray> (k, 4) (optional field)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ann_file (str): Annotation file path.
|
||||||
|
pipeline (list[dict] | dict): Processing pipeline
|
||||||
|
If is list[dict] all data will pass through this pipeline,
|
||||||
|
If is dict, query and support data will be processed with
|
||||||
|
two different pipelines and the dict should contain two keys:
|
||||||
|
- 'query': list[dict]
|
||||||
|
- 'support': list[dict]
|
||||||
|
classes (str | Sequence[str]): Classes for model training and
|
||||||
|
provide fixed label for each class.
|
||||||
|
data_root (str, optional): Data root for ``ann_file``,
|
||||||
|
``img_prefix``, ``seg_prefix``, ``proposal_file`` if specified.
|
||||||
|
test_mode (bool, optional): If set True, annotation will not be loaded.
|
||||||
|
filter_empty_gt (bool, optional): If set true, images without bounding
|
||||||
|
boxes of the dataset's classes will be filtered out. This option
|
||||||
|
only works when `test_mode=False`, i.e., we never filter images
|
||||||
|
during tests.
|
||||||
|
ann_shot_filter (dict, optional): If set None, all annotation from
|
||||||
|
ann file will be loaded. If not None, annotation shot filter will
|
||||||
|
specific which class and the maximum number of instances to load
|
||||||
|
from annotation file. For example: {'dog': 10, 'person': 5}.
|
||||||
|
Default: None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
CLASSES = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ann_file,
|
||||||
|
pipeline,
|
||||||
|
classes,
|
||||||
|
data_root=None,
|
||||||
|
img_prefix='',
|
||||||
|
seg_prefix=None,
|
||||||
|
proposal_file=None,
|
||||||
|
test_mode=False,
|
||||||
|
filter_empty_gt=True,
|
||||||
|
ann_shot_filter=None,
|
||||||
|
):
|
||||||
|
self.ann_file = ann_file
|
||||||
|
self.data_root = data_root
|
||||||
|
self.img_prefix = img_prefix
|
||||||
|
self.seg_prefix = seg_prefix
|
||||||
|
self.proposal_file = proposal_file
|
||||||
|
self.test_mode = test_mode
|
||||||
|
self.filter_empty_gt = filter_empty_gt
|
||||||
|
self.CLASSES = self.get_classes(classes)
|
||||||
|
|
||||||
|
self.ann_shot_filter = ann_shot_filter
|
||||||
|
if self.ann_shot_filter is not None:
|
||||||
|
for class_name in list(self.ann_shot_filter.keys()):
|
||||||
|
assert class_name in self.CLASSES, \
|
||||||
|
f'class {class_name} from ' \
|
||||||
|
f'ann_shot_filter not in CLASSES, '
|
||||||
|
|
||||||
|
# join paths if data_root is specified
|
||||||
|
if self.data_root is not None:
|
||||||
|
if not osp.isabs(self.ann_file):
|
||||||
|
self.ann_file = osp.join(self.data_root, self.ann_file)
|
||||||
|
if not (self.img_prefix is None or osp.isabs(self.img_prefix)):
|
||||||
|
self.img_prefix = osp.join(self.data_root, self.img_prefix)
|
||||||
|
if not (self.seg_prefix is None or osp.isabs(self.seg_prefix)):
|
||||||
|
self.seg_prefix = osp.join(self.data_root, self.seg_prefix)
|
||||||
|
if not (self.proposal_file is None
|
||||||
|
or osp.isabs(self.proposal_file)):
|
||||||
|
self.proposal_file = osp.join(self.data_root,
|
||||||
|
self.proposal_file)
|
||||||
|
# load annotations (and proposals)
|
||||||
|
self.data_infos = self.load_annotations(self.ann_file)
|
||||||
|
|
||||||
|
# filter annotations according to ann_shot_filter
|
||||||
|
if self.ann_shot_filter is not None:
|
||||||
|
self.data_infos = self._filter_annotations(self.data_infos,
|
||||||
|
self.ann_shot_filter)
|
||||||
|
|
||||||
|
if self.proposal_file is not None:
|
||||||
|
self.proposals = self.load_proposals(self.proposal_file)
|
||||||
|
else:
|
||||||
|
self.proposals = None
|
||||||
|
|
||||||
|
# filter images too small and containing no annotations
|
||||||
|
if not test_mode:
|
||||||
|
valid_inds = self._filter_imgs()
|
||||||
|
self.data_infos = [self.data_infos[i] for i in valid_inds]
|
||||||
|
if self.proposals is not None:
|
||||||
|
self.proposals = [self.proposals[i] for i in valid_inds]
|
||||||
|
# set group flag for the sampler
|
||||||
|
self._set_group_flag()
|
||||||
|
|
||||||
|
# processing pipeline if there are two pipeline the
|
||||||
|
# pipeline will be determined by key name of query or support
|
||||||
|
if isinstance(pipeline, dict):
|
||||||
|
self.pipeline = {}
|
||||||
|
for key in pipeline.keys():
|
||||||
|
self.pipeline[key] = Compose(pipeline[key])
|
||||||
|
else:
|
||||||
|
self.pipeline = Compose(pipeline)
|
||||||
|
|
||||||
|
def get_ann_info(self, idx):
|
||||||
|
"""Get annotation by index.
|
||||||
|
|
||||||
|
When override this function please make sure same annotations are used
|
||||||
|
during the whole training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx (int): Index of data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Annotation info of specified index.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return copy.deepcopy(self.data_infos[idx]['ann'])
|
||||||
|
|
||||||
|
def prepare_train_img(self, idx, pipeline_key=None, gt_idx=None):
|
||||||
|
"""Get training data and annotations after pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx (int): Index of data.
|
||||||
|
pipeline_key (str): Name of pipeline
|
||||||
|
gt_idx (list[int]): Index of used annotation.
|
||||||
|
Returns:
|
||||||
|
dict: Training data and annotation after pipeline with new keys \
|
||||||
|
introduced by pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
img_info = self.data_infos[idx]
|
||||||
|
ann_info = self.get_ann_info(idx)
|
||||||
|
|
||||||
|
# annotation filter
|
||||||
|
if gt_idx is not None:
|
||||||
|
selected_ann_info = {
|
||||||
|
'bboxes': ann_info['bboxes'][gt_idx],
|
||||||
|
'labels': ann_info['labels'][gt_idx]
|
||||||
|
}
|
||||||
|
# keep pace with new annotations
|
||||||
|
new_img_info = copy.deepcopy(img_info)
|
||||||
|
new_img_info['ann'] = selected_ann_info
|
||||||
|
results = dict(img_info=new_img_info, ann_info=selected_ann_info)
|
||||||
|
else:
|
||||||
|
results = dict(img_info=copy.deepcopy(img_info), ann_info=ann_info)
|
||||||
|
|
||||||
|
if self.proposals is not None:
|
||||||
|
results['proposals'] = self.proposals[idx]
|
||||||
|
|
||||||
|
self.pre_pipeline(results)
|
||||||
|
if pipeline_key is None:
|
||||||
|
return self.pipeline(results)
|
||||||
|
else:
|
||||||
|
return self.pipeline[pipeline_key](results)
|
||||||
|
|
||||||
|
def _filter_annotations(self, data_infos, ann_shot_filter):
|
||||||
|
"""Filter out annotations not in class_masks and excess annotations of
|
||||||
|
specific class, while annotations of other classes in class_masks
|
||||||
|
remain unchanged.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_infos (list[dict]): Annotation infos.
|
||||||
|
ann_shot_filter (dict): Specific which class and how many
|
||||||
|
instances of each class to load from annotation file.
|
||||||
|
For example: {'dog': 10, 'cat': 10, 'person': 5} Default: None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[dict]: Annotation infos where number of specified class
|
||||||
|
shots less than or equal to predefined number.
|
||||||
|
"""
|
||||||
|
# build instance indexes of (img_id, gt_idx)
|
||||||
|
total_instance_dict = {key: [] for key in ann_shot_filter.keys()}
|
||||||
|
|
||||||
|
for data_info in data_infos:
|
||||||
|
img_id = data_info['id']
|
||||||
|
ann = data_info['ann']
|
||||||
|
for i in range(ann['labels'].shape[0]):
|
||||||
|
instance_class_name = self.CLASSES[ann['labels'][i]]
|
||||||
|
if instance_class_name in ann_shot_filter.keys():
|
||||||
|
total_instance_dict[instance_class_name].append(
|
||||||
|
(img_id, i))
|
||||||
|
|
||||||
|
total_instance_indexes = []
|
||||||
|
for class_name in ann_shot_filter.keys():
|
||||||
|
num_shot = ann_shot_filter[class_name]
|
||||||
|
instance_indexes = total_instance_dict[class_name]
|
||||||
|
# we will random sample from all instances to get exact
|
||||||
|
# number of instance
|
||||||
|
if len(instance_indexes) > num_shot:
|
||||||
|
random_select = np.random.choice(
|
||||||
|
len(instance_indexes), num_shot, replace=False)
|
||||||
|
total_instance_indexes += \
|
||||||
|
[instance_indexes[i] for i in random_select]
|
||||||
|
# number of shot less than the predefined number,
|
||||||
|
# which may cause the performance degradation
|
||||||
|
elif len(instance_indexes) < num_shot:
|
||||||
|
warning = f'number of {class_name} instance ' \
|
||||||
|
f'is {len(instance_indexes)} which is ' \
|
||||||
|
f'less than predefined shots {num_shot}.'
|
||||||
|
warnings.warn(warning)
|
||||||
|
total_instance_indexes += instance_indexes
|
||||||
|
else:
|
||||||
|
total_instance_indexes += instance_indexes
|
||||||
|
|
||||||
|
new_data_infos = []
|
||||||
|
for data_info in data_infos:
|
||||||
|
img_id = data_info['id']
|
||||||
|
selected_instance_index = \
|
||||||
|
sorted([instance[1] for instance in total_instance_indexes
|
||||||
|
if instance[0] == img_id])
|
||||||
|
ann = data_info['ann']
|
||||||
|
if len(selected_instance_index) == 0:
|
||||||
|
continue
|
||||||
|
selected_ann = dict(
|
||||||
|
bboxes=ann['bboxes'][selected_instance_index],
|
||||||
|
labels=ann['labels'][selected_instance_index],
|
||||||
|
)
|
||||||
|
if ann.get('bboxes_ignore') is not None:
|
||||||
|
selected_ann['bboxes_ignore'] = ann['bboxes_ignore']
|
||||||
|
if ann.get('labels_ignore') is not None:
|
||||||
|
selected_ann['labels_ignore'] = ann['labels_ignore']
|
||||||
|
new_data_infos.append(
|
||||||
|
dict(
|
||||||
|
id=img_id,
|
||||||
|
filename=data_info['filename'],
|
||||||
|
width=data_info['width'],
|
||||||
|
height=data_info['height'],
|
||||||
|
ann=selected_ann))
|
||||||
|
return new_data_infos
|
|
@ -0,0 +1,92 @@
|
||||||
|
# Copyright (c) Open-MMLab. All rights reserved.
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from mmcv.parallel.data_container import DataContainer
|
||||||
|
from torch.utils.data.dataloader import default_collate
|
||||||
|
|
||||||
|
|
||||||
|
def query_support_collate_fn(batch, samples_per_gpu=1):
|
||||||
|
"""Puts each data field into a tensor/DataContainer with outer dimension
|
||||||
|
batch size.
|
||||||
|
|
||||||
|
Extend default_collate to add support for
|
||||||
|
:type:`~mmcv.parallel.DataContainer`. There are 3 cases.
|
||||||
|
|
||||||
|
1. cpu_only = True, e.g., meta data
|
||||||
|
2. cpu_only = False, stack = True, e.g., images tensors
|
||||||
|
3. cpu_only = False, stack = False, e.g., gt bboxes
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(batch, Sequence):
|
||||||
|
raise TypeError(f'{batch.dtype} is not supported.')
|
||||||
|
|
||||||
|
# process the support batch data in type of List: [ List: [ DataContainer]]
|
||||||
|
if isinstance(batch[0], Sequence):
|
||||||
|
samples_per_gpu = len(batch[0]) * samples_per_gpu
|
||||||
|
batch = sum(batch, [])
|
||||||
|
|
||||||
|
if isinstance(batch[0], DataContainer):
|
||||||
|
stacked = []
|
||||||
|
if batch[0].cpu_only:
|
||||||
|
for i in range(0, len(batch), samples_per_gpu):
|
||||||
|
stacked.append(
|
||||||
|
[sample.data for sample in batch[i:i + samples_per_gpu]])
|
||||||
|
return DataContainer(
|
||||||
|
stacked, batch[0].stack, batch[0].padding_value, cpu_only=True)
|
||||||
|
elif batch[0].stack:
|
||||||
|
for i in range(0, len(batch), samples_per_gpu):
|
||||||
|
assert isinstance(batch[i].data, torch.Tensor)
|
||||||
|
|
||||||
|
if batch[i].pad_dims is not None:
|
||||||
|
ndim = batch[i].dim()
|
||||||
|
assert ndim > batch[i].pad_dims
|
||||||
|
max_shape = [0 for _ in range(batch[i].pad_dims)]
|
||||||
|
for dim in range(1, batch[i].pad_dims + 1):
|
||||||
|
max_shape[dim - 1] = batch[i].size(-dim)
|
||||||
|
for sample in batch[i:i + samples_per_gpu]:
|
||||||
|
for dim in range(0, ndim - batch[i].pad_dims):
|
||||||
|
assert batch[i].size(dim) == sample.size(dim)
|
||||||
|
for dim in range(1, batch[i].pad_dims + 1):
|
||||||
|
max_shape[dim - 1] = max(max_shape[dim - 1],
|
||||||
|
sample.size(-dim))
|
||||||
|
padded_samples = []
|
||||||
|
for sample in batch[i:i + samples_per_gpu]:
|
||||||
|
pad = [0 for _ in range(batch[i].pad_dims * 2)]
|
||||||
|
for dim in range(1, batch[i].pad_dims + 1):
|
||||||
|
pad[2 * dim -
|
||||||
|
1] = max_shape[dim - 1] - sample.size(-dim)
|
||||||
|
padded_samples.append(
|
||||||
|
F.pad(
|
||||||
|
sample.data, pad, value=sample.padding_value))
|
||||||
|
stacked.append(default_collate(padded_samples))
|
||||||
|
elif batch[i].pad_dims is None:
|
||||||
|
stacked.append(
|
||||||
|
default_collate([
|
||||||
|
sample.data
|
||||||
|
for sample in batch[i:i + samples_per_gpu]
|
||||||
|
]))
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'pad_dims should be either None or integers (1-3)')
|
||||||
|
|
||||||
|
else:
|
||||||
|
for i in range(0, len(batch), samples_per_gpu):
|
||||||
|
stacked.append(
|
||||||
|
[sample.data for sample in batch[i:i + samples_per_gpu]])
|
||||||
|
return DataContainer(stacked, batch[0].stack, batch[0].padding_value)
|
||||||
|
elif isinstance(batch[0], Sequence):
|
||||||
|
transposed = zip(*batch)
|
||||||
|
return [
|
||||||
|
query_support_collate_fn(samples, samples_per_gpu)
|
||||||
|
for samples in transposed
|
||||||
|
]
|
||||||
|
elif isinstance(batch[0], Mapping):
|
||||||
|
return {
|
||||||
|
key: query_support_collate_fn([d[key] for d in batch],
|
||||||
|
samples_per_gpu)
|
||||||
|
for key in batch[0]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return default_collate(batch)
|
|
@ -0,0 +1,183 @@
|
||||||
|
import os.path as osp
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
from mmdet.datasets.builder import DATASETS
|
||||||
|
|
||||||
|
from .few_shot_custom import FewShotCustomDataset
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module()
|
||||||
|
class FewShotVOCDataset(FewShotCustomDataset):
|
||||||
|
"""VOC dataset for few shot detection.
|
||||||
|
|
||||||
|
FewShotVOCDataset allow annotation mask during loading annotation.
|
||||||
|
The annotation can be loaded from image id or image path. For example:
|
||||||
|
|
||||||
|
.. code-block:: none
|
||||||
|
|
||||||
|
ann_image_id.txt:
|
||||||
|
000001
|
||||||
|
000002
|
||||||
|
|
||||||
|
ann_image_path.txt:
|
||||||
|
VOC2007/JPEGImages/000001.jpg
|
||||||
|
VOC2007/JPEGImages/000002.jpg
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_size (int | float, optional): The minimum size of bounding
|
||||||
|
boxes in the images. If the size of a bounding box is less than
|
||||||
|
``min_size``, it would be add to ignored field. Default: None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, min_size=None, **kwargs):
|
||||||
|
assert self.CLASSES or kwargs.get(
|
||||||
|
'classes', None), 'CLASSES in `XMLDataset` can not be None.'
|
||||||
|
self.min_size = min_size
|
||||||
|
super(FewShotVOCDataset, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
def load_annotations(self, ann_file):
|
||||||
|
"""Load annotation from XML style ann_file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ann_file (str): Path of XML file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[dict]: Annotation info from XML file.
|
||||||
|
"""
|
||||||
|
self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)}
|
||||||
|
data_infos = []
|
||||||
|
img_names = mmcv.list_from_file(ann_file)
|
||||||
|
for img_name in img_names:
|
||||||
|
# ann file in image path format
|
||||||
|
if 'VOC2007' in img_name:
|
||||||
|
dataset_year = 'VOC2007'
|
||||||
|
img_id = img_name.split('/')[-1].split('.')[0]
|
||||||
|
filename = img_name
|
||||||
|
# ann file in image path format
|
||||||
|
elif 'VOC2012' in img_name:
|
||||||
|
dataset_year = 'VOC2012'
|
||||||
|
img_id = img_name.split('/')[-1].split('.')[0]
|
||||||
|
filename = img_name
|
||||||
|
# ann file in image id format
|
||||||
|
elif 'VOC2007' in ann_file:
|
||||||
|
dataset_year = 'VOC2007'
|
||||||
|
img_id = img_name
|
||||||
|
filename = f'VOC2007/JPEGImages/{img_name}.jpg'
|
||||||
|
# ann file in image id format
|
||||||
|
elif 'VOC2012' in ann_file:
|
||||||
|
dataset_year = 'VOC2012'
|
||||||
|
img_id = img_name
|
||||||
|
filename = f'VOC2012/JPEGImages/{img_name}.jpg'
|
||||||
|
else:
|
||||||
|
raise ValueError('Cannot infer dataset year from img_prefix')
|
||||||
|
|
||||||
|
xml_path = osp.join(self.img_prefix, dataset_year, 'Annotations',
|
||||||
|
f'{img_id}.xml')
|
||||||
|
tree = ET.parse(xml_path)
|
||||||
|
root = tree.getroot()
|
||||||
|
size = root.find('size')
|
||||||
|
if size is not None:
|
||||||
|
width = int(size.find('width').text)
|
||||||
|
height = int(size.find('height').text)
|
||||||
|
else:
|
||||||
|
img_path = osp.join(self.img_prefix, dataset_year,
|
||||||
|
'JPEGImages', '{}.jpg'.format(img_id))
|
||||||
|
img = mmcv.imread(img_path)
|
||||||
|
width, height = img.size
|
||||||
|
# save annotation infos into data infos, because not all the
|
||||||
|
# annotations will be used for training and the used annotations
|
||||||
|
# should stay the same anytime during training.
|
||||||
|
ann_info = self._get_ann_info(dataset_year, img_id)
|
||||||
|
data_infos.append(
|
||||||
|
dict(
|
||||||
|
id=img_id,
|
||||||
|
filename=filename,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
ann=ann_info))
|
||||||
|
return data_infos
|
||||||
|
|
||||||
|
def _get_ann_info(self, dataset_year, img_id):
|
||||||
|
"""Get annotation from XML file by img_id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_year (str): Year of voc dataset. Options are
|
||||||
|
'VOC2007', 'VOC2012'
|
||||||
|
img_id (str): Id of image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Annotation info of specified id with specified class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
bboxes = []
|
||||||
|
labels = []
|
||||||
|
bboxes_ignore = []
|
||||||
|
labels_ignore = []
|
||||||
|
|
||||||
|
xml_path = osp.join(self.img_prefix, dataset_year, 'Annotations',
|
||||||
|
f'{img_id}.xml')
|
||||||
|
tree = ET.parse(xml_path)
|
||||||
|
root = tree.getroot()
|
||||||
|
for obj in root.findall('object'):
|
||||||
|
name = obj.find('name').text
|
||||||
|
if name not in self.CLASSES:
|
||||||
|
continue
|
||||||
|
label = self.cat2label[name]
|
||||||
|
difficult = obj.find('difficult')
|
||||||
|
difficult = 0 if difficult is None else int(difficult.text)
|
||||||
|
bnd_box = obj.find('bndbox')
|
||||||
|
# TODO: check whether it is necessary to use int
|
||||||
|
# Coordinates may be float type
|
||||||
|
bbox = [
|
||||||
|
int(float(bnd_box.find('xmin').text)),
|
||||||
|
int(float(bnd_box.find('ymin').text)),
|
||||||
|
int(float(bnd_box.find('xmax').text)),
|
||||||
|
int(float(bnd_box.find('ymax').text))
|
||||||
|
]
|
||||||
|
ignore = False
|
||||||
|
if self.min_size:
|
||||||
|
assert not self.test_mode
|
||||||
|
w = bbox[2] - bbox[0]
|
||||||
|
h = bbox[3] - bbox[1]
|
||||||
|
if w < self.min_size or h < self.min_size:
|
||||||
|
ignore = True
|
||||||
|
if difficult or ignore:
|
||||||
|
bboxes_ignore.append(bbox)
|
||||||
|
labels_ignore.append(label)
|
||||||
|
else:
|
||||||
|
bboxes.append(bbox)
|
||||||
|
labels.append(label)
|
||||||
|
if not bboxes:
|
||||||
|
bboxes = np.zeros((0, 4))
|
||||||
|
labels = np.zeros((0, ))
|
||||||
|
else:
|
||||||
|
bboxes = np.array(bboxes, ndmin=2) - 1
|
||||||
|
labels = np.array(labels)
|
||||||
|
if not bboxes_ignore:
|
||||||
|
bboxes_ignore = np.zeros((0, 4))
|
||||||
|
labels_ignore = np.zeros((0, ))
|
||||||
|
else:
|
||||||
|
bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1
|
||||||
|
labels_ignore = np.array(labels_ignore)
|
||||||
|
ann_info = dict(
|
||||||
|
bboxes=bboxes.astype(np.float32),
|
||||||
|
labels=labels.astype(np.int64),
|
||||||
|
bboxes_ignore=bboxes_ignore.astype(np.float32),
|
||||||
|
labels_ignore=labels_ignore.astype(np.int64))
|
||||||
|
return ann_info
|
||||||
|
|
||||||
|
def _filter_imgs(self, min_size=32):
|
||||||
|
"""Filter images too small or without annotation."""
|
||||||
|
valid_inds = []
|
||||||
|
for i, img_info in enumerate(self.data_infos):
|
||||||
|
if min(img_info['width'], img_info['height']) < min_size:
|
||||||
|
continue
|
||||||
|
if self.filter_empty_gt:
|
||||||
|
cat_ids = img_info['ann']['labels'].astype(np.int).tolist()
|
||||||
|
if len(cat_ids) > 0:
|
||||||
|
valid_inds.append(i)
|
||||||
|
else:
|
||||||
|
valid_inds.append(i)
|
||||||
|
return valid_inds
|
|
@ -3,7 +3,8 @@ line_length = 79
|
||||||
multi_line_output = 0
|
multi_line_output = 0
|
||||||
known_standard_library = setuptools
|
known_standard_library = setuptools
|
||||||
known_first_party = mmfewshot
|
known_first_party = mmfewshot
|
||||||
known_third_party = mmcls,mmcv,mmdet,numpy,pytest,torch
|
known_third_party = mmcls,mmcv,mmdet,numpy,pytest,terminaltables,torch
|
||||||
|
|
||||||
no_lines_before = STDLIB,LOCALFOLDER
|
no_lines_before = STDLIB,LOCALFOLDER
|
||||||
default_section = THIRDPARTY
|
default_section = THIRDPARTY
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
<annotation>
|
||||||
|
<folder>VOC2007</folder>
|
||||||
|
<filename>000001.jpg</filename>
|
||||||
|
<source>
|
||||||
|
<database>The VOC2007 Database</database>
|
||||||
|
<annotation>PASCAL VOC2007</annotation>
|
||||||
|
<image>flickr</image>
|
||||||
|
<flickrid>341012865</flickrid>
|
||||||
|
</source>
|
||||||
|
<owner>
|
||||||
|
<flickrid>Fried Camels</flickrid>
|
||||||
|
<name>Jinky the Fruit Bat</name>
|
||||||
|
</owner>
|
||||||
|
<size>
|
||||||
|
<width>353</width>
|
||||||
|
<height>500</height>
|
||||||
|
<depth>3</depth>
|
||||||
|
</size>
|
||||||
|
<segmented>0</segmented>
|
||||||
|
<object>
|
||||||
|
<name>dog</name>
|
||||||
|
<pose>Left</pose>
|
||||||
|
<truncated>1</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>48</xmin>
|
||||||
|
<ymin>240</ymin>
|
||||||
|
<xmax>195</xmax>
|
||||||
|
<ymax>371</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
<object>
|
||||||
|
<name>person</name>
|
||||||
|
<pose>Left</pose>
|
||||||
|
<truncated>1</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>8</xmin>
|
||||||
|
<ymin>12</ymin>
|
||||||
|
<xmax>352</xmax>
|
||||||
|
<ymax>498</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
</annotation>
|
|
@ -0,0 +1,32 @@
|
||||||
|
<annotation>
|
||||||
|
<folder>VOC2007</folder>
|
||||||
|
<filename>000002.jpg</filename>
|
||||||
|
<source>
|
||||||
|
<database>The VOC2007 Database</database>
|
||||||
|
<annotation>PASCAL VOC2007</annotation>
|
||||||
|
<image>flickr</image>
|
||||||
|
<flickrid>329145082</flickrid>
|
||||||
|
</source>
|
||||||
|
<owner>
|
||||||
|
<flickrid>hiromori2</flickrid>
|
||||||
|
<name>Hiroyuki Mori</name>
|
||||||
|
</owner>
|
||||||
|
<size>
|
||||||
|
<width>335</width>
|
||||||
|
<height>500</height>
|
||||||
|
<depth>3</depth>
|
||||||
|
</size>
|
||||||
|
<segmented>0</segmented>
|
||||||
|
<object>
|
||||||
|
<name>train</name>
|
||||||
|
<pose>Unspecified</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>139</xmin>
|
||||||
|
<ymin>200</ymin>
|
||||||
|
<xmax>207</xmax>
|
||||||
|
<ymax>301</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
</annotation>
|
|
@ -0,0 +1,44 @@
|
||||||
|
<annotation>
|
||||||
|
<folder>VOC2007</folder>
|
||||||
|
<filename>000003.jpg</filename>
|
||||||
|
<source>
|
||||||
|
<database>The VOC2007 Database</database>
|
||||||
|
<annotation>PASCAL VOC2007</annotation>
|
||||||
|
<image>flickr</image>
|
||||||
|
<flickrid>138563409</flickrid>
|
||||||
|
</source>
|
||||||
|
<owner>
|
||||||
|
<flickrid>RandomEvent101</flickrid>
|
||||||
|
<name>?</name>
|
||||||
|
</owner>
|
||||||
|
<size>
|
||||||
|
<width>500</width>
|
||||||
|
<height>375</height>
|
||||||
|
<depth>3</depth>
|
||||||
|
</size>
|
||||||
|
<segmented>0</segmented>
|
||||||
|
<object>
|
||||||
|
<name>sofa</name>
|
||||||
|
<pose>Unspecified</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>123</xmin>
|
||||||
|
<ymin>155</ymin>
|
||||||
|
<xmax>215</xmax>
|
||||||
|
<ymax>195</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
<object>
|
||||||
|
<name>chair</name>
|
||||||
|
<pose>Left</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>239</xmin>
|
||||||
|
<ymin>156</ymin>
|
||||||
|
<xmax>307</xmax>
|
||||||
|
<ymax>205</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
</annotation>
|
|
@ -0,0 +1,104 @@
|
||||||
|
<annotation>
|
||||||
|
<folder>VOC2007</folder>
|
||||||
|
<filename>000004.jpg</filename>
|
||||||
|
<source>
|
||||||
|
<database>The VOC2007 Database</database>
|
||||||
|
<annotation>PASCAL VOC2007</annotation>
|
||||||
|
<image>flickr</image>
|
||||||
|
<flickrid>322032655</flickrid>
|
||||||
|
</source>
|
||||||
|
<owner>
|
||||||
|
<flickrid>paytonc</flickrid>
|
||||||
|
<name>Payton Chung</name>
|
||||||
|
</owner>
|
||||||
|
<size>
|
||||||
|
<width>500</width>
|
||||||
|
<height>406</height>
|
||||||
|
<depth>3</depth>
|
||||||
|
</size>
|
||||||
|
<segmented>0</segmented>
|
||||||
|
<object>
|
||||||
|
<name>car</name>
|
||||||
|
<pose>Frontal</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>13</xmin>
|
||||||
|
<ymin>311</ymin>
|
||||||
|
<xmax>84</xmax>
|
||||||
|
<ymax>362</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
<object>
|
||||||
|
<name>car</name>
|
||||||
|
<pose>Unspecified</pose>
|
||||||
|
<truncated>1</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>362</xmin>
|
||||||
|
<ymin>330</ymin>
|
||||||
|
<xmax>500</xmax>
|
||||||
|
<ymax>389</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
<object>
|
||||||
|
<name>car</name>
|
||||||
|
<pose>Unspecified</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>235</xmin>
|
||||||
|
<ymin>328</ymin>
|
||||||
|
<xmax>334</xmax>
|
||||||
|
<ymax>375</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
<object>
|
||||||
|
<name>car</name>
|
||||||
|
<pose>Unspecified</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>175</xmin>
|
||||||
|
<ymin>327</ymin>
|
||||||
|
<xmax>252</xmax>
|
||||||
|
<ymax>364</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
<object>
|
||||||
|
<name>car</name>
|
||||||
|
<pose>Unspecified</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>139</xmin>
|
||||||
|
<ymin>320</ymin>
|
||||||
|
<xmax>189</xmax>
|
||||||
|
<ymax>359</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
<object>
|
||||||
|
<name>car</name>
|
||||||
|
<pose>Unspecified</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>108</xmin>
|
||||||
|
<ymin>325</ymin>
|
||||||
|
<xmax>150</xmax>
|
||||||
|
<ymax>353</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
<object>
|
||||||
|
<name>car</name>
|
||||||
|
<pose>Unspecified</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>84</xmin>
|
||||||
|
<ymin>323</ymin>
|
||||||
|
<xmax>121</xmax>
|
||||||
|
<ymax>350</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
</annotation>
|
|
@ -0,0 +1,80 @@
|
||||||
|
<annotation>
|
||||||
|
<folder>VOC2007</folder>
|
||||||
|
<filename>000005.jpg</filename>
|
||||||
|
<source>
|
||||||
|
<database>The VOC2007 Database</database>
|
||||||
|
<annotation>PASCAL VOC2007</annotation>
|
||||||
|
<image>flickr</image>
|
||||||
|
<flickrid>325991873</flickrid>
|
||||||
|
</source>
|
||||||
|
<owner>
|
||||||
|
<flickrid>archintent louisville</flickrid>
|
||||||
|
<name>?</name>
|
||||||
|
</owner>
|
||||||
|
<size>
|
||||||
|
<width>500</width>
|
||||||
|
<height>375</height>
|
||||||
|
<depth>3</depth>
|
||||||
|
</size>
|
||||||
|
<segmented>0</segmented>
|
||||||
|
<object>
|
||||||
|
<name>chair</name>
|
||||||
|
<pose>Rear</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>263</xmin>
|
||||||
|
<ymin>211</ymin>
|
||||||
|
<xmax>324</xmax>
|
||||||
|
<ymax>339</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
<object>
|
||||||
|
<name>chair</name>
|
||||||
|
<pose>Unspecified</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>165</xmin>
|
||||||
|
<ymin>264</ymin>
|
||||||
|
<xmax>253</xmax>
|
||||||
|
<ymax>372</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
<object>
|
||||||
|
<name>chair</name>
|
||||||
|
<pose>Unspecified</pose>
|
||||||
|
<truncated>1</truncated>
|
||||||
|
<difficult>1</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>5</xmin>
|
||||||
|
<ymin>244</ymin>
|
||||||
|
<xmax>67</xmax>
|
||||||
|
<ymax>374</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
<object>
|
||||||
|
<name>chair</name>
|
||||||
|
<pose>Unspecified</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>241</xmin>
|
||||||
|
<ymin>194</ymin>
|
||||||
|
<xmax>295</xmax>
|
||||||
|
<ymax>299</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
<object>
|
||||||
|
<name>chair</name>
|
||||||
|
<pose>Unspecified</pose>
|
||||||
|
<truncated>1</truncated>
|
||||||
|
<difficult>1</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>277</xmin>
|
||||||
|
<ymin>186</ymin>
|
||||||
|
<xmax>312</xmax>
|
||||||
|
<ymax>220</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
</annotation>
|
|
@ -0,0 +1,5 @@
|
||||||
|
000001
|
||||||
|
000002
|
||||||
|
000003
|
||||||
|
000004
|
||||||
|
000005
|
|
@ -0,0 +1,5 @@
|
||||||
|
000001
|
||||||
|
000002
|
||||||
|
000003
|
||||||
|
000004
|
||||||
|
000005
|
After Width: | Height: | Size: 10 KiB |
After Width: | Height: | Size: 111 KiB |
After Width: | Height: | Size: 120 KiB |
After Width: | Height: | Size: 100 KiB |
After Width: | Height: | Size: 83 KiB |
|
@ -0,0 +1,63 @@
|
||||||
|
<annotation>
|
||||||
|
<folder>VOC2012</folder>
|
||||||
|
<filename>2007_000027.jpg</filename>
|
||||||
|
<source>
|
||||||
|
<database>The VOC2007 Database</database>
|
||||||
|
<annotation>PASCAL VOC2007</annotation>
|
||||||
|
<image>flickr</image>
|
||||||
|
</source>
|
||||||
|
<size>
|
||||||
|
<width>486</width>
|
||||||
|
<height>500</height>
|
||||||
|
<depth>3</depth>
|
||||||
|
</size>
|
||||||
|
<segmented>0</segmented>
|
||||||
|
<object>
|
||||||
|
<name>person</name>
|
||||||
|
<pose>Unspecified</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>174</xmin>
|
||||||
|
<ymin>101</ymin>
|
||||||
|
<xmax>349</xmax>
|
||||||
|
<ymax>351</ymax>
|
||||||
|
</bndbox>
|
||||||
|
<part>
|
||||||
|
<name>head</name>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>169</xmin>
|
||||||
|
<ymin>104</ymin>
|
||||||
|
<xmax>209</xmax>
|
||||||
|
<ymax>146</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</part>
|
||||||
|
<part>
|
||||||
|
<name>hand</name>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>278</xmin>
|
||||||
|
<ymin>210</ymin>
|
||||||
|
<xmax>297</xmax>
|
||||||
|
<ymax>233</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</part>
|
||||||
|
<part>
|
||||||
|
<name>foot</name>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>273</xmin>
|
||||||
|
<ymin>333</ymin>
|
||||||
|
<xmax>297</xmax>
|
||||||
|
<ymax>354</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</part>
|
||||||
|
<part>
|
||||||
|
<name>foot</name>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>319</xmin>
|
||||||
|
<ymin>307</ymin>
|
||||||
|
<xmax>340</xmax>
|
||||||
|
<ymax>326</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</part>
|
||||||
|
</object>
|
||||||
|
</annotation>
|
|
@ -0,0 +1,63 @@
|
||||||
|
<annotation>
|
||||||
|
<folder>VOC2012</folder>
|
||||||
|
<filename>2007_000032.jpg</filename>
|
||||||
|
<source>
|
||||||
|
<database>The VOC2007 Database</database>
|
||||||
|
<annotation>PASCAL VOC2007</annotation>
|
||||||
|
<image>flickr</image>
|
||||||
|
</source>
|
||||||
|
<size>
|
||||||
|
<width>500</width>
|
||||||
|
<height>281</height>
|
||||||
|
<depth>3</depth>
|
||||||
|
</size>
|
||||||
|
<segmented>1</segmented>
|
||||||
|
<object>
|
||||||
|
<name>aeroplane</name>
|
||||||
|
<pose>Frontal</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>104</xmin>
|
||||||
|
<ymin>78</ymin>
|
||||||
|
<xmax>375</xmax>
|
||||||
|
<ymax>183</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
<object>
|
||||||
|
<name>aeroplane</name>
|
||||||
|
<pose>Left</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>133</xmin>
|
||||||
|
<ymin>88</ymin>
|
||||||
|
<xmax>197</xmax>
|
||||||
|
<ymax>123</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
<object>
|
||||||
|
<name>person</name>
|
||||||
|
<pose>Rear</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>195</xmin>
|
||||||
|
<ymin>180</ymin>
|
||||||
|
<xmax>213</xmax>
|
||||||
|
<ymax>229</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
<object>
|
||||||
|
<name>person</name>
|
||||||
|
<pose>Rear</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>26</xmin>
|
||||||
|
<ymin>189</ymin>
|
||||||
|
<xmax>44</xmax>
|
||||||
|
<ymax>238</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
</annotation>
|
|
@ -0,0 +1,51 @@
|
||||||
|
<annotation>
|
||||||
|
<folder>VOC2012</folder>
|
||||||
|
<filename>2007_000033.jpg</filename>
|
||||||
|
<source>
|
||||||
|
<database>The VOC2007 Database</database>
|
||||||
|
<annotation>PASCAL VOC2007</annotation>
|
||||||
|
<image>flickr</image>
|
||||||
|
</source>
|
||||||
|
<size>
|
||||||
|
<width>500</width>
|
||||||
|
<height>366</height>
|
||||||
|
<depth>3</depth>
|
||||||
|
</size>
|
||||||
|
<segmented>1</segmented>
|
||||||
|
<object>
|
||||||
|
<name>aeroplane</name>
|
||||||
|
<pose>Unspecified</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>9</xmin>
|
||||||
|
<ymin>107</ymin>
|
||||||
|
<xmax>499</xmax>
|
||||||
|
<ymax>263</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
<object>
|
||||||
|
<name>aeroplane</name>
|
||||||
|
<pose>Left</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>421</xmin>
|
||||||
|
<ymin>200</ymin>
|
||||||
|
<xmax>482</xmax>
|
||||||
|
<ymax>226</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
<object>
|
||||||
|
<name>aeroplane</name>
|
||||||
|
<pose>Left</pose>
|
||||||
|
<truncated>1</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>325</xmin>
|
||||||
|
<ymin>188</ymin>
|
||||||
|
<xmax>411</xmax>
|
||||||
|
<ymax>223</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
</annotation>
|
|
@ -0,0 +1,27 @@
|
||||||
|
<annotation>
|
||||||
|
<folder>VOC2012</folder>
|
||||||
|
<filename>2007_000039.jpg</filename>
|
||||||
|
<source>
|
||||||
|
<database>The VOC2007 Database</database>
|
||||||
|
<annotation>PASCAL VOC2007</annotation>
|
||||||
|
<image>flickr</image>
|
||||||
|
</source>
|
||||||
|
<size>
|
||||||
|
<width>500</width>
|
||||||
|
<height>375</height>
|
||||||
|
<depth>3</depth>
|
||||||
|
</size>
|
||||||
|
<segmented>1</segmented>
|
||||||
|
<object>
|
||||||
|
<name>tvmonitor</name>
|
||||||
|
<pose>Frontal</pose>
|
||||||
|
<truncated>0</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>156</xmin>
|
||||||
|
<ymin>89</ymin>
|
||||||
|
<xmax>344</xmax>
|
||||||
|
<ymax>279</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
</annotation>
|
|
@ -0,0 +1,39 @@
|
||||||
|
<annotation>
|
||||||
|
<folder>VOC2012</folder>
|
||||||
|
<filename>2007_000042.jpg</filename>
|
||||||
|
<source>
|
||||||
|
<database>The VOC2007 Database</database>
|
||||||
|
<annotation>PASCAL VOC2007</annotation>
|
||||||
|
<image>flickr</image>
|
||||||
|
</source>
|
||||||
|
<size>
|
||||||
|
<width>500</width>
|
||||||
|
<height>335</height>
|
||||||
|
<depth>3</depth>
|
||||||
|
</size>
|
||||||
|
<segmented>1</segmented>
|
||||||
|
<object>
|
||||||
|
<name>train</name>
|
||||||
|
<pose>Unspecified</pose>
|
||||||
|
<truncated>1</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>263</xmin>
|
||||||
|
<ymin>32</ymin>
|
||||||
|
<xmax>500</xmax>
|
||||||
|
<ymax>295</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
<object>
|
||||||
|
<name>train</name>
|
||||||
|
<pose>Unspecified</pose>
|
||||||
|
<truncated>1</truncated>
|
||||||
|
<difficult>0</difficult>
|
||||||
|
<bndbox>
|
||||||
|
<xmin>1</xmin>
|
||||||
|
<ymin>36</ymin>
|
||||||
|
<xmax>235</xmax>
|
||||||
|
<ymax>299</ymax>
|
||||||
|
</bndbox>
|
||||||
|
</object>
|
||||||
|
</annotation>
|
|
@ -0,0 +1,5 @@
|
||||||
|
2007_000027
|
||||||
|
2007_000032
|
||||||
|
2007_000033
|
||||||
|
2007_000039
|
||||||
|
2007_000042
|
|
@ -0,0 +1,5 @@
|
||||||
|
2007_000027
|
||||||
|
2007_000032
|
||||||
|
2007_000033
|
||||||
|
2007_000039
|
||||||
|
2007_000042
|
After Width: | Height: | Size: 142 KiB |
After Width: | Height: | Size: 54 KiB |
After Width: | Height: | Size: 70 KiB |
After Width: | Height: | Size: 63 KiB |
After Width: | Height: | Size: 81 KiB |
|
@ -0,0 +1,77 @@
|
||||||
|
{
|
||||||
|
"images": [
|
||||||
|
{
|
||||||
|
"file_name": "fake1.jpg",
|
||||||
|
"height": 800,
|
||||||
|
"width": 800,
|
||||||
|
"id": 0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"file_name": "fake2.jpg",
|
||||||
|
"height": 800,
|
||||||
|
"width": 800,
|
||||||
|
"id": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"file_name": "fake3.jpg",
|
||||||
|
"height": 800,
|
||||||
|
"width": 800,
|
||||||
|
"id": 2
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"annotations": [
|
||||||
|
{
|
||||||
|
"bbox": [
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
20,
|
||||||
|
20
|
||||||
|
],
|
||||||
|
"area": 400.00,
|
||||||
|
"score": 1.0,
|
||||||
|
"category_id": 1,
|
||||||
|
"id": 1,
|
||||||
|
"image_id": 0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"bbox": [
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
20,
|
||||||
|
20
|
||||||
|
],
|
||||||
|
"area": 400.00,
|
||||||
|
"score": 1.0,
|
||||||
|
"category_id": 2,
|
||||||
|
"id": 2,
|
||||||
|
"image_id": 0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"bbox": [
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
20,
|
||||||
|
20
|
||||||
|
],
|
||||||
|
"area": 400.00,
|
||||||
|
"score": 1.0,
|
||||||
|
"category_id": 1,
|
||||||
|
"id": 3,
|
||||||
|
"image_id": 1
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"categories": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"name": "bus",
|
||||||
|
"supercategory": "none"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"name": "car",
|
||||||
|
"supercategory": "none"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"licenses": [],
|
||||||
|
"info": null
|
||||||
|
}
|
After Width: | Height: | Size: 35 KiB |
|
@ -0,0 +1,10 @@
|
||||||
|
VOC2007/JPEGImages/000001.jpg
|
||||||
|
VOC2007/JPEGImages/000002.jpg
|
||||||
|
VOC2007/JPEGImages/000003.jpg
|
||||||
|
VOC2007/JPEGImages/000004.jpg
|
||||||
|
VOC2007/JPEGImages/000005.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000027.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000032.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000033.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000039.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000042.jpg
|
|
@ -0,0 +1,10 @@
|
||||||
|
VOC2007/JPEGImages/000001.jpg
|
||||||
|
VOC2007/JPEGImages/000002.jpg
|
||||||
|
VOC2007/JPEGImages/000003.jpg
|
||||||
|
VOC2007/JPEGImages/000004.jpg
|
||||||
|
VOC2007/JPEGImages/000005.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000027.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000032.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000033.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000039.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000042.jpg
|
|
@ -0,0 +1,10 @@
|
||||||
|
VOC2007/JPEGImages/000001.jpg
|
||||||
|
VOC2007/JPEGImages/000002.jpg
|
||||||
|
VOC2007/JPEGImages/000003.jpg
|
||||||
|
VOC2007/JPEGImages/000004.jpg
|
||||||
|
VOC2007/JPEGImages/000005.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000027.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000032.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000033.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000039.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000042.jpg
|
|
@ -0,0 +1,10 @@
|
||||||
|
VOC2007/JPEGImages/000001.jpg
|
||||||
|
VOC2007/JPEGImages/000002.jpg
|
||||||
|
VOC2007/JPEGImages/000003.jpg
|
||||||
|
VOC2007/JPEGImages/000004.jpg
|
||||||
|
VOC2007/JPEGImages/000005.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000027.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000032.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000033.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000039.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000042.jpg
|
|
@ -0,0 +1,10 @@
|
||||||
|
VOC2007/JPEGImages/000001.jpg
|
||||||
|
VOC2007/JPEGImages/000002.jpg
|
||||||
|
VOC2007/JPEGImages/000003.jpg
|
||||||
|
VOC2007/JPEGImages/000004.jpg
|
||||||
|
VOC2007/JPEGImages/000005.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000027.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000032.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000033.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000039.jpg
|
||||||
|
VOC2012/JPEGImages/2007_000042.jpg
|
After Width: | Height: | Size: 38 KiB |
|
@ -0,0 +1,247 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmfewshot.apis.train import set_random_seed
|
||||||
|
from mmfewshot.detection.datasets.builder import (build_dataloader,
|
||||||
|
build_dataset)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataloader():
|
||||||
|
set_random_seed(2021)
|
||||||
|
|
||||||
|
# test regular and few shot annotations
|
||||||
|
dataconfigs = [{
|
||||||
|
'type': 'NwayKshotDataset',
|
||||||
|
'support_way': 5,
|
||||||
|
'support_shot': 1,
|
||||||
|
'dataset': {
|
||||||
|
'type':
|
||||||
|
'FewShotVOCDataset',
|
||||||
|
'ann_file': [
|
||||||
|
'tests/data/VOCdevkit/VOC2007/ImageSets/Main/trainval.txt',
|
||||||
|
'tests/data/VOCdevkit/VOC2012/ImageSets/Main/trainval.txt'
|
||||||
|
],
|
||||||
|
'img_prefix': [
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
],
|
||||||
|
'pipeline': {
|
||||||
|
'query': [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations', with_bbox=True),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.5),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='DefaultFormatBundle'),
|
||||||
|
dict(
|
||||||
|
type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
|
||||||
|
],
|
||||||
|
'support': [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations', with_bbox=True),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.5),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='DefaultFormatBundle'),
|
||||||
|
dict(
|
||||||
|
type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
|
||||||
|
]
|
||||||
|
},
|
||||||
|
'classes': ('person', 'dog', 'chair', 'car', 'aeroplane', 'train'),
|
||||||
|
'merge_dataset':
|
||||||
|
True
|
||||||
|
}
|
||||||
|
}, {
|
||||||
|
'type': 'NwayKshotDataset',
|
||||||
|
'support_way': 5,
|
||||||
|
'support_shot': 1,
|
||||||
|
'dataset': {
|
||||||
|
'type':
|
||||||
|
'FewShotVOCDataset',
|
||||||
|
'ann_file': [
|
||||||
|
'tests/data/few_shot_voc_split/1.txt',
|
||||||
|
'tests/data/few_shot_voc_split/2.txt',
|
||||||
|
'tests/data/few_shot_voc_split/3.txt',
|
||||||
|
'tests/data/few_shot_voc_split/4.txt',
|
||||||
|
'tests/data/few_shot_voc_split/5.txt'
|
||||||
|
],
|
||||||
|
'img_prefix': [
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
],
|
||||||
|
'ann_shot_filter': [{
|
||||||
|
'person': 2
|
||||||
|
}, {
|
||||||
|
'dog': 2
|
||||||
|
}, {
|
||||||
|
'chair': 3
|
||||||
|
}, {
|
||||||
|
'car': 3
|
||||||
|
}, {
|
||||||
|
'aeroplane': 3
|
||||||
|
}],
|
||||||
|
'pipeline': {
|
||||||
|
'query': [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations', with_bbox=True),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.5),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='DefaultFormatBundle'),
|
||||||
|
dict(
|
||||||
|
type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
|
||||||
|
],
|
||||||
|
'support': [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations', with_bbox=True),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.5),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='DefaultFormatBundle'),
|
||||||
|
dict(
|
||||||
|
type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
|
||||||
|
]
|
||||||
|
},
|
||||||
|
'classes': ('person', 'dog', 'chair', 'car', 'aeroplane'),
|
||||||
|
'merge_dataset':
|
||||||
|
True
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
|
||||||
|
for dataconfig in dataconfigs:
|
||||||
|
|
||||||
|
nway_kshot_dataset = build_dataset(cfg=dataconfig)
|
||||||
|
nway_kshot_dataloader = build_dataloader(
|
||||||
|
nway_kshot_dataset,
|
||||||
|
samples_per_gpu=2,
|
||||||
|
workers_per_gpu=0,
|
||||||
|
num_gpus=1,
|
||||||
|
dist=False,
|
||||||
|
shuffle=True,
|
||||||
|
seed=2021)
|
||||||
|
|
||||||
|
for i, data_batch in enumerate(nway_kshot_dataloader):
|
||||||
|
assert len(data_batch['query_data']['img_metas'].data[0]) == 2
|
||||||
|
assert len(nway_kshot_dataloader.query_data_loader) == \
|
||||||
|
len(nway_kshot_dataloader.support_data_loader)
|
||||||
|
support_labels = data_batch['support_data']['gt_labels'].data[0]
|
||||||
|
assert len(set(torch.cat(
|
||||||
|
support_labels).tolist())) == dataconfig['support_way']
|
||||||
|
assert len(torch.cat(support_labels).tolist()) == \
|
||||||
|
dataconfig['support_way'] * dataconfig['support_shot']
|
||||||
|
|
||||||
|
dataconfigs = [{
|
||||||
|
'type': 'QueryAwareDataset',
|
||||||
|
'support_way': 3,
|
||||||
|
'support_shot': 5,
|
||||||
|
'dataset': {
|
||||||
|
'type':
|
||||||
|
'FewShotVOCDataset',
|
||||||
|
'ann_file': [
|
||||||
|
'tests/data/VOCdevkit/VOC2007/ImageSets/Main/trainval.txt',
|
||||||
|
'tests/data/VOCdevkit/VOC2012/ImageSets/Main/trainval.txt'
|
||||||
|
],
|
||||||
|
'img_prefix': [
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
],
|
||||||
|
'pipeline': {
|
||||||
|
'query': [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations', with_bbox=True),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.5),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='DefaultFormatBundle'),
|
||||||
|
dict(
|
||||||
|
type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
|
||||||
|
],
|
||||||
|
'support': [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations', with_bbox=True),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.5),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='DefaultFormatBundle'),
|
||||||
|
dict(
|
||||||
|
type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
|
||||||
|
]
|
||||||
|
},
|
||||||
|
'classes': ('dog', 'chair', 'car'),
|
||||||
|
'merge_dataset':
|
||||||
|
True
|
||||||
|
}
|
||||||
|
}, {
|
||||||
|
'type': 'QueryAwareDataset',
|
||||||
|
'support_way': 3,
|
||||||
|
'support_shot': 2,
|
||||||
|
'dataset': {
|
||||||
|
'type':
|
||||||
|
'FewShotVOCDataset',
|
||||||
|
'ann_file': [
|
||||||
|
'tests/data/few_shot_voc_split/1.txt',
|
||||||
|
'tests/data/few_shot_voc_split/2.txt',
|
||||||
|
'tests/data/few_shot_voc_split/3.txt',
|
||||||
|
'tests/data/few_shot_voc_split/4.txt',
|
||||||
|
'tests/data/few_shot_voc_split/5.txt'
|
||||||
|
],
|
||||||
|
'img_prefix': [
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
],
|
||||||
|
'ann_shot_filter': [{
|
||||||
|
'person': 1
|
||||||
|
}, {
|
||||||
|
'dog': 1
|
||||||
|
}, {
|
||||||
|
'chair': 2
|
||||||
|
}, {
|
||||||
|
'car': 2
|
||||||
|
}, {
|
||||||
|
'aeroplane': 2
|
||||||
|
}],
|
||||||
|
'pipeline': {
|
||||||
|
'query': [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations', with_bbox=True),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.5),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='DefaultFormatBundle'),
|
||||||
|
dict(
|
||||||
|
type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
|
||||||
|
],
|
||||||
|
'support': [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations', with_bbox=True),
|
||||||
|
dict(type='RandomFlip', flip_ratio=0.5),
|
||||||
|
dict(type='Pad', size_divisor=32),
|
||||||
|
dict(type='DefaultFormatBundle'),
|
||||||
|
dict(
|
||||||
|
type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
|
||||||
|
]
|
||||||
|
},
|
||||||
|
'classes': ('person', 'dog', 'chair', 'car', 'aeroplane'),
|
||||||
|
'merge_dataset':
|
||||||
|
True
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
|
||||||
|
for dataconfig in dataconfigs:
|
||||||
|
query_aware_dataset = build_dataset(cfg=dataconfig)
|
||||||
|
query_aware_dataloader = build_dataloader(
|
||||||
|
query_aware_dataset,
|
||||||
|
samples_per_gpu=2,
|
||||||
|
workers_per_gpu=0,
|
||||||
|
num_gpus=1,
|
||||||
|
dist=False,
|
||||||
|
shuffle=True,
|
||||||
|
seed=2021)
|
||||||
|
|
||||||
|
for i, data_batch in enumerate(query_aware_dataloader):
|
||||||
|
assert len(data_batch['query_data']['img_metas'].data[0]) == 2
|
||||||
|
assert len(data_batch['query_data']['query_class'].tolist()) == 2
|
||||||
|
support_labels = data_batch['support_data']['gt_labels'].data[0]
|
||||||
|
half_batch = len(support_labels) // 2
|
||||||
|
assert len(set(torch.cat(support_labels[:half_batch]).tolist())) \
|
||||||
|
== dataconfig['support_way']
|
||||||
|
assert len(set(torch.cat(support_labels[half_batch:]).tolist())) \
|
||||||
|
== dataconfig['support_way']
|
|
@ -0,0 +1,49 @@
|
||||||
|
from mmfewshot.apis.train import set_random_seed
|
||||||
|
from mmfewshot.detection.datasets.coco import FewShotCocoDataset
|
||||||
|
|
||||||
|
|
||||||
|
def test_few_shot_voc_dataset():
|
||||||
|
set_random_seed(2021)
|
||||||
|
# test regular annotation loading
|
||||||
|
dataconfig = {
|
||||||
|
'ann_file': 'tests/data/coco_sample.json',
|
||||||
|
'img_prefix': '',
|
||||||
|
'pipeline': {
|
||||||
|
'query': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}],
|
||||||
|
'support': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
'classes': ('bus', 'car')
|
||||||
|
}
|
||||||
|
few_shot_custom_dataset = FewShotCocoDataset(**dataconfig)
|
||||||
|
|
||||||
|
# filter image without labels
|
||||||
|
assert len(few_shot_custom_dataset.data_infos) == 2
|
||||||
|
assert few_shot_custom_dataset.CLASSES == ('bus', 'car')
|
||||||
|
# test loading annotation with specific class
|
||||||
|
dataconfig = {
|
||||||
|
'ann_file': 'tests/data/few_shot_coco_split/bus.json',
|
||||||
|
'img_prefix': '',
|
||||||
|
'ann_shot_filter': {
|
||||||
|
'bus': 5
|
||||||
|
},
|
||||||
|
'pipeline': {
|
||||||
|
'query': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}],
|
||||||
|
'support': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
'classes': ('bus', 'dog', 'car'),
|
||||||
|
}
|
||||||
|
few_shot_custom_dataset = FewShotCocoDataset(**dataconfig)
|
||||||
|
count = 0
|
||||||
|
for datainfo in few_shot_custom_dataset.data_infos:
|
||||||
|
count += len(datainfo['ann']['labels'])
|
||||||
|
for i in range(len(datainfo['ann']['labels'])):
|
||||||
|
assert datainfo['ann']['labels'][i] == 0
|
||||||
|
assert count == 5
|
|
@ -0,0 +1,112 @@
|
||||||
|
import copy
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mmfewshot.detection.datasets import FewShotCustomDataset
|
||||||
|
|
||||||
|
data_infos = [
|
||||||
|
{
|
||||||
|
'id': '1',
|
||||||
|
'filename': 'tests/data/VOCdevkit/VOC2007/JPEGImages/000001.jpg',
|
||||||
|
'width': 800,
|
||||||
|
'height': 720,
|
||||||
|
'ann': {
|
||||||
|
'bboxes': np.array([[10, 10, 100, 100], [20, 20, 200, 200]]),
|
||||||
|
'labels': np.array([0, 1])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'id': '2',
|
||||||
|
'filename': 'tests/data/VOCdevkit/VOC2007/JPEGImages/000002.jpg',
|
||||||
|
'width': 800,
|
||||||
|
'height': 720,
|
||||||
|
'ann': {
|
||||||
|
'bboxes': np.array([[11, 11, 100, 100], [20, 20, 200, 200]]),
|
||||||
|
'labels': np.array([1, 1])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'id': '3',
|
||||||
|
'filename': 'tests/data/VOCdevkit/VOC2007/JPEGImages/000003.jpg',
|
||||||
|
'width': 800,
|
||||||
|
'height': 720,
|
||||||
|
'ann': {
|
||||||
|
'bboxes':
|
||||||
|
np.array([[11, 11, 100, 100], [20, 20, 200, 200],
|
||||||
|
[20, 20, 200, 200]]),
|
||||||
|
'labels':
|
||||||
|
np.array([2, 3, 3, 4])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'id': '4',
|
||||||
|
'filename': 'tests/data/VOCdevkit/VOC2007/JPEGImages/000004.jpg',
|
||||||
|
'width': 800,
|
||||||
|
'height': 720,
|
||||||
|
'ann': {
|
||||||
|
'bboxes':
|
||||||
|
np.array([[11, 11, 100, 100], [20, 20, 200, 200],
|
||||||
|
[20, 20, 200, 200], [20, 20, 200, 200]]),
|
||||||
|
'labels':
|
||||||
|
np.array([2, 2, 4, 4])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@patch('mmfewshot.detection.datasets.FewShotCustomDataset.load_annotations',
|
||||||
|
MagicMock(return_value=data_infos))
|
||||||
|
def test_few_shot_custom_dataset():
|
||||||
|
dataconfig = {
|
||||||
|
'ann_file': '',
|
||||||
|
'img_prefix': '',
|
||||||
|
'ann_shot_filter': {
|
||||||
|
'cat': 10,
|
||||||
|
'dog': 10,
|
||||||
|
'person': 2,
|
||||||
|
'car': 2,
|
||||||
|
},
|
||||||
|
'pipeline': {
|
||||||
|
'query': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}],
|
||||||
|
'support': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
'classes': ('cat', 'dog', 'person', 'car', 'bird')
|
||||||
|
}
|
||||||
|
|
||||||
|
few_shot_custom_dataset = FewShotCustomDataset(**dataconfig)
|
||||||
|
original_data_infos = copy.deepcopy(few_shot_custom_dataset.data_infos)
|
||||||
|
|
||||||
|
# test prepare_train_img()
|
||||||
|
data = few_shot_custom_dataset.prepare_train_img(0, 'query')
|
||||||
|
assert (data['img_info']['ann']['bboxes'] == np.array([[10, 10, 100, 100],
|
||||||
|
[20, 20, 200,
|
||||||
|
200]])).all()
|
||||||
|
assert (data['img_info']['ann']['labels'] == np.array([0, 1])).all()
|
||||||
|
|
||||||
|
data = few_shot_custom_dataset.prepare_train_img(1, 'support')
|
||||||
|
assert (data['img_info']['ann']['bboxes'] == np.array([[11, 11, 100, 100],
|
||||||
|
[20, 20, 200,
|
||||||
|
200]])).all()
|
||||||
|
assert (data['img_info']['ann']['labels'] == np.array([1, 1])).all()
|
||||||
|
|
||||||
|
data = few_shot_custom_dataset.prepare_train_img(0, 'query', [0])
|
||||||
|
assert (data['img_info']['ann']['bboxes'] == np.array([[10, 10, 100,
|
||||||
|
100]])).all()
|
||||||
|
assert (data['img_info']['ann']['labels'] == np.array([0])).all()
|
||||||
|
|
||||||
|
data = few_shot_custom_dataset.prepare_train_img(0, 'support', [1])
|
||||||
|
assert (data['img_info']['ann']['bboxes'] == np.array([[20, 20, 200,
|
||||||
|
200]])).all()
|
||||||
|
assert (data['img_info']['ann']['labels'] == np.array([1])).all()
|
||||||
|
|
||||||
|
# test whether data_infos have been accidentally changed or not
|
||||||
|
for i in range(len(few_shot_custom_dataset)):
|
||||||
|
assert (original_data_infos[i]['ann']['bboxes'] ==
|
||||||
|
few_shot_custom_dataset.data_infos[i]['ann']['bboxes']).all()
|
||||||
|
assert (original_data_infos[i]['ann']['labels'] ==
|
||||||
|
few_shot_custom_dataset.data_infos[i]['ann']['labels']).all()
|
|
@ -0,0 +1,70 @@
|
||||||
|
from mmfewshot.apis.train import set_random_seed
|
||||||
|
from mmfewshot.detection.datasets.voc import FewShotVOCDataset
|
||||||
|
|
||||||
|
|
||||||
|
def test_few_shot_voc_dataset():
|
||||||
|
set_random_seed(2021)
|
||||||
|
# test regular annotation loading
|
||||||
|
dataconfig = {
|
||||||
|
'ann_file': 'tests/data/VOCdevkit/VOC2007/ImageSets/Main/trainval.txt',
|
||||||
|
'img_prefix': 'tests/data/VOCdevkit/',
|
||||||
|
'pipeline': {
|
||||||
|
'query': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}],
|
||||||
|
'support': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
'classes': ('car', 'dog', 'chair')
|
||||||
|
}
|
||||||
|
few_shot_custom_dataset = FewShotVOCDataset(**dataconfig)
|
||||||
|
|
||||||
|
# filter image without labels
|
||||||
|
assert len(few_shot_custom_dataset.data_infos) == 4
|
||||||
|
assert few_shot_custom_dataset.CLASSES == ('car', 'dog', 'chair')
|
||||||
|
# test loading annotation with specific class
|
||||||
|
dataconfig = {
|
||||||
|
'ann_file': 'tests/data/few_shot_voc_split/1.txt',
|
||||||
|
'img_prefix': 'tests/data/VOCdevkit/',
|
||||||
|
'ann_shot_filter': {
|
||||||
|
'aeroplane': 10
|
||||||
|
},
|
||||||
|
'pipeline': {
|
||||||
|
'query': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}],
|
||||||
|
'support': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
'classes': ('car', 'dog', 'chair', 'aeroplane'),
|
||||||
|
}
|
||||||
|
few_shot_custom_dataset = FewShotVOCDataset(**dataconfig)
|
||||||
|
count = 0
|
||||||
|
for datainfo in few_shot_custom_dataset.data_infos:
|
||||||
|
count += len(datainfo['ann']['bboxes'])
|
||||||
|
assert count == 5
|
||||||
|
|
||||||
|
# test loading annotation with specific class with specific shot
|
||||||
|
dataconfig = {
|
||||||
|
'ann_file': 'tests/data/few_shot_voc_split/1.txt',
|
||||||
|
'img_prefix': 'tests/data/VOCdevkit/',
|
||||||
|
'ann_shot_filter': {
|
||||||
|
'aeroplane': 2
|
||||||
|
},
|
||||||
|
'pipeline': {
|
||||||
|
'query': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}],
|
||||||
|
'support': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
'classes': ('car', 'dog', 'chair', 'aeroplane'),
|
||||||
|
}
|
||||||
|
few_shot_custom_dataset = FewShotVOCDataset(**dataconfig)
|
||||||
|
count = 0
|
||||||
|
for datainfo in few_shot_custom_dataset.data_infos:
|
||||||
|
count += len(datainfo['ann']['bboxes'])
|
||||||
|
assert count == 2
|
|
@ -0,0 +1,137 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mmfewshot.apis.train import set_random_seed
|
||||||
|
from mmfewshot.detection.datasets.builder import build_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_dataset():
|
||||||
|
set_random_seed(2023)
|
||||||
|
# test merge dataset load regular annotation
|
||||||
|
dataconfig = {
|
||||||
|
'type':
|
||||||
|
'FewShotVOCDataset',
|
||||||
|
'ann_file': [
|
||||||
|
'tests/data/VOCdevkit/VOC2007/ImageSets/Main/trainval.txt',
|
||||||
|
'tests/data/VOCdevkit/VOC2012/ImageSets/Main/trainval.txt'
|
||||||
|
],
|
||||||
|
'img_prefix': [
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
],
|
||||||
|
'pipeline': {
|
||||||
|
'query': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}],
|
||||||
|
'support': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
'classes': ('person', 'dog', 'chair', 'car', 'aeroplane'),
|
||||||
|
'merge_dataset':
|
||||||
|
True
|
||||||
|
}
|
||||||
|
merge_dataset = build_dataset(cfg=dataconfig)
|
||||||
|
count = [0 for _ in range(5)]
|
||||||
|
for data_info in merge_dataset.dataset.data_infos:
|
||||||
|
# test label merge
|
||||||
|
if data_info['id'] == '000001':
|
||||||
|
assert (np.sort(data_info['ann']['labels']) == np.array([0, 1
|
||||||
|
])).all()
|
||||||
|
for label in data_info['ann']['labels']:
|
||||||
|
count[label] += 1
|
||||||
|
assert count == [4, 1, 4, 7, 5]
|
||||||
|
|
||||||
|
# test merge dataset load annotation by class
|
||||||
|
dataconfig = {
|
||||||
|
'type':
|
||||||
|
'FewShotVOCDataset',
|
||||||
|
'ann_file': [
|
||||||
|
'tests/data/few_shot_voc_split/1.txt',
|
||||||
|
'tests/data/few_shot_voc_split/2.txt',
|
||||||
|
'tests/data/few_shot_voc_split/3.txt',
|
||||||
|
'tests/data/few_shot_voc_split/4.txt',
|
||||||
|
'tests/data/few_shot_voc_split/5.txt'
|
||||||
|
],
|
||||||
|
'img_prefix': [
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
],
|
||||||
|
'ann_shot_filter': [{
|
||||||
|
'person': 2
|
||||||
|
}, {
|
||||||
|
'dog': 2
|
||||||
|
}, {
|
||||||
|
'chair': 3
|
||||||
|
}, {
|
||||||
|
'car': 3
|
||||||
|
}, {
|
||||||
|
'aeroplane': 3
|
||||||
|
}],
|
||||||
|
'pipeline': {
|
||||||
|
'query': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}],
|
||||||
|
'support': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
'classes': ('person', 'dog', 'chair', 'car', 'aeroplane'),
|
||||||
|
'merge_dataset':
|
||||||
|
True
|
||||||
|
}
|
||||||
|
merge_dataset = build_dataset(cfg=dataconfig)
|
||||||
|
count = [0 for _ in range(5)]
|
||||||
|
for data_info in merge_dataset.dataset.data_infos:
|
||||||
|
# test label merge
|
||||||
|
if data_info['id'] == '000001':
|
||||||
|
assert (np.sort(data_info['ann']['labels']) == np.array([0, 1
|
||||||
|
])).all()
|
||||||
|
for label in data_info['ann']['labels']:
|
||||||
|
count[label] += 1
|
||||||
|
assert count == [2, 1, 3, 3, 3]
|
||||||
|
|
||||||
|
# test loading annotation with specific class with specific shot
|
||||||
|
dataconfig = {
|
||||||
|
'type':
|
||||||
|
'FewShotCocoDataset',
|
||||||
|
'ann_file': [
|
||||||
|
'tests/data/few_shot_coco_split/bus.json',
|
||||||
|
'tests/data/few_shot_coco_split/car.json',
|
||||||
|
'tests/data/few_shot_coco_split/cat.json',
|
||||||
|
'tests/data/few_shot_coco_split/dog.json',
|
||||||
|
'tests/data/few_shot_coco_split/person.json',
|
||||||
|
],
|
||||||
|
'img_prefix': ['', '', '', '', ''],
|
||||||
|
'ann_shot_filter': [{
|
||||||
|
'bus': 2
|
||||||
|
}, {
|
||||||
|
'car': 2
|
||||||
|
}, {
|
||||||
|
'cat': 3
|
||||||
|
}, {
|
||||||
|
'dog': 3
|
||||||
|
}, {
|
||||||
|
'person': 3
|
||||||
|
}],
|
||||||
|
'pipeline': {
|
||||||
|
'query': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}],
|
||||||
|
'support': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
'classes': ('bus', 'car', 'cat', 'dog', 'person'),
|
||||||
|
'merge_dataset':
|
||||||
|
True
|
||||||
|
}
|
||||||
|
merge_dataset = build_dataset(cfg=dataconfig)
|
||||||
|
count = [0 for _ in range(5)]
|
||||||
|
for data_info in merge_dataset.dataset.data_infos:
|
||||||
|
# test label merge
|
||||||
|
for label in data_info['ann']['labels']:
|
||||||
|
count[label] += 1
|
||||||
|
assert count == [2, 2, 3, 3, 3]
|
|
@ -0,0 +1,145 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mmfewshot.apis.train import set_random_seed
|
||||||
|
from mmfewshot.detection.datasets.builder import build_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def test_nway_kshot_dataset():
|
||||||
|
set_random_seed(2021)
|
||||||
|
# test regular and few shot annotations
|
||||||
|
dataconfigs = [{
|
||||||
|
'type': 'NwayKshotDataset',
|
||||||
|
'support_way': 5,
|
||||||
|
'support_shot': 1,
|
||||||
|
'dataset': {
|
||||||
|
'type':
|
||||||
|
'FewShotVOCDataset',
|
||||||
|
'ann_file': [
|
||||||
|
'tests/data/VOCdevkit/VOC2007/ImageSets/Main/trainval.txt',
|
||||||
|
'tests/data/VOCdevkit/VOC2012/ImageSets/Main/trainval.txt'
|
||||||
|
],
|
||||||
|
'img_prefix': [
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
],
|
||||||
|
'pipeline': {
|
||||||
|
'query': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}],
|
||||||
|
'support': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
'classes': ('person', 'dog', 'chair', 'car', 'aeroplane'),
|
||||||
|
'merge_dataset':
|
||||||
|
True
|
||||||
|
}
|
||||||
|
}, {
|
||||||
|
'type': 'NwayKshotDataset',
|
||||||
|
'support_way': 5,
|
||||||
|
'support_shot': 1,
|
||||||
|
'dataset': {
|
||||||
|
'type':
|
||||||
|
'FewShotVOCDataset',
|
||||||
|
'ann_file': [
|
||||||
|
'tests/data/few_shot_voc_split/1.txt',
|
||||||
|
'tests/data/few_shot_voc_split/2.txt',
|
||||||
|
'tests/data/few_shot_voc_split/3.txt',
|
||||||
|
'tests/data/few_shot_voc_split/4.txt',
|
||||||
|
'tests/data/few_shot_voc_split/5.txt'
|
||||||
|
],
|
||||||
|
'img_prefix': [
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
],
|
||||||
|
'ann_shot_filter': [{
|
||||||
|
'person': 2
|
||||||
|
}, {
|
||||||
|
'dog': 2
|
||||||
|
}, {
|
||||||
|
'chair': 3
|
||||||
|
}, {
|
||||||
|
'car': 3
|
||||||
|
}, {
|
||||||
|
'aeroplane': 3
|
||||||
|
}],
|
||||||
|
'pipeline': {
|
||||||
|
'query': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}],
|
||||||
|
'support': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
'classes': ('person', 'dog', 'chair', 'car', 'aeroplane'),
|
||||||
|
'merge_dataset':
|
||||||
|
True
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
for dataconfig in dataconfigs:
|
||||||
|
# test query dataset with 5 way 1 shot
|
||||||
|
nway_kshot_dataset = build_dataset(cfg=dataconfig)
|
||||||
|
assert nway_kshot_dataset.data_type == 'query'
|
||||||
|
assert np.sum(nway_kshot_dataset.flag) == 0
|
||||||
|
assert isinstance(nway_kshot_dataset[0], dict)
|
||||||
|
# test support dataset with 5 way 1 shot
|
||||||
|
nway_kshot_dataset.convert_query_to_support(support_dataset_len=2)
|
||||||
|
batch_index = nway_kshot_dataset.batch_index
|
||||||
|
assert nway_kshot_dataset.data_type == 'support'
|
||||||
|
assert nway_kshot_dataset.flag.shape[0] == 2
|
||||||
|
assert len(batch_index) == 2
|
||||||
|
assert len(batch_index[0]) == 5
|
||||||
|
assert len(batch_index[0][0]) == 2
|
||||||
|
# test batch of support dataset with 5 way 1 shot
|
||||||
|
support_batch = nway_kshot_dataset[0]
|
||||||
|
assert isinstance(support_batch, list)
|
||||||
|
count_classes = [0 for _ in range(5)]
|
||||||
|
for item in support_batch:
|
||||||
|
count_classes[item['ann_info']['labels'][0]] += 1
|
||||||
|
for count in count_classes:
|
||||||
|
assert count == 1
|
||||||
|
# test support dataset with 4 way 2 shot
|
||||||
|
dataconfig['support_way'] = 4
|
||||||
|
dataconfig['support_shot'] = 2
|
||||||
|
nway_kshot_dataset = build_dataset(cfg=dataconfig)
|
||||||
|
assert nway_kshot_dataset.data_type == 'query'
|
||||||
|
assert np.sum(nway_kshot_dataset.flag) == 0
|
||||||
|
assert isinstance(nway_kshot_dataset[0], dict)
|
||||||
|
# test support dataset with 4 way 2 shot
|
||||||
|
nway_kshot_dataset.convert_query_to_support(support_dataset_len=3)
|
||||||
|
batch_index = nway_kshot_dataset.batch_index
|
||||||
|
assert nway_kshot_dataset.data_type == 'support'
|
||||||
|
assert nway_kshot_dataset.flag.shape[0] == 3
|
||||||
|
assert len(batch_index) == 3
|
||||||
|
assert len(batch_index[0]) == 4 * 2
|
||||||
|
assert len(batch_index[0][0]) == 2
|
||||||
|
for i in range(len(nway_kshot_dataset.CLASSES)):
|
||||||
|
assert len(nway_kshot_dataset.data_infos_by_class[i]) >= 2
|
||||||
|
# test batch of support dataset with 4 way 2 shot
|
||||||
|
for idx in range(3):
|
||||||
|
support_batch = nway_kshot_dataset[idx]
|
||||||
|
assert isinstance(support_batch, list)
|
||||||
|
count_classes = [0 for _ in range(5)]
|
||||||
|
dog_ann = None
|
||||||
|
for item in support_batch:
|
||||||
|
label = item['ann_info']['labels'][0]
|
||||||
|
count_classes[label] += 1
|
||||||
|
# test whether dog label is repeat or not
|
||||||
|
# (only one dog instance)
|
||||||
|
if label == 1:
|
||||||
|
if dog_ann is None:
|
||||||
|
dog_ann = item['ann_info']['bboxes']
|
||||||
|
else:
|
||||||
|
assert (dog_ann == item['ann_info']['bboxes']).all()
|
||||||
|
# test number of classes sampled
|
||||||
|
# 4 class have 2 shots 1 class has 0 shot
|
||||||
|
is_skip = False
|
||||||
|
for count in count_classes:
|
||||||
|
if count == 0:
|
||||||
|
assert not is_skip
|
||||||
|
is_skip = True
|
||||||
|
else:
|
||||||
|
assert count == 2
|
|
@ -0,0 +1,136 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mmfewshot.apis.train import set_random_seed
|
||||||
|
from mmfewshot.detection.datasets.builder import build_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_aware_dataset():
|
||||||
|
set_random_seed(2023)
|
||||||
|
# test regular annotations
|
||||||
|
dataconfig = {
|
||||||
|
'type': 'QueryAwareDataset',
|
||||||
|
'support_way': 3,
|
||||||
|
'support_shot': 5,
|
||||||
|
'dataset': {
|
||||||
|
'type':
|
||||||
|
'FewShotVOCDataset',
|
||||||
|
'ann_file': [
|
||||||
|
'tests/data/VOCdevkit/VOC2007/ImageSets/Main/trainval.txt',
|
||||||
|
'tests/data/VOCdevkit/VOC2012/ImageSets/Main/trainval.txt'
|
||||||
|
],
|
||||||
|
'img_prefix': [
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
],
|
||||||
|
'pipeline': {
|
||||||
|
'query': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}],
|
||||||
|
'support': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
'classes': ('dog', 'chair', 'car'),
|
||||||
|
'merge_dataset':
|
||||||
|
True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
# test query dataset with 5 way 2 shot
|
||||||
|
query_aware_dataset = build_dataset(cfg=dataconfig)
|
||||||
|
|
||||||
|
assert np.sum(query_aware_dataset.flag) == 0
|
||||||
|
# print(query_aware_dataset.data_infos_by_class)
|
||||||
|
# self.data_infos_by_class = {
|
||||||
|
# 0: [(0, 0)],
|
||||||
|
# 1: [(1, 0), (3, 0), (3, 1), (3, 2)],
|
||||||
|
# 2: [(2, 0), (2, 1), (2, 2), (2, 3), (2, 4), (2, 5), (2, 6)]
|
||||||
|
# }
|
||||||
|
assert query_aware_dataset.sample_support_shots(0, 0, True) == \
|
||||||
|
[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0)]
|
||||||
|
support = query_aware_dataset.sample_support_shots(0, 1, False)
|
||||||
|
assert len(set(support)) == 4
|
||||||
|
support = query_aware_dataset.sample_support_shots(1, 1, False)
|
||||||
|
assert len(set(support)) == 3
|
||||||
|
support = query_aware_dataset.sample_support_shots(3, 1, False)
|
||||||
|
assert len(set(support)) == 1
|
||||||
|
support = query_aware_dataset.sample_support_shots(3, 2)
|
||||||
|
assert len(set(support)) == 5
|
||||||
|
support = query_aware_dataset.sample_support_shots(3, 0)
|
||||||
|
assert len(set(support)) == 1
|
||||||
|
|
||||||
|
dataconfig = {
|
||||||
|
'type': 'QueryAwareDataset',
|
||||||
|
'support_way': 3,
|
||||||
|
'support_shot': 2,
|
||||||
|
'dataset': {
|
||||||
|
'type':
|
||||||
|
'FewShotVOCDataset',
|
||||||
|
'ann_file': [
|
||||||
|
'tests/data/few_shot_voc_split/1.txt',
|
||||||
|
'tests/data/few_shot_voc_split/2.txt',
|
||||||
|
'tests/data/few_shot_voc_split/3.txt',
|
||||||
|
'tests/data/few_shot_voc_split/4.txt',
|
||||||
|
'tests/data/few_shot_voc_split/5.txt'
|
||||||
|
],
|
||||||
|
'img_prefix': [
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
'tests/data/VOCdevkit/',
|
||||||
|
],
|
||||||
|
'ann_shot_filter': [{
|
||||||
|
'person': 1
|
||||||
|
}, {
|
||||||
|
'dog': 1
|
||||||
|
}, {
|
||||||
|
'chair': 2
|
||||||
|
}, {
|
||||||
|
'car': 2
|
||||||
|
}, {
|
||||||
|
'aeroplane': 2
|
||||||
|
}],
|
||||||
|
'pipeline': {
|
||||||
|
'query': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}],
|
||||||
|
'support': [{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
'classes': ('person', 'dog', 'chair', 'car', 'aeroplane'),
|
||||||
|
'merge_dataset':
|
||||||
|
True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
query_aware_dataset = build_dataset(cfg=dataconfig)
|
||||||
|
|
||||||
|
assert np.sum(query_aware_dataset.flag) == 0
|
||||||
|
# print(query_aware_dataset.data_infos_by_class)
|
||||||
|
# self.data_infos_by_class = {
|
||||||
|
# 0: [(0, 0)],
|
||||||
|
# 1: [(1, 0)],
|
||||||
|
# 2: [(2, 0), (2, 1)],
|
||||||
|
# 3: [(3, 0), (3, 1)],
|
||||||
|
# 4: [(4, 0), (5, 0)]}
|
||||||
|
assert query_aware_dataset.sample_support_shots(0, 0, True) == \
|
||||||
|
[(0, 0), (0, 0)]
|
||||||
|
support = query_aware_dataset.sample_support_shots(0, 1, False)
|
||||||
|
assert len(set(support)) == 1
|
||||||
|
support = query_aware_dataset.sample_support_shots(3, 0)
|
||||||
|
assert len(set(support)) == 1
|
||||||
|
assert len(support) == 2
|
||||||
|
support = query_aware_dataset.sample_support_shots(3, 2)
|
||||||
|
assert len(set(support)) == 2
|
||||||
|
|
||||||
|
batch = query_aware_dataset[0]
|
||||||
|
assert len(batch['support_data']) == 6
|
||||||
|
assert batch['query_data']['ann_info']['labels'][0] == \
|
||||||
|
batch['support_data'][0]['ann_info']['labels'][0]
|
||||||
|
assert batch['query_data']['ann_info']['labels'][0] == \
|
||||||
|
batch['support_data'][1]['ann_info']['labels'][0]
|
||||||
|
assert batch['support_data'][2]['ann_info']['labels'][0] == \
|
||||||
|
batch['support_data'][3]['ann_info']['labels'][0]
|
||||||
|
assert batch['support_data'][4]['ann_info']['labels'][0] == \
|
||||||
|
batch['support_data'][5]['ann_info']['labels'][0]
|