[Fix] Fix mmcls for cifar10 config on dev-1.x (#1056)

* fix mmcls cifar10 config

* fix dump-info for cifar10 config

* use regex instead of packlist

* remove constants unused

* fix while logic

* fix lint

* fix docformatter

* fix classification to avoid conflict
This commit is contained in:
hanrui1sensetime 2022-09-22 11:40:45 +08:00 committed by GitHub
parent 1d38426891
commit 01df218da5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -161,7 +161,6 @@ class Classification(BaseTask):
`np.ndarray`, `torch.Tensor`.
input_shape (list[int]): A list of two integer in (width, height)
format specifying input shape. Default: None.
Returns:
tuple: (data, img), meta information for the input image and input.
"""
@ -170,6 +169,11 @@ class Classification(BaseTask):
model_cfg = process_model_config(self.model_cfg, imgs, input_shape)
from mmengine.dataset import Compose
pipeline = deepcopy(model_cfg.test_pipeline)
move_pipeline = []
while pipeline[-1]['type'] != 'PackClsInputs':
sub_pipeline = pipeline.pop(-1)
move_pipeline = [sub_pipeline] + move_pipeline
pipeline = pipeline[:-1] + move_pipeline + pipeline[-1:]
pipeline = Compose(pipeline)
if isinstance(imgs, str):
@ -248,6 +252,13 @@ class Classification(BaseTask):
transforms = [
item for item in pipeline if 'Random' not in item['type']
]
move_pipeline = []
import re
while re.search('Pack[a-z | A-Z]+Inputs',
transforms[-1]['type']) is None:
sub_pipeline = transforms.pop(-1)
move_pipeline = [sub_pipeline] + move_pipeline
transforms = transforms[:-1] + move_pipeline + transforms[-1:]
for i, transform in enumerate(transforms):
if transform['type'] == 'PackClsInputs':
meta_keys += transform[