[Imporve] Update tools to enable `pin_memory` and `persistent_workers` by default. (#1024)

pull/1046/head
Ma Zerun 2022-09-14 11:57:32 +08:00 committed by GitHub
parent 96a1a34415
commit 9999da646c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 112 additions and 17 deletions

View File

@ -4,11 +4,13 @@ import copy
import os
import os.path as osp
from mmengine.config import Config, DictAction
from mmengine.config import Config, ConfigDict, DictAction
from mmengine.dist import sync_random_seed
from mmengine.fileio import dump, load
from mmengine.hooks import Hook
from mmengine.runner import Runner, find_latest_checkpoint
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
from mmcls.utils import register_all_modules
@ -58,6 +60,15 @@ def parse_args():
action='store_true',
help='whether to auto scale the learning rate according to the '
'actual batch size and the original batch size.')
parser.add_argument(
'--no-pin-memory',
action='store_true',
help='whether to disable the pin_memory option in dataloaders.')
parser.add_argument(
'--no-persistent-workers',
action='store_true',
help='whether to disable the persistent_workers option in dataloaders.'
)
parser.add_argument(
'--cfg-options',
nargs='+',
@ -112,6 +123,30 @@ def merge_args(cfg, args):
if args.auto_scale_lr:
cfg.auto_scale_lr.enable = True
# set dataloader args
default_dataloader_cfg = ConfigDict(
pin_memory=True,
persistent_workers=True,
collate_fn=dict(type='default_collate'),
)
if digit_version(TORCH_VERSION) < digit_version('1.8.0'):
default_dataloader_cfg.persistent_workers = False
def set_default_dataloader_cfg(cfg, field):
if cfg.get(field, None) is None:
return
dataloader_cfg = copy.deepcopy(default_dataloader_cfg)
dataloader_cfg.update(cfg[field])
cfg[field] = dataloader_cfg
if args.no_pin_memory:
cfg[field]['pin_memory'] = False
if args.no_persistent_workers:
cfg[field]['persistent_workers'] = False
set_default_dataloader_cfg(cfg, 'train_dataloader')
set_default_dataloader_cfg(cfg, 'val_dataloader')
set_default_dataloader_cfg(cfg, 'test_dataloader')
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

View File

@ -2,9 +2,10 @@
import argparse
import os
import os.path as osp
from copy import deepcopy
import mmengine
from mmengine.config import Config, DictAction
from mmengine.config import Config, ConfigDict, DictAction
from mmengine.hooks import Hook
from mmengine.runner import Runner
@ -51,6 +52,10 @@ def parse_args():
type=float,
default=2,
help='display time of every window. (second)')
parser.add_argument(
'--no-pin-memory',
action='store_true',
help='whether to disable the pin_memory option in dataloaders.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
@ -65,6 +70,19 @@ def parse_args():
def merge_args(cfg, args):
"""Merge CLI arguments to config."""
cfg.launcher = args.launcher
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
cfg.load_from = args.checkpoint
# -------------------- visualization --------------------
if args.show or (args.show_dir is not None):
assert 'visualization' in cfg.default_hooks, \
@ -87,6 +105,26 @@ def merge_args(cfg, args):
else:
cfg.test_evaluator = [cfg.test_evaluator, dump_metric]
# set dataloader args
default_dataloader_cfg = ConfigDict(
pin_memory=True,
collate_fn=dict(type='default_collate'),
)
def set_default_dataloader_cfg(cfg, field):
if cfg.get(field, None) is None:
return
dataloader_cfg = deepcopy(default_dataloader_cfg)
dataloader_cfg.update(cfg[field])
cfg[field] = dataloader_cfg
if args.no_pin_memory:
cfg[field]['pin_memory'] = False
set_default_dataloader_cfg(cfg, 'test_dataloader')
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
return cfg
@ -100,20 +138,6 @@ def main():
# load config
cfg = Config.fromfile(args.config)
cfg = merge_args(cfg, args)
cfg.launcher = args.launcher
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
cfg.load_from = args.checkpoint
# build the runner from config
runner = Runner.from_cfg(cfg)

View File

@ -2,9 +2,12 @@
import argparse
import os
import os.path as osp
from copy import deepcopy
from mmengine.config import Config, DictAction
from mmengine.config import Config, ConfigDict, DictAction
from mmengine.runner import Runner
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
from mmcls.utils import register_all_modules
@ -34,6 +37,15 @@ def parse_args():
action='store_true',
help='whether to auto scale the learning rate according to the '
'actual batch size and the original batch size.')
parser.add_argument(
'--no-pin-memory',
action='store_true',
help='whether to disable the pin_memory option in dataloaders.')
parser.add_argument(
'--no-persistent-workers',
action='store_true',
help='whether to disable the persistent_workers option in dataloaders.'
)
parser.add_argument(
'--cfg-options',
nargs='+',
@ -96,6 +108,30 @@ def merge_args(cfg, args):
if args.auto_scale_lr:
cfg.auto_scale_lr.enable = True
# set dataloader args
default_dataloader_cfg = ConfigDict(
pin_memory=True,
persistent_workers=True,
collate_fn=dict(type='default_collate'),
)
if digit_version(TORCH_VERSION) < digit_version('1.8.0'):
default_dataloader_cfg.persistent_workers = False
def set_default_dataloader_cfg(cfg, field):
if cfg.get(field, None) is None:
return
dataloader_cfg = deepcopy(default_dataloader_cfg)
dataloader_cfg.update(cfg[field])
cfg[field] = dataloader_cfg
if args.no_pin_memory:
cfg[field]['pin_memory'] = False
if args.no_persistent_workers:
cfg[field]['persistent_workers'] = False
set_default_dataloader_cfg(cfg, 'train_dataloader')
set_default_dataloader_cfg(cfg, 'val_dataloader')
set_default_dataloader_cfg(cfg, 'test_dataloader')
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)