update code
parent
e35f571c12
commit
4f2c81c539
|
@ -67,10 +67,5 @@ data = dict(
|
|||
img_prefix=data_root,
|
||||
pipeline=test_pipeline,
|
||||
test_mode=True,
|
||||
classes='BASE_CLASSES'),
|
||||
train_dataloader=dict(persistent_workers=False),
|
||||
val_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2),
|
||||
test_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2))
|
||||
classes='BASE_CLASSES'))
|
||||
evaluation = dict(interval=5000, metric='bbox', classwise=True)
|
||||
|
|
|
@ -75,10 +75,5 @@ data = dict(
|
|||
pipeline=test_pipeline,
|
||||
test_mode=True,
|
||||
classes=None,
|
||||
),
|
||||
train_dataloader=dict(persistent_workers=False),
|
||||
val_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2),
|
||||
test_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2))
|
||||
))
|
||||
evaluation = dict(interval=3000, metric='mAP')
|
||||
|
|
|
@ -70,12 +70,7 @@ data = dict(
|
|||
img_prefix=data_root,
|
||||
pipeline=test_pipeline,
|
||||
test_mode=True,
|
||||
classes='ALL_CLASSES'),
|
||||
train_dataloader=dict(persistent_workers=False),
|
||||
val_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2),
|
||||
test_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2))
|
||||
classes='ALL_CLASSES'))
|
||||
evaluation = dict(
|
||||
interval=4000,
|
||||
metric='bbox',
|
||||
|
|
|
@ -77,10 +77,5 @@ data = dict(
|
|||
img_prefix=data_root,
|
||||
pipeline=test_pipeline,
|
||||
test_mode=True,
|
||||
classes=None),
|
||||
train_dataloader=dict(persistent_workers=False),
|
||||
val_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2),
|
||||
test_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2))
|
||||
classes=None))
|
||||
evaluation = dict(interval=2000, metric='mAP', class_splits=None)
|
||||
|
|
|
@ -89,10 +89,5 @@ data = dict(
|
|||
pipeline=train_multi_pipelines['support'],
|
||||
instance_wise=True,
|
||||
classes='BASE_CLASSES',
|
||||
dataset_name='model_init_dataset'),
|
||||
train_dataloader=dict(persistent_workers=False),
|
||||
val_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2),
|
||||
test_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2))
|
||||
dataset_name='model_init_dataset'))
|
||||
evaluation = dict(interval=20000, metric='bbox', classwise=True)
|
||||
|
|
|
@ -112,10 +112,5 @@ data = dict(
|
|||
use_difficult=False,
|
||||
instance_wise=True,
|
||||
classes=None,
|
||||
dataset_name='model_init_dataset'),
|
||||
train_dataloader=dict(persistent_workers=False),
|
||||
val_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2),
|
||||
test_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2))
|
||||
dataset_name='model_init_dataset'))
|
||||
evaluation = dict(interval=5000, metric='mAP')
|
||||
|
|
|
@ -90,12 +90,7 @@ data = dict(
|
|||
pipeline=train_multi_pipelines['support'],
|
||||
instance_wise=True,
|
||||
classes='ALL_CLASSES',
|
||||
dataset_name='model_init_dataset'),
|
||||
train_dataloader=dict(persistent_workers=False),
|
||||
val_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2),
|
||||
test_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2))
|
||||
dataset_name='model_init_dataset'))
|
||||
evaluation = dict(
|
||||
interval=3000,
|
||||
metric='bbox',
|
||||
|
|
|
@ -97,10 +97,5 @@ data = dict(
|
|||
instance_wise=True,
|
||||
num_novel_shots=None,
|
||||
classes=None,
|
||||
dataset_name='model_init_dataset'),
|
||||
train_dataloader=dict(persistent_workers=False),
|
||||
val_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2),
|
||||
test_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2))
|
||||
dataset_name='model_init_dataset'))
|
||||
evaluation = dict(interval=3000, metric='mAP', class_splits=None)
|
||||
|
|
|
@ -103,10 +103,5 @@ data = dict(
|
|||
num_base_shots=10,
|
||||
instance_wise=True,
|
||||
min_bbox_area=32 * 32,
|
||||
dataset_name='model_init_dataset'),
|
||||
train_dataloader=dict(persistent_workers=False),
|
||||
val_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2),
|
||||
test_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2))
|
||||
dataset_name='model_init_dataset'))
|
||||
evaluation = dict(interval=20000, metric='bbox', classwise=True)
|
||||
|
|
|
@ -114,10 +114,5 @@ data = dict(
|
|||
instance_wise=True,
|
||||
classes=None,
|
||||
min_bbox_area=32 * 32,
|
||||
dataset_name='model_init'),
|
||||
train_dataloader=dict(persistent_workers=False),
|
||||
val_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2),
|
||||
test_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2))
|
||||
dataset_name='model_init'))
|
||||
evaluation = dict(interval=20000, metric='mAP')
|
||||
|
|
|
@ -101,11 +101,6 @@ data = dict(
|
|||
num_base_shots=None,
|
||||
instance_wise=True,
|
||||
min_bbox_area=32 * 32,
|
||||
dataset_name='model_init_dataset'),
|
||||
train_dataloader=dict(persistent_workers=False),
|
||||
val_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2),
|
||||
test_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2))
|
||||
dataset_name='model_init_dataset'))
|
||||
evaluation = dict(
|
||||
interval=3000, metric='bbox', classwise=True, class_splits=None)
|
||||
|
|
|
@ -104,10 +104,5 @@ data = dict(
|
|||
num_novel_shots=None,
|
||||
classes=None,
|
||||
min_bbox_area=32 * 32,
|
||||
dataset_name='model_init_dataset'),
|
||||
train_dataloader=dict(persistent_workers=False),
|
||||
val_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2),
|
||||
test_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2))
|
||||
dataset_name='model_init_dataset'))
|
||||
evaluation = dict(interval=3000, metric='mAP', class_splits=None)
|
||||
|
|
|
@ -95,10 +95,5 @@ data = dict(
|
|||
img_prefix=data_root,
|
||||
pipeline=test_pipeline,
|
||||
test_mode=True,
|
||||
classes='BASE_CLASSES'),
|
||||
train_dataloader=dict(persistent_workers=False),
|
||||
val_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2),
|
||||
test_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2))
|
||||
classes='BASE_CLASSES'))
|
||||
evaluation = dict(interval=20000, metric='bbox', classwise=True)
|
||||
|
|
|
@ -102,10 +102,5 @@ data = dict(
|
|||
pipeline=test_pipeline,
|
||||
coordinate_offset=[-1, -1, -1, -1],
|
||||
test_mode=True,
|
||||
classes=None),
|
||||
train_dataloader=dict(persistent_workers=False),
|
||||
val_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2),
|
||||
test_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2))
|
||||
classes=None))
|
||||
evaluation = dict(interval=5000, metric='mAP')
|
||||
|
|
|
@ -95,12 +95,7 @@ data = dict(
|
|||
img_prefix=data_root,
|
||||
pipeline=test_pipeline,
|
||||
test_mode=True,
|
||||
classes='ALL_CLASSES'),
|
||||
train_dataloader=dict(persistent_workers=False),
|
||||
val_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2),
|
||||
test_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2))
|
||||
classes='ALL_CLASSES'))
|
||||
evaluation = dict(
|
||||
interval=3000,
|
||||
metric='bbox',
|
||||
|
|
|
@ -102,10 +102,5 @@ data = dict(
|
|||
pipeline=test_pipeline,
|
||||
coordinate_offset=[-1, -1, -1, -1],
|
||||
test_mode=True,
|
||||
classes=None),
|
||||
train_dataloader=dict(persistent_workers=False),
|
||||
val_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2),
|
||||
test_dataloader=dict(
|
||||
persistent_workers=False, samples_per_gpu=1, workers_per_gpu=2))
|
||||
classes=None))
|
||||
evaluation = dict(interval=3000, metric='mAP', class_splits=None)
|
||||
|
|
|
@ -14,7 +14,7 @@ from mmfewshot.classification.core.evaluation import (DistMetaTestEvalHook,
|
|||
MetaTestEvalHook)
|
||||
from mmfewshot.classification.datasets.builder import (
|
||||
build_dataloader, build_dataset, build_meta_test_dataloader)
|
||||
from mmfewshot.utils import get_root_logger
|
||||
from mmfewshot.utils import compat_cfg, get_root_logger
|
||||
|
||||
|
||||
def train_model(model: Union[MMDataParallel, MMDistributedDataParallel],
|
||||
|
@ -25,12 +25,13 @@ def train_model(model: Union[MMDataParallel, MMDistributedDataParallel],
|
|||
timestamp: str = None,
|
||||
device: str = None,
|
||||
meta: Dict = None) -> None:
|
||||
cfg = compat_cfg(cfg)
|
||||
logger = get_root_logger(log_level=cfg.log_level)
|
||||
|
||||
# prepare data loaders
|
||||
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
|
||||
|
||||
loader_cfg = dict(
|
||||
train_dataloader_default_args = dict(
|
||||
# cfg.gpus will be ignored if distributed
|
||||
num_gpus=len(cfg.gpu_ids),
|
||||
dist=distributed,
|
||||
|
@ -38,16 +39,11 @@ def train_model(model: Union[MMDataParallel, MMDistributedDataParallel],
|
|||
seed=cfg.get('seed'),
|
||||
pin_memory=cfg.get('pin_memory', False),
|
||||
use_infinite_sampler=cfg.use_infinite_sampler)
|
||||
# The overall dataloader settings
|
||||
loader_cfg.update({
|
||||
k: v
|
||||
for k, v in cfg.data.items() if k not in [
|
||||
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
|
||||
'test_dataloader'
|
||||
]
|
||||
})
|
||||
# The specific dataloader settings
|
||||
train_loader_cfg = {**loader_cfg, **cfg.data.get('train_dataloader', {})}
|
||||
train_loader_cfg = {
|
||||
**train_dataloader_default_args,
|
||||
**cfg.data.get('train_dataloader', {})
|
||||
}
|
||||
|
||||
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ from mmfewshot.detection.core import (QuerySupportDistEvalHook,
|
|||
QuerySupportEvalHook)
|
||||
from mmfewshot.detection.datasets import (build_dataloader, build_dataset,
|
||||
get_copy_dataset_type)
|
||||
from mmfewshot.utils import get_root_logger
|
||||
from mmfewshot.utils import compat_cfg, get_root_logger
|
||||
|
||||
|
||||
def train_detector(model: nn.Module,
|
||||
|
@ -26,6 +26,7 @@ def train_detector(model: nn.Module,
|
|||
validate: bool = False,
|
||||
timestamp: Optional[str] = None,
|
||||
meta: Optional[Dict] = None) -> None:
|
||||
cfg = compat_cfg(cfg)
|
||||
logger = get_root_logger(log_level=cfg.log_level)
|
||||
|
||||
# prepare data loaders
|
||||
|
@ -40,13 +41,6 @@ def train_detector(model: nn.Module,
|
|||
data_cfg=copy.deepcopy(cfg.data),
|
||||
use_infinite_sampler=cfg.use_infinite_sampler,
|
||||
persistent_workers=False)
|
||||
train_dataloader_default_args.update({
|
||||
k: v
|
||||
for k, v in cfg.data.items() if k not in [
|
||||
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
|
||||
'test_dataloader', 'model_init'
|
||||
]
|
||||
})
|
||||
train_loader_cfg = {
|
||||
**train_dataloader_default_args,
|
||||
**cfg.data.get('train_dataloader', {})
|
||||
|
@ -115,22 +109,6 @@ def train_detector(model: nn.Module,
|
|||
dist=distributed,
|
||||
shuffle=False,
|
||||
persistent_workers=False)
|
||||
|
||||
# update overall dataloader(for train, val and test) setting
|
||||
val_dataloader_default_args.update({
|
||||
k: v
|
||||
for k, v in cfg.data.items() if k not in [
|
||||
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
|
||||
'test_dataloader', 'samples_per_gpu', 'model_init'
|
||||
]
|
||||
})
|
||||
if 'samples_per_gpu' in cfg.data.val:
|
||||
logger.warning('`samples_per_gpu` in `val` field of '
|
||||
'data will be deprecated, you should'
|
||||
' move it to `val_dataloader` field')
|
||||
# keep default value of `sample_per_gpu` is 1
|
||||
val_dataloader_default_args['samples_per_gpu'] = \
|
||||
cfg.data.val.pop('samples_per_gpu')
|
||||
val_dataloader_args = {
|
||||
**val_dataloader_default_args,
|
||||
**cfg.data.get('val_dataloader', {})
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .collate import multi_pipeline_collate_fn
|
||||
from .compat_config import compat_cfg
|
||||
from .dist_utils import check_dist_init, sync_random_seed
|
||||
from .infinite_sampler import (DistributedInfiniteGroupSampler,
|
||||
DistributedInfiniteSampler,
|
||||
|
@ -12,5 +13,5 @@ __all__ = [
|
|||
'multi_pipeline_collate_fn', 'local_numpy_seed',
|
||||
'InfiniteEpochBasedRunner', 'InfiniteSampler', 'InfiniteGroupSampler',
|
||||
'DistributedInfiniteSampler', 'DistributedInfiniteGroupSampler',
|
||||
'get_root_logger', 'check_dist_init', 'sync_random_seed'
|
||||
'get_root_logger', 'check_dist_init', 'sync_random_seed', 'compat_cfg'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import warnings
|
||||
|
||||
from mmcv import ConfigDict
|
||||
|
||||
|
||||
def compat_cfg(cfg):
|
||||
"""This function would modify some filed to keep the compatibility of
|
||||
config.
|
||||
|
||||
For example, it will move some args which will be deprecated to the correct
|
||||
fields.
|
||||
"""
|
||||
cfg = copy.deepcopy(cfg)
|
||||
cfg = compat_imgs_per_gpu(cfg)
|
||||
cfg = compat_loader_args(cfg)
|
||||
cfg = compat_runner_args(cfg)
|
||||
return cfg
|
||||
|
||||
|
||||
def compat_runner_args(cfg):
|
||||
if 'runner' not in cfg:
|
||||
cfg.runner = ConfigDict({
|
||||
'type': 'EpochBasedRunner',
|
||||
'max_epochs': cfg.total_epochs
|
||||
})
|
||||
warnings.warn(
|
||||
'config is now expected to have a `runner` section, '
|
||||
'please set `runner` in your config.', UserWarning)
|
||||
else:
|
||||
if 'total_epochs' in cfg:
|
||||
assert cfg.total_epochs == cfg.runner.max_epochs
|
||||
return cfg
|
||||
|
||||
|
||||
def compat_imgs_per_gpu(cfg):
|
||||
cfg = copy.deepcopy(cfg)
|
||||
if 'imgs_per_gpu' in cfg.data:
|
||||
warnings.warn('"imgs_per_gpu" is deprecated in MMDet V2.0. '
|
||||
'Please use "samples_per_gpu" instead')
|
||||
if 'samples_per_gpu' in cfg.data:
|
||||
warnings.warn(
|
||||
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
|
||||
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
|
||||
f'={cfg.data.imgs_per_gpu} is used in this experiments')
|
||||
else:
|
||||
warnings.warn('Automatically set "samples_per_gpu"="imgs_per_gpu"='
|
||||
f'{cfg.data.imgs_per_gpu} in this experiments')
|
||||
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
|
||||
return cfg
|
||||
|
||||
|
||||
def compat_loader_args(cfg):
|
||||
"""Deprecated sample_per_gpu in cfg.data."""
|
||||
|
||||
cfg = copy.deepcopy(cfg)
|
||||
if 'train_dataloader' not in cfg.data:
|
||||
cfg.data['train_dataloader'] = ConfigDict()
|
||||
if 'val_dataloader' not in cfg.data:
|
||||
cfg.data['val_dataloader'] = ConfigDict()
|
||||
if 'test_dataloader' not in cfg.data:
|
||||
cfg.data['test_dataloader'] = ConfigDict()
|
||||
|
||||
# special process for train_dataloader
|
||||
if 'samples_per_gpu' in cfg.data:
|
||||
|
||||
samples_per_gpu = cfg.data.pop('samples_per_gpu')
|
||||
assert 'samples_per_gpu' not in \
|
||||
cfg.data.train_dataloader, ('`samples_per_gpu` are set '
|
||||
'in `data` field and ` '
|
||||
'data.train_dataloader` '
|
||||
'at the same time. '
|
||||
'Please only set it in '
|
||||
'`data.train_dataloader`. ')
|
||||
cfg.data.train_dataloader['samples_per_gpu'] = samples_per_gpu
|
||||
|
||||
if 'persistent_workers' in cfg.data:
|
||||
|
||||
persistent_workers = cfg.data.pop('persistent_workers')
|
||||
assert 'persistent_workers' not in \
|
||||
cfg.data.train_dataloader, ('`persistent_workers` are set '
|
||||
'in `data` field and ` '
|
||||
'data.train_dataloader` '
|
||||
'at the same time. '
|
||||
'Please only set it in '
|
||||
'`data.train_dataloader`. ')
|
||||
cfg.data.train_dataloader['persistent_workers'] = persistent_workers
|
||||
|
||||
if 'workers_per_gpu' in cfg.data:
|
||||
|
||||
workers_per_gpu = cfg.data.pop('workers_per_gpu')
|
||||
cfg.data.train_dataloader['workers_per_gpu'] = workers_per_gpu
|
||||
cfg.data.val_dataloader['workers_per_gpu'] = workers_per_gpu
|
||||
cfg.data.test_dataloader['workers_per_gpu'] = workers_per_gpu
|
||||
|
||||
# special process for val_dataloader
|
||||
if 'samples_per_gpu' in cfg.data.val:
|
||||
# keep default value of `sample_per_gpu` is 1
|
||||
assert 'samples_per_gpu' not in \
|
||||
cfg.data.val_dataloader, ('`samples_per_gpu` are set '
|
||||
'in `data.val` field and ` '
|
||||
'data.val_dataloader` at '
|
||||
'the same time. '
|
||||
'Please only set it in '
|
||||
'`data.val_dataloader`. ')
|
||||
cfg.data.val_dataloader['samples_per_gpu'] = \
|
||||
cfg.data.val.pop('samples_per_gpu')
|
||||
# special process for val_dataloader
|
||||
|
||||
# in case the test dataset is concatenated
|
||||
if isinstance(cfg.data.test, dict):
|
||||
if 'samples_per_gpu' in cfg.data.test:
|
||||
assert 'samples_per_gpu' not in \
|
||||
cfg.data.test_dataloader, ('`samples_per_gpu` are set '
|
||||
'in `data.test` field and ` '
|
||||
'data.test_dataloader` '
|
||||
'at the same time. '
|
||||
'Please only set it in '
|
||||
'`data.test_dataloader`. ')
|
||||
|
||||
cfg.data.test_dataloader['samples_per_gpu'] = \
|
||||
cfg.data.test.pop('samples_per_gpu')
|
||||
|
||||
elif isinstance(cfg.data.test, list):
|
||||
for ds_cfg in cfg.data.test:
|
||||
if 'samples_per_gpu' in ds_cfg:
|
||||
assert 'samples_per_gpu' not in \
|
||||
cfg.data.test_dataloader, ('`samples_per_gpu` are set '
|
||||
'in `data.test` field and ` '
|
||||
'data.test_dataloader` at'
|
||||
' the same time. '
|
||||
'Please only set it in '
|
||||
'`data.test_dataloader`. ')
|
||||
samples_per_gpu = max(
|
||||
[ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])
|
||||
cfg.data.test_dataloader['samples_per_gpu'] = samples_per_gpu
|
||||
|
||||
return cfg
|
|
@ -0,0 +1,116 @@
|
|||
import pytest
|
||||
from mmcv import ConfigDict
|
||||
|
||||
from mmfewshot.utils.compat_config import (compat_imgs_per_gpu,
|
||||
compat_loader_args,
|
||||
compat_runner_args)
|
||||
|
||||
|
||||
def test_compat_runner_args():
|
||||
cfg = ConfigDict(dict(total_epochs=12))
|
||||
with pytest.warns(None) as record:
|
||||
cfg = compat_runner_args(cfg)
|
||||
assert len(record) == 1
|
||||
assert 'runner' in record.list[0].message.args[0]
|
||||
assert 'runner' in cfg
|
||||
assert cfg.runner.type == 'EpochBasedRunner'
|
||||
assert cfg.runner.max_epochs == cfg.total_epochs
|
||||
|
||||
|
||||
def test_compat_loader_args():
|
||||
cfg = ConfigDict(dict(data=dict(val=dict(), test=dict(), train=dict())))
|
||||
cfg = compat_loader_args(cfg)
|
||||
# auto fill loader args
|
||||
assert 'val_dataloader' in cfg.data
|
||||
assert 'train_dataloader' in cfg.data
|
||||
assert 'test_dataloader' in cfg.data
|
||||
cfg = ConfigDict(
|
||||
dict(
|
||||
data=dict(
|
||||
samples_per_gpu=1,
|
||||
persistent_workers=True,
|
||||
workers_per_gpu=1,
|
||||
val=dict(samples_per_gpu=3),
|
||||
test=dict(samples_per_gpu=2),
|
||||
train=dict())))
|
||||
|
||||
cfg = compat_loader_args(cfg)
|
||||
|
||||
assert cfg.data.train_dataloader.workers_per_gpu == 1
|
||||
assert cfg.data.train_dataloader.samples_per_gpu == 1
|
||||
assert cfg.data.train_dataloader.persistent_workers
|
||||
assert cfg.data.val_dataloader.workers_per_gpu == 1
|
||||
assert cfg.data.val_dataloader.samples_per_gpu == 3
|
||||
assert cfg.data.test_dataloader.workers_per_gpu == 1
|
||||
assert cfg.data.test_dataloader.samples_per_gpu == 2
|
||||
|
||||
# test test is a list
|
||||
cfg = ConfigDict(
|
||||
dict(
|
||||
data=dict(
|
||||
samples_per_gpu=1,
|
||||
persistent_workers=True,
|
||||
workers_per_gpu=1,
|
||||
val=dict(samples_per_gpu=3),
|
||||
test=[dict(samples_per_gpu=2),
|
||||
dict(samples_per_gpu=3)],
|
||||
train=dict())))
|
||||
|
||||
cfg = compat_loader_args(cfg)
|
||||
assert cfg.data.test_dataloader.samples_per_gpu == 3
|
||||
|
||||
# assert can not set args at the same time
|
||||
cfg = ConfigDict(
|
||||
dict(
|
||||
data=dict(
|
||||
samples_per_gpu=1,
|
||||
persistent_workers=True,
|
||||
workers_per_gpu=1,
|
||||
val=dict(samples_per_gpu=3),
|
||||
test=dict(samples_per_gpu=2),
|
||||
train=dict(),
|
||||
train_dataloader=dict(samples_per_gpu=2))))
|
||||
# samples_per_gpu can not be set in `train_dataloader`
|
||||
# and data field at the same time
|
||||
with pytest.raises(AssertionError):
|
||||
compat_loader_args(cfg)
|
||||
cfg = ConfigDict(
|
||||
dict(
|
||||
data=dict(
|
||||
samples_per_gpu=1,
|
||||
persistent_workers=True,
|
||||
workers_per_gpu=1,
|
||||
val=dict(samples_per_gpu=3),
|
||||
test=dict(samples_per_gpu=2),
|
||||
train=dict(),
|
||||
val_dataloader=dict(samples_per_gpu=2))))
|
||||
# samples_per_gpu can not be set in `val_dataloader`
|
||||
# and data field at the same time
|
||||
with pytest.raises(AssertionError):
|
||||
compat_loader_args(cfg)
|
||||
cfg = ConfigDict(
|
||||
dict(
|
||||
data=dict(
|
||||
samples_per_gpu=1,
|
||||
persistent_workers=True,
|
||||
workers_per_gpu=1,
|
||||
val=dict(samples_per_gpu=3),
|
||||
test=dict(samples_per_gpu=2),
|
||||
test_dataloader=dict(samples_per_gpu=2))))
|
||||
# samples_per_gpu can not be set in `test_dataloader`
|
||||
# and data field at the same time
|
||||
with pytest.raises(AssertionError):
|
||||
compat_loader_args(cfg)
|
||||
|
||||
|
||||
def test_compat_imgs_per_gpu():
|
||||
cfg = ConfigDict(
|
||||
dict(
|
||||
data=dict(
|
||||
imgs_per_gpu=1,
|
||||
samples_per_gpu=2,
|
||||
val=dict(),
|
||||
test=dict(),
|
||||
train=dict())))
|
||||
cfg = compat_imgs_per_gpu(cfg)
|
||||
assert cfg.data.samples_per_gpu == cfg.data.imgs_per_gpu
|
|
@ -13,6 +13,7 @@ from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
|
|||
from mmfewshot.detection.datasets import (build_dataloader, build_dataset,
|
||||
get_copy_dataset_type)
|
||||
from mmfewshot.detection.models import build_detector
|
||||
from mmfewshot.utils import compat_cfg
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -112,6 +113,7 @@ def main():
|
|||
raise ValueError('The output file must be a pkl file.')
|
||||
|
||||
cfg = Config.fromfile(args.config)
|
||||
cfg = compat_cfg(cfg)
|
||||
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
@ -153,10 +155,15 @@ def main():
|
|||
'test_dataloader', 'samples_per_gpu', 'model_init'
|
||||
]
|
||||
})
|
||||
test_loader_cfg = {
|
||||
**test_dataloader_default_args,
|
||||
**cfg.data.get('test_dataloader', {})
|
||||
}
|
||||
|
||||
# currently only support single images testing
|
||||
assert test_dataloader_default_args['samples_per_gpu'] == 1, \
|
||||
assert test_loader_cfg['samples_per_gpu'] == 1, \
|
||||
'currently only support single images testing'
|
||||
data_loader = build_dataloader(dataset, **test_dataloader_default_args)
|
||||
data_loader = build_dataloader(dataset, **test_loader_cfg)
|
||||
|
||||
# pop frozen_parameters
|
||||
cfg.model.pop('frozen_parameters', None)
|
||||
|
|
Loading…
Reference in New Issue