diff --git a/mmdeploy/codebase/mmedit/deploy/super_resolution.py b/mmdeploy/codebase/mmedit/deploy/super_resolution.py index 4fb5de02a..6ab534587 100644 --- a/mmdeploy/codebase/mmedit/deploy/super_resolution.py +++ b/mmdeploy/codebase/mmedit/deploy/super_resolution.py @@ -309,14 +309,13 @@ class SuperResolution(BaseTask): preprocess = model_cfg.test_pipeline preprocess.insert(1, model_cfg.model.data_preprocessor) + preprocess.insert(2, dict(type='ImageToTensor', keys=['img'])) transforms = preprocess for i, transform in enumerate(transforms): if 'keys' in transform and transform['keys'] == ['lq']: transform['keys'] = ['img'] if 'key' in transform and transform['key'] == 'lq': transform['key'] = 'img' - if transform['type'] == 'ToTensor': - transform['type'] = 'ImageToTensor' if transform['type'] == 'EditDataPreprocessor': transform['type'] = 'Normalize' transform['to_rgb'] = transform.get('to_rgb', False)