mirror of https://github.com/open-mmlab/mmocr.git
[Fix] Support both ConcatDataset and UniformConcatDataset (#675)
* support UniformConcatDataset * update * rm useless * handle 2d-list datasetspull/690/head
parent
abbae7bfd1
commit
f8dfbd4177
|
@ -18,9 +18,6 @@ def main():
|
|||
|
||||
# build the model from a config file and a checkpoint file
|
||||
model = init_detector(args.config, args.checkpoint, device=args.device)
|
||||
if model.cfg.data.test['type'] == 'ConcatDataset':
|
||||
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
|
||||
0].pipeline
|
||||
|
||||
# test a single text
|
||||
input_sentence = input('Please enter a sentence you want to test: ')
|
||||
|
|
|
@ -29,9 +29,6 @@ def main():
|
|||
device = torch.device(args.device)
|
||||
|
||||
model = init_detector(args.config, args.checkpoint, device=device)
|
||||
if model.cfg.data.test['type'] == 'ConcatDataset':
|
||||
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
|
||||
0].pipeline
|
||||
|
||||
camera = cv2.VideoCapture(args.camera_id)
|
||||
|
||||
|
|
|
@ -97,7 +97,10 @@ def model_inference(model,
|
|||
device = next(model.parameters()).device # model device
|
||||
|
||||
if cfg.data.test.get('pipeline', None) is None:
|
||||
cfg.data.test.pipeline = cfg.data.test.datasets[0].pipeline
|
||||
if is_2dlist(cfg.data.test.datasets):
|
||||
cfg.data.test.pipeline = cfg.data.test.datasets[0][0].pipeline
|
||||
else:
|
||||
cfg.data.test.pipeline = cfg.data.test.datasets[0].pipeline
|
||||
if is_2dlist(cfg.data.test.pipeline):
|
||||
cfg.data.test.pipeline = cfg.data.test.pipeline[0]
|
||||
|
||||
|
@ -205,6 +208,13 @@ def text_model_inference(model, input_sentence):
|
|||
assert isinstance(input_sentence, str)
|
||||
|
||||
cfg = model.cfg
|
||||
if cfg.data.test.get('pipeline', None) is None:
|
||||
if is_2dlist(cfg.data.test.datasets):
|
||||
cfg.data.test.pipeline = cfg.data.test.datasets[0][0].pipeline
|
||||
else:
|
||||
cfg.data.test.pipeline = cfg.data.test.datasets[0].pipeline
|
||||
if is_2dlist(cfg.data.test.pipeline):
|
||||
cfg.data.test.pipeline = cfg.data.test.pipeline[0]
|
||||
test_pipeline = Compose(cfg.data.test.pipeline)
|
||||
data = {'text': input_sentence, 'label': {}}
|
||||
|
||||
|
|
|
@ -396,9 +396,6 @@ class MMOCR:
|
|||
for model in list(filter(None, [self.recog_model, self.detect_model])):
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
if model.cfg.data.test['type'] == 'ConcatDataset':
|
||||
model.cfg.data.test.pipeline = \
|
||||
model.cfg.data.test['datasets'][0].pipeline
|
||||
|
||||
def readtext(self,
|
||||
img,
|
||||
|
|
|
@ -15,10 +15,6 @@ def build_model(config_file):
|
|||
model = init_detector(config_file, checkpoint=None, device=device)
|
||||
model = revert_sync_batchnorm(model)
|
||||
|
||||
if model.cfg.data.test['type'] == 'ConcatDataset':
|
||||
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
|
||||
0].pipeline
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ from mmdet.datasets.pipelines import Compose
|
|||
from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer,
|
||||
TensorRTDetector, TensorRTRecognizer)
|
||||
from mmocr.datasets.pipelines.crop import crop_img # noqa: F401
|
||||
from mmocr.utils import is_2dlist
|
||||
|
||||
|
||||
def get_GiB(x: int):
|
||||
|
@ -258,9 +259,15 @@ if __name__ == '__main__':
|
|||
}
|
||||
|
||||
cfg = mmcv.Config.fromfile(args.model_config)
|
||||
if cfg.data.test['type'] == 'ConcatDataset':
|
||||
cfg.data.test.pipeline = \
|
||||
cfg.data.test['datasets'][0].pipeline
|
||||
if cfg.data.test.get('pipeline', None) is None:
|
||||
if is_2dlist(cfg.data.test.datasets):
|
||||
cfg.data.test.pipeline = \
|
||||
cfg.data.test.datasets[0][0].pipeline
|
||||
else:
|
||||
cfg.data.test.pipeline = \
|
||||
cfg.data.test['datasets'][0].pipeline
|
||||
if is_2dlist(cfg.data.test.pipeline):
|
||||
cfg.data.test.pipeline = cfg.data.test.pipeline[0]
|
||||
onnx2tensorrt(
|
||||
args.onnx_file,
|
||||
args.model_type,
|
||||
|
|
|
@ -14,6 +14,7 @@ from torch import nn
|
|||
from mmocr.apis import init_detector
|
||||
from mmocr.core.deployment import ONNXRuntimeDetector, ONNXRuntimeRecognizer
|
||||
from mmocr.datasets.pipelines.crop import crop_img # noqa: F401
|
||||
from mmocr.utils import is_2dlist
|
||||
|
||||
|
||||
def _convert_batchnorm(module):
|
||||
|
@ -327,9 +328,15 @@ def main():
|
|||
model = init_detector(args.model_config, args.model_ckpt, device=device)
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
if model.cfg.data.test['type'] == 'ConcatDataset':
|
||||
model.cfg.data.test.pipeline = \
|
||||
model.cfg.data.test['datasets'][0].pipeline
|
||||
if model.cfg.data.test.get('pipeline', None) is None:
|
||||
if is_2dlist(model.cfg.data.test.datasets):
|
||||
model.cfg.data.test.pipeline = \
|
||||
model.cfg.data.test.datasets[0][0].pipeline
|
||||
else:
|
||||
model.cfg.data.test.pipeline = \
|
||||
model.cfg.data.test['datasets'][0].pipeline
|
||||
if is_2dlist(model.cfg.data.test.pipeline):
|
||||
model.cfg.data.test.pipeline = model.cfg.data.test.pipeline[0]
|
||||
|
||||
pytorch2onnx(
|
||||
model,
|
||||
|
|
|
@ -76,9 +76,6 @@ def main():
|
|||
model = init_detector(args.config, args.checkpoint, device=args.device)
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
if model.cfg.data.test['type'] == 'ConcatDataset':
|
||||
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
|
||||
0].pipeline
|
||||
|
||||
# Start Inference
|
||||
out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')
|
||||
|
|
|
@ -60,9 +60,6 @@ def main():
|
|||
model = init_detector(args.config, args.checkpoint, device=args.device)
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
if model.cfg.data.test['type'] == 'ConcatDataset':
|
||||
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
|
||||
0].pipeline
|
||||
|
||||
# Start Inference
|
||||
out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')
|
||||
|
|
|
@ -17,7 +17,7 @@ from mmocr import __version__
|
|||
from mmocr.apis import init_random_seed, train_detector
|
||||
from mmocr.datasets import build_dataset
|
||||
from mmocr.models import build_detector
|
||||
from mmocr.utils import collect_env, get_root_logger
|
||||
from mmocr.utils import collect_env, get_root_logger, is_2dlist
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -72,11 +72,6 @@ def parse_args():
|
|||
default='none',
|
||||
help='Options for job launcher.')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
parser.add_argument(
|
||||
'--mc-config',
|
||||
type=str,
|
||||
default='',
|
||||
help='Memory cache config for image loading speed-up during training.')
|
||||
|
||||
args = parser.parse_args()
|
||||
if 'LOCAL_RANK' not in os.environ:
|
||||
|
@ -100,17 +95,6 @@ def main():
|
|||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# update mc config
|
||||
if args.mc_config:
|
||||
mc = Config.fromfile(args.mc_config)
|
||||
if isinstance(cfg.data.train, list):
|
||||
for i in range(len(cfg.data.train)):
|
||||
cfg.data.train[i].pipeline[0].update(
|
||||
file_client_args=mc['mc_file_client_args'])
|
||||
else:
|
||||
cfg.data.train.pipeline[0].update(
|
||||
file_client_args=mc['mc_file_client_args'])
|
||||
|
||||
# set cudnn_benchmark
|
||||
if cfg.get('cudnn_benchmark', False):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
@ -184,12 +168,17 @@ def main():
|
|||
datasets = [build_dataset(cfg.data.train)]
|
||||
if len(cfg.workflow) == 2:
|
||||
val_dataset = copy.deepcopy(cfg.data.val)
|
||||
if cfg.data.train['type'] == 'ConcatDataset':
|
||||
train_pipeline = cfg.data.train['datasets'][0].pipeline
|
||||
if cfg.data.train.get('pipeline', None) is None:
|
||||
if is_2dlist(cfg.data.train.datasets):
|
||||
train_pipeline = cfg.data.train.datasets[0][0].pipeline
|
||||
else:
|
||||
train_pipeline = cfg.data.train.datasets[0].pipeline
|
||||
elif is_2dlist(cfg.data.train.pipeline):
|
||||
train_pipeline = cfg.data.train.pipeline[0]
|
||||
else:
|
||||
train_pipeline = cfg.data.train.pipeline
|
||||
|
||||
if val_dataset['type'] == 'ConcatDataset':
|
||||
if val_dataset['type'] in ['ConcatDataset', 'UniformConcatDataset']:
|
||||
for dataset in val_dataset['datasets']:
|
||||
dataset.pipeline = train_pipeline
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue