[Enhancement] More customizable fields in dataloaders (#933)

* [Enhancement] More customizable fields in val and test dataloaders

* update default_loader_cfg
pull/1018/head
Tong Gao 2022-04-18 09:24:07 +08:00 committed by GitHub
parent 20fc909fc4
commit b4a9a87eee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 25 deletions

View File

@ -30,30 +30,30 @@ def train_detector(model,
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
# step 1: give default values and override (if exist) from cfg.data
loader_cfg = {
default_loader_cfg = {
**dict(
num_gpus=len(cfg.gpu_ids),
dist=distributed,
seed=cfg.get('seed'),
drop_last=False,
dist=distributed,
num_gpus=len(cfg.gpu_ids)),
persistent_workers=False),
**({} if torch.__version__ != 'parrots' else dict(
prefetch_num=2,
pin_memory=False,
)),
**dict((k, cfg.data[k]) for k in [
'samples_per_gpu',
'workers_per_gpu',
'shuffle',
'seed',
'drop_last',
'prefetch_num',
'pin_memory',
'persistent_workers',
] if k in cfg.data)
}
# update overall dataloader(for train, val and test) setting
default_loader_cfg.update({
k: v
for k, v in cfg.data.items() if k not in [
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
'test_dataloader'
]
})
# step 2: cfg.data.train_dataloader has highest priority
train_loader_cfg = dict(loader_cfg, **cfg.data.get('train_dataloader', {}))
train_loader_cfg = dict(default_loader_cfg,
**cfg.data.get('train_dataloader', {}))
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
@ -135,7 +135,7 @@ def train_detector(model,
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
val_loader_cfg = {
**loader_cfg,
**default_loader_cfg,
**dict(shuffle=False, drop_last=False),
**cfg.data.get('val_dataloader', {}),
**dict(samples_per_gpu=val_samples_per_gpu)

View File

@ -164,22 +164,22 @@ def main():
# build the dataloader
dataset = build_dataset(cfg.data.test, dict(test_mode=True))
# step 1: give default values and override (if exist) from cfg.data
loader_cfg = {
default_loader_cfg = {
**dict(seed=cfg.get('seed'), drop_last=False, dist=distributed),
**({} if torch.__version__ != 'parrots' else dict(
prefetch_num=2,
pin_memory=False,
)),
**dict((k, cfg.data[k]) for k in [
'workers_per_gpu',
'seed',
'prefetch_num',
'pin_memory',
'persistent_workers',
] if k in cfg.data)
))
}
default_loader_cfg.update({
k: v
for k, v in cfg.data.items() if k not in [
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
'test_dataloader'
]
})
test_loader_cfg = {
**loader_cfg,
**default_loader_cfg,
**dict(shuffle=False, drop_last=False),
**cfg.data.get('test_dataloader', {}),
**dict(samples_per_gpu=samples_per_gpu)