[Fix] Fix mmcls for cifar10 config on dev-1.x (#1056)
* fix mmcls cifar10 config * fix dump-info for cifar10 config * use regex instead of packlist * remove constants unused * fix while logic * fix lint * fix docformatter * fix classification to avoid conflictpull/1091/head
parent
144fcf7b3a
commit
89a9e92769
|
@ -161,7 +161,6 @@ class Classification(BaseTask):
|
|||
accepted data type are `str`, `np.ndarray`, Sequence.
|
||||
input_shape (list[int]): A list of two integer in (width, height)
|
||||
format specifying input shape. Default: None.
|
||||
|
||||
Returns:
|
||||
tuple: (data, img), meta information for the input image and input.
|
||||
"""
|
||||
|
@ -171,6 +170,11 @@ class Classification(BaseTask):
|
|||
f'test_pipeline not found in {self.model_cfg}.'
|
||||
model_cfg = process_model_config(self.model_cfg, imgs, input_shape)
|
||||
pipeline = deepcopy(model_cfg.test_pipeline)
|
||||
move_pipeline = []
|
||||
while pipeline[-1]['type'] != 'PackClsInputs':
|
||||
sub_pipeline = pipeline.pop(-1)
|
||||
move_pipeline = [sub_pipeline] + move_pipeline
|
||||
pipeline = pipeline[:-1] + move_pipeline + pipeline[-1:]
|
||||
pipeline = Compose(pipeline)
|
||||
|
||||
if isinstance(imgs, str):
|
||||
|
@ -248,6 +252,13 @@ class Classification(BaseTask):
|
|||
transforms = [
|
||||
item for item in pipeline if 'Random' not in item['type']
|
||||
]
|
||||
move_pipeline = []
|
||||
import re
|
||||
while re.search('Pack[a-z | A-Z]+Inputs',
|
||||
transforms[-1]['type']) is None:
|
||||
sub_pipeline = transforms.pop(-1)
|
||||
move_pipeline = [sub_pipeline] + move_pipeline
|
||||
transforms = transforms[:-1] + move_pipeline + transforms[-1:]
|
||||
for i, transform in enumerate(transforms):
|
||||
if transform['type'] == 'PackClsInputs':
|
||||
meta_keys += transform[
|
||||
|
|
Loading…
Reference in New Issue