From 01df218da54a1197929facd72875cc0cddaa6c27 Mon Sep 17 00:00:00 2001 From: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com> Date: Thu, 22 Sep 2022 11:40:45 +0800 Subject: [PATCH] [Fix] Fix mmcls for cifar10 config on dev-1.x (#1056) * fix mmcls cifar10 config * fix dump-info for cifar10 config * use regex instead of packlist * remove constants unused * fix while logic * fix lint * fix docformatter * fix classification to avoid conflict --- mmdeploy/codebase/mmcls/deploy/classification.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/mmdeploy/codebase/mmcls/deploy/classification.py b/mmdeploy/codebase/mmcls/deploy/classification.py index 9d6cb4430..6e3ef7038 100644 --- a/mmdeploy/codebase/mmcls/deploy/classification.py +++ b/mmdeploy/codebase/mmcls/deploy/classification.py @@ -161,7 +161,6 @@ class Classification(BaseTask): `np.ndarray`, `torch.Tensor`. input_shape (list[int]): A list of two integer in (width, height) format specifying input shape. Default: None. - Returns: tuple: (data, img), meta information for the input image and input. """ @@ -170,6 +169,11 @@ class Classification(BaseTask): model_cfg = process_model_config(self.model_cfg, imgs, input_shape) from mmengine.dataset import Compose pipeline = deepcopy(model_cfg.test_pipeline) + move_pipeline = [] + while pipeline[-1]['type'] != 'PackClsInputs': + sub_pipeline = pipeline.pop(-1) + move_pipeline = [sub_pipeline] + move_pipeline + pipeline = pipeline[:-1] + move_pipeline + pipeline[-1:] pipeline = Compose(pipeline) if isinstance(imgs, str): @@ -248,6 +252,13 @@ class Classification(BaseTask): transforms = [ item for item in pipeline if 'Random' not in item['type'] ] + move_pipeline = [] + import re + while re.search('Pack[a-z | A-Z]+Inputs', + transforms[-1]['type']) is None: + sub_pipeline = transforms.pop(-1) + move_pipeline = [sub_pipeline] + move_pipeline + transforms = transforms[:-1] + move_pipeline + transforms[-1:] for i, transform in enumerate(transforms): if transform['type'] == 'PackClsInputs': meta_keys += transform[