[Fix] Fix mmcls for cifar10 config on dev-1.x (#1056)

* fix mmcls cifar10 config

* fix dump-info for cifar10 config

* use regex instead of packlist

* remove constants unused

* fix while logic

* fix lint

* fix docformatter

* fix classification to avoid conflict
pull/1091/head
hanrui1sensetime 2022-09-22 11:40:45 +08:00 committed by RunningLeon
parent 144fcf7b3a
commit 89a9e92769
1 changed files with 12 additions and 1 deletions

View File

@ -161,7 +161,6 @@ class Classification(BaseTask):
accepted data type are `str`, `np.ndarray`, Sequence.
input_shape (list[int]): A list of two integer in (width, height)
format specifying input shape. Default: None.
Returns:
tuple: (data, img), meta information for the input image and input.
"""
@ -171,6 +170,11 @@ class Classification(BaseTask):
f'test_pipeline not found in {self.model_cfg}.'
model_cfg = process_model_config(self.model_cfg, imgs, input_shape)
pipeline = deepcopy(model_cfg.test_pipeline)
move_pipeline = []
while pipeline[-1]['type'] != 'PackClsInputs':
sub_pipeline = pipeline.pop(-1)
move_pipeline = [sub_pipeline] + move_pipeline
pipeline = pipeline[:-1] + move_pipeline + pipeline[-1:]
pipeline = Compose(pipeline)
if isinstance(imgs, str):
@ -248,6 +252,13 @@ class Classification(BaseTask):
transforms = [
item for item in pipeline if 'Random' not in item['type']
]
move_pipeline = []
import re
while re.search('Pack[a-z | A-Z]+Inputs',
transforms[-1]['type']) is None:
sub_pipeline = transforms.pop(-1)
move_pipeline = [sub_pipeline] + move_pipeline
transforms = transforms[:-1] + move_pipeline + transforms[-1:]
for i, transform in enumerate(transforms):
if transform['type'] == 'PackClsInputs':
meta_keys += transform[