[Fix] Support customize runner and visualization in train/test.py, an… (#1279)

* [Fix] Support customize runner and visualization in train/test.py, and update configs missing from dataflow refactor

* Fix vis

* Apply suggestions from code review

Co-authored-by: Xinyu Wang <45810070+xinke-wang@users.noreply.github.com>

* [Config] Refactor & fix DB, DBPP, DRRG configs (#1181)

* refactor base datasets, fix drrg config

* rename

* update dbnet and drrg

* fix

* fix

* Raise Error

Co-authored-by: Xinyu Wang <45810070+xinke-wang@users.noreply.github.com>
pull/1299/head
Tong Gao 2022-08-22 10:48:50 +08:00 committed by GitHub
parent d73903a9a0
commit 7fcfa09431
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 76 additions and 3 deletions

View File

@ -6,6 +6,13 @@ default_hooks = dict(
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=20),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(
type='VisualizationHook',
interval=1,
enable=False,
show=False,
draw_gt=False,
draw_pred=False),
)
env_cfg = dict(

View File

@ -49,6 +49,11 @@ test_pipeline = [
scale_divisor=1,
ratio_range=(1.0, 1.0),
aspect_ratio_range=(1.0, 1.0)),
dict(
type='LoadOCRAnnotations',
with_polygon=True,
with_bbox=True,
with_label=True),
dict(
type='PackTextDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor',

View File

@ -49,6 +49,11 @@ test_pipeline = [
scale_divisor=1,
ratio_range=(1.0, 1.0),
aspect_ratio_range=(1.0, 1.0)),
dict(
type='LoadOCRAnnotations',
with_polygon=True,
with_bbox=True,
with_label=True),
dict(
type='PackTextDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor',

View File

@ -4,12 +4,12 @@ import os
import os.path as osp
from mmengine.config import Config, DictAction
from mmengine.registry import RUNNERS
from mmengine.runner import Runner
from mmocr.utils import register_all_modules
# TODO: support fuse_conv_bn, visualization, and format_only
def parse_args():
parser = argparse.ArgumentParser(description='Test (and eval) a model')
parser.add_argument('config', help='Test config file path')
@ -21,6 +21,15 @@ def parse_args():
'--save-preds',
action='store_true',
help='Dump predictions to a pickle file for offline evaluation')
parser.add_argument(
'--show', action='store_true', help='Show prediction results')
parser.add_argument(
'--show-dir',
help='Directory where painted images will be saved. '
'If specified, it will be automatically saved '
'to the work_dir/timestamp/show_dir')
parser.add_argument(
'--wait-time', type=float, default=2, help='The interval of show (s)')
parser.add_argument(
'--cfg-options',
nargs='+',
@ -43,6 +52,29 @@ def parse_args():
return args
def trigger_visualization_hook(cfg, args):
default_hooks = cfg.default_hooks
if 'visualization' in default_hooks:
visualization_hook = default_hooks['visualization']
# Turn on visualization
visualization_hook['enable'] = True
visualization_hook['draw_gt'] = True
visualization_hook['draw_pred'] = True
if args.show:
visualization_hook['show'] = True
visualization_hook['wait_time'] = args.wait_time
if args.show_dir:
cfg.visualizer['save_dir'] = args.show_dir
cfg.visualizer['vis_backends'] = [dict(type='LocalVisBackend')]
else:
raise RuntimeError(
'VisualizationHook must be included in default_hooks.'
'refer to usage '
'"visualization=dict(type=\'VisualizationHook\')"')
return cfg
def main():
args = parse_args()
@ -67,6 +99,14 @@ def main():
cfg.load_from = args.checkpoint
# TODO: It will be supported after refactoring the visualizer
if args.show and args.show_dir:
raise NotImplementedError('--show and --show-dir cannot be set '
'at the same time')
if args.show or args.show_dir:
cfg = trigger_visualization_hook(cfg, args)
# save predictions
if args.save_preds:
dump_metric = dict(
@ -81,7 +121,13 @@ def main():
cfg.test_evaluator = [cfg.test_evaluator, dump_metric]
# build the runner from config
runner = Runner.from_cfg(cfg)
if 'runner_type' not in cfg:
# build the default runner
runner = Runner.from_cfg(cfg)
else:
# build customized runner from the registry
# if 'runner_type' is set in the cfg
runner = RUNNERS.build(cfg)
# start testing
runner.test()

View File

@ -6,6 +6,7 @@ import os.path as osp
from mmengine.config import Config, DictAction
from mmengine.logging import print_log
from mmengine.registry import RUNNERS
from mmengine.runner import Runner
from mmocr.utils import register_all_modules
@ -85,8 +86,10 @@ def main():
f'`OptimWrapper` but got {optim_wrapper}.')
cfg.optim_wrapper.type = 'AmpOptimWrapper'
cfg.optim_wrapper.loss_scale = 'dynamic'
if args.resume:
cfg.resume = True
if args.auto_scale_lr:
if cfg.get('auto_scale_lr'):
cfg.auto_scale_lr = True
@ -96,8 +99,15 @@ def main():
'please set `auto_scale_lr = dict(base_batch_size=xx)',
logger='current',
level=logging.WARNING)
# build the runner from config
runner = Runner.from_cfg(cfg)
if 'runner_type' not in cfg:
# build the default runner
runner = Runner.from_cfg(cfg)
else:
# build customized runner from the registry
# if 'runner_type' is set in the cfg
runner = RUNNERS.build(cfg)
# start training
runner.train()