diff --git a/mmocr/apis/test.py b/mmocr/apis/test.py deleted file mode 100644 index 489f6e92..00000000 --- a/mmocr/apis/test.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os.path as osp - -import mmcv -import numpy as np -import torch -from mmcv.image import tensor2imgs -from mmcv.parallel import DataContainer -from mmdet.core import encode_mask_results - -from .utils import tensor2grayimgs - - -def retrieve_img_tensor_and_meta(data): - """Retrieval img_tensor, img_metas and img_norm_cfg. - - Args: - data (dict): One batch data from data_loader. - - Returns: - tuple: Returns (img_tensor, img_metas, img_norm_cfg). - - - | img_tensor (Tensor): Input image tensor with shape - :math:`(N, C, H, W)`. - - | img_metas (list[dict]): The metadata of images. - - | img_norm_cfg (dict): Config for image normalization. - """ - - if isinstance(data['img'], torch.Tensor): - # for textrecog with batch_size > 1 - # and not use 'DefaultFormatBundle' in pipeline - img_tensor = data['img'] - img_metas = data['img_metas'].data[0] - elif isinstance(data['img'], list): - if isinstance(data['img'][0], torch.Tensor): - # for textrecog with aug_test and batch_size = 1 - img_tensor = data['img'][0] - elif isinstance(data['img'][0], DataContainer): - # for textdet with 'MultiScaleFlipAug' - # and 'DefaultFormatBundle' in pipeline - img_tensor = data['img'][0].data[0] - img_metas = data['img_metas'][0].data[0] - elif isinstance(data['img'], DataContainer): - # for textrecog with 'DefaultFormatBundle' in pipeline - img_tensor = data['img'].data[0] - img_metas = data['img_metas'].data[0] - - must_keys = ['img_norm_cfg', 'ori_filename', 'img_shape', 'ori_shape'] - for key in must_keys: - if key not in img_metas[0]: - raise KeyError( - f'Please add {key} to the "meta_keys" in the pipeline') - - img_norm_cfg = img_metas[0]['img_norm_cfg'] - if max(img_norm_cfg['mean']) <= 1: - img_norm_cfg['mean'] = [255 * x for x in img_norm_cfg['mean']] - img_norm_cfg['std'] = [255 * x for x in img_norm_cfg['std']] - - return img_tensor, img_metas, img_norm_cfg - - -def single_gpu_test(model, - data_loader, - show=False, - out_dir=None, - is_kie=False, - show_score_thr=0.3): - model.eval() - results = [] - dataset = data_loader.dataset - prog_bar = mmcv.ProgressBar(len(dataset)) - for data in data_loader: - with torch.no_grad(): - result = model(return_loss=False, rescale=True, **data) - - batch_size = len(result) - if show or out_dir: - if is_kie: - img_tensor = data['img'].data[0] - if img_tensor.shape[0] != 1: - raise KeyError('Visualizing KIE outputs in batches is' - 'currently not supported.') - gt_bboxes = data['gt_bboxes'].data[0] - img_metas = data['img_metas'].data[0] - must_keys = ['img_norm_cfg', 'ori_filename', 'img_shape'] - for key in must_keys: - if key not in img_metas[0]: - raise KeyError( - f'Please add {key} to the "meta_keys" in config.') - # for no visual model - if np.prod(img_tensor.shape) == 0: - imgs = [] - for img_meta in img_metas: - try: - img = mmcv.imread(img_meta['filename']) - except Exception as e: - print(f'Load image with error: {e}, ' - 'use empty image instead.') - img = np.ones( - img_meta['img_shape'], dtype=np.uint8) - imgs.append(img) - else: - imgs = tensor2imgs(img_tensor, - **img_metas[0]['img_norm_cfg']) - for i, img in enumerate(imgs): - h, w, _ = img_metas[i]['img_shape'] - img_show = img[:h, :w, :] - if out_dir: - out_file = osp.join(out_dir, - img_metas[i]['ori_filename']) - else: - out_file = None - - model.module.show_result( - img_show, - result[i], - gt_bboxes[i], - show=show, - out_file=out_file) - else: - img_tensor, img_metas, img_norm_cfg = \ - retrieve_img_tensor_and_meta(data) - - if img_tensor.size(1) == 1: - imgs = tensor2grayimgs(img_tensor, **img_norm_cfg) - else: - imgs = tensor2imgs(img_tensor, **img_norm_cfg) - assert len(imgs) == len(img_metas) - - for j, (img, img_meta) in enumerate(zip(imgs, img_metas)): - img_shape, ori_shape = img_meta['img_shape'], img_meta[ - 'ori_shape'] - img_show = img[:img_shape[0], :img_shape[1]] - img_show = mmcv.imresize(img_show, - (ori_shape[1], ori_shape[0])) - - if out_dir: - out_file = osp.join(out_dir, img_meta['ori_filename']) - else: - out_file = None - - model.module.show_result( - img_show, - result[j], - show=show, - out_file=out_file, - score_thr=show_score_thr) - - # encode mask results - if isinstance(result[0], tuple): - result = [(bbox_results, encode_mask_results(mask_results)) - for bbox_results, mask_results in result] - results.extend(result) - - for _ in range(batch_size): - prog_bar.update() - return results diff --git a/mmocr/apis/train.py b/mmocr/apis/train.py deleted file mode 100644 index 56bcc4b7..00000000 --- a/mmocr/apis/train.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import warnings - -import mmcv -import numpy as np -import torch -import torch.distributed as dist -from mmcv.parallel import MMDataParallel, MMDistributedDataParallel -from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, - Fp16OptimizerHook, OptimizerHook, build_optimizer, - build_runner, get_dist_info) -from mmdet.core import DistEvalHook, EvalHook -from mmdet.datasets import build_dataloader - -from mmocr import digit_version -from mmocr.apis.utils import (disable_text_recog_aug_test, - replace_image_to_tensor) -from mmocr.registry import DATASETS -from mmocr.utils import get_root_logger - - -def train_detector(model, - dataset, - cfg, - distributed=False, - validate=False, - timestamp=None, - meta=None): - logger = get_root_logger(cfg.log_level) - - # prepare data loaders - dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] - # step 1: give default values and override (if exist) from cfg.data - default_loader_cfg = { - **dict( - num_gpus=len(cfg.gpu_ids), - dist=distributed, - seed=cfg.get('seed'), - drop_last=False, - persistent_workers=False), - **({} if torch.__version__ != 'parrots' else dict( - prefetch_num=2, - pin_memory=False, - )), - } - # update overall dataloader(for train, val and test) setting - default_loader_cfg.update({ - k: v - for k, v in cfg.data.items() if k not in [ - 'train', 'val', 'test', 'train_dataloader', 'val_dataloader', - 'test_dataloader' - ] - }) - - # step 2: cfg.data.train_dataloader has highest priority - train_loader_cfg = dict(default_loader_cfg, - **cfg.data.get('train_dataloader', {})) - - data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset] - - # put model on gpus - if distributed: - find_unused_parameters = cfg.get('find_unused_parameters', False) - # Sets the `find_unused_parameters` parameter in - # torch.nn.parallel.DistributedDataParallel - model = MMDistributedDataParallel( - model.cuda(), - device_ids=[torch.cuda.current_device()], - broadcast_buffers=False, - find_unused_parameters=find_unused_parameters) - else: - if not torch.cuda.is_available(): - assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \ - 'Please use MMCV >= 1.4.4 for CPU training!' - model = MMDataParallel(model, device_ids=cfg.gpu_ids) - - # build runner - optimizer = build_optimizer(model, cfg.optimizer) - - if 'runner' not in cfg: - cfg.runner = { - 'type': 'EpochBasedRunner', - 'max_epochs': cfg.total_epochs - } - warnings.warn( - 'config is now expected to have a `runner` section, ' - 'please set `runner` in your config.', UserWarning) - else: - if 'total_epochs' in cfg: - assert cfg.total_epochs == cfg.runner.max_epochs - - runner = build_runner( - cfg.runner, - default_args=dict( - model=model, - optimizer=optimizer, - work_dir=cfg.work_dir, - logger=logger, - meta=meta)) - - # an ugly workaround to make .log and .log.json filenames the same - runner.timestamp = timestamp - - # fp16 setting - fp16_cfg = cfg.get('fp16', None) - if fp16_cfg is not None: - optimizer_config = Fp16OptimizerHook( - **cfg.optimizer_config, **fp16_cfg, distributed=distributed) - elif distributed and 'type' not in cfg.optimizer_config: - optimizer_config = OptimizerHook(**cfg.optimizer_config) - else: - optimizer_config = cfg.optimizer_config - - # register hooks - runner.register_training_hooks( - cfg.lr_config, - optimizer_config, - cfg.checkpoint_config, - cfg.log_config, - cfg.get('momentum_config', None), - custom_hooks_config=cfg.get('custom_hooks', None)) - if distributed: - if isinstance(runner, EpochBasedRunner): - runner.register_hook(DistSamplerSeedHook()) - - # register eval hooks - if validate: - val_samples_per_gpu = (cfg.data.get('val_dataloader', {})).get( - 'samples_per_gpu', cfg.data.get('samples_per_gpu', 1)) - if val_samples_per_gpu > 1: - # Support batch_size > 1 in test for text recognition - # by disable MultiRotateAugOCR since it is useless for most case - cfg = disable_text_recog_aug_test(cfg) - cfg = replace_image_to_tensor(cfg) - - val_dataset = DATASETS.build(cfg.data.val, dict(test_mode=True)) - - val_loader_cfg = { - **default_loader_cfg, - **dict(shuffle=False, drop_last=False), - **cfg.data.get('val_dataloader', {}), - **dict(samples_per_gpu=val_samples_per_gpu) - } - - val_dataloader = build_dataloader(val_dataset, **val_loader_cfg) - - eval_cfg = cfg.get('evaluation', {}) - eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' - eval_hook = DistEvalHook if distributed else EvalHook - runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) - - if cfg.resume_from: - runner.resume(cfg.resume_from) - elif cfg.load_from: - runner.load_checkpoint(cfg.load_from) - runner.run(data_loaders, cfg.workflow) - - -def init_random_seed(seed=None, device='cuda'): - """Initialize random seed. If the seed is None, it will be replaced by a - random number, and then broadcasted to all processes. - - Args: - seed (int, Optional): The seed. - device (str): The device where the seed will be put on. - - Returns: - int: Seed to be used. - """ - if seed is not None: - return seed - - # Make sure all ranks share the same random seed to prevent - # some potential bugs. Please refer to - # https://github.com/open-mmlab/mmdetection/issues/6339 - rank, world_size = get_dist_info() - seed = np.random.randint(2**31) - if world_size == 1: - return seed - - if rank == 0: - random_num = torch.tensor(seed, dtype=torch.int32, device=device) - else: - random_num = torch.tensor(0, dtype=torch.int32, device=device) - dist.broadcast(random_num, src=0) - return random_num.item() diff --git a/mmocr/utils/__init__.py b/mmocr/utils/__init__.py index 1fe21ecd..76e159d8 100644 --- a/mmocr/utils/__init__.py +++ b/mmocr/utils/__init__.py @@ -18,7 +18,7 @@ from .polygon_utils import (crop_polygon, is_poly_outside_rect, poly2bbox, poly2shapely, poly_intersection, poly_iou, poly_make_valid, poly_union, polys2shapely, rescale_polygon, rescale_polygons) -from .setup_env import setup_multi_processes +from .setup_env import register_all_modules from .string_util import StringStrip __all__ = [ @@ -27,9 +27,9 @@ __all__ = [ 'valid_boundary', 'drop_orientation', 'convert_annotations', 'is_not_png', 'list_to_file', 'list_from_file', 'is_on_same_line', 'stitch_boxes_into_lines', 'StringStrip', 'revert_sync_batchnorm', - 'bezier_to_polygon', 'sort_points', 'setup_multi_processes', 'recog2lmdb', - 'dump_ocr_data', 'recog_anno_to_imginfo', 'rescale_polygons', - 'rescale_polygon', 'rescale_bboxes', 'bbox2poly', 'crop_polygon', - 'is_poly_outside_rect', 'poly2bbox', 'poly_intersection', 'poly_iou', - 'poly_make_valid', 'poly_union', 'poly2shapely', 'polys2shapely' + 'bezier_to_polygon', 'sort_points', 'recog2lmdb', 'dump_ocr_data', + 'recog_anno_to_imginfo', 'rescale_polygons', 'rescale_polygon', + 'rescale_bboxes', 'bbox2poly', 'crop_polygon', 'is_poly_outside_rect', + 'poly2bbox', 'poly_intersection', 'poly_iou', 'poly_make_valid', + 'poly_union', 'poly2shapely', 'polys2shapely', 'register_all_modules' ] diff --git a/mmocr/utils/setup_env.py b/mmocr/utils/setup_env.py index 21def2f0..f4703a01 100644 --- a/mmocr/utils/setup_env.py +++ b/mmocr/utils/setup_env.py @@ -1,47 +1,39 @@ # Copyright (c) OpenMMLab. All rights reserved. -import os -import platform +import datetime import warnings -import cv2 -import torch.multiprocessing as mp +from mmengine import DefaultScope -def setup_multi_processes(cfg): - """Setup multi-processing environment variables.""" - # set multi-process start method as `fork` to speed up the training - if platform.system() != 'Windows': - mp_start_method = cfg.get('mp_start_method', 'fork') - current_method = mp.get_start_method(allow_none=True) - if current_method is not None and current_method != mp_start_method: - warnings.warn( - f'Multi-processing start method `{mp_start_method}` is ' - f'different from the previous setting `{current_method}`.' - f'It will be force set to `{mp_start_method}`. You can change ' - f'this behavior by changing `mp_start_method` in your config.') - mp.set_start_method(mp_start_method, force=True) +def register_all_modules(init_default_scope: bool = True) -> None: + """Register all modules in mmocr into the registries. - # disable opencv multithreading to avoid system being overloaded - opencv_num_threads = cfg.get('opencv_num_threads', 0) - cv2.setNumThreads(opencv_num_threads) + Args: + init_default_scope (bool): Whether initialize the mmocr default scope. + When `init_default_scope=True`, the global default scope will be + set to `mmocr`, and all registries will build modules from mmocr's + registry node. To understand more about the registry, please refer + to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md + Defaults to True. + """ # noqa + import mmocr.core # noqa: F401,F403 + import mmocr.datasets # noqa: F401,F403 + import mmocr.metrics # noqa: F401,F403 + import mmocr.models # noqa: F401,F403 - # setup OMP threads - # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa - if 'OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1: - omp_num_threads = 1 - warnings.warn( - f'Setting OMP_NUM_THREADS environment variable for each process ' - f'to be {omp_num_threads} in default, to avoid your system being ' - f'overloaded, please further tune the variable for optimal ' - f'performance in your application as needed.') - os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) - - # setup MKL threads - if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1: - mkl_num_threads = 1 - warnings.warn( - f'Setting MKL_NUM_THREADS environment variable for each process ' - f'to be {mkl_num_threads} in default, to avoid your system being ' - f'overloaded, please further tune the variable for optimal ' - f'performance in your application as needed.') - os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) + if init_default_scope: + never_created = DefaultScope.get_current_instance() is None \ + or not DefaultScope.check_instance_created('mmocr') + if never_created: + DefaultScope.get_instance('mmocr', scope_name='mmocr') + return + current_scope = DefaultScope.get_current_instance() + if current_scope.scope_name != 'mmocr': + warnings.warn('The current default scope ' + f'"{current_scope.scope_name}" is not "mmocr", ' + '`register_all_modules` will force the current' + 'default scope to be "mmocr". If this is not ' + 'expected, please set `init_default_scope=False`.') + # avoid name conflict + new_instance_name = f'mmocr-{datetime.datetime.now()}' + DefaultScope.get_instance(new_instance_name, scope_name='mmocr') diff --git a/tests/test_utils/test_setup_env.py b/tests/test_utils/test_setup_env.py new file mode 100644 index 00000000..c2dd9811 --- /dev/null +++ b/tests/test_utils/test_setup_env.py @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import datetime +import sys +from unittest import TestCase + +from mmengine import DefaultScope + +from mmocr.utils import register_all_modules + + +class TestSetupEnv(TestCase): + + def test_register_all_modules(self): + from mmocr.registry import DATASETS + + # not init default scope + sys.modules.pop('mmocr.datasets', None) + sys.modules.pop('mmocr.datasets.ocr_dataset', None) + DATASETS._module_dict.pop('OCRDataset', None) + self.assertFalse('OCRDataset' in DATASETS.module_dict) + register_all_modules(init_default_scope=False) + self.assertTrue('OCRDataset' in DATASETS.module_dict) + + # init default scope + sys.modules.pop('mmocr.datasets') + sys.modules.pop('mmocr.datasets.ocr_dataset') + DATASETS._module_dict.pop('OCRDataset', None) + self.assertFalse('OCRDataset' in DATASETS.module_dict) + register_all_modules(init_default_scope=True) + self.assertTrue('OCRDataset' in DATASETS.module_dict) + self.assertEqual(DefaultScope.get_current_instance().scope_name, + 'mmocr') + + # init default scope when another scope is init + name = f'test-{datetime.datetime.now()}' + DefaultScope.get_instance(name, scope_name='test') + with self.assertWarnsRegex( + Warning, 'The current default scope "test" is not "mmocr"'): + register_all_modules(init_default_scope=True) diff --git a/tools/test.py b/tools/test.py index 5663c8e9..19d89557 100755 --- a/tools/test.py +++ b/tools/test.py @@ -1,235 +1,73 @@ -#!/usr/bin/env python # Copyright (c) OpenMMLab. All rights reserved. import argparse import os -import warnings +import os.path as osp -import mmcv -import torch -from mmcv import Config, DictAction -from mmcv.cnn import fuse_conv_bn -from mmcv.parallel import MMDataParallel, MMDistributedDataParallel -from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, - wrap_fp16_model) -from mmdet.apis import multi_gpu_test +from mmengine.config import Config, DictAction +from mmengine.runner import Runner -from mmocr.apis.test import single_gpu_test -from mmocr.apis.utils import (disable_text_recog_aug_test, - replace_image_to_tensor) -from mmocr.datasets import build_dataloader -from mmocr.models import build_detector -from mmocr.registry import DATASETS -from mmocr.utils import revert_sync_batchnorm, setup_multi_processes +from mmocr.utils import register_all_modules +# TODO: support fuse_conv_bn, visualization, and format_only def parse_args(): - parser = argparse.ArgumentParser( - description='MMOCR test (and eval) a model.') - parser.add_argument('config', help='Test config file path.') - parser.add_argument('checkpoint', help='Checkpoint file.') - parser.add_argument('--out', help='Output result file in pickle format.') + parser = argparse.ArgumentParser(description='Test (and eval) a model') + parser.add_argument('config', help='Test config file path') + parser.add_argument('checkpoint', help='Checkpoint file') parser.add_argument( - '--fuse-conv-bn', - action='store_true', - help='Whether to fuse conv and bn, this will slightly increase' - 'the inference speed.') - parser.add_argument( - '--gpu-id', - type=int, - default=0, - help='id of gpu to use ' - '(only applicable to non-distributed testing)') - parser.add_argument( - '--format-only', - action='store_true', - help='Format the output results without performing evaluation. It is' - 'useful when you want to format the results to a specific format and ' - 'submit them to the test server.') - parser.add_argument( - '--eval', - type=str, - nargs='+', - help='The evaluation metrics. Options: \'hmean-ic13\', \'hmean-iou' - '\' for text detection tasks, \'acc\' for text recognition tasks, and ' - '\'macro-f1\' for key information extraction tasks.') - parser.add_argument('--show', action='store_true', help='Show results.') - parser.add_argument( - '--show-dir', help='Directory where the output images will be saved.') - parser.add_argument( - '--show-score-thr', - type=float, - default=0.3, - help='Score threshold (default: 0.3).') - parser.add_argument( - '--gpu-collect', - action='store_true', - help='Whether to use gpu to collect results.') - parser.add_argument( - '--tmpdir', - help='The tmp directory used for collecting results from multiple ' - 'workers, available when gpu-collect is not specified.') + '--work-dir', + help='The directory to save the file containing evaluation metrics') parser.add_argument( '--cfg-options', nargs='+', action=DictAction, help='Override some settings in the used config, the key-value pair ' - 'in xxx=yyy format will be merged into the config file. If the value ' - 'to be overwritten is a list, it should be of the form of either ' - 'key="[a,b]" or key=a,b. The argument also allows nested list/tuple ' - 'values, e.g. key="[(a,b),(c,d)]". Note that the quotation marks ' - 'are necessary and that no white space is allowed.') - parser.add_argument( - '--options', - nargs='+', - action=DictAction, - help='Custom options for evaluation, the key-value pair in xxx=yyy ' - 'format will be kwargs for dataset.evaluate() function (deprecate), ' - 'change to --eval-options instead.') - parser.add_argument( - '--eval-options', - nargs='+', - action=DictAction, - help='Custom options for evaluation, the key-value pair in xxx=yyy ' - 'format will be kwargs for dataset.evaluate() function.') + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') parser.add_argument( '--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', - help='Options for job launcher.') + help='Job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) - - if args.options and args.eval_options: - raise ValueError( - '--options and --eval-options cannot be both ' - 'specified, --options is deprecated in favor of --eval-options.') - if args.options: - warnings.warn('--options is deprecated in favor of --eval-options.') - args.eval_options = args.options return args def main(): args = parse_args() - assert ( - args.out or args.eval or args.format_only or args.show - or args.show_dir), ( - 'Please specify at least one operation (save/eval/format/show the ' - 'results / save the results) with the argument "--out", "--eval"' - ', "--format-only", "--show" or "--show-dir".') - - if args.eval and args.format_only: - raise ValueError('--eval and --format_only cannot be both specified.') - - if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): - raise ValueError('The output file must be a pkl file.') + # register all modules in mmocr into the registries + # do not init the default scope here because it will be init in the runner + register_all_modules(init_default_scope=False) + # load config cfg = Config.fromfile(args.config) + cfg.launcher = args.launcher if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) - setup_multi_processes(cfg) - # set cudnn_benchmark - if cfg.get('cudnn_benchmark', False): - torch.backends.cudnn.benchmark = True - if cfg.model.get('pretrained'): - cfg.model.pretrained = None - if cfg.model.get('neck'): - if isinstance(cfg.model.neck, list): - for neck_cfg in cfg.model.neck: - if neck_cfg.get('rfp_backbone'): - if neck_cfg.rfp_backbone.get('pretrained'): - neck_cfg.rfp_backbone.pretrained = None - elif cfg.model.neck.get('rfp_backbone'): - if cfg.model.neck.rfp_backbone.get('pretrained'): - cfg.model.neck.rfp_backbone.pretrained = None + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) - # in case the test dataset is concatenated - samples_per_gpu = (cfg.data.get('test_dataloader', {})).get( - 'samples_per_gpu', cfg.data.get('samples_per_gpu', 1)) - if samples_per_gpu > 1: - cfg = disable_text_recog_aug_test(cfg) - cfg = replace_image_to_tensor(cfg) + cfg.load_from = args.checkpoint - # init distributed env first, since logger depends on the dist info. - if args.launcher == 'none': - cfg.gpu_ids = [args.gpu_id] - distributed = False - else: - distributed = True - init_dist(args.launcher, **cfg.dist_params) + # build the runner from config + runner = Runner.from_cfg(cfg) - # build the dataloader - dataset = DATASETS.build(cfg.data.test, dict(test_mode=True)) - # step 1: give default values and override (if exist) from cfg.data - default_loader_cfg = { - **dict(seed=cfg.get('seed'), drop_last=False, dist=distributed), - **({} if torch.__version__ != 'parrots' else dict( - prefetch_num=2, - pin_memory=False, - )) - } - default_loader_cfg.update({ - k: v - for k, v in cfg.data.items() if k not in [ - 'train', 'val', 'test', 'train_dataloader', 'val_dataloader', - 'test_dataloader' - ] - }) - test_loader_cfg = { - **default_loader_cfg, - **dict(shuffle=False, drop_last=False), - **cfg.data.get('test_dataloader', {}), - **dict(samples_per_gpu=samples_per_gpu) - } - - data_loader = build_dataloader(dataset, **test_loader_cfg) - - # build the model and load checkpoint - cfg.model.train_cfg = None - model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) - model = revert_sync_batchnorm(model) - fp16_cfg = cfg.get('fp16', None) - if fp16_cfg is not None: - wrap_fp16_model(model) - load_checkpoint(model, args.checkpoint, map_location='cpu') - if args.fuse_conv_bn: - model = fuse_conv_bn(model) - - if not distributed: - model = MMDataParallel(model, device_ids=cfg.gpu_ids) - is_kie = cfg.model.type in ['SDMGR'] - outputs = single_gpu_test(model, data_loader, args.show, args.show_dir, - is_kie, args.show_score_thr) - else: - model = MMDistributedDataParallel( - model.cuda(), - device_ids=[torch.cuda.current_device()], - broadcast_buffers=False) - outputs = multi_gpu_test(model, data_loader, args.tmpdir, - args.gpu_collect) - - rank, _ = get_dist_info() - if rank == 0: - if args.out: - print(f'\nwriting results to {args.out}') - mmcv.dump(outputs, args.out) - kwargs = {} if args.eval_options is None else args.eval_options - if args.format_only: - dataset.format_results(outputs, **kwargs) - if args.eval: - eval_kwargs = cfg.get('evaluation', {}).copy() - # hard-code way to remove EvalHook args - for key in [ - 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', - 'rule' - ]: - eval_kwargs.pop(key, None) - eval_kwargs.update(dict(metric=args.eval, **kwargs)) - print(dataset.evaluate(outputs, **eval_kwargs)) + # start testing + runner.test() if __name__ == '__main__': diff --git a/tools/train.py b/tools/train.py index 901ada85..4e5e7f85 100755 --- a/tools/train.py +++ b/tools/train.py @@ -1,118 +1,54 @@ -#!/usr/bin/env python # Copyright (c) OpenMMLab. All rights reserved. import argparse -import copy import os import os.path as osp -import time -import warnings -import mmcv -import torch -import torch.distributed as dist -from mmcv import Config, DictAction -from mmcv.runner import get_dist_info, init_dist, set_random_seed -from mmcv.utils import get_git_hash +from mmengine.config import Config, DictAction +from mmengine.runner import Runner -from mmocr import __version__ -from mmocr.apis import init_random_seed, train_detector -from mmocr.models import build_detector -from mmocr.registry import DATASETS -from mmocr.utils import (collect_env, get_root_logger, is_2dlist, - setup_multi_processes) +from mmocr.utils import register_all_modules def parse_args(): - parser = argparse.ArgumentParser(description='Train a detector.') - parser.add_argument('config', help='Train config file path.') - parser.add_argument('--work-dir', help='The dir to save logs and models.') - parser.add_argument( - '--load-from', help='The checkpoint file to load from.') - parser.add_argument( - '--resume-from', help='The checkpoint file to resume from.') - parser.add_argument( - '--no-validate', - action='store_true', - help='Whether not to evaluate the checkpoint during training.') - group_gpus = parser.add_mutually_exclusive_group() - group_gpus.add_argument( - '--gpus', - type=int, - help='(Deprecated, please use --gpu-id) number of gpus to use ' - '(only applicable to non-distributed training).') - group_gpus.add_argument( - '--gpu-ids', - type=int, - nargs='+', - help='(Deprecated, please use --gpu-id) ids of gpus to use ' - '(only applicable to non-distributed training)') - group_gpus.add_argument( - '--gpu-id', - type=int, - default=0, - help='id of gpu to use ' - '(only applicable to non-distributed training)') - parser.add_argument('--seed', type=int, default=None, help='Random seed.') - parser.add_argument( - '--diff-seed', - action='store_true', - help='Whether or not set different seeds for different ranks') - parser.add_argument( - '--deterministic', - action='store_true', - help='Whether to set deterministic options for CUDNN backend.') - parser.add_argument( - '--options', - nargs='+', - action=DictAction, - help='Override some settings in the used config, the key-value pair ' - 'in xxx=yyy format will be merged into config file (deprecate), ' - 'change to --cfg-options instead.') + parser = argparse.ArgumentParser(description='Train a model') + parser.add_argument('config', help='Train config file path') + parser.add_argument('--work-dir', help='The dir to save logs and models') parser.add_argument( '--cfg-options', nargs='+', action=DictAction, help='Override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file. If the value to ' - 'be overwritten is a list, it should be of the form of either ' - 'key="[a,b]" or key=a,b .The argument also allows nested list/tuple ' - 'values, e.g. key="[(a,b),(c,d)]". Note that the quotation marks ' - 'are necessary and that no white space is allowed.') + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') parser.add_argument( '--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', - help='Options for job launcher.') + help='Job launcher') parser.add_argument('--local_rank', type=int, default=0) - args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) - if args.options and args.cfg_options: - raise ValueError( - '--options and --cfg-options cannot be both ' - 'specified, --options is deprecated in favor of --cfg-options') - if args.options: - warnings.warn('--options is deprecated in favor of --cfg-options') - args.cfg_options = args.options - return args def main(): args = parse_args() + # register all modules in mmdet into the registries + # do not init the default scope here because it will be init in the runner + register_all_modules(init_default_scope=False) + + # load config cfg = Config.fromfile(args.config) + cfg.launcher = args.launcher if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) - setup_multi_processes(cfg) - - # set cudnn_benchmark - if cfg.get('cudnn_benchmark', False): - torch.backends.cudnn.benchmark = True - # work_dir is determined in this priority: CLI > segment in file > filename if args.work_dir is not None: # update configs according to CLI args if args.work_dir is not None @@ -121,109 +57,12 @@ def main(): # use config filename as default work_dir if cfg.work_dir is None cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) - if args.load_from is not None: - cfg.load_from = args.load_from - if args.resume_from is not None: - cfg.resume_from = args.resume_from - if args.gpus is not None: - cfg.gpu_ids = range(1) - warnings.warn('`--gpus` is deprecated because we only support ' - 'single GPU mode in non-distributed training. ' - 'Use `gpus=1` now.') - if args.gpu_ids is not None: - cfg.gpu_ids = args.gpu_ids[0:1] - warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. ' - 'Because we only support single GPU mode in ' - 'non-distributed training. Use the first GPU ' - 'in `gpu_ids` now.') - if args.gpus is None and args.gpu_ids is None: - cfg.gpu_ids = [args.gpu_id] - # init distributed env first, since logger depends on the dist info. - if args.launcher == 'none': - distributed = False - else: - distributed = True - init_dist(args.launcher, **cfg.dist_params) - # re-set gpu_ids with distributed training mode - _, world_size = get_dist_info() - cfg.gpu_ids = range(world_size) + # build the runner from config + runner = Runner.from_cfg(cfg) - # create work_dir - mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) - # dump config - cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) - # init the logger before other steps - timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) - log_file = osp.join(cfg.work_dir, f'{timestamp}.log') - logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) - - # init the meta dict to record some important information such as - # environment info and seed, which will be logged - meta = dict() - # log env info - env_info_dict = collect_env() - env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) - dash_line = '-' * 60 + '\n' - logger.info('Environment info:\n' + dash_line + env_info + '\n' + - dash_line) - meta['env_info'] = env_info - meta['config'] = cfg.pretty_text - # log some basic info - logger.info(f'Distributed training: {distributed}') - logger.info(f'Config:\n{cfg.pretty_text}') - - # set random seeds - seed = init_random_seed(args.seed) - seed = seed + dist.get_rank() if args.diff_seed else seed - logger.info(f'Set random seed to {seed}, ' - f'deterministic: {args.deterministic}') - set_random_seed(seed, deterministic=args.deterministic) - cfg.seed = seed - meta['seed'] = seed - meta['exp_name'] = osp.basename(args.config) - - model = build_detector( - cfg.model, - train_cfg=cfg.get('train_cfg'), - test_cfg=cfg.get('test_cfg')) - model.init_weights() - - datasets = [DATASETS.build(cfg.data.train)] - if len(cfg.workflow) == 2: - val_dataset = copy.deepcopy(cfg.data.val) - if cfg.data.train.get('pipeline', None) is None: - if is_2dlist(cfg.data.train.datasets): - train_pipeline = cfg.data.train.datasets[0][0].pipeline - else: - train_pipeline = cfg.data.train.datasets[0].pipeline - elif is_2dlist(cfg.data.train.pipeline): - train_pipeline = cfg.data.train.pipeline[0] - else: - train_pipeline = cfg.data.train.pipeline - - if val_dataset['type'] in ['ConcatDataset', 'UniformConcatDataset']: - for dataset in val_dataset['datasets']: - dataset.pipeline = train_pipeline - else: - val_dataset.pipeline = train_pipeline - datasets.append(build_dataset(val_dataset)) - if cfg.get('checkpoint_config', None) is not None: - # save mmdet version, config file content and class names in - # checkpoints as meta data - cfg.checkpoint_config.meta = dict( - mmocr_version=__version__ + get_git_hash()[:7], - CLASSES=datasets[0].CLASSES) - # add an attribute for visualization convenience - model.CLASSES = datasets[0].CLASSES - train_detector( - model, - datasets, - cfg, - distributed=distributed, - validate=(not args.no_validate), - timestamp=timestamp, - meta=meta) + # start training + runner.train() if __name__ == '__main__':