[Refactor] train and test

pull/1178/head
liukuikun 2022-05-26 08:55:13 +00:00 committed by gaotongxiao
parent 4246b1eaee
commit fe43259a05
7 changed files with 132 additions and 767 deletions

View File

@ -1,157 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import mmcv
import numpy as np
import torch
from mmcv.image import tensor2imgs
from mmcv.parallel import DataContainer
from mmdet.core import encode_mask_results
from .utils import tensor2grayimgs
def retrieve_img_tensor_and_meta(data):
"""Retrieval img_tensor, img_metas and img_norm_cfg.
Args:
data (dict): One batch data from data_loader.
Returns:
tuple: Returns (img_tensor, img_metas, img_norm_cfg).
- | img_tensor (Tensor): Input image tensor with shape
:math:`(N, C, H, W)`.
- | img_metas (list[dict]): The metadata of images.
- | img_norm_cfg (dict): Config for image normalization.
"""
if isinstance(data['img'], torch.Tensor):
# for textrecog with batch_size > 1
# and not use 'DefaultFormatBundle' in pipeline
img_tensor = data['img']
img_metas = data['img_metas'].data[0]
elif isinstance(data['img'], list):
if isinstance(data['img'][0], torch.Tensor):
# for textrecog with aug_test and batch_size = 1
img_tensor = data['img'][0]
elif isinstance(data['img'][0], DataContainer):
# for textdet with 'MultiScaleFlipAug'
# and 'DefaultFormatBundle' in pipeline
img_tensor = data['img'][0].data[0]
img_metas = data['img_metas'][0].data[0]
elif isinstance(data['img'], DataContainer):
# for textrecog with 'DefaultFormatBundle' in pipeline
img_tensor = data['img'].data[0]
img_metas = data['img_metas'].data[0]
must_keys = ['img_norm_cfg', 'ori_filename', 'img_shape', 'ori_shape']
for key in must_keys:
if key not in img_metas[0]:
raise KeyError(
f'Please add {key} to the "meta_keys" in the pipeline')
img_norm_cfg = img_metas[0]['img_norm_cfg']
if max(img_norm_cfg['mean']) <= 1:
img_norm_cfg['mean'] = [255 * x for x in img_norm_cfg['mean']]
img_norm_cfg['std'] = [255 * x for x in img_norm_cfg['std']]
return img_tensor, img_metas, img_norm_cfg
def single_gpu_test(model,
data_loader,
show=False,
out_dir=None,
is_kie=False,
show_score_thr=0.3):
model.eval()
results = []
dataset = data_loader.dataset
prog_bar = mmcv.ProgressBar(len(dataset))
for data in data_loader:
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
batch_size = len(result)
if show or out_dir:
if is_kie:
img_tensor = data['img'].data[0]
if img_tensor.shape[0] != 1:
raise KeyError('Visualizing KIE outputs in batches is'
'currently not supported.')
gt_bboxes = data['gt_bboxes'].data[0]
img_metas = data['img_metas'].data[0]
must_keys = ['img_norm_cfg', 'ori_filename', 'img_shape']
for key in must_keys:
if key not in img_metas[0]:
raise KeyError(
f'Please add {key} to the "meta_keys" in config.')
# for no visual model
if np.prod(img_tensor.shape) == 0:
imgs = []
for img_meta in img_metas:
try:
img = mmcv.imread(img_meta['filename'])
except Exception as e:
print(f'Load image with error: {e}, '
'use empty image instead.')
img = np.ones(
img_meta['img_shape'], dtype=np.uint8)
imgs.append(img)
else:
imgs = tensor2imgs(img_tensor,
**img_metas[0]['img_norm_cfg'])
for i, img in enumerate(imgs):
h, w, _ = img_metas[i]['img_shape']
img_show = img[:h, :w, :]
if out_dir:
out_file = osp.join(out_dir,
img_metas[i]['ori_filename'])
else:
out_file = None
model.module.show_result(
img_show,
result[i],
gt_bboxes[i],
show=show,
out_file=out_file)
else:
img_tensor, img_metas, img_norm_cfg = \
retrieve_img_tensor_and_meta(data)
if img_tensor.size(1) == 1:
imgs = tensor2grayimgs(img_tensor, **img_norm_cfg)
else:
imgs = tensor2imgs(img_tensor, **img_norm_cfg)
assert len(imgs) == len(img_metas)
for j, (img, img_meta) in enumerate(zip(imgs, img_metas)):
img_shape, ori_shape = img_meta['img_shape'], img_meta[
'ori_shape']
img_show = img[:img_shape[0], :img_shape[1]]
img_show = mmcv.imresize(img_show,
(ori_shape[1], ori_shape[0]))
if out_dir:
out_file = osp.join(out_dir, img_meta['ori_filename'])
else:
out_file = None
model.module.show_result(
img_show,
result[j],
show=show,
out_file=out_file,
score_thr=show_score_thr)
# encode mask results
if isinstance(result[0], tuple):
result = [(bbox_results, encode_mask_results(mask_results))
for bbox_results, mask_results in result]
results.extend(result)
for _ in range(batch_size):
prog_bar.update()
return results

View File

@ -1,186 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import mmcv
import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner,
Fp16OptimizerHook, OptimizerHook, build_optimizer,
build_runner, get_dist_info)
from mmdet.core import DistEvalHook, EvalHook
from mmdet.datasets import build_dataloader
from mmocr import digit_version
from mmocr.apis.utils import (disable_text_recog_aug_test,
replace_image_to_tensor)
from mmocr.registry import DATASETS
from mmocr.utils import get_root_logger
def train_detector(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
logger = get_root_logger(cfg.log_level)
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
# step 1: give default values and override (if exist) from cfg.data
default_loader_cfg = {
**dict(
num_gpus=len(cfg.gpu_ids),
dist=distributed,
seed=cfg.get('seed'),
drop_last=False,
persistent_workers=False),
**({} if torch.__version__ != 'parrots' else dict(
prefetch_num=2,
pin_memory=False,
)),
}
# update overall dataloader(for train, val and test) setting
default_loader_cfg.update({
k: v
for k, v in cfg.data.items() if k not in [
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
'test_dataloader'
]
})
# step 2: cfg.data.train_dataloader has highest priority
train_loader_cfg = dict(default_loader_cfg,
**cfg.data.get('train_dataloader', {}))
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
# put model on gpus
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
if not torch.cuda.is_available():
assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \
'Please use MMCV >= 1.4.4 for CPU training!'
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
if 'runner' not in cfg:
cfg.runner = {
'type': 'EpochBasedRunner',
'max_epochs': cfg.total_epochs
}
warnings.warn(
'config is now expected to have a `runner` section, '
'please set `runner` in your config.', UserWarning)
else:
if 'total_epochs' in cfg:
assert cfg.total_epochs == cfg.runner.max_epochs
runner = build_runner(
cfg.runner,
default_args=dict(
model=model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta))
# an ugly workaround to make .log and .log.json filenames the same
runner.timestamp = timestamp
# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
optimizer_config = Fp16OptimizerHook(
**cfg.optimizer_config, **fp16_cfg, distributed=distributed)
elif distributed and 'type' not in cfg.optimizer_config:
optimizer_config = OptimizerHook(**cfg.optimizer_config)
else:
optimizer_config = cfg.optimizer_config
# register hooks
runner.register_training_hooks(
cfg.lr_config,
optimizer_config,
cfg.checkpoint_config,
cfg.log_config,
cfg.get('momentum_config', None),
custom_hooks_config=cfg.get('custom_hooks', None))
if distributed:
if isinstance(runner, EpochBasedRunner):
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
if validate:
val_samples_per_gpu = (cfg.data.get('val_dataloader', {})).get(
'samples_per_gpu', cfg.data.get('samples_per_gpu', 1))
if val_samples_per_gpu > 1:
# Support batch_size > 1 in test for text recognition
# by disable MultiRotateAugOCR since it is useless for most case
cfg = disable_text_recog_aug_test(cfg)
cfg = replace_image_to_tensor(cfg)
val_dataset = DATASETS.build(cfg.data.val, dict(test_mode=True))
val_loader_cfg = {
**default_loader_cfg,
**dict(shuffle=False, drop_last=False),
**cfg.data.get('val_dataloader', {}),
**dict(samples_per_gpu=val_samples_per_gpu)
}
val_dataloader = build_dataloader(val_dataset, **val_loader_cfg)
eval_cfg = cfg.get('evaluation', {})
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow)
def init_random_seed(seed=None, device='cuda'):
"""Initialize random seed. If the seed is None, it will be replaced by a
random number, and then broadcasted to all processes.
Args:
seed (int, Optional): The seed.
device (str): The device where the seed will be put on.
Returns:
int: Seed to be used.
"""
if seed is not None:
return seed
# Make sure all ranks share the same random seed to prevent
# some potential bugs. Please refer to
# https://github.com/open-mmlab/mmdetection/issues/6339
rank, world_size = get_dist_info()
seed = np.random.randint(2**31)
if world_size == 1:
return seed
if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()

View File

@ -18,7 +18,7 @@ from .polygon_utils import (crop_polygon, is_poly_outside_rect, poly2bbox,
poly2shapely, poly_intersection, poly_iou,
poly_make_valid, poly_union, polys2shapely,
rescale_polygon, rescale_polygons)
from .setup_env import setup_multi_processes
from .setup_env import register_all_modules
from .string_util import StringStrip
__all__ = [
@ -27,9 +27,9 @@ __all__ = [
'valid_boundary', 'drop_orientation', 'convert_annotations', 'is_not_png',
'list_to_file', 'list_from_file', 'is_on_same_line',
'stitch_boxes_into_lines', 'StringStrip', 'revert_sync_batchnorm',
'bezier_to_polygon', 'sort_points', 'setup_multi_processes', 'recog2lmdb',
'dump_ocr_data', 'recog_anno_to_imginfo', 'rescale_polygons',
'rescale_polygon', 'rescale_bboxes', 'bbox2poly', 'crop_polygon',
'is_poly_outside_rect', 'poly2bbox', 'poly_intersection', 'poly_iou',
'poly_make_valid', 'poly_union', 'poly2shapely', 'polys2shapely'
'bezier_to_polygon', 'sort_points', 'recog2lmdb', 'dump_ocr_data',
'recog_anno_to_imginfo', 'rescale_polygons', 'rescale_polygon',
'rescale_bboxes', 'bbox2poly', 'crop_polygon', 'is_poly_outside_rect',
'poly2bbox', 'poly_intersection', 'poly_iou', 'poly_make_valid',
'poly_union', 'poly2shapely', 'polys2shapely', 'register_all_modules'
]

View File

@ -1,47 +1,39 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import platform
import datetime
import warnings
import cv2
import torch.multiprocessing as mp
from mmengine import DefaultScope
def setup_multi_processes(cfg):
"""Setup multi-processing environment variables."""
# set multi-process start method as `fork` to speed up the training
if platform.system() != 'Windows':
mp_start_method = cfg.get('mp_start_method', 'fork')
current_method = mp.get_start_method(allow_none=True)
if current_method is not None and current_method != mp_start_method:
warnings.warn(
f'Multi-processing start method `{mp_start_method}` is '
f'different from the previous setting `{current_method}`.'
f'It will be force set to `{mp_start_method}`. You can change '
f'this behavior by changing `mp_start_method` in your config.')
mp.set_start_method(mp_start_method, force=True)
def register_all_modules(init_default_scope: bool = True) -> None:
"""Register all modules in mmocr into the registries.
# disable opencv multithreading to avoid system being overloaded
opencv_num_threads = cfg.get('opencv_num_threads', 0)
cv2.setNumThreads(opencv_num_threads)
Args:
init_default_scope (bool): Whether initialize the mmocr default scope.
When `init_default_scope=True`, the global default scope will be
set to `mmocr`, and all registries will build modules from mmocr's
registry node. To understand more about the registry, please refer
to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md
Defaults to True.
""" # noqa
import mmocr.core # noqa: F401,F403
import mmocr.datasets # noqa: F401,F403
import mmocr.metrics # noqa: F401,F403
import mmocr.models # noqa: F401,F403
# setup OMP threads
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
if 'OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
omp_num_threads = 1
warnings.warn(
f'Setting OMP_NUM_THREADS environment variable for each process '
f'to be {omp_num_threads} in default, to avoid your system being '
f'overloaded, please further tune the variable for optimal '
f'performance in your application as needed.')
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
# setup MKL threads
if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
mkl_num_threads = 1
warnings.warn(
f'Setting MKL_NUM_THREADS environment variable for each process '
f'to be {mkl_num_threads} in default, to avoid your system being '
f'overloaded, please further tune the variable for optimal '
f'performance in your application as needed.')
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
if init_default_scope:
never_created = DefaultScope.get_current_instance() is None \
or not DefaultScope.check_instance_created('mmocr')
if never_created:
DefaultScope.get_instance('mmocr', scope_name='mmocr')
return
current_scope = DefaultScope.get_current_instance()
if current_scope.scope_name != 'mmocr':
warnings.warn('The current default scope '
f'"{current_scope.scope_name}" is not "mmocr", '
'`register_all_modules` will force the current'
'default scope to be "mmocr". If this is not '
'expected, please set `init_default_scope=False`.')
# avoid name conflict
new_instance_name = f'mmocr-{datetime.datetime.now()}'
DefaultScope.get_instance(new_instance_name, scope_name='mmocr')

View File

@ -0,0 +1,39 @@
# Copyright (c) OpenMMLab. All rights reserved.
import datetime
import sys
from unittest import TestCase
from mmengine import DefaultScope
from mmocr.utils import register_all_modules
class TestSetupEnv(TestCase):
def test_register_all_modules(self):
from mmocr.registry import DATASETS
# not init default scope
sys.modules.pop('mmocr.datasets', None)
sys.modules.pop('mmocr.datasets.ocr_dataset', None)
DATASETS._module_dict.pop('OCRDataset', None)
self.assertFalse('OCRDataset' in DATASETS.module_dict)
register_all_modules(init_default_scope=False)
self.assertTrue('OCRDataset' in DATASETS.module_dict)
# init default scope
sys.modules.pop('mmocr.datasets')
sys.modules.pop('mmocr.datasets.ocr_dataset')
DATASETS._module_dict.pop('OCRDataset', None)
self.assertFalse('OCRDataset' in DATASETS.module_dict)
register_all_modules(init_default_scope=True)
self.assertTrue('OCRDataset' in DATASETS.module_dict)
self.assertEqual(DefaultScope.get_current_instance().scope_name,
'mmocr')
# init default scope when another scope is init
name = f'test-{datetime.datetime.now()}'
DefaultScope.get_instance(name, scope_name='test')
with self.assertWarnsRegex(
Warning, 'The current default scope "test" is not "mmocr"'):
register_all_modules(init_default_scope=True)

View File

@ -1,235 +1,73 @@
#!/usr/bin/env python
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import warnings
import os.path as osp
import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
wrap_fp16_model)
from mmdet.apis import multi_gpu_test
from mmengine.config import Config, DictAction
from mmengine.runner import Runner
from mmocr.apis.test import single_gpu_test
from mmocr.apis.utils import (disable_text_recog_aug_test,
replace_image_to_tensor)
from mmocr.datasets import build_dataloader
from mmocr.models import build_detector
from mmocr.registry import DATASETS
from mmocr.utils import revert_sync_batchnorm, setup_multi_processes
from mmocr.utils import register_all_modules
# TODO: support fuse_conv_bn, visualization, and format_only
def parse_args():
parser = argparse.ArgumentParser(
description='MMOCR test (and eval) a model.')
parser.add_argument('config', help='Test config file path.')
parser.add_argument('checkpoint', help='Checkpoint file.')
parser.add_argument('--out', help='Output result file in pickle format.')
parser = argparse.ArgumentParser(description='Test (and eval) a model')
parser.add_argument('config', help='Test config file path')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--fuse-conv-bn',
action='store_true',
help='Whether to fuse conv and bn, this will slightly increase'
'the inference speed.')
parser.add_argument(
'--gpu-id',
type=int,
default=0,
help='id of gpu to use '
'(only applicable to non-distributed testing)')
parser.add_argument(
'--format-only',
action='store_true',
help='Format the output results without performing evaluation. It is'
'useful when you want to format the results to a specific format and '
'submit them to the test server.')
parser.add_argument(
'--eval',
type=str,
nargs='+',
help='The evaluation metrics. Options: \'hmean-ic13\', \'hmean-iou'
'\' for text detection tasks, \'acc\' for text recognition tasks, and '
'\'macro-f1\' for key information extraction tasks.')
parser.add_argument('--show', action='store_true', help='Show results.')
parser.add_argument(
'--show-dir', help='Directory where the output images will be saved.')
parser.add_argument(
'--show-score-thr',
type=float,
default=0.3,
help='Score threshold (default: 0.3).')
parser.add_argument(
'--gpu-collect',
action='store_true',
help='Whether to use gpu to collect results.')
parser.add_argument(
'--tmpdir',
help='The tmp directory used for collecting results from multiple '
'workers, available when gpu-collect is not specified.')
'--work-dir',
help='The directory to save the file containing evaluation metrics')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='Override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into the config file. If the value '
'to be overwritten is a list, it should be of the form of either '
'key="[a,b]" or key=a,b. The argument also allows nested list/tuple '
'values, e.g. key="[(a,b),(c,d)]". Note that the quotation marks '
'are necessary and that no white space is allowed.')
parser.add_argument(
'--options',
nargs='+',
action=DictAction,
help='Custom options for evaluation, the key-value pair in xxx=yyy '
'format will be kwargs for dataset.evaluate() function (deprecate), '
'change to --eval-options instead.')
parser.add_argument(
'--eval-options',
nargs='+',
action=DictAction,
help='Custom options for evaluation, the key-value pair in xxx=yyy '
'format will be kwargs for dataset.evaluate() function.')
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='Options for job launcher.')
help='Job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
if args.options and args.eval_options:
raise ValueError(
'--options and --eval-options cannot be both '
'specified, --options is deprecated in favor of --eval-options.')
if args.options:
warnings.warn('--options is deprecated in favor of --eval-options.')
args.eval_options = args.options
return args
def main():
args = parse_args()
assert (
args.out or args.eval or args.format_only or args.show
or args.show_dir), (
'Please specify at least one operation (save/eval/format/show the '
'results / save the results) with the argument "--out", "--eval"'
', "--format-only", "--show" or "--show-dir".')
if args.eval and args.format_only:
raise ValueError('--eval and --format_only cannot be both specified.')
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
raise ValueError('The output file must be a pkl file.')
# register all modules in mmocr into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)
# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
setup_multi_processes(cfg)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
if cfg.model.get('pretrained'):
cfg.model.pretrained = None
if cfg.model.get('neck'):
if isinstance(cfg.model.neck, list):
for neck_cfg in cfg.model.neck:
if neck_cfg.get('rfp_backbone'):
if neck_cfg.rfp_backbone.get('pretrained'):
neck_cfg.rfp_backbone.pretrained = None
elif cfg.model.neck.get('rfp_backbone'):
if cfg.model.neck.rfp_backbone.get('pretrained'):
cfg.model.neck.rfp_backbone.pretrained = None
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
# in case the test dataset is concatenated
samples_per_gpu = (cfg.data.get('test_dataloader', {})).get(
'samples_per_gpu', cfg.data.get('samples_per_gpu', 1))
if samples_per_gpu > 1:
cfg = disable_text_recog_aug_test(cfg)
cfg = replace_image_to_tensor(cfg)
cfg.load_from = args.checkpoint
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
cfg.gpu_ids = [args.gpu_id]
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# build the runner from config
runner = Runner.from_cfg(cfg)
# build the dataloader
dataset = DATASETS.build(cfg.data.test, dict(test_mode=True))
# step 1: give default values and override (if exist) from cfg.data
default_loader_cfg = {
**dict(seed=cfg.get('seed'), drop_last=False, dist=distributed),
**({} if torch.__version__ != 'parrots' else dict(
prefetch_num=2,
pin_memory=False,
))
}
default_loader_cfg.update({
k: v
for k, v in cfg.data.items() if k not in [
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
'test_dataloader'
]
})
test_loader_cfg = {
**default_loader_cfg,
**dict(shuffle=False, drop_last=False),
**cfg.data.get('test_dataloader', {}),
**dict(samples_per_gpu=samples_per_gpu)
}
data_loader = build_dataloader(dataset, **test_loader_cfg)
# build the model and load checkpoint
cfg.model.train_cfg = None
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
model = revert_sync_batchnorm(model)
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
load_checkpoint(model, args.checkpoint, map_location='cpu')
if args.fuse_conv_bn:
model = fuse_conv_bn(model)
if not distributed:
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
is_kie = cfg.model.type in ['SDMGR']
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
is_kie, args.show_score_thr)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader, args.tmpdir,
args.gpu_collect)
rank, _ = get_dist_info()
if rank == 0:
if args.out:
print(f'\nwriting results to {args.out}')
mmcv.dump(outputs, args.out)
kwargs = {} if args.eval_options is None else args.eval_options
if args.format_only:
dataset.format_results(outputs, **kwargs)
if args.eval:
eval_kwargs = cfg.get('evaluation', {}).copy()
# hard-code way to remove EvalHook args
for key in [
'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
'rule'
]:
eval_kwargs.pop(key, None)
eval_kwargs.update(dict(metric=args.eval, **kwargs))
print(dataset.evaluate(outputs, **eval_kwargs))
# start testing
runner.test()
if __name__ == '__main__':

View File

@ -1,118 +1,54 @@
#!/usr/bin/env python
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import copy
import os
import os.path as osp
import time
import warnings
import mmcv
import torch
import torch.distributed as dist
from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist, set_random_seed
from mmcv.utils import get_git_hash
from mmengine.config import Config, DictAction
from mmengine.runner import Runner
from mmocr import __version__
from mmocr.apis import init_random_seed, train_detector
from mmocr.models import build_detector
from mmocr.registry import DATASETS
from mmocr.utils import (collect_env, get_root_logger, is_2dlist,
setup_multi_processes)
from mmocr.utils import register_all_modules
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector.')
parser.add_argument('config', help='Train config file path.')
parser.add_argument('--work-dir', help='The dir to save logs and models.')
parser.add_argument(
'--load-from', help='The checkpoint file to load from.')
parser.add_argument(
'--resume-from', help='The checkpoint file to resume from.')
parser.add_argument(
'--no-validate',
action='store_true',
help='Whether not to evaluate the checkpoint during training.')
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument(
'--gpus',
type=int,
help='(Deprecated, please use --gpu-id) number of gpus to use '
'(only applicable to non-distributed training).')
group_gpus.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='(Deprecated, please use --gpu-id) ids of gpus to use '
'(only applicable to non-distributed training)')
group_gpus.add_argument(
'--gpu-id',
type=int,
default=0,
help='id of gpu to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=None, help='Random seed.')
parser.add_argument(
'--diff-seed',
action='store_true',
help='Whether or not set different seeds for different ranks')
parser.add_argument(
'--deterministic',
action='store_true',
help='Whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--options',
nargs='+',
action=DictAction,
help='Override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file (deprecate), '
'change to --cfg-options instead.')
parser = argparse.ArgumentParser(description='Train a model')
parser.add_argument('config', help='Train config file path')
parser.add_argument('--work-dir', help='The dir to save logs and models')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='Override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be of the form of either '
'key="[a,b]" or key=a,b .The argument also allows nested list/tuple '
'values, e.g. key="[(a,b),(c,d)]". Note that the quotation marks '
'are necessary and that no white space is allowed.')
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='Options for job launcher.')
help='Job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
if args.options and args.cfg_options:
raise ValueError(
'--options and --cfg-options cannot be both '
'specified, --options is deprecated in favor of --cfg-options')
if args.options:
warnings.warn('--options is deprecated in favor of --cfg-options')
args.cfg_options = args.options
return args
def main():
args = parse_args()
# register all modules in mmdet into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)
# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
setup_multi_processes(cfg)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
@ -121,109 +57,12 @@ def main():
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
if args.load_from is not None:
cfg.load_from = args.load_from
if args.resume_from is not None:
cfg.resume_from = args.resume_from
if args.gpus is not None:
cfg.gpu_ids = range(1)
warnings.warn('`--gpus` is deprecated because we only support '
'single GPU mode in non-distributed training. '
'Use `gpus=1` now.')
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids[0:1]
warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
'Because we only support single GPU mode in '
'non-distributed training. Use the first GPU '
'in `gpu_ids` now.')
if args.gpus is None and args.gpu_ids is None:
cfg.gpu_ids = [args.gpu_id]
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# re-set gpu_ids with distributed training mode
_, world_size = get_dist_info()
cfg.gpu_ids = range(world_size)
# build the runner from config
runner = Runner.from_cfg(cfg)
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# dump config
cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
meta['env_info'] = env_info
meta['config'] = cfg.pretty_text
# log some basic info
logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')
# set random seeds
seed = init_random_seed(args.seed)
seed = seed + dist.get_rank() if args.diff_seed else seed
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
cfg.seed = seed
meta['seed'] = seed
meta['exp_name'] = osp.basename(args.config)
model = build_detector(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
model.init_weights()
datasets = [DATASETS.build(cfg.data.train)]
if len(cfg.workflow) == 2:
val_dataset = copy.deepcopy(cfg.data.val)
if cfg.data.train.get('pipeline', None) is None:
if is_2dlist(cfg.data.train.datasets):
train_pipeline = cfg.data.train.datasets[0][0].pipeline
else:
train_pipeline = cfg.data.train.datasets[0].pipeline
elif is_2dlist(cfg.data.train.pipeline):
train_pipeline = cfg.data.train.pipeline[0]
else:
train_pipeline = cfg.data.train.pipeline
if val_dataset['type'] in ['ConcatDataset', 'UniformConcatDataset']:
for dataset in val_dataset['datasets']:
dataset.pipeline = train_pipeline
else:
val_dataset.pipeline = train_pipeline
datasets.append(build_dataset(val_dataset))
if cfg.get('checkpoint_config', None) is not None:
# save mmdet version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(
mmocr_version=__version__ + get_git_hash()[:7],
CLASSES=datasets[0].CLASSES)
# add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
train_detector(
model,
datasets,
cfg,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
meta=meta)
# start training
runner.train()
if __name__ == '__main__':