[Fix][New_config] Fix demo bug (#1647)
* Fix demo * Update implement --------- Co-authored-by: mzr1996 <mzr1996@163.com>pull/1653/head
parent
6d7fe91a98
commit
d6056af2b8
|
@ -83,9 +83,11 @@ class FeatureExtractor(BaseInferencer):
|
|||
|
||||
def _init_pipeline(self, cfg: Config) -> Callable:
|
||||
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
|
||||
# Image loading is finished in `self.preprocess`.
|
||||
test_pipeline_cfg = test_pipeline_cfg[1:]
|
||||
from mmpretrain.datasets import remove_transform
|
||||
|
||||
# Image loading is finished in `self.preprocess`.
|
||||
test_pipeline_cfg = remove_transform(test_pipeline_cfg,
|
||||
'LoadImageFromFile')
|
||||
test_pipeline = Compose(
|
||||
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
|
||||
return test_pipeline
|
||||
|
|
|
@ -70,9 +70,11 @@ class ImageCaptionInferencer(BaseInferencer):
|
|||
|
||||
def _init_pipeline(self, cfg: Config) -> Callable:
|
||||
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
|
||||
# Image loading is finished in `self.preprocess`.
|
||||
test_pipeline_cfg = test_pipeline_cfg[1:]
|
||||
from mmpretrain.datasets import remove_transform
|
||||
|
||||
# Image loading is finished in `self.preprocess`.
|
||||
test_pipeline_cfg = remove_transform(test_pipeline_cfg,
|
||||
'LoadImageFromFile')
|
||||
test_pipeline = Compose(
|
||||
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
|
||||
return test_pipeline
|
||||
|
|
|
@ -110,9 +110,11 @@ class ImageClassificationInferencer(BaseInferencer):
|
|||
|
||||
def _init_pipeline(self, cfg: Config) -> Callable:
|
||||
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
|
||||
# Image loading is finished in `self.preprocess`.
|
||||
test_pipeline_cfg = test_pipeline_cfg[1:]
|
||||
from mmpretrain.datasets import remove_transform
|
||||
|
||||
# Image loading is finished in `self.preprocess`.
|
||||
test_pipeline_cfg = remove_transform(test_pipeline_cfg,
|
||||
'LoadImageFromFile')
|
||||
test_pipeline = Compose(
|
||||
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
|
||||
return test_pipeline
|
||||
|
|
|
@ -172,9 +172,11 @@ class ImageRetrievalInferencer(BaseInferencer):
|
|||
|
||||
def _init_pipeline(self, cfg: Config) -> Callable:
|
||||
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
|
||||
# Image loading is finished in `self.preprocess`.
|
||||
test_pipeline_cfg = test_pipeline_cfg[1:]
|
||||
from mmpretrain.datasets import remove_transform
|
||||
|
||||
# Image loading is finished in `self.preprocess`.
|
||||
test_pipeline_cfg = remove_transform(test_pipeline_cfg,
|
||||
'LoadImageFromFile')
|
||||
test_pipeline = Compose(
|
||||
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
|
||||
return test_pipeline
|
||||
|
|
|
@ -86,9 +86,11 @@ class VisualGroundingInferencer(BaseInferencer):
|
|||
|
||||
def _init_pipeline(self, cfg: Config) -> Callable:
|
||||
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
|
||||
# Image loading is finished in `self.preprocess`.
|
||||
test_pipeline_cfg = test_pipeline_cfg[1:]
|
||||
from mmpretrain.datasets import remove_transform
|
||||
|
||||
# Image loading is finished in `self.preprocess`.
|
||||
test_pipeline_cfg = remove_transform(test_pipeline_cfg,
|
||||
'LoadImageFromFile')
|
||||
test_pipeline = Compose(
|
||||
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
|
||||
return test_pipeline
|
||||
|
|
|
@ -88,9 +88,11 @@ class VisualQuestionAnsweringInferencer(BaseInferencer):
|
|||
|
||||
def _init_pipeline(self, cfg: Config) -> Callable:
|
||||
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
|
||||
# Image loading is finished in `self.preprocess`.
|
||||
test_pipeline_cfg = test_pipeline_cfg[1:]
|
||||
from mmpretrain.datasets import remove_transform
|
||||
|
||||
# Image loading is finished in `self.preprocess`.
|
||||
test_pipeline_cfg = remove_transform(test_pipeline_cfg,
|
||||
'LoadImageFromFile')
|
||||
test_pipeline = Compose(
|
||||
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
|
||||
return test_pipeline
|
||||
|
|
|
@ -16,6 +16,7 @@ from .processing import (Albumentations, BEiTMaskGenerator, CleanCaption,
|
|||
RandomErasing, RandomResizedCrop,
|
||||
RandomResizedCropAndInterpolationWithTwoPic,
|
||||
RandomTranslatePad, ResizeEdge, SimMIMMaskGenerator)
|
||||
from .utils import get_transform_idx, remove_transform
|
||||
from .wrappers import ApplyToList, MultiView
|
||||
|
||||
for t in (CenterCrop, LoadImageFromFile, Normalize, RandomFlip,
|
||||
|
@ -34,5 +35,6 @@ __all__ = [
|
|||
'SimMIMMaskGenerator', 'CenterCrop', 'LoadImageFromFile', 'Normalize',
|
||||
'RandomFlip', 'RandomGrayscale', 'RandomResize', 'Resize', 'MultiView',
|
||||
'ApplyToList', 'CleanCaption', 'RandomTranslatePad',
|
||||
'RandomResizedCropAndInterpolationWithTwoPic'
|
||||
'RandomResizedCropAndInterpolationWithTwoPic', 'get_transform_idx',
|
||||
'remove_transform'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from typing import List, Union
|
||||
|
||||
from mmcv.transforms import BaseTransform
|
||||
|
||||
PIPELINE_TYPE = List[Union[dict, BaseTransform]]
|
||||
|
||||
|
||||
def get_transform_idx(pipeline: PIPELINE_TYPE, target: str) -> int:
|
||||
"""Returns the index of the transform in a pipeline.
|
||||
|
||||
Args:
|
||||
pipeline (List[dict] | List[BaseTransform]): The transforms list.
|
||||
target (str): The target transform class name.
|
||||
|
||||
Returns:
|
||||
int: The transform index. Returns -1 if not found.
|
||||
"""
|
||||
for i, transform in enumerate(pipeline):
|
||||
if isinstance(transform, dict):
|
||||
if isinstance(transform['type'], type):
|
||||
if transform['type'].__name__ == target:
|
||||
return i
|
||||
else:
|
||||
if transform['type'] == target:
|
||||
return i
|
||||
else:
|
||||
if transform.__class__.__name__ == target:
|
||||
return i
|
||||
|
||||
return -1
|
||||
|
||||
|
||||
def remove_transform(pipeline: PIPELINE_TYPE, target: str, inplace=False):
|
||||
"""Remove the target transform type from the pipeline.
|
||||
|
||||
Args:
|
||||
pipeline (List[dict] | List[BaseTransform]): The transforms list.
|
||||
target (str): The target transform class name.
|
||||
inplace (bool): Whether to modify the pipeline inplace.
|
||||
|
||||
Returns:
|
||||
The modified transform.
|
||||
"""
|
||||
idx = get_transform_idx(pipeline, target)
|
||||
if not inplace:
|
||||
pipeline = copy.deepcopy(pipeline)
|
||||
while idx >= 0:
|
||||
pipeline.pop(idx)
|
||||
idx = get_transform_idx(pipeline, target)
|
||||
|
||||
return pipeline
|
|
@ -52,9 +52,12 @@ class DatasetValidator():
|
|||
def __init__(self, dataset_cfg, log_file_path):
|
||||
super(DatasetValidator, self).__init__()
|
||||
# keep only LoadImageFromFile pipeline
|
||||
assert dataset_cfg.pipeline[0]['type'] == 'LoadImageFromFile', (
|
||||
'This tool is only for datasets needs to load image from files.')
|
||||
self.pipeline = TRANSFORMS.build(dataset_cfg.pipeline[0])
|
||||
from mmpretrain.datasets import get_transform_idx
|
||||
|
||||
load_idx = get_transform_idx(dataset_cfg.pipeline, 'LoadImageFromFile')
|
||||
assert load_idx >= 0, \
|
||||
'This tool is only for datasets needs to load image from files.'
|
||||
self.pipeline = TRANSFORMS.build(dataset_cfg.pipeline[load_idx])
|
||||
dataset_cfg.pipeline = []
|
||||
dataset = build_dataset(dataset_cfg)
|
||||
|
||||
|
|
Loading…
Reference in New Issue