diff --git a/mmdeploy/codebase/mmcls/deploy/classification.py b/mmdeploy/codebase/mmcls/deploy/classification.py index 9d6cb4430..6e3ef7038 100644 --- a/mmdeploy/codebase/mmcls/deploy/classification.py +++ b/mmdeploy/codebase/mmcls/deploy/classification.py @@ -161,7 +161,6 @@ class Classification(BaseTask): `np.ndarray`, `torch.Tensor`. input_shape (list[int]): A list of two integer in (width, height) format specifying input shape. Default: None. - Returns: tuple: (data, img), meta information for the input image and input. """ @@ -170,6 +169,11 @@ class Classification(BaseTask): model_cfg = process_model_config(self.model_cfg, imgs, input_shape) from mmengine.dataset import Compose pipeline = deepcopy(model_cfg.test_pipeline) + move_pipeline = [] + while pipeline[-1]['type'] != 'PackClsInputs': + sub_pipeline = pipeline.pop(-1) + move_pipeline = [sub_pipeline] + move_pipeline + pipeline = pipeline[:-1] + move_pipeline + pipeline[-1:] pipeline = Compose(pipeline) if isinstance(imgs, str): @@ -248,6 +252,13 @@ class Classification(BaseTask): transforms = [ item for item in pipeline if 'Random' not in item['type'] ] + move_pipeline = [] + import re + while re.search('Pack[a-z | A-Z]+Inputs', + transforms[-1]['type']) is None: + sub_pipeline = transforms.pop(-1) + move_pipeline = [sub_pipeline] + move_pipeline + transforms = transforms[:-1] + move_pipeline + transforms[-1:] for i, transform in enumerate(transforms): if transform['type'] == 'PackClsInputs': meta_keys += transform[