mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
54 lines
1.5 KiB
Python
54 lines
1.5 KiB
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||
|
import copy
|
||
|
from typing import List, Union
|
||
|
|
||
|
from mmcv.transforms import BaseTransform
|
||
|
|
||
|
PIPELINE_TYPE = List[Union[dict, BaseTransform]]
|
||
|
|
||
|
|
||
|
def get_transform_idx(pipeline: PIPELINE_TYPE, target: str) -> int:
|
||
|
"""Returns the index of the transform in a pipeline.
|
||
|
|
||
|
Args:
|
||
|
pipeline (List[dict] | List[BaseTransform]): The transforms list.
|
||
|
target (str): The target transform class name.
|
||
|
|
||
|
Returns:
|
||
|
int: The transform index. Returns -1 if not found.
|
||
|
"""
|
||
|
for i, transform in enumerate(pipeline):
|
||
|
if isinstance(transform, dict):
|
||
|
if isinstance(transform['type'], type):
|
||
|
if transform['type'].__name__ == target:
|
||
|
return i
|
||
|
else:
|
||
|
if transform['type'] == target:
|
||
|
return i
|
||
|
else:
|
||
|
if transform.__class__.__name__ == target:
|
||
|
return i
|
||
|
|
||
|
return -1
|
||
|
|
||
|
|
||
|
def remove_transform(pipeline: PIPELINE_TYPE, target: str, inplace=False):
|
||
|
"""Remove the target transform type from the pipeline.
|
||
|
|
||
|
Args:
|
||
|
pipeline (List[dict] | List[BaseTransform]): The transforms list.
|
||
|
target (str): The target transform class name.
|
||
|
inplace (bool): Whether to modify the pipeline inplace.
|
||
|
|
||
|
Returns:
|
||
|
The modified transform.
|
||
|
"""
|
||
|
idx = get_transform_idx(pipeline, target)
|
||
|
if not inplace:
|
||
|
pipeline = copy.deepcopy(pipeline)
|
||
|
while idx >= 0:
|
||
|
pipeline.pop(idx)
|
||
|
idx = get_transform_idx(pipeline, target)
|
||
|
|
||
|
return pipeline
|