diff --git a/mmdeploy/codebase/mmcls/deploy/classification.py b/mmdeploy/codebase/mmcls/deploy/classification.py index c11ee40c3..7c7af13ba 100644 --- a/mmdeploy/codebase/mmcls/deploy/classification.py +++ b/mmdeploy/codebase/mmcls/deploy/classification.py @@ -161,7 +161,6 @@ class Classification(BaseTask): accepted data type are `str`, `np.ndarray`, Sequence. input_shape (list[int]): A list of two integer in (width, height) format specifying input shape. Default: None. - Returns: tuple: (data, img), meta information for the input image and input. """ @@ -171,6 +170,11 @@ class Classification(BaseTask): f'test_pipeline not found in {self.model_cfg}.' model_cfg = process_model_config(self.model_cfg, imgs, input_shape) pipeline = deepcopy(model_cfg.test_pipeline) + move_pipeline = [] + while pipeline[-1]['type'] != 'PackClsInputs': + sub_pipeline = pipeline.pop(-1) + move_pipeline = [sub_pipeline] + move_pipeline + pipeline = pipeline[:-1] + move_pipeline + pipeline[-1:] pipeline = Compose(pipeline) if isinstance(imgs, str): @@ -248,6 +252,13 @@ class Classification(BaseTask): transforms = [ item for item in pipeline if 'Random' not in item['type'] ] + move_pipeline = [] + import re + while re.search('Pack[a-z | A-Z]+Inputs', + transforms[-1]['type']) is None: + sub_pipeline = transforms.pop(-1) + move_pipeline = [sub_pipeline] + move_pipeline + transforms = transforms[:-1] + move_pipeline + transforms[-1:] for i, transform in enumerate(transforms): if transform['type'] == 'PackClsInputs': meta_keys += transform[