mirror of https://github.com/open-mmlab/mmocr.git
[Fix] Support both ConcatDataset and UniformConcatDataset (#675)
* support UniformConcatDataset * update * rm useless * handle 2d-list datasetspull/690/head
parent
abbae7bfd1
commit
f8dfbd4177
|
@ -18,9 +18,6 @@ def main():
|
||||||
|
|
||||||
# build the model from a config file and a checkpoint file
|
# build the model from a config file and a checkpoint file
|
||||||
model = init_detector(args.config, args.checkpoint, device=args.device)
|
model = init_detector(args.config, args.checkpoint, device=args.device)
|
||||||
if model.cfg.data.test['type'] == 'ConcatDataset':
|
|
||||||
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
|
|
||||||
0].pipeline
|
|
||||||
|
|
||||||
# test a single text
|
# test a single text
|
||||||
input_sentence = input('Please enter a sentence you want to test: ')
|
input_sentence = input('Please enter a sentence you want to test: ')
|
||||||
|
|
|
@ -29,9 +29,6 @@ def main():
|
||||||
device = torch.device(args.device)
|
device = torch.device(args.device)
|
||||||
|
|
||||||
model = init_detector(args.config, args.checkpoint, device=device)
|
model = init_detector(args.config, args.checkpoint, device=device)
|
||||||
if model.cfg.data.test['type'] == 'ConcatDataset':
|
|
||||||
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
|
|
||||||
0].pipeline
|
|
||||||
|
|
||||||
camera = cv2.VideoCapture(args.camera_id)
|
camera = cv2.VideoCapture(args.camera_id)
|
||||||
|
|
||||||
|
|
|
@ -97,7 +97,10 @@ def model_inference(model,
|
||||||
device = next(model.parameters()).device # model device
|
device = next(model.parameters()).device # model device
|
||||||
|
|
||||||
if cfg.data.test.get('pipeline', None) is None:
|
if cfg.data.test.get('pipeline', None) is None:
|
||||||
cfg.data.test.pipeline = cfg.data.test.datasets[0].pipeline
|
if is_2dlist(cfg.data.test.datasets):
|
||||||
|
cfg.data.test.pipeline = cfg.data.test.datasets[0][0].pipeline
|
||||||
|
else:
|
||||||
|
cfg.data.test.pipeline = cfg.data.test.datasets[0].pipeline
|
||||||
if is_2dlist(cfg.data.test.pipeline):
|
if is_2dlist(cfg.data.test.pipeline):
|
||||||
cfg.data.test.pipeline = cfg.data.test.pipeline[0]
|
cfg.data.test.pipeline = cfg.data.test.pipeline[0]
|
||||||
|
|
||||||
|
@ -205,6 +208,13 @@ def text_model_inference(model, input_sentence):
|
||||||
assert isinstance(input_sentence, str)
|
assert isinstance(input_sentence, str)
|
||||||
|
|
||||||
cfg = model.cfg
|
cfg = model.cfg
|
||||||
|
if cfg.data.test.get('pipeline', None) is None:
|
||||||
|
if is_2dlist(cfg.data.test.datasets):
|
||||||
|
cfg.data.test.pipeline = cfg.data.test.datasets[0][0].pipeline
|
||||||
|
else:
|
||||||
|
cfg.data.test.pipeline = cfg.data.test.datasets[0].pipeline
|
||||||
|
if is_2dlist(cfg.data.test.pipeline):
|
||||||
|
cfg.data.test.pipeline = cfg.data.test.pipeline[0]
|
||||||
test_pipeline = Compose(cfg.data.test.pipeline)
|
test_pipeline = Compose(cfg.data.test.pipeline)
|
||||||
data = {'text': input_sentence, 'label': {}}
|
data = {'text': input_sentence, 'label': {}}
|
||||||
|
|
||||||
|
|
|
@ -396,9 +396,6 @@ class MMOCR:
|
||||||
for model in list(filter(None, [self.recog_model, self.detect_model])):
|
for model in list(filter(None, [self.recog_model, self.detect_model])):
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
model = model.module
|
model = model.module
|
||||||
if model.cfg.data.test['type'] == 'ConcatDataset':
|
|
||||||
model.cfg.data.test.pipeline = \
|
|
||||||
model.cfg.data.test['datasets'][0].pipeline
|
|
||||||
|
|
||||||
def readtext(self,
|
def readtext(self,
|
||||||
img,
|
img,
|
||||||
|
|
|
@ -15,10 +15,6 @@ def build_model(config_file):
|
||||||
model = init_detector(config_file, checkpoint=None, device=device)
|
model = init_detector(config_file, checkpoint=None, device=device)
|
||||||
model = revert_sync_batchnorm(model)
|
model = revert_sync_batchnorm(model)
|
||||||
|
|
||||||
if model.cfg.data.test['type'] == 'ConcatDataset':
|
|
||||||
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
|
|
||||||
0].pipeline
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@ from mmdet.datasets.pipelines import Compose
|
||||||
from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer,
|
from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer,
|
||||||
TensorRTDetector, TensorRTRecognizer)
|
TensorRTDetector, TensorRTRecognizer)
|
||||||
from mmocr.datasets.pipelines.crop import crop_img # noqa: F401
|
from mmocr.datasets.pipelines.crop import crop_img # noqa: F401
|
||||||
|
from mmocr.utils import is_2dlist
|
||||||
|
|
||||||
|
|
||||||
def get_GiB(x: int):
|
def get_GiB(x: int):
|
||||||
|
@ -258,9 +259,15 @@ if __name__ == '__main__':
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg = mmcv.Config.fromfile(args.model_config)
|
cfg = mmcv.Config.fromfile(args.model_config)
|
||||||
if cfg.data.test['type'] == 'ConcatDataset':
|
if cfg.data.test.get('pipeline', None) is None:
|
||||||
cfg.data.test.pipeline = \
|
if is_2dlist(cfg.data.test.datasets):
|
||||||
cfg.data.test['datasets'][0].pipeline
|
cfg.data.test.pipeline = \
|
||||||
|
cfg.data.test.datasets[0][0].pipeline
|
||||||
|
else:
|
||||||
|
cfg.data.test.pipeline = \
|
||||||
|
cfg.data.test['datasets'][0].pipeline
|
||||||
|
if is_2dlist(cfg.data.test.pipeline):
|
||||||
|
cfg.data.test.pipeline = cfg.data.test.pipeline[0]
|
||||||
onnx2tensorrt(
|
onnx2tensorrt(
|
||||||
args.onnx_file,
|
args.onnx_file,
|
||||||
args.model_type,
|
args.model_type,
|
||||||
|
|
|
@ -14,6 +14,7 @@ from torch import nn
|
||||||
from mmocr.apis import init_detector
|
from mmocr.apis import init_detector
|
||||||
from mmocr.core.deployment import ONNXRuntimeDetector, ONNXRuntimeRecognizer
|
from mmocr.core.deployment import ONNXRuntimeDetector, ONNXRuntimeRecognizer
|
||||||
from mmocr.datasets.pipelines.crop import crop_img # noqa: F401
|
from mmocr.datasets.pipelines.crop import crop_img # noqa: F401
|
||||||
|
from mmocr.utils import is_2dlist
|
||||||
|
|
||||||
|
|
||||||
def _convert_batchnorm(module):
|
def _convert_batchnorm(module):
|
||||||
|
@ -327,9 +328,15 @@ def main():
|
||||||
model = init_detector(args.model_config, args.model_ckpt, device=device)
|
model = init_detector(args.model_config, args.model_ckpt, device=device)
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
model = model.module
|
model = model.module
|
||||||
if model.cfg.data.test['type'] == 'ConcatDataset':
|
if model.cfg.data.test.get('pipeline', None) is None:
|
||||||
model.cfg.data.test.pipeline = \
|
if is_2dlist(model.cfg.data.test.datasets):
|
||||||
model.cfg.data.test['datasets'][0].pipeline
|
model.cfg.data.test.pipeline = \
|
||||||
|
model.cfg.data.test.datasets[0][0].pipeline
|
||||||
|
else:
|
||||||
|
model.cfg.data.test.pipeline = \
|
||||||
|
model.cfg.data.test['datasets'][0].pipeline
|
||||||
|
if is_2dlist(model.cfg.data.test.pipeline):
|
||||||
|
model.cfg.data.test.pipeline = model.cfg.data.test.pipeline[0]
|
||||||
|
|
||||||
pytorch2onnx(
|
pytorch2onnx(
|
||||||
model,
|
model,
|
||||||
|
|
|
@ -76,9 +76,6 @@ def main():
|
||||||
model = init_detector(args.config, args.checkpoint, device=args.device)
|
model = init_detector(args.config, args.checkpoint, device=args.device)
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
model = model.module
|
model = model.module
|
||||||
if model.cfg.data.test['type'] == 'ConcatDataset':
|
|
||||||
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
|
|
||||||
0].pipeline
|
|
||||||
|
|
||||||
# Start Inference
|
# Start Inference
|
||||||
out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')
|
out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')
|
||||||
|
|
|
@ -60,9 +60,6 @@ def main():
|
||||||
model = init_detector(args.config, args.checkpoint, device=args.device)
|
model = init_detector(args.config, args.checkpoint, device=args.device)
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
model = model.module
|
model = model.module
|
||||||
if model.cfg.data.test['type'] == 'ConcatDataset':
|
|
||||||
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
|
|
||||||
0].pipeline
|
|
||||||
|
|
||||||
# Start Inference
|
# Start Inference
|
||||||
out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')
|
out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')
|
||||||
|
|
|
@ -17,7 +17,7 @@ from mmocr import __version__
|
||||||
from mmocr.apis import init_random_seed, train_detector
|
from mmocr.apis import init_random_seed, train_detector
|
||||||
from mmocr.datasets import build_dataset
|
from mmocr.datasets import build_dataset
|
||||||
from mmocr.models import build_detector
|
from mmocr.models import build_detector
|
||||||
from mmocr.utils import collect_env, get_root_logger
|
from mmocr.utils import collect_env, get_root_logger, is_2dlist
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
@ -72,11 +72,6 @@ def parse_args():
|
||||||
default='none',
|
default='none',
|
||||||
help='Options for job launcher.')
|
help='Options for job launcher.')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
parser.add_argument(
|
|
||||||
'--mc-config',
|
|
||||||
type=str,
|
|
||||||
default='',
|
|
||||||
help='Memory cache config for image loading speed-up during training.')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if 'LOCAL_RANK' not in os.environ:
|
if 'LOCAL_RANK' not in os.environ:
|
||||||
|
@ -100,17 +95,6 @@ def main():
|
||||||
if args.cfg_options is not None:
|
if args.cfg_options is not None:
|
||||||
cfg.merge_from_dict(args.cfg_options)
|
cfg.merge_from_dict(args.cfg_options)
|
||||||
|
|
||||||
# update mc config
|
|
||||||
if args.mc_config:
|
|
||||||
mc = Config.fromfile(args.mc_config)
|
|
||||||
if isinstance(cfg.data.train, list):
|
|
||||||
for i in range(len(cfg.data.train)):
|
|
||||||
cfg.data.train[i].pipeline[0].update(
|
|
||||||
file_client_args=mc['mc_file_client_args'])
|
|
||||||
else:
|
|
||||||
cfg.data.train.pipeline[0].update(
|
|
||||||
file_client_args=mc['mc_file_client_args'])
|
|
||||||
|
|
||||||
# set cudnn_benchmark
|
# set cudnn_benchmark
|
||||||
if cfg.get('cudnn_benchmark', False):
|
if cfg.get('cudnn_benchmark', False):
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
@ -184,12 +168,17 @@ def main():
|
||||||
datasets = [build_dataset(cfg.data.train)]
|
datasets = [build_dataset(cfg.data.train)]
|
||||||
if len(cfg.workflow) == 2:
|
if len(cfg.workflow) == 2:
|
||||||
val_dataset = copy.deepcopy(cfg.data.val)
|
val_dataset = copy.deepcopy(cfg.data.val)
|
||||||
if cfg.data.train['type'] == 'ConcatDataset':
|
if cfg.data.train.get('pipeline', None) is None:
|
||||||
train_pipeline = cfg.data.train['datasets'][0].pipeline
|
if is_2dlist(cfg.data.train.datasets):
|
||||||
|
train_pipeline = cfg.data.train.datasets[0][0].pipeline
|
||||||
|
else:
|
||||||
|
train_pipeline = cfg.data.train.datasets[0].pipeline
|
||||||
|
elif is_2dlist(cfg.data.train.pipeline):
|
||||||
|
train_pipeline = cfg.data.train.pipeline[0]
|
||||||
else:
|
else:
|
||||||
train_pipeline = cfg.data.train.pipeline
|
train_pipeline = cfg.data.train.pipeline
|
||||||
|
|
||||||
if val_dataset['type'] == 'ConcatDataset':
|
if val_dataset['type'] in ['ConcatDataset', 'UniformConcatDataset']:
|
||||||
for dataset in val_dataset['datasets']:
|
for dataset in val_dataset['datasets']:
|
||||||
dataset.pipeline = train_pipeline
|
dataset.pipeline = train_pipeline
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue