mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Fix] Fix mmcls for cifar10 config on dev-1.x (#1056)
* fix mmcls cifar10 config * fix dump-info for cifar10 config * use regex instead of packlist * remove constants unused * fix while logic * fix lint * fix docformatter * fix classification to avoid conflict
This commit is contained in:
parent
1d38426891
commit
01df218da5
@ -161,7 +161,6 @@ class Classification(BaseTask):
|
|||||||
`np.ndarray`, `torch.Tensor`.
|
`np.ndarray`, `torch.Tensor`.
|
||||||
input_shape (list[int]): A list of two integer in (width, height)
|
input_shape (list[int]): A list of two integer in (width, height)
|
||||||
format specifying input shape. Default: None.
|
format specifying input shape. Default: None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (data, img), meta information for the input image and input.
|
tuple: (data, img), meta information for the input image and input.
|
||||||
"""
|
"""
|
||||||
@ -170,6 +169,11 @@ class Classification(BaseTask):
|
|||||||
model_cfg = process_model_config(self.model_cfg, imgs, input_shape)
|
model_cfg = process_model_config(self.model_cfg, imgs, input_shape)
|
||||||
from mmengine.dataset import Compose
|
from mmengine.dataset import Compose
|
||||||
pipeline = deepcopy(model_cfg.test_pipeline)
|
pipeline = deepcopy(model_cfg.test_pipeline)
|
||||||
|
move_pipeline = []
|
||||||
|
while pipeline[-1]['type'] != 'PackClsInputs':
|
||||||
|
sub_pipeline = pipeline.pop(-1)
|
||||||
|
move_pipeline = [sub_pipeline] + move_pipeline
|
||||||
|
pipeline = pipeline[:-1] + move_pipeline + pipeline[-1:]
|
||||||
pipeline = Compose(pipeline)
|
pipeline = Compose(pipeline)
|
||||||
|
|
||||||
if isinstance(imgs, str):
|
if isinstance(imgs, str):
|
||||||
@ -248,6 +252,13 @@ class Classification(BaseTask):
|
|||||||
transforms = [
|
transforms = [
|
||||||
item for item in pipeline if 'Random' not in item['type']
|
item for item in pipeline if 'Random' not in item['type']
|
||||||
]
|
]
|
||||||
|
move_pipeline = []
|
||||||
|
import re
|
||||||
|
while re.search('Pack[a-z | A-Z]+Inputs',
|
||||||
|
transforms[-1]['type']) is None:
|
||||||
|
sub_pipeline = transforms.pop(-1)
|
||||||
|
move_pipeline = [sub_pipeline] + move_pipeline
|
||||||
|
transforms = transforms[:-1] + move_pipeline + transforms[-1:]
|
||||||
for i, transform in enumerate(transforms):
|
for i, transform in enumerate(transforms):
|
||||||
if transform['type'] == 'PackClsInputs':
|
if transform['type'] == 'PackClsInputs':
|
||||||
meta_keys += transform[
|
meta_keys += transform[
|
||||||
|
Loading…
x
Reference in New Issue
Block a user