mirror of https://github.com/open-mmlab/mmocr.git
[Refactor] train and test
parent
4246b1eaee
commit
fe43259a05
|
@ -1,157 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.image import tensor2imgs
|
||||
from mmcv.parallel import DataContainer
|
||||
from mmdet.core import encode_mask_results
|
||||
|
||||
from .utils import tensor2grayimgs
|
||||
|
||||
|
||||
def retrieve_img_tensor_and_meta(data):
|
||||
"""Retrieval img_tensor, img_metas and img_norm_cfg.
|
||||
|
||||
Args:
|
||||
data (dict): One batch data from data_loader.
|
||||
|
||||
Returns:
|
||||
tuple: Returns (img_tensor, img_metas, img_norm_cfg).
|
||||
|
||||
- | img_tensor (Tensor): Input image tensor with shape
|
||||
:math:`(N, C, H, W)`.
|
||||
- | img_metas (list[dict]): The metadata of images.
|
||||
- | img_norm_cfg (dict): Config for image normalization.
|
||||
"""
|
||||
|
||||
if isinstance(data['img'], torch.Tensor):
|
||||
# for textrecog with batch_size > 1
|
||||
# and not use 'DefaultFormatBundle' in pipeline
|
||||
img_tensor = data['img']
|
||||
img_metas = data['img_metas'].data[0]
|
||||
elif isinstance(data['img'], list):
|
||||
if isinstance(data['img'][0], torch.Tensor):
|
||||
# for textrecog with aug_test and batch_size = 1
|
||||
img_tensor = data['img'][0]
|
||||
elif isinstance(data['img'][0], DataContainer):
|
||||
# for textdet with 'MultiScaleFlipAug'
|
||||
# and 'DefaultFormatBundle' in pipeline
|
||||
img_tensor = data['img'][0].data[0]
|
||||
img_metas = data['img_metas'][0].data[0]
|
||||
elif isinstance(data['img'], DataContainer):
|
||||
# for textrecog with 'DefaultFormatBundle' in pipeline
|
||||
img_tensor = data['img'].data[0]
|
||||
img_metas = data['img_metas'].data[0]
|
||||
|
||||
must_keys = ['img_norm_cfg', 'ori_filename', 'img_shape', 'ori_shape']
|
||||
for key in must_keys:
|
||||
if key not in img_metas[0]:
|
||||
raise KeyError(
|
||||
f'Please add {key} to the "meta_keys" in the pipeline')
|
||||
|
||||
img_norm_cfg = img_metas[0]['img_norm_cfg']
|
||||
if max(img_norm_cfg['mean']) <= 1:
|
||||
img_norm_cfg['mean'] = [255 * x for x in img_norm_cfg['mean']]
|
||||
img_norm_cfg['std'] = [255 * x for x in img_norm_cfg['std']]
|
||||
|
||||
return img_tensor, img_metas, img_norm_cfg
|
||||
|
||||
|
||||
def single_gpu_test(model,
|
||||
data_loader,
|
||||
show=False,
|
||||
out_dir=None,
|
||||
is_kie=False,
|
||||
show_score_thr=0.3):
|
||||
model.eval()
|
||||
results = []
|
||||
dataset = data_loader.dataset
|
||||
prog_bar = mmcv.ProgressBar(len(dataset))
|
||||
for data in data_loader:
|
||||
with torch.no_grad():
|
||||
result = model(return_loss=False, rescale=True, **data)
|
||||
|
||||
batch_size = len(result)
|
||||
if show or out_dir:
|
||||
if is_kie:
|
||||
img_tensor = data['img'].data[0]
|
||||
if img_tensor.shape[0] != 1:
|
||||
raise KeyError('Visualizing KIE outputs in batches is'
|
||||
'currently not supported.')
|
||||
gt_bboxes = data['gt_bboxes'].data[0]
|
||||
img_metas = data['img_metas'].data[0]
|
||||
must_keys = ['img_norm_cfg', 'ori_filename', 'img_shape']
|
||||
for key in must_keys:
|
||||
if key not in img_metas[0]:
|
||||
raise KeyError(
|
||||
f'Please add {key} to the "meta_keys" in config.')
|
||||
# for no visual model
|
||||
if np.prod(img_tensor.shape) == 0:
|
||||
imgs = []
|
||||
for img_meta in img_metas:
|
||||
try:
|
||||
img = mmcv.imread(img_meta['filename'])
|
||||
except Exception as e:
|
||||
print(f'Load image with error: {e}, '
|
||||
'use empty image instead.')
|
||||
img = np.ones(
|
||||
img_meta['img_shape'], dtype=np.uint8)
|
||||
imgs.append(img)
|
||||
else:
|
||||
imgs = tensor2imgs(img_tensor,
|
||||
**img_metas[0]['img_norm_cfg'])
|
||||
for i, img in enumerate(imgs):
|
||||
h, w, _ = img_metas[i]['img_shape']
|
||||
img_show = img[:h, :w, :]
|
||||
if out_dir:
|
||||
out_file = osp.join(out_dir,
|
||||
img_metas[i]['ori_filename'])
|
||||
else:
|
||||
out_file = None
|
||||
|
||||
model.module.show_result(
|
||||
img_show,
|
||||
result[i],
|
||||
gt_bboxes[i],
|
||||
show=show,
|
||||
out_file=out_file)
|
||||
else:
|
||||
img_tensor, img_metas, img_norm_cfg = \
|
||||
retrieve_img_tensor_and_meta(data)
|
||||
|
||||
if img_tensor.size(1) == 1:
|
||||
imgs = tensor2grayimgs(img_tensor, **img_norm_cfg)
|
||||
else:
|
||||
imgs = tensor2imgs(img_tensor, **img_norm_cfg)
|
||||
assert len(imgs) == len(img_metas)
|
||||
|
||||
for j, (img, img_meta) in enumerate(zip(imgs, img_metas)):
|
||||
img_shape, ori_shape = img_meta['img_shape'], img_meta[
|
||||
'ori_shape']
|
||||
img_show = img[:img_shape[0], :img_shape[1]]
|
||||
img_show = mmcv.imresize(img_show,
|
||||
(ori_shape[1], ori_shape[0]))
|
||||
|
||||
if out_dir:
|
||||
out_file = osp.join(out_dir, img_meta['ori_filename'])
|
||||
else:
|
||||
out_file = None
|
||||
|
||||
model.module.show_result(
|
||||
img_show,
|
||||
result[j],
|
||||
show=show,
|
||||
out_file=out_file,
|
||||
score_thr=show_score_thr)
|
||||
|
||||
# encode mask results
|
||||
if isinstance(result[0], tuple):
|
||||
result = [(bbox_results, encode_mask_results(mask_results))
|
||||
for bbox_results, mask_results in result]
|
||||
results.extend(result)
|
||||
|
||||
for _ in range(batch_size):
|
||||
prog_bar.update()
|
||||
return results
|
|
@ -1,186 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner,
|
||||
Fp16OptimizerHook, OptimizerHook, build_optimizer,
|
||||
build_runner, get_dist_info)
|
||||
from mmdet.core import DistEvalHook, EvalHook
|
||||
from mmdet.datasets import build_dataloader
|
||||
|
||||
from mmocr import digit_version
|
||||
from mmocr.apis.utils import (disable_text_recog_aug_test,
|
||||
replace_image_to_tensor)
|
||||
from mmocr.registry import DATASETS
|
||||
from mmocr.utils import get_root_logger
|
||||
|
||||
|
||||
def train_detector(model,
|
||||
dataset,
|
||||
cfg,
|
||||
distributed=False,
|
||||
validate=False,
|
||||
timestamp=None,
|
||||
meta=None):
|
||||
logger = get_root_logger(cfg.log_level)
|
||||
|
||||
# prepare data loaders
|
||||
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
|
||||
# step 1: give default values and override (if exist) from cfg.data
|
||||
default_loader_cfg = {
|
||||
**dict(
|
||||
num_gpus=len(cfg.gpu_ids),
|
||||
dist=distributed,
|
||||
seed=cfg.get('seed'),
|
||||
drop_last=False,
|
||||
persistent_workers=False),
|
||||
**({} if torch.__version__ != 'parrots' else dict(
|
||||
prefetch_num=2,
|
||||
pin_memory=False,
|
||||
)),
|
||||
}
|
||||
# update overall dataloader(for train, val and test) setting
|
||||
default_loader_cfg.update({
|
||||
k: v
|
||||
for k, v in cfg.data.items() if k not in [
|
||||
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
|
||||
'test_dataloader'
|
||||
]
|
||||
})
|
||||
|
||||
# step 2: cfg.data.train_dataloader has highest priority
|
||||
train_loader_cfg = dict(default_loader_cfg,
|
||||
**cfg.data.get('train_dataloader', {}))
|
||||
|
||||
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
|
||||
|
||||
# put model on gpus
|
||||
if distributed:
|
||||
find_unused_parameters = cfg.get('find_unused_parameters', False)
|
||||
# Sets the `find_unused_parameters` parameter in
|
||||
# torch.nn.parallel.DistributedDataParallel
|
||||
model = MMDistributedDataParallel(
|
||||
model.cuda(),
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
broadcast_buffers=False,
|
||||
find_unused_parameters=find_unused_parameters)
|
||||
else:
|
||||
if not torch.cuda.is_available():
|
||||
assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \
|
||||
'Please use MMCV >= 1.4.4 for CPU training!'
|
||||
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
|
||||
|
||||
# build runner
|
||||
optimizer = build_optimizer(model, cfg.optimizer)
|
||||
|
||||
if 'runner' not in cfg:
|
||||
cfg.runner = {
|
||||
'type': 'EpochBasedRunner',
|
||||
'max_epochs': cfg.total_epochs
|
||||
}
|
||||
warnings.warn(
|
||||
'config is now expected to have a `runner` section, '
|
||||
'please set `runner` in your config.', UserWarning)
|
||||
else:
|
||||
if 'total_epochs' in cfg:
|
||||
assert cfg.total_epochs == cfg.runner.max_epochs
|
||||
|
||||
runner = build_runner(
|
||||
cfg.runner,
|
||||
default_args=dict(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
work_dir=cfg.work_dir,
|
||||
logger=logger,
|
||||
meta=meta))
|
||||
|
||||
# an ugly workaround to make .log and .log.json filenames the same
|
||||
runner.timestamp = timestamp
|
||||
|
||||
# fp16 setting
|
||||
fp16_cfg = cfg.get('fp16', None)
|
||||
if fp16_cfg is not None:
|
||||
optimizer_config = Fp16OptimizerHook(
|
||||
**cfg.optimizer_config, **fp16_cfg, distributed=distributed)
|
||||
elif distributed and 'type' not in cfg.optimizer_config:
|
||||
optimizer_config = OptimizerHook(**cfg.optimizer_config)
|
||||
else:
|
||||
optimizer_config = cfg.optimizer_config
|
||||
|
||||
# register hooks
|
||||
runner.register_training_hooks(
|
||||
cfg.lr_config,
|
||||
optimizer_config,
|
||||
cfg.checkpoint_config,
|
||||
cfg.log_config,
|
||||
cfg.get('momentum_config', None),
|
||||
custom_hooks_config=cfg.get('custom_hooks', None))
|
||||
if distributed:
|
||||
if isinstance(runner, EpochBasedRunner):
|
||||
runner.register_hook(DistSamplerSeedHook())
|
||||
|
||||
# register eval hooks
|
||||
if validate:
|
||||
val_samples_per_gpu = (cfg.data.get('val_dataloader', {})).get(
|
||||
'samples_per_gpu', cfg.data.get('samples_per_gpu', 1))
|
||||
if val_samples_per_gpu > 1:
|
||||
# Support batch_size > 1 in test for text recognition
|
||||
# by disable MultiRotateAugOCR since it is useless for most case
|
||||
cfg = disable_text_recog_aug_test(cfg)
|
||||
cfg = replace_image_to_tensor(cfg)
|
||||
|
||||
val_dataset = DATASETS.build(cfg.data.val, dict(test_mode=True))
|
||||
|
||||
val_loader_cfg = {
|
||||
**default_loader_cfg,
|
||||
**dict(shuffle=False, drop_last=False),
|
||||
**cfg.data.get('val_dataloader', {}),
|
||||
**dict(samples_per_gpu=val_samples_per_gpu)
|
||||
}
|
||||
|
||||
val_dataloader = build_dataloader(val_dataset, **val_loader_cfg)
|
||||
|
||||
eval_cfg = cfg.get('evaluation', {})
|
||||
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
|
||||
eval_hook = DistEvalHook if distributed else EvalHook
|
||||
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
|
||||
|
||||
if cfg.resume_from:
|
||||
runner.resume(cfg.resume_from)
|
||||
elif cfg.load_from:
|
||||
runner.load_checkpoint(cfg.load_from)
|
||||
runner.run(data_loaders, cfg.workflow)
|
||||
|
||||
|
||||
def init_random_seed(seed=None, device='cuda'):
|
||||
"""Initialize random seed. If the seed is None, it will be replaced by a
|
||||
random number, and then broadcasted to all processes.
|
||||
|
||||
Args:
|
||||
seed (int, Optional): The seed.
|
||||
device (str): The device where the seed will be put on.
|
||||
|
||||
Returns:
|
||||
int: Seed to be used.
|
||||
"""
|
||||
if seed is not None:
|
||||
return seed
|
||||
|
||||
# Make sure all ranks share the same random seed to prevent
|
||||
# some potential bugs. Please refer to
|
||||
# https://github.com/open-mmlab/mmdetection/issues/6339
|
||||
rank, world_size = get_dist_info()
|
||||
seed = np.random.randint(2**31)
|
||||
if world_size == 1:
|
||||
return seed
|
||||
|
||||
if rank == 0:
|
||||
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
|
||||
else:
|
||||
random_num = torch.tensor(0, dtype=torch.int32, device=device)
|
||||
dist.broadcast(random_num, src=0)
|
||||
return random_num.item()
|
|
@ -18,7 +18,7 @@ from .polygon_utils import (crop_polygon, is_poly_outside_rect, poly2bbox,
|
|||
poly2shapely, poly_intersection, poly_iou,
|
||||
poly_make_valid, poly_union, polys2shapely,
|
||||
rescale_polygon, rescale_polygons)
|
||||
from .setup_env import setup_multi_processes
|
||||
from .setup_env import register_all_modules
|
||||
from .string_util import StringStrip
|
||||
|
||||
__all__ = [
|
||||
|
@ -27,9 +27,9 @@ __all__ = [
|
|||
'valid_boundary', 'drop_orientation', 'convert_annotations', 'is_not_png',
|
||||
'list_to_file', 'list_from_file', 'is_on_same_line',
|
||||
'stitch_boxes_into_lines', 'StringStrip', 'revert_sync_batchnorm',
|
||||
'bezier_to_polygon', 'sort_points', 'setup_multi_processes', 'recog2lmdb',
|
||||
'dump_ocr_data', 'recog_anno_to_imginfo', 'rescale_polygons',
|
||||
'rescale_polygon', 'rescale_bboxes', 'bbox2poly', 'crop_polygon',
|
||||
'is_poly_outside_rect', 'poly2bbox', 'poly_intersection', 'poly_iou',
|
||||
'poly_make_valid', 'poly_union', 'poly2shapely', 'polys2shapely'
|
||||
'bezier_to_polygon', 'sort_points', 'recog2lmdb', 'dump_ocr_data',
|
||||
'recog_anno_to_imginfo', 'rescale_polygons', 'rescale_polygon',
|
||||
'rescale_bboxes', 'bbox2poly', 'crop_polygon', 'is_poly_outside_rect',
|
||||
'poly2bbox', 'poly_intersection', 'poly_iou', 'poly_make_valid',
|
||||
'poly_union', 'poly2shapely', 'polys2shapely', 'register_all_modules'
|
||||
]
|
||||
|
|
|
@ -1,47 +1,39 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import platform
|
||||
import datetime
|
||||
import warnings
|
||||
|
||||
import cv2
|
||||
import torch.multiprocessing as mp
|
||||
from mmengine import DefaultScope
|
||||
|
||||
|
||||
def setup_multi_processes(cfg):
|
||||
"""Setup multi-processing environment variables."""
|
||||
# set multi-process start method as `fork` to speed up the training
|
||||
if platform.system() != 'Windows':
|
||||
mp_start_method = cfg.get('mp_start_method', 'fork')
|
||||
current_method = mp.get_start_method(allow_none=True)
|
||||
if current_method is not None and current_method != mp_start_method:
|
||||
warnings.warn(
|
||||
f'Multi-processing start method `{mp_start_method}` is '
|
||||
f'different from the previous setting `{current_method}`.'
|
||||
f'It will be force set to `{mp_start_method}`. You can change '
|
||||
f'this behavior by changing `mp_start_method` in your config.')
|
||||
mp.set_start_method(mp_start_method, force=True)
|
||||
def register_all_modules(init_default_scope: bool = True) -> None:
|
||||
"""Register all modules in mmocr into the registries.
|
||||
|
||||
# disable opencv multithreading to avoid system being overloaded
|
||||
opencv_num_threads = cfg.get('opencv_num_threads', 0)
|
||||
cv2.setNumThreads(opencv_num_threads)
|
||||
Args:
|
||||
init_default_scope (bool): Whether initialize the mmocr default scope.
|
||||
When `init_default_scope=True`, the global default scope will be
|
||||
set to `mmocr`, and all registries will build modules from mmocr's
|
||||
registry node. To understand more about the registry, please refer
|
||||
to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md
|
||||
Defaults to True.
|
||||
""" # noqa
|
||||
import mmocr.core # noqa: F401,F403
|
||||
import mmocr.datasets # noqa: F401,F403
|
||||
import mmocr.metrics # noqa: F401,F403
|
||||
import mmocr.models # noqa: F401,F403
|
||||
|
||||
# setup OMP threads
|
||||
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
|
||||
if 'OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
|
||||
omp_num_threads = 1
|
||||
warnings.warn(
|
||||
f'Setting OMP_NUM_THREADS environment variable for each process '
|
||||
f'to be {omp_num_threads} in default, to avoid your system being '
|
||||
f'overloaded, please further tune the variable for optimal '
|
||||
f'performance in your application as needed.')
|
||||
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
|
||||
|
||||
# setup MKL threads
|
||||
if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
|
||||
mkl_num_threads = 1
|
||||
warnings.warn(
|
||||
f'Setting MKL_NUM_THREADS environment variable for each process '
|
||||
f'to be {mkl_num_threads} in default, to avoid your system being '
|
||||
f'overloaded, please further tune the variable for optimal '
|
||||
f'performance in your application as needed.')
|
||||
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
|
||||
if init_default_scope:
|
||||
never_created = DefaultScope.get_current_instance() is None \
|
||||
or not DefaultScope.check_instance_created('mmocr')
|
||||
if never_created:
|
||||
DefaultScope.get_instance('mmocr', scope_name='mmocr')
|
||||
return
|
||||
current_scope = DefaultScope.get_current_instance()
|
||||
if current_scope.scope_name != 'mmocr':
|
||||
warnings.warn('The current default scope '
|
||||
f'"{current_scope.scope_name}" is not "mmocr", '
|
||||
'`register_all_modules` will force the current'
|
||||
'default scope to be "mmocr". If this is not '
|
||||
'expected, please set `init_default_scope=False`.')
|
||||
# avoid name conflict
|
||||
new_instance_name = f'mmocr-{datetime.datetime.now()}'
|
||||
DefaultScope.get_instance(new_instance_name, scope_name='mmocr')
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import datetime
|
||||
import sys
|
||||
from unittest import TestCase
|
||||
|
||||
from mmengine import DefaultScope
|
||||
|
||||
from mmocr.utils import register_all_modules
|
||||
|
||||
|
||||
class TestSetupEnv(TestCase):
|
||||
|
||||
def test_register_all_modules(self):
|
||||
from mmocr.registry import DATASETS
|
||||
|
||||
# not init default scope
|
||||
sys.modules.pop('mmocr.datasets', None)
|
||||
sys.modules.pop('mmocr.datasets.ocr_dataset', None)
|
||||
DATASETS._module_dict.pop('OCRDataset', None)
|
||||
self.assertFalse('OCRDataset' in DATASETS.module_dict)
|
||||
register_all_modules(init_default_scope=False)
|
||||
self.assertTrue('OCRDataset' in DATASETS.module_dict)
|
||||
|
||||
# init default scope
|
||||
sys.modules.pop('mmocr.datasets')
|
||||
sys.modules.pop('mmocr.datasets.ocr_dataset')
|
||||
DATASETS._module_dict.pop('OCRDataset', None)
|
||||
self.assertFalse('OCRDataset' in DATASETS.module_dict)
|
||||
register_all_modules(init_default_scope=True)
|
||||
self.assertTrue('OCRDataset' in DATASETS.module_dict)
|
||||
self.assertEqual(DefaultScope.get_current_instance().scope_name,
|
||||
'mmocr')
|
||||
|
||||
# init default scope when another scope is init
|
||||
name = f'test-{datetime.datetime.now()}'
|
||||
DefaultScope.get_instance(name, scope_name='test')
|
||||
with self.assertWarnsRegex(
|
||||
Warning, 'The current default scope "test" is not "mmocr"'):
|
||||
register_all_modules(init_default_scope=True)
|
230
tools/test.py
230
tools/test.py
|
@ -1,235 +1,73 @@
|
|||
#!/usr/bin/env python
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os
|
||||
import warnings
|
||||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
from mmcv import Config, DictAction
|
||||
from mmcv.cnn import fuse_conv_bn
|
||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
|
||||
wrap_fp16_model)
|
||||
from mmdet.apis import multi_gpu_test
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.runner import Runner
|
||||
|
||||
from mmocr.apis.test import single_gpu_test
|
||||
from mmocr.apis.utils import (disable_text_recog_aug_test,
|
||||
replace_image_to_tensor)
|
||||
from mmocr.datasets import build_dataloader
|
||||
from mmocr.models import build_detector
|
||||
from mmocr.registry import DATASETS
|
||||
from mmocr.utils import revert_sync_batchnorm, setup_multi_processes
|
||||
from mmocr.utils import register_all_modules
|
||||
|
||||
|
||||
# TODO: support fuse_conv_bn, visualization, and format_only
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='MMOCR test (and eval) a model.')
|
||||
parser.add_argument('config', help='Test config file path.')
|
||||
parser.add_argument('checkpoint', help='Checkpoint file.')
|
||||
parser.add_argument('--out', help='Output result file in pickle format.')
|
||||
parser = argparse.ArgumentParser(description='Test (and eval) a model')
|
||||
parser.add_argument('config', help='Test config file path')
|
||||
parser.add_argument('checkpoint', help='Checkpoint file')
|
||||
parser.add_argument(
|
||||
'--fuse-conv-bn',
|
||||
action='store_true',
|
||||
help='Whether to fuse conv and bn, this will slightly increase'
|
||||
'the inference speed.')
|
||||
parser.add_argument(
|
||||
'--gpu-id',
|
||||
type=int,
|
||||
default=0,
|
||||
help='id of gpu to use '
|
||||
'(only applicable to non-distributed testing)')
|
||||
parser.add_argument(
|
||||
'--format-only',
|
||||
action='store_true',
|
||||
help='Format the output results without performing evaluation. It is'
|
||||
'useful when you want to format the results to a specific format and '
|
||||
'submit them to the test server.')
|
||||
parser.add_argument(
|
||||
'--eval',
|
||||
type=str,
|
||||
nargs='+',
|
||||
help='The evaluation metrics. Options: \'hmean-ic13\', \'hmean-iou'
|
||||
'\' for text detection tasks, \'acc\' for text recognition tasks, and '
|
||||
'\'macro-f1\' for key information extraction tasks.')
|
||||
parser.add_argument('--show', action='store_true', help='Show results.')
|
||||
parser.add_argument(
|
||||
'--show-dir', help='Directory where the output images will be saved.')
|
||||
parser.add_argument(
|
||||
'--show-score-thr',
|
||||
type=float,
|
||||
default=0.3,
|
||||
help='Score threshold (default: 0.3).')
|
||||
parser.add_argument(
|
||||
'--gpu-collect',
|
||||
action='store_true',
|
||||
help='Whether to use gpu to collect results.')
|
||||
parser.add_argument(
|
||||
'--tmpdir',
|
||||
help='The tmp directory used for collecting results from multiple '
|
||||
'workers, available when gpu-collect is not specified.')
|
||||
'--work-dir',
|
||||
help='The directory to save the file containing evaluation metrics')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='Override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into the config file. If the value '
|
||||
'to be overwritten is a list, it should be of the form of either '
|
||||
'key="[a,b]" or key=a,b. The argument also allows nested list/tuple '
|
||||
'values, e.g. key="[(a,b),(c,d)]". Note that the quotation marks '
|
||||
'are necessary and that no white space is allowed.')
|
||||
parser.add_argument(
|
||||
'--options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='Custom options for evaluation, the key-value pair in xxx=yyy '
|
||||
'format will be kwargs for dataset.evaluate() function (deprecate), '
|
||||
'change to --eval-options instead.')
|
||||
parser.add_argument(
|
||||
'--eval-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='Custom options for evaluation, the key-value pair in xxx=yyy '
|
||||
'format will be kwargs for dataset.evaluate() function.')
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
parser.add_argument(
|
||||
'--launcher',
|
||||
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
||||
default='none',
|
||||
help='Options for job launcher.')
|
||||
help='Job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
if 'LOCAL_RANK' not in os.environ:
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
|
||||
if args.options and args.eval_options:
|
||||
raise ValueError(
|
||||
'--options and --eval-options cannot be both '
|
||||
'specified, --options is deprecated in favor of --eval-options.')
|
||||
if args.options:
|
||||
warnings.warn('--options is deprecated in favor of --eval-options.')
|
||||
args.eval_options = args.options
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
assert (
|
||||
args.out or args.eval or args.format_only or args.show
|
||||
or args.show_dir), (
|
||||
'Please specify at least one operation (save/eval/format/show the '
|
||||
'results / save the results) with the argument "--out", "--eval"'
|
||||
', "--format-only", "--show" or "--show-dir".')
|
||||
|
||||
if args.eval and args.format_only:
|
||||
raise ValueError('--eval and --format_only cannot be both specified.')
|
||||
|
||||
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
|
||||
raise ValueError('The output file must be a pkl file.')
|
||||
# register all modules in mmocr into the registries
|
||||
# do not init the default scope here because it will be init in the runner
|
||||
register_all_modules(init_default_scope=False)
|
||||
|
||||
# load config
|
||||
cfg = Config.fromfile(args.config)
|
||||
cfg.launcher = args.launcher
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
setup_multi_processes(cfg)
|
||||
|
||||
# set cudnn_benchmark
|
||||
if cfg.get('cudnn_benchmark', False):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
if cfg.model.get('pretrained'):
|
||||
cfg.model.pretrained = None
|
||||
if cfg.model.get('neck'):
|
||||
if isinstance(cfg.model.neck, list):
|
||||
for neck_cfg in cfg.model.neck:
|
||||
if neck_cfg.get('rfp_backbone'):
|
||||
if neck_cfg.rfp_backbone.get('pretrained'):
|
||||
neck_cfg.rfp_backbone.pretrained = None
|
||||
elif cfg.model.neck.get('rfp_backbone'):
|
||||
if cfg.model.neck.rfp_backbone.get('pretrained'):
|
||||
cfg.model.neck.rfp_backbone.pretrained = None
|
||||
# work_dir is determined in this priority: CLI > segment in file > filename
|
||||
if args.work_dir is not None:
|
||||
# update configs according to CLI args if args.work_dir is not None
|
||||
cfg.work_dir = args.work_dir
|
||||
elif cfg.get('work_dir', None) is None:
|
||||
# use config filename as default work_dir if cfg.work_dir is None
|
||||
cfg.work_dir = osp.join('./work_dirs',
|
||||
osp.splitext(osp.basename(args.config))[0])
|
||||
|
||||
# in case the test dataset is concatenated
|
||||
samples_per_gpu = (cfg.data.get('test_dataloader', {})).get(
|
||||
'samples_per_gpu', cfg.data.get('samples_per_gpu', 1))
|
||||
if samples_per_gpu > 1:
|
||||
cfg = disable_text_recog_aug_test(cfg)
|
||||
cfg = replace_image_to_tensor(cfg)
|
||||
cfg.load_from = args.checkpoint
|
||||
|
||||
# init distributed env first, since logger depends on the dist info.
|
||||
if args.launcher == 'none':
|
||||
cfg.gpu_ids = [args.gpu_id]
|
||||
distributed = False
|
||||
else:
|
||||
distributed = True
|
||||
init_dist(args.launcher, **cfg.dist_params)
|
||||
# build the runner from config
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
# build the dataloader
|
||||
dataset = DATASETS.build(cfg.data.test, dict(test_mode=True))
|
||||
# step 1: give default values and override (if exist) from cfg.data
|
||||
default_loader_cfg = {
|
||||
**dict(seed=cfg.get('seed'), drop_last=False, dist=distributed),
|
||||
**({} if torch.__version__ != 'parrots' else dict(
|
||||
prefetch_num=2,
|
||||
pin_memory=False,
|
||||
))
|
||||
}
|
||||
default_loader_cfg.update({
|
||||
k: v
|
||||
for k, v in cfg.data.items() if k not in [
|
||||
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
|
||||
'test_dataloader'
|
||||
]
|
||||
})
|
||||
test_loader_cfg = {
|
||||
**default_loader_cfg,
|
||||
**dict(shuffle=False, drop_last=False),
|
||||
**cfg.data.get('test_dataloader', {}),
|
||||
**dict(samples_per_gpu=samples_per_gpu)
|
||||
}
|
||||
|
||||
data_loader = build_dataloader(dataset, **test_loader_cfg)
|
||||
|
||||
# build the model and load checkpoint
|
||||
cfg.model.train_cfg = None
|
||||
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
|
||||
model = revert_sync_batchnorm(model)
|
||||
fp16_cfg = cfg.get('fp16', None)
|
||||
if fp16_cfg is not None:
|
||||
wrap_fp16_model(model)
|
||||
load_checkpoint(model, args.checkpoint, map_location='cpu')
|
||||
if args.fuse_conv_bn:
|
||||
model = fuse_conv_bn(model)
|
||||
|
||||
if not distributed:
|
||||
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
|
||||
is_kie = cfg.model.type in ['SDMGR']
|
||||
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
|
||||
is_kie, args.show_score_thr)
|
||||
else:
|
||||
model = MMDistributedDataParallel(
|
||||
model.cuda(),
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
broadcast_buffers=False)
|
||||
outputs = multi_gpu_test(model, data_loader, args.tmpdir,
|
||||
args.gpu_collect)
|
||||
|
||||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
if args.out:
|
||||
print(f'\nwriting results to {args.out}')
|
||||
mmcv.dump(outputs, args.out)
|
||||
kwargs = {} if args.eval_options is None else args.eval_options
|
||||
if args.format_only:
|
||||
dataset.format_results(outputs, **kwargs)
|
||||
if args.eval:
|
||||
eval_kwargs = cfg.get('evaluation', {}).copy()
|
||||
# hard-code way to remove EvalHook args
|
||||
for key in [
|
||||
'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
|
||||
'rule'
|
||||
]:
|
||||
eval_kwargs.pop(key, None)
|
||||
eval_kwargs.update(dict(metric=args.eval, **kwargs))
|
||||
print(dataset.evaluate(outputs, **eval_kwargs))
|
||||
# start testing
|
||||
runner.test()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
203
tools/train.py
203
tools/train.py
|
@ -1,118 +1,54 @@
|
|||
#!/usr/bin/env python
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import copy
|
||||
import os
|
||||
import os.path as osp
|
||||
import time
|
||||
import warnings
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv import Config, DictAction
|
||||
from mmcv.runner import get_dist_info, init_dist, set_random_seed
|
||||
from mmcv.utils import get_git_hash
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.runner import Runner
|
||||
|
||||
from mmocr import __version__
|
||||
from mmocr.apis import init_random_seed, train_detector
|
||||
from mmocr.models import build_detector
|
||||
from mmocr.registry import DATASETS
|
||||
from mmocr.utils import (collect_env, get_root_logger, is_2dlist,
|
||||
setup_multi_processes)
|
||||
from mmocr.utils import register_all_modules
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Train a detector.')
|
||||
parser.add_argument('config', help='Train config file path.')
|
||||
parser.add_argument('--work-dir', help='The dir to save logs and models.')
|
||||
parser.add_argument(
|
||||
'--load-from', help='The checkpoint file to load from.')
|
||||
parser.add_argument(
|
||||
'--resume-from', help='The checkpoint file to resume from.')
|
||||
parser.add_argument(
|
||||
'--no-validate',
|
||||
action='store_true',
|
||||
help='Whether not to evaluate the checkpoint during training.')
|
||||
group_gpus = parser.add_mutually_exclusive_group()
|
||||
group_gpus.add_argument(
|
||||
'--gpus',
|
||||
type=int,
|
||||
help='(Deprecated, please use --gpu-id) number of gpus to use '
|
||||
'(only applicable to non-distributed training).')
|
||||
group_gpus.add_argument(
|
||||
'--gpu-ids',
|
||||
type=int,
|
||||
nargs='+',
|
||||
help='(Deprecated, please use --gpu-id) ids of gpus to use '
|
||||
'(only applicable to non-distributed training)')
|
||||
group_gpus.add_argument(
|
||||
'--gpu-id',
|
||||
type=int,
|
||||
default=0,
|
||||
help='id of gpu to use '
|
||||
'(only applicable to non-distributed training)')
|
||||
parser.add_argument('--seed', type=int, default=None, help='Random seed.')
|
||||
parser.add_argument(
|
||||
'--diff-seed',
|
||||
action='store_true',
|
||||
help='Whether or not set different seeds for different ranks')
|
||||
parser.add_argument(
|
||||
'--deterministic',
|
||||
action='store_true',
|
||||
help='Whether to set deterministic options for CUDNN backend.')
|
||||
parser.add_argument(
|
||||
'--options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='Override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file (deprecate), '
|
||||
'change to --cfg-options instead.')
|
||||
parser = argparse.ArgumentParser(description='Train a model')
|
||||
parser.add_argument('config', help='Train config file path')
|
||||
parser.add_argument('--work-dir', help='The dir to save logs and models')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='Override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be of the form of either '
|
||||
'key="[a,b]" or key=a,b .The argument also allows nested list/tuple '
|
||||
'values, e.g. key="[(a,b),(c,d)]". Note that the quotation marks '
|
||||
'are necessary and that no white space is allowed.')
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
parser.add_argument(
|
||||
'--launcher',
|
||||
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
||||
default='none',
|
||||
help='Options for job launcher.')
|
||||
help='Job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
||||
args = parser.parse_args()
|
||||
if 'LOCAL_RANK' not in os.environ:
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
|
||||
if args.options and args.cfg_options:
|
||||
raise ValueError(
|
||||
'--options and --cfg-options cannot be both '
|
||||
'specified, --options is deprecated in favor of --cfg-options')
|
||||
if args.options:
|
||||
warnings.warn('--options is deprecated in favor of --cfg-options')
|
||||
args.cfg_options = args.options
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# register all modules in mmdet into the registries
|
||||
# do not init the default scope here because it will be init in the runner
|
||||
register_all_modules(init_default_scope=False)
|
||||
|
||||
# load config
|
||||
cfg = Config.fromfile(args.config)
|
||||
cfg.launcher = args.launcher
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
setup_multi_processes(cfg)
|
||||
|
||||
# set cudnn_benchmark
|
||||
if cfg.get('cudnn_benchmark', False):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
# work_dir is determined in this priority: CLI > segment in file > filename
|
||||
if args.work_dir is not None:
|
||||
# update configs according to CLI args if args.work_dir is not None
|
||||
|
@ -121,109 +57,12 @@ def main():
|
|||
# use config filename as default work_dir if cfg.work_dir is None
|
||||
cfg.work_dir = osp.join('./work_dirs',
|
||||
osp.splitext(osp.basename(args.config))[0])
|
||||
if args.load_from is not None:
|
||||
cfg.load_from = args.load_from
|
||||
if args.resume_from is not None:
|
||||
cfg.resume_from = args.resume_from
|
||||
if args.gpus is not None:
|
||||
cfg.gpu_ids = range(1)
|
||||
warnings.warn('`--gpus` is deprecated because we only support '
|
||||
'single GPU mode in non-distributed training. '
|
||||
'Use `gpus=1` now.')
|
||||
if args.gpu_ids is not None:
|
||||
cfg.gpu_ids = args.gpu_ids[0:1]
|
||||
warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
|
||||
'Because we only support single GPU mode in '
|
||||
'non-distributed training. Use the first GPU '
|
||||
'in `gpu_ids` now.')
|
||||
if args.gpus is None and args.gpu_ids is None:
|
||||
cfg.gpu_ids = [args.gpu_id]
|
||||
|
||||
# init distributed env first, since logger depends on the dist info.
|
||||
if args.launcher == 'none':
|
||||
distributed = False
|
||||
else:
|
||||
distributed = True
|
||||
init_dist(args.launcher, **cfg.dist_params)
|
||||
# re-set gpu_ids with distributed training mode
|
||||
_, world_size = get_dist_info()
|
||||
cfg.gpu_ids = range(world_size)
|
||||
# build the runner from config
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
# create work_dir
|
||||
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
|
||||
# dump config
|
||||
cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
|
||||
# init the logger before other steps
|
||||
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||||
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
|
||||
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
|
||||
|
||||
# init the meta dict to record some important information such as
|
||||
# environment info and seed, which will be logged
|
||||
meta = dict()
|
||||
# log env info
|
||||
env_info_dict = collect_env()
|
||||
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
|
||||
dash_line = '-' * 60 + '\n'
|
||||
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
|
||||
dash_line)
|
||||
meta['env_info'] = env_info
|
||||
meta['config'] = cfg.pretty_text
|
||||
# log some basic info
|
||||
logger.info(f'Distributed training: {distributed}')
|
||||
logger.info(f'Config:\n{cfg.pretty_text}')
|
||||
|
||||
# set random seeds
|
||||
seed = init_random_seed(args.seed)
|
||||
seed = seed + dist.get_rank() if args.diff_seed else seed
|
||||
logger.info(f'Set random seed to {seed}, '
|
||||
f'deterministic: {args.deterministic}')
|
||||
set_random_seed(seed, deterministic=args.deterministic)
|
||||
cfg.seed = seed
|
||||
meta['seed'] = seed
|
||||
meta['exp_name'] = osp.basename(args.config)
|
||||
|
||||
model = build_detector(
|
||||
cfg.model,
|
||||
train_cfg=cfg.get('train_cfg'),
|
||||
test_cfg=cfg.get('test_cfg'))
|
||||
model.init_weights()
|
||||
|
||||
datasets = [DATASETS.build(cfg.data.train)]
|
||||
if len(cfg.workflow) == 2:
|
||||
val_dataset = copy.deepcopy(cfg.data.val)
|
||||
if cfg.data.train.get('pipeline', None) is None:
|
||||
if is_2dlist(cfg.data.train.datasets):
|
||||
train_pipeline = cfg.data.train.datasets[0][0].pipeline
|
||||
else:
|
||||
train_pipeline = cfg.data.train.datasets[0].pipeline
|
||||
elif is_2dlist(cfg.data.train.pipeline):
|
||||
train_pipeline = cfg.data.train.pipeline[0]
|
||||
else:
|
||||
train_pipeline = cfg.data.train.pipeline
|
||||
|
||||
if val_dataset['type'] in ['ConcatDataset', 'UniformConcatDataset']:
|
||||
for dataset in val_dataset['datasets']:
|
||||
dataset.pipeline = train_pipeline
|
||||
else:
|
||||
val_dataset.pipeline = train_pipeline
|
||||
datasets.append(build_dataset(val_dataset))
|
||||
if cfg.get('checkpoint_config', None) is not None:
|
||||
# save mmdet version, config file content and class names in
|
||||
# checkpoints as meta data
|
||||
cfg.checkpoint_config.meta = dict(
|
||||
mmocr_version=__version__ + get_git_hash()[:7],
|
||||
CLASSES=datasets[0].CLASSES)
|
||||
# add an attribute for visualization convenience
|
||||
model.CLASSES = datasets[0].CLASSES
|
||||
train_detector(
|
||||
model,
|
||||
datasets,
|
||||
cfg,
|
||||
distributed=distributed,
|
||||
validate=(not args.no_validate),
|
||||
timestamp=timestamp,
|
||||
meta=meta)
|
||||
# start training
|
||||
runner.train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue