diff --git a/mmcls/datasets/__init__.py b/mmcls/datasets/__init__.py index dfe28bf90..2014692d3 100644 --- a/mmcls/datasets/__init__.py +++ b/mmcls/datasets/__init__.py @@ -10,6 +10,7 @@ from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset, from .imagenet import ImageNet, ImageNet21k from .mnist import MNIST, FashionMNIST from .multi_label import MultiLabelDataset +from .pipelines import * # noqa: F401,F403 from .samplers import DistributedSampler, RepeatAugSampler from .voc import VOC diff --git a/mmcls/utils/__init__.py b/mmcls/utils/__init__.py index 4afaf6ff4..f2533fea4 100644 --- a/mmcls/utils/__init__.py +++ b/mmcls/utils/__init__.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from .collect_env import collect_env from .logger import get_root_logger, load_json_log -from .setup_env import setup_multi_processes +from .setup_env import register_all_modules, setup_multi_processes __all__ = [ - 'collect_env', 'get_root_logger', 'load_json_log', 'setup_multi_processes' + 'collect_env', 'get_root_logger', 'load_json_log', 'setup_multi_processes', + 'register_all_modules' ] diff --git a/mmcls/utils/setup_env.py b/mmcls/utils/setup_env.py index 21def2f08..4cbd44401 100644 --- a/mmcls/utils/setup_env.py +++ b/mmcls/utils/setup_env.py @@ -1,10 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. +import datetime import os import platform import warnings import cv2 import torch.multiprocessing as mp +from mmengine import DefaultScope def setup_multi_processes(cfg): @@ -45,3 +47,35 @@ def setup_multi_processes(cfg): f'overloaded, please further tune the variable for optimal ' f'performance in your application as needed.') os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) + + +def register_all_modules(init_default_scope: bool = True) -> None: + """Register all modules in mmcls into the registries. + + Args: + init_default_scope (bool): Whether initialize the mmcls default scope. + If True, the global default scope will be set to `mmcls`, and all + registries will build modules from mmcls's registry node. To + understand more about the registry, please refer to + https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md + Defaults to True. + """ # noqa + import mmcls.core # noqa: F401,F403 + import mmcls.datasets # noqa: F401,F403 + import mmcls.metrics # noqa: F401,F403 + import mmcls.models # noqa: F401,F403 + + if not init_default_scope: + return + + current_scope = DefaultScope.get_current_instance() + if current_scope is None: + DefaultScope.get_instance('mmcls', scope_name='mmcls') + elif current_scope.scope_name != 'mmcls': + warnings.warn(f'The current default scope "{current_scope.scope_name}"' + ' is not "mmcls", `register_all_modules` will force the ' + 'current default scope to be "mmcls". If this is not ' + 'expected, please set `init_default_scope=False`.') + # avoid name conflict + new_instance_name = f'mmcls-{datetime.datetime.now()}' + DefaultScope.get_instance(new_instance_name, scope_name='mmcls') diff --git a/tests/test_utils/test_setup_env.py b/tests/test_utils/test_setup_env.py index 2679dbbf5..050069bc4 100644 --- a/tests/test_utils/test_setup_env.py +++ b/tests/test_utils/test_setup_env.py @@ -1,12 +1,47 @@ # Copyright (c) OpenMMLab. All rights reserved. +import datetime import multiprocessing as mp import os import platform +import sys +from unittest import TestCase import cv2 from mmcv import Config +from mmengine import DefaultScope -from mmcls.utils import setup_multi_processes +from mmcls.utils import register_all_modules, setup_multi_processes + + +class TestSetupEnv(TestCase): + + def test_register_all_modules(self): + from mmcls.registry import DATASETS + + # not init default scope + sys.modules.pop('mmcls.datasets', None) + sys.modules.pop('mmcls.datasets.custom', None) + DATASETS._module_dict.pop('CustomDataset', None) + self.assertFalse('CustomDataset' in DATASETS.module_dict) + register_all_modules(init_default_scope=False) + self.assertTrue('CustomDataset' in DATASETS.module_dict) + + # init default scope + sys.modules.pop('mmcls.datasets') + sys.modules.pop('mmcls.datasets.custom') + DATASETS._module_dict.pop('CustomDataset', None) + self.assertFalse('CustomDataset' in DATASETS.module_dict) + register_all_modules(init_default_scope=True) + self.assertTrue('CustomDataset' in DATASETS.module_dict) + self.assertEqual(DefaultScope.get_current_instance().scope_name, + 'mmcls') + + # init default scope when another scope is init + name = f'test-{datetime.datetime.now()}' + DefaultScope.get_instance(name, scope_name='test') + with self.assertWarnsRegex( + Warning, 'The current default scope "test" is not "mmcls"'): + register_all_modules(init_default_scope=True) def test_setup_multi_processes(): diff --git a/tools/test.py b/tools/test.py index ddac9d72d..5b7d484c7 100644 --- a/tools/test.py +++ b/tools/test.py @@ -1,55 +1,22 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse import os -import warnings -from numbers import Number +import os.path as osp -import mmcv -import numpy as np -import torch -from mmcv import DictAction -from mmcv.parallel import MMDataParallel, MMDistributedDataParallel -from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, - wrap_fp16_model) +from mmengine.config import Config, DictAction +from mmengine.runner import Runner -from mmcls.apis import multi_gpu_test, single_gpu_test -from mmcls.datasets import build_dataloader, build_dataset -from mmcls.models import build_classifier -from mmcls.utils import get_root_logger, setup_multi_processes +from mmcls.utils import register_all_modules def parse_args(): - parser = argparse.ArgumentParser(description='mmcls test model') + parser = argparse.ArgumentParser( + description='MMCLS test (and eval) a model') parser.add_argument('config', help='test config file path') parser.add_argument('checkpoint', help='checkpoint file') - parser.add_argument('--out', help='output result file') - out_options = ['class_scores', 'pred_score', 'pred_label', 'pred_class'] parser.add_argument( - '--out-items', - nargs='+', - default=['all'], - choices=out_options + ['none', 'all'], - help='Besides metrics, what items will be included in the output ' - f'result file. You can choose some of ({", ".join(out_options)}), ' - 'or use "all" to include all above, or use "none" to disable all of ' - 'above. Defaults to output all.', - metavar='') - parser.add_argument( - '--metrics', - type=str, - nargs='+', - help='evaluation metrics, which depends on the dataset, e.g., ' - '"accuracy", "precision", "recall", "f1_score", "support" for single ' - 'label dataset, and "mAP", "CP", "CR", "CF1", "OP", "OR", "OF1" for ' - 'multi-label dataset') - parser.add_argument('--show', action='store_true', help='show results') - parser.add_argument( - '--show-dir', help='directory where painted images will be saved') - parser.add_argument( - '--gpu-collect', - action='store_true', - help='whether to use gpu to collect results') - parser.add_argument('--tmpdir', help='tmp dir for writing some results') + '--work-dir', + help='the directory to save the file containing evaluation metrics') parser.add_argument( '--cfg-options', nargs='+', @@ -60,193 +27,47 @@ def parse_args(): 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') - parser.add_argument( - '--metric-options', - nargs='+', - action=DictAction, - default={}, - help='custom options for evaluation, the key-value pair in xxx=yyy ' - 'format will be parsed as a dict metric_options for dataset.evaluate()' - ' function.') - parser.add_argument( - '--show-options', - nargs='+', - action=DictAction, - help='custom options for show_result. key-value pair in xxx=yyy.' - 'Check available options in `model.show_result`.') - parser.add_argument( - '--gpu-ids', - type=int, - nargs='+', - help='(Deprecated, please use --gpu-id) ids of gpus to use ' - '(only applicable to non-distributed testing)') - parser.add_argument( - '--gpu-id', - type=int, - default=0, - help='id of gpu to use ' - '(only applicable to non-distributed testing)') parser.add_argument( '--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) - parser.add_argument( - '--device', - choices=['cpu', 'cuda', 'ipu'], - default='cuda', - help='device used for testing') args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) - - assert args.metrics or args.out, \ - 'Please specify at least one of output path and evaluation metrics.' - return args def main(): args = parse_args() - cfg = mmcv.Config.fromfile(args.config) + # register all modules in mmcls into the registries + # do not init the default scope here because it will be init in the runner + register_all_modules(init_default_scope=False) + + # load config + cfg = Config.fromfile(args.config) + cfg.launcher = args.launcher if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) - # set multi-process settings - setup_multi_processes(cfg) + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) - # set cudnn_benchmark - if cfg.get('cudnn_benchmark', False): - torch.backends.cudnn.benchmark = True - cfg.model.pretrained = None + cfg.load_from = args.checkpoint - if args.gpu_ids is not None: - cfg.gpu_ids = args.gpu_ids[0:1] - warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. ' - 'Because we only support single GPU mode in ' - 'non-distributed testing. Use the first GPU ' - 'in `gpu_ids` now.') - else: - cfg.gpu_ids = [args.gpu_id] + # build the runner from config + runner = Runner.from_cfg(cfg) - # init distributed env first, since logger depends on the dist info. - if args.launcher == 'none': - distributed = False - else: - distributed = True - init_dist(args.launcher, **cfg.dist_params) - - dataset = build_dataset(cfg.data.test, default_args=dict(test_mode=True)) - - # build the dataloader - # The default loader config - loader_cfg = dict( - # cfg.gpus will be ignored if distributed - num_gpus=1 if args.device == 'ipu' else len(cfg.gpu_ids), - dist=distributed, - round_up=True, - ) - # The overall dataloader settings - loader_cfg.update({ - k: v - for k, v in cfg.data.items() if k not in [ - 'train', 'val', 'test', 'train_dataloader', 'val_dataloader', - 'test_dataloader' - ] - }) - test_loader_cfg = { - **loader_cfg, - 'shuffle': False, # Not shuffle by default - 'sampler_cfg': None, # Not use sampler by default - **cfg.data.get('test_dataloader', {}), - } - # the extra round_up data will be removed during gpu/cpu collect - data_loader = build_dataloader(dataset, **test_loader_cfg) - - # build the model and load checkpoint - model = build_classifier(cfg.model) - fp16_cfg = cfg.get('fp16', None) - if fp16_cfg is not None: - wrap_fp16_model(model) - checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') - - if 'CLASSES' in checkpoint.get('meta', {}): - CLASSES = checkpoint['meta']['CLASSES'] - else: - from mmcls.datasets import ImageNet - warnings.simplefilter('once') - warnings.warn('Class names are not saved in the checkpoint\'s ' - 'meta data, use imagenet by default.') - CLASSES = ImageNet.CLASSES - - if not distributed: - if args.device == 'cpu': - model = model.cpu() - elif args.device == 'ipu': - from mmcv.device.ipu import cfg2options, ipu_model_wrapper - opts = cfg2options(cfg.runner.get('options_cfg', {})) - if fp16_cfg is not None: - model.half() - model = ipu_model_wrapper(model, opts, fp16_cfg=fp16_cfg) - data_loader.init(opts['inference']) - else: - model = MMDataParallel(model, device_ids=cfg.gpu_ids) - if not model.device_ids: - assert mmcv.digit_version(mmcv.__version__) >= (1, 4, 4), \ - 'To test with CPU, please confirm your mmcv version ' \ - 'is not lower than v1.4.4' - model.CLASSES = CLASSES - show_kwargs = {} if args.show_options is None else args.show_options - outputs = single_gpu_test(model, data_loader, args.show, args.show_dir, - **show_kwargs) - else: - model = MMDistributedDataParallel( - model.cuda(), - device_ids=[torch.cuda.current_device()], - broadcast_buffers=False) - outputs = multi_gpu_test(model, data_loader, args.tmpdir, - args.gpu_collect) - - rank, _ = get_dist_info() - if rank == 0: - results = {} - logger = get_root_logger() - if args.metrics: - eval_results = dataset.evaluate( - results=outputs, - metric=args.metrics, - metric_options=args.metric_options, - logger=logger) - results.update(eval_results) - for k, v in eval_results.items(): - if isinstance(v, np.ndarray): - v = [round(out, 2) for out in v.tolist()] - elif isinstance(v, Number): - v = round(v, 2) - else: - raise ValueError(f'Unsupport metric type: {type(v)}') - print(f'\n{k} : {v}') - if args.out: - if 'none' not in args.out_items: - scores = np.vstack(outputs) - pred_score = np.max(scores, axis=1) - pred_label = np.argmax(scores, axis=1) - pred_class = [CLASSES[lb] for lb in pred_label] - res_items = { - 'class_scores': scores, - 'pred_score': pred_score, - 'pred_label': pred_label, - 'pred_class': pred_class - } - if 'all' in args.out_items: - results.update(res_items) - else: - for key in args.out_items: - results[key] = res_items[key] - print(f'\ndumping results to {args.out}') - mmcv.dump(results, args.out) + # start testing + runner.test() if __name__ == '__main__': diff --git a/tools/train.py b/tools/train.py index bd6a0c75c..1f82a5f08 100644 --- a/tools/train.py +++ b/tools/train.py @@ -1,26 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse -import copy import os import os.path as osp -import time -import warnings -import mmcv -import torch -import torch.distributed as dist -from mmcv import Config, DictAction -from mmcv.runner import get_dist_info, init_dist +from mmengine.config import Config, DictAction +from mmengine.runner import Runner -from mmcls import __version__ -from mmcls.apis import init_random_seed, set_random_seed, train_model -from mmcls.datasets import build_dataset -from mmcls.models import build_classifier -from mmcls.utils import collect_env, get_root_logger, setup_multi_processes +from mmcls.utils import register_all_modules def parse_args(): - parser = argparse.ArgumentParser(description='Train a model') + parser = argparse.ArgumentParser(description='Train a classifier') parser.add_argument('config', help='train config file path') parser.add_argument('--work-dir', help='the dir to save logs and models') parser.add_argument( @@ -29,40 +19,6 @@ def parse_args(): '--no-validate', action='store_true', help='whether not to evaluate the checkpoint during training') - group_gpus = parser.add_mutually_exclusive_group() - group_gpus.add_argument( - '--device', help='device used for training. (Deprecated)') - group_gpus.add_argument( - '--gpus', - type=int, - help='(Deprecated, please use --gpu-id) number of gpus to use ' - '(only applicable to non-distributed training)') - group_gpus.add_argument( - '--gpu-ids', - type=int, - nargs='+', - help='(Deprecated, please use --gpu-id) ids of gpus to use ' - '(only applicable to non-distributed training)') - group_gpus.add_argument( - '--gpu-id', - type=int, - default=0, - help='id of gpu to use ' - '(only applicable to non-distributed training)') - parser.add_argument( - '--ipu-replicas', - type=int, - default=None, - help='num of ipu replicas to use') - parser.add_argument('--seed', type=int, default=None, help='random seed') - parser.add_argument( - '--diff-seed', - action='store_true', - help='Whether or not set different seeds for different ranks') - parser.add_argument( - '--deterministic', - action='store_true', - help='whether to set deterministic options for CUDNN backend.') parser.add_argument( '--cfg-options', nargs='+', @@ -86,20 +42,34 @@ def parse_args(): return args +def merge_args(cfg, args): + """Merge CLI arguments to config.""" + if args.resume_from is not None: + cfg.resume = True + cfg.load_from = args.resume_from + + if args.no_validate is not None: + cfg.val_cfg = None + cfg.val_dataloader = None + cfg.val_evaluator = None + + return cfg + + def main(): args = parse_args() + # register all modules in mmcls into the registries + # do not init the default scope here because it will be init in the runner + register_all_modules(init_default_scope=False) + + # load config cfg = Config.fromfile(args.config) + cfg = merge_args(cfg, args) + cfg.launcher = args.launcher if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) - # set multi-process settings - setup_multi_processes(cfg) - - # set cudnn_benchmark - if cfg.get('cudnn_benchmark', False): - torch.backends.cudnn.benchmark = True - # work_dir is determined in this priority: CLI > segment in file > filename if args.work_dir is not None: # update configs according to CLI args if args.work_dir is not None @@ -108,95 +78,12 @@ def main(): # use config filename as default work_dir if cfg.work_dir is None cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) - if args.resume_from is not None: - cfg.resume_from = args.resume_from - if args.gpus is not None: - cfg.gpu_ids = range(1) - warnings.warn('`--gpus` is deprecated because we only support ' - 'single GPU mode in non-distributed training. ' - 'Use `gpus=1` now.') - if args.gpu_ids is not None: - cfg.gpu_ids = args.gpu_ids[0:1] - warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. ' - 'Because we only support single GPU mode in ' - 'non-distributed training. Use the first GPU ' - 'in `gpu_ids` now.') - if args.gpus is None and args.gpu_ids is None: - cfg.gpu_ids = [args.gpu_id] - if args.ipu_replicas is not None: - cfg.ipu_replicas = args.ipu_replicas - args.device = 'ipu' + # build the runner from config + runner = Runner.from_cfg(cfg) - # init distributed env first, since logger depends on the dist info. - if args.launcher == 'none': - distributed = False - else: - distributed = True - init_dist(args.launcher, **cfg.dist_params) - _, world_size = get_dist_info() - cfg.gpu_ids = range(world_size) - - # create work_dir - mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) - # dump config - cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) - # init the logger before other steps - timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) - log_file = osp.join(cfg.work_dir, f'{timestamp}.log') - logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) - - # init the meta dict to record some important information such as - # environment info and seed, which will be logged - meta = dict() - # log env info - env_info_dict = collect_env() - env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) - dash_line = '-' * 60 + '\n' - logger.info('Environment info:\n' + dash_line + env_info + '\n' + - dash_line) - meta['env_info'] = env_info - - # log some basic info - logger.info(f'Distributed training: {distributed}') - logger.info(f'Config:\n{cfg.pretty_text}') - - # set random seeds - seed = init_random_seed(args.seed) - seed = seed + dist.get_rank() if args.diff_seed else seed - logger.info(f'Set random seed to {seed}, ' - f'deterministic: {args.deterministic}') - set_random_seed(seed, deterministic=args.deterministic) - cfg.seed = seed - meta['seed'] = seed - - model = build_classifier(cfg.model) - model.init_weights() - - datasets = [build_dataset(cfg.data.train)] - if len(cfg.workflow) == 2: - val_dataset = copy.deepcopy(cfg.data.val) - val_dataset.pipeline = cfg.data.train.pipeline - datasets.append(build_dataset(val_dataset)) - - # save mmcls version, config file content and class names in - # runner as meta data - meta.update( - dict( - mmcls_version=__version__, - config=cfg.pretty_text, - CLASSES=datasets[0].CLASSES)) - - # add an attribute for visualization convenience - train_model( - model, - datasets, - cfg, - distributed=distributed, - validate=(not args.no_validate), - timestamp=timestamp, - device=args.device, - meta=meta) + # start training + runner.train() if __name__ == '__main__':