[Fix] Support both ConcatDataset and UniformConcatDataset (#675)

* support UniformConcatDataset

* update

* rm useless

* handle 2d-list datasets
pull/690/head
Hongbin Sun 2021-12-22 20:32:02 +08:00 committed by GitHub
parent abbae7bfd1
commit f8dfbd4177
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 40 additions and 46 deletions

View File

@ -18,9 +18,6 @@ def main():
# build the model from a config file and a checkpoint file # build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device) model = init_detector(args.config, args.checkpoint, device=args.device)
if model.cfg.data.test['type'] == 'ConcatDataset':
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
0].pipeline
# test a single text # test a single text
input_sentence = input('Please enter a sentence you want to test: ') input_sentence = input('Please enter a sentence you want to test: ')

View File

@ -29,9 +29,6 @@ def main():
device = torch.device(args.device) device = torch.device(args.device)
model = init_detector(args.config, args.checkpoint, device=device) model = init_detector(args.config, args.checkpoint, device=device)
if model.cfg.data.test['type'] == 'ConcatDataset':
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
0].pipeline
camera = cv2.VideoCapture(args.camera_id) camera = cv2.VideoCapture(args.camera_id)

View File

@ -97,7 +97,10 @@ def model_inference(model,
device = next(model.parameters()).device # model device device = next(model.parameters()).device # model device
if cfg.data.test.get('pipeline', None) is None: if cfg.data.test.get('pipeline', None) is None:
cfg.data.test.pipeline = cfg.data.test.datasets[0].pipeline if is_2dlist(cfg.data.test.datasets):
cfg.data.test.pipeline = cfg.data.test.datasets[0][0].pipeline
else:
cfg.data.test.pipeline = cfg.data.test.datasets[0].pipeline
if is_2dlist(cfg.data.test.pipeline): if is_2dlist(cfg.data.test.pipeline):
cfg.data.test.pipeline = cfg.data.test.pipeline[0] cfg.data.test.pipeline = cfg.data.test.pipeline[0]
@ -205,6 +208,13 @@ def text_model_inference(model, input_sentence):
assert isinstance(input_sentence, str) assert isinstance(input_sentence, str)
cfg = model.cfg cfg = model.cfg
if cfg.data.test.get('pipeline', None) is None:
if is_2dlist(cfg.data.test.datasets):
cfg.data.test.pipeline = cfg.data.test.datasets[0][0].pipeline
else:
cfg.data.test.pipeline = cfg.data.test.datasets[0].pipeline
if is_2dlist(cfg.data.test.pipeline):
cfg.data.test.pipeline = cfg.data.test.pipeline[0]
test_pipeline = Compose(cfg.data.test.pipeline) test_pipeline = Compose(cfg.data.test.pipeline)
data = {'text': input_sentence, 'label': {}} data = {'text': input_sentence, 'label': {}}

View File

@ -396,9 +396,6 @@ class MMOCR:
for model in list(filter(None, [self.recog_model, self.detect_model])): for model in list(filter(None, [self.recog_model, self.detect_model])):
if hasattr(model, 'module'): if hasattr(model, 'module'):
model = model.module model = model.module
if model.cfg.data.test['type'] == 'ConcatDataset':
model.cfg.data.test.pipeline = \
model.cfg.data.test['datasets'][0].pipeline
def readtext(self, def readtext(self,
img, img,

View File

@ -15,10 +15,6 @@ def build_model(config_file):
model = init_detector(config_file, checkpoint=None, device=device) model = init_detector(config_file, checkpoint=None, device=device)
model = revert_sync_batchnorm(model) model = revert_sync_batchnorm(model)
if model.cfg.data.test['type'] == 'ConcatDataset':
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
0].pipeline
return model return model

View File

@ -16,6 +16,7 @@ from mmdet.datasets.pipelines import Compose
from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer, from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer,
TensorRTDetector, TensorRTRecognizer) TensorRTDetector, TensorRTRecognizer)
from mmocr.datasets.pipelines.crop import crop_img # noqa: F401 from mmocr.datasets.pipelines.crop import crop_img # noqa: F401
from mmocr.utils import is_2dlist
def get_GiB(x: int): def get_GiB(x: int):
@ -258,9 +259,15 @@ if __name__ == '__main__':
} }
cfg = mmcv.Config.fromfile(args.model_config) cfg = mmcv.Config.fromfile(args.model_config)
if cfg.data.test['type'] == 'ConcatDataset': if cfg.data.test.get('pipeline', None) is None:
cfg.data.test.pipeline = \ if is_2dlist(cfg.data.test.datasets):
cfg.data.test['datasets'][0].pipeline cfg.data.test.pipeline = \
cfg.data.test.datasets[0][0].pipeline
else:
cfg.data.test.pipeline = \
cfg.data.test['datasets'][0].pipeline
if is_2dlist(cfg.data.test.pipeline):
cfg.data.test.pipeline = cfg.data.test.pipeline[0]
onnx2tensorrt( onnx2tensorrt(
args.onnx_file, args.onnx_file,
args.model_type, args.model_type,

View File

@ -14,6 +14,7 @@ from torch import nn
from mmocr.apis import init_detector from mmocr.apis import init_detector
from mmocr.core.deployment import ONNXRuntimeDetector, ONNXRuntimeRecognizer from mmocr.core.deployment import ONNXRuntimeDetector, ONNXRuntimeRecognizer
from mmocr.datasets.pipelines.crop import crop_img # noqa: F401 from mmocr.datasets.pipelines.crop import crop_img # noqa: F401
from mmocr.utils import is_2dlist
def _convert_batchnorm(module): def _convert_batchnorm(module):
@ -327,9 +328,15 @@ def main():
model = init_detector(args.model_config, args.model_ckpt, device=device) model = init_detector(args.model_config, args.model_ckpt, device=device)
if hasattr(model, 'module'): if hasattr(model, 'module'):
model = model.module model = model.module
if model.cfg.data.test['type'] == 'ConcatDataset': if model.cfg.data.test.get('pipeline', None) is None:
model.cfg.data.test.pipeline = \ if is_2dlist(model.cfg.data.test.datasets):
model.cfg.data.test['datasets'][0].pipeline model.cfg.data.test.pipeline = \
model.cfg.data.test.datasets[0][0].pipeline
else:
model.cfg.data.test.pipeline = \
model.cfg.data.test['datasets'][0].pipeline
if is_2dlist(model.cfg.data.test.pipeline):
model.cfg.data.test.pipeline = model.cfg.data.test.pipeline[0]
pytorch2onnx( pytorch2onnx(
model, model,

View File

@ -76,9 +76,6 @@ def main():
model = init_detector(args.config, args.checkpoint, device=args.device) model = init_detector(args.config, args.checkpoint, device=args.device)
if hasattr(model, 'module'): if hasattr(model, 'module'):
model = model.module model = model.module
if model.cfg.data.test['type'] == 'ConcatDataset':
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
0].pipeline
# Start Inference # Start Inference
out_vis_dir = osp.join(args.out_dir, 'out_vis_dir') out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')

View File

@ -60,9 +60,6 @@ def main():
model = init_detector(args.config, args.checkpoint, device=args.device) model = init_detector(args.config, args.checkpoint, device=args.device)
if hasattr(model, 'module'): if hasattr(model, 'module'):
model = model.module model = model.module
if model.cfg.data.test['type'] == 'ConcatDataset':
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
0].pipeline
# Start Inference # Start Inference
out_vis_dir = osp.join(args.out_dir, 'out_vis_dir') out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')

View File

@ -17,7 +17,7 @@ from mmocr import __version__
from mmocr.apis import init_random_seed, train_detector from mmocr.apis import init_random_seed, train_detector
from mmocr.datasets import build_dataset from mmocr.datasets import build_dataset
from mmocr.models import build_detector from mmocr.models import build_detector
from mmocr.utils import collect_env, get_root_logger from mmocr.utils import collect_env, get_root_logger, is_2dlist
def parse_args(): def parse_args():
@ -72,11 +72,6 @@ def parse_args():
default='none', default='none',
help='Options for job launcher.') help='Options for job launcher.')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--mc-config',
type=str,
default='',
help='Memory cache config for image loading speed-up during training.')
args = parser.parse_args() args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ: if 'LOCAL_RANK' not in os.environ:
@ -100,17 +95,6 @@ def main():
if args.cfg_options is not None: if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options) cfg.merge_from_dict(args.cfg_options)
# update mc config
if args.mc_config:
mc = Config.fromfile(args.mc_config)
if isinstance(cfg.data.train, list):
for i in range(len(cfg.data.train)):
cfg.data.train[i].pipeline[0].update(
file_client_args=mc['mc_file_client_args'])
else:
cfg.data.train.pipeline[0].update(
file_client_args=mc['mc_file_client_args'])
# set cudnn_benchmark # set cudnn_benchmark
if cfg.get('cudnn_benchmark', False): if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
@ -184,12 +168,17 @@ def main():
datasets = [build_dataset(cfg.data.train)] datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2: if len(cfg.workflow) == 2:
val_dataset = copy.deepcopy(cfg.data.val) val_dataset = copy.deepcopy(cfg.data.val)
if cfg.data.train['type'] == 'ConcatDataset': if cfg.data.train.get('pipeline', None) is None:
train_pipeline = cfg.data.train['datasets'][0].pipeline if is_2dlist(cfg.data.train.datasets):
train_pipeline = cfg.data.train.datasets[0][0].pipeline
else:
train_pipeline = cfg.data.train.datasets[0].pipeline
elif is_2dlist(cfg.data.train.pipeline):
train_pipeline = cfg.data.train.pipeline[0]
else: else:
train_pipeline = cfg.data.train.pipeline train_pipeline = cfg.data.train.pipeline
if val_dataset['type'] == 'ConcatDataset': if val_dataset['type'] in ['ConcatDataset', 'UniformConcatDataset']:
for dataset in val_dataset['datasets']: for dataset in val_dataset['datasets']:
dataset.pipeline = train_pipeline dataset.pipeline = train_pipeline
else: else: