diff --git a/demo/ner_demo.py b/demo/ner_demo.py index 9301b65c..0fdcd86f 100755 --- a/demo/ner_demo.py +++ b/demo/ner_demo.py @@ -18,9 +18,6 @@ def main(): # build the model from a config file and a checkpoint file model = init_detector(args.config, args.checkpoint, device=args.device) - if model.cfg.data.test['type'] == 'ConcatDataset': - model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][ - 0].pipeline # test a single text input_sentence = input('Please enter a sentence you want to test: ') diff --git a/demo/webcam_demo.py b/demo/webcam_demo.py index b7b61f39..475c29c2 100644 --- a/demo/webcam_demo.py +++ b/demo/webcam_demo.py @@ -29,9 +29,6 @@ def main(): device = torch.device(args.device) model = init_detector(args.config, args.checkpoint, device=device) - if model.cfg.data.test['type'] == 'ConcatDataset': - model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][ - 0].pipeline camera = cv2.VideoCapture(args.camera_id) diff --git a/mmocr/apis/inference.py b/mmocr/apis/inference.py index ca2ae7c7..1a8d5eec 100644 --- a/mmocr/apis/inference.py +++ b/mmocr/apis/inference.py @@ -97,7 +97,10 @@ def model_inference(model, device = next(model.parameters()).device # model device if cfg.data.test.get('pipeline', None) is None: - cfg.data.test.pipeline = cfg.data.test.datasets[0].pipeline + if is_2dlist(cfg.data.test.datasets): + cfg.data.test.pipeline = cfg.data.test.datasets[0][0].pipeline + else: + cfg.data.test.pipeline = cfg.data.test.datasets[0].pipeline if is_2dlist(cfg.data.test.pipeline): cfg.data.test.pipeline = cfg.data.test.pipeline[0] @@ -205,6 +208,13 @@ def text_model_inference(model, input_sentence): assert isinstance(input_sentence, str) cfg = model.cfg + if cfg.data.test.get('pipeline', None) is None: + if is_2dlist(cfg.data.test.datasets): + cfg.data.test.pipeline = cfg.data.test.datasets[0][0].pipeline + else: + cfg.data.test.pipeline = cfg.data.test.datasets[0].pipeline + if is_2dlist(cfg.data.test.pipeline): + cfg.data.test.pipeline = cfg.data.test.pipeline[0] test_pipeline = Compose(cfg.data.test.pipeline) data = {'text': input_sentence, 'label': {}} diff --git a/mmocr/utils/ocr.py b/mmocr/utils/ocr.py index d4bcebfb..09158fea 100755 --- a/mmocr/utils/ocr.py +++ b/mmocr/utils/ocr.py @@ -396,9 +396,6 @@ class MMOCR: for model in list(filter(None, [self.recog_model, self.detect_model])): if hasattr(model, 'module'): model = model.module - if model.cfg.data.test['type'] == 'ConcatDataset': - model.cfg.data.test.pipeline = \ - model.cfg.data.test['datasets'][0].pipeline def readtext(self, img, diff --git a/tests/test_apis/test_model_inference.py b/tests/test_apis/test_model_inference.py index eb787fb8..62bdda6d 100644 --- a/tests/test_apis/test_model_inference.py +++ b/tests/test_apis/test_model_inference.py @@ -15,10 +15,6 @@ def build_model(config_file): model = init_detector(config_file, checkpoint=None, device=device) model = revert_sync_batchnorm(model) - if model.cfg.data.test['type'] == 'ConcatDataset': - model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][ - 0].pipeline - return model diff --git a/tools/deployment/onnx2tensorrt.py b/tools/deployment/onnx2tensorrt.py index adca671e..8d49a796 100644 --- a/tools/deployment/onnx2tensorrt.py +++ b/tools/deployment/onnx2tensorrt.py @@ -16,6 +16,7 @@ from mmdet.datasets.pipelines import Compose from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer, TensorRTDetector, TensorRTRecognizer) from mmocr.datasets.pipelines.crop import crop_img # noqa: F401 +from mmocr.utils import is_2dlist def get_GiB(x: int): @@ -258,9 +259,15 @@ if __name__ == '__main__': } cfg = mmcv.Config.fromfile(args.model_config) - if cfg.data.test['type'] == 'ConcatDataset': - cfg.data.test.pipeline = \ - cfg.data.test['datasets'][0].pipeline + if cfg.data.test.get('pipeline', None) is None: + if is_2dlist(cfg.data.test.datasets): + cfg.data.test.pipeline = \ + cfg.data.test.datasets[0][0].pipeline + else: + cfg.data.test.pipeline = \ + cfg.data.test['datasets'][0].pipeline + if is_2dlist(cfg.data.test.pipeline): + cfg.data.test.pipeline = cfg.data.test.pipeline[0] onnx2tensorrt( args.onnx_file, args.model_type, diff --git a/tools/deployment/pytorch2onnx.py b/tools/deployment/pytorch2onnx.py index 503be5b3..954bbd9c 100644 --- a/tools/deployment/pytorch2onnx.py +++ b/tools/deployment/pytorch2onnx.py @@ -14,6 +14,7 @@ from torch import nn from mmocr.apis import init_detector from mmocr.core.deployment import ONNXRuntimeDetector, ONNXRuntimeRecognizer from mmocr.datasets.pipelines.crop import crop_img # noqa: F401 +from mmocr.utils import is_2dlist def _convert_batchnorm(module): @@ -327,9 +328,15 @@ def main(): model = init_detector(args.model_config, args.model_ckpt, device=device) if hasattr(model, 'module'): model = model.module - if model.cfg.data.test['type'] == 'ConcatDataset': - model.cfg.data.test.pipeline = \ - model.cfg.data.test['datasets'][0].pipeline + if model.cfg.data.test.get('pipeline', None) is None: + if is_2dlist(model.cfg.data.test.datasets): + model.cfg.data.test.pipeline = \ + model.cfg.data.test.datasets[0][0].pipeline + else: + model.cfg.data.test.pipeline = \ + model.cfg.data.test['datasets'][0].pipeline + if is_2dlist(model.cfg.data.test.pipeline): + model.cfg.data.test.pipeline = model.cfg.data.test.pipeline[0] pytorch2onnx( model, diff --git a/tools/det_test_imgs.py b/tools/det_test_imgs.py index c26d2b25..75ddf298 100755 --- a/tools/det_test_imgs.py +++ b/tools/det_test_imgs.py @@ -76,9 +76,6 @@ def main(): model = init_detector(args.config, args.checkpoint, device=args.device) if hasattr(model, 'module'): model = model.module - if model.cfg.data.test['type'] == 'ConcatDataset': - model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][ - 0].pipeline # Start Inference out_vis_dir = osp.join(args.out_dir, 'out_vis_dir') diff --git a/tools/recog_test_imgs.py b/tools/recog_test_imgs.py index 5b72a901..6b6da088 100755 --- a/tools/recog_test_imgs.py +++ b/tools/recog_test_imgs.py @@ -60,9 +60,6 @@ def main(): model = init_detector(args.config, args.checkpoint, device=args.device) if hasattr(model, 'module'): model = model.module - if model.cfg.data.test['type'] == 'ConcatDataset': - model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][ - 0].pipeline # Start Inference out_vis_dir = osp.join(args.out_dir, 'out_vis_dir') diff --git a/tools/train.py b/tools/train.py index 78eb1355..0f5085d7 100755 --- a/tools/train.py +++ b/tools/train.py @@ -17,7 +17,7 @@ from mmocr import __version__ from mmocr.apis import init_random_seed, train_detector from mmocr.datasets import build_dataset from mmocr.models import build_detector -from mmocr.utils import collect_env, get_root_logger +from mmocr.utils import collect_env, get_root_logger, is_2dlist def parse_args(): @@ -72,11 +72,6 @@ def parse_args(): default='none', help='Options for job launcher.') parser.add_argument('--local_rank', type=int, default=0) - parser.add_argument( - '--mc-config', - type=str, - default='', - help='Memory cache config for image loading speed-up during training.') args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: @@ -100,17 +95,6 @@ def main(): if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) - # update mc config - if args.mc_config: - mc = Config.fromfile(args.mc_config) - if isinstance(cfg.data.train, list): - for i in range(len(cfg.data.train)): - cfg.data.train[i].pipeline[0].update( - file_client_args=mc['mc_file_client_args']) - else: - cfg.data.train.pipeline[0].update( - file_client_args=mc['mc_file_client_args']) - # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True @@ -184,12 +168,17 @@ def main(): datasets = [build_dataset(cfg.data.train)] if len(cfg.workflow) == 2: val_dataset = copy.deepcopy(cfg.data.val) - if cfg.data.train['type'] == 'ConcatDataset': - train_pipeline = cfg.data.train['datasets'][0].pipeline + if cfg.data.train.get('pipeline', None) is None: + if is_2dlist(cfg.data.train.datasets): + train_pipeline = cfg.data.train.datasets[0][0].pipeline + else: + train_pipeline = cfg.data.train.datasets[0].pipeline + elif is_2dlist(cfg.data.train.pipeline): + train_pipeline = cfg.data.train.pipeline[0] else: train_pipeline = cfg.data.train.pipeline - if val_dataset['type'] == 'ConcatDataset': + if val_dataset['type'] in ['ConcatDataset', 'UniformConcatDataset']: for dataset in val_dataset['datasets']: dataset.pipeline = train_pipeline else: