mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Fix] Fix mmcls for cifar10 config on dev-1.x (#1056)
* fix mmcls cifar10 config * fix dump-info for cifar10 config * use regex instead of packlist * remove constants unused * fix while logic * fix lint * fix docformatter * fix classification to avoid conflict
This commit is contained in:
parent
1d38426891
commit
01df218da5
@ -161,7 +161,6 @@ class Classification(BaseTask):
|
||||
`np.ndarray`, `torch.Tensor`.
|
||||
input_shape (list[int]): A list of two integer in (width, height)
|
||||
format specifying input shape. Default: None.
|
||||
|
||||
Returns:
|
||||
tuple: (data, img), meta information for the input image and input.
|
||||
"""
|
||||
@ -170,6 +169,11 @@ class Classification(BaseTask):
|
||||
model_cfg = process_model_config(self.model_cfg, imgs, input_shape)
|
||||
from mmengine.dataset import Compose
|
||||
pipeline = deepcopy(model_cfg.test_pipeline)
|
||||
move_pipeline = []
|
||||
while pipeline[-1]['type'] != 'PackClsInputs':
|
||||
sub_pipeline = pipeline.pop(-1)
|
||||
move_pipeline = [sub_pipeline] + move_pipeline
|
||||
pipeline = pipeline[:-1] + move_pipeline + pipeline[-1:]
|
||||
pipeline = Compose(pipeline)
|
||||
|
||||
if isinstance(imgs, str):
|
||||
@ -248,6 +252,13 @@ class Classification(BaseTask):
|
||||
transforms = [
|
||||
item for item in pipeline if 'Random' not in item['type']
|
||||
]
|
||||
move_pipeline = []
|
||||
import re
|
||||
while re.search('Pack[a-z | A-Z]+Inputs',
|
||||
transforms[-1]['type']) is None:
|
||||
sub_pipeline = transforms.pop(-1)
|
||||
move_pipeline = [sub_pipeline] + move_pipeline
|
||||
transforms = transforms[:-1] + move_pipeline + transforms[-1:]
|
||||
for i, transform in enumerate(transforms):
|
||||
if transform['type'] == 'PackClsInputs':
|
||||
meta_keys += transform[
|
||||
|
Loading…
x
Reference in New Issue
Block a user