[Fix] Support customize runner and visualization in train/test.py, an… ()

* [Fix] Support customize runner and visualization in train/test.py, and update configs missing from dataflow refactor

* Fix vis

* Apply suggestions from code review

Co-authored-by: Xinyu Wang <45810070+xinke-wang@users.noreply.github.com>

* [Config] Refactor & fix DB, DBPP, DRRG configs ()

* refactor base datasets, fix drrg config

* rename

* update dbnet and drrg

* fix

* fix

* Raise Error

Co-authored-by: Xinyu Wang <45810070+xinke-wang@users.noreply.github.com>
pull/1299/head
Tong Gao 2022-08-22 10:48:50 +08:00 committed by GitHub
parent d73903a9a0
commit 7fcfa09431
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 76 additions and 3 deletions

View File

@ -6,6 +6,13 @@ default_hooks = dict(
param_scheduler=dict(type='ParamSchedulerHook'), param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=20), checkpoint=dict(type='CheckpointHook', interval=20),
sampler_seed=dict(type='DistSamplerSeedHook'), sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(
type='VisualizationHook',
interval=1,
enable=False,
show=False,
draw_gt=False,
draw_pred=False),
) )
env_cfg = dict( env_cfg = dict(

View File

@ -49,6 +49,11 @@ test_pipeline = [
scale_divisor=1, scale_divisor=1,
ratio_range=(1.0, 1.0), ratio_range=(1.0, 1.0),
aspect_ratio_range=(1.0, 1.0)), aspect_ratio_range=(1.0, 1.0)),
dict(
type='LoadOCRAnnotations',
with_polygon=True,
with_bbox=True,
with_label=True),
dict( dict(
type='PackTextDetInputs', type='PackTextDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor', meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor',

View File

@ -49,6 +49,11 @@ test_pipeline = [
scale_divisor=1, scale_divisor=1,
ratio_range=(1.0, 1.0), ratio_range=(1.0, 1.0),
aspect_ratio_range=(1.0, 1.0)), aspect_ratio_range=(1.0, 1.0)),
dict(
type='LoadOCRAnnotations',
with_polygon=True,
with_bbox=True,
with_label=True),
dict( dict(
type='PackTextDetInputs', type='PackTextDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor', meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor',

View File

@ -4,12 +4,12 @@ import os
import os.path as osp import os.path as osp
from mmengine.config import Config, DictAction from mmengine.config import Config, DictAction
from mmengine.registry import RUNNERS
from mmengine.runner import Runner from mmengine.runner import Runner
from mmocr.utils import register_all_modules from mmocr.utils import register_all_modules
# TODO: support fuse_conv_bn, visualization, and format_only
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Test (and eval) a model') parser = argparse.ArgumentParser(description='Test (and eval) a model')
parser.add_argument('config', help='Test config file path') parser.add_argument('config', help='Test config file path')
@ -21,6 +21,15 @@ def parse_args():
'--save-preds', '--save-preds',
action='store_true', action='store_true',
help='Dump predictions to a pickle file for offline evaluation') help='Dump predictions to a pickle file for offline evaluation')
parser.add_argument(
'--show', action='store_true', help='Show prediction results')
parser.add_argument(
'--show-dir',
help='Directory where painted images will be saved. '
'If specified, it will be automatically saved '
'to the work_dir/timestamp/show_dir')
parser.add_argument(
'--wait-time', type=float, default=2, help='The interval of show (s)')
parser.add_argument( parser.add_argument(
'--cfg-options', '--cfg-options',
nargs='+', nargs='+',
@ -43,6 +52,29 @@ def parse_args():
return args return args
def trigger_visualization_hook(cfg, args):
default_hooks = cfg.default_hooks
if 'visualization' in default_hooks:
visualization_hook = default_hooks['visualization']
# Turn on visualization
visualization_hook['enable'] = True
visualization_hook['draw_gt'] = True
visualization_hook['draw_pred'] = True
if args.show:
visualization_hook['show'] = True
visualization_hook['wait_time'] = args.wait_time
if args.show_dir:
cfg.visualizer['save_dir'] = args.show_dir
cfg.visualizer['vis_backends'] = [dict(type='LocalVisBackend')]
else:
raise RuntimeError(
'VisualizationHook must be included in default_hooks.'
'refer to usage '
'"visualization=dict(type=\'VisualizationHook\')"')
return cfg
def main(): def main():
args = parse_args() args = parse_args()
@ -67,6 +99,14 @@ def main():
cfg.load_from = args.checkpoint cfg.load_from = args.checkpoint
# TODO: It will be supported after refactoring the visualizer
if args.show and args.show_dir:
raise NotImplementedError('--show and --show-dir cannot be set '
'at the same time')
if args.show or args.show_dir:
cfg = trigger_visualization_hook(cfg, args)
# save predictions # save predictions
if args.save_preds: if args.save_preds:
dump_metric = dict( dump_metric = dict(
@ -81,7 +121,13 @@ def main():
cfg.test_evaluator = [cfg.test_evaluator, dump_metric] cfg.test_evaluator = [cfg.test_evaluator, dump_metric]
# build the runner from config # build the runner from config
runner = Runner.from_cfg(cfg) if 'runner_type' not in cfg:
# build the default runner
runner = Runner.from_cfg(cfg)
else:
# build customized runner from the registry
# if 'runner_type' is set in the cfg
runner = RUNNERS.build(cfg)
# start testing # start testing
runner.test() runner.test()

View File

@ -6,6 +6,7 @@ import os.path as osp
from mmengine.config import Config, DictAction from mmengine.config import Config, DictAction
from mmengine.logging import print_log from mmengine.logging import print_log
from mmengine.registry import RUNNERS
from mmengine.runner import Runner from mmengine.runner import Runner
from mmocr.utils import register_all_modules from mmocr.utils import register_all_modules
@ -85,8 +86,10 @@ def main():
f'`OptimWrapper` but got {optim_wrapper}.') f'`OptimWrapper` but got {optim_wrapper}.')
cfg.optim_wrapper.type = 'AmpOptimWrapper' cfg.optim_wrapper.type = 'AmpOptimWrapper'
cfg.optim_wrapper.loss_scale = 'dynamic' cfg.optim_wrapper.loss_scale = 'dynamic'
if args.resume: if args.resume:
cfg.resume = True cfg.resume = True
if args.auto_scale_lr: if args.auto_scale_lr:
if cfg.get('auto_scale_lr'): if cfg.get('auto_scale_lr'):
cfg.auto_scale_lr = True cfg.auto_scale_lr = True
@ -96,8 +99,15 @@ def main():
'please set `auto_scale_lr = dict(base_batch_size=xx)', 'please set `auto_scale_lr = dict(base_batch_size=xx)',
logger='current', logger='current',
level=logging.WARNING) level=logging.WARNING)
# build the runner from config # build the runner from config
runner = Runner.from_cfg(cfg) if 'runner_type' not in cfg:
# build the default runner
runner = Runner.from_cfg(cfg)
else:
# build customized runner from the registry
# if 'runner_type' is set in the cfg
runner = RUNNERS.build(cfg)
# start training # start training
runner.train() runner.train()