[Fix][New_config] Fix demo bug (#1647)

* Fix demo

* Update implement

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
pull/1653/head
Mashiro 2023-06-19 15:15:28 +08:00 committed by GitHub
parent 6d7fe91a98
commit d6056af2b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 92 additions and 22 deletions

View File

@ -83,9 +83,11 @@ class FeatureExtractor(BaseInferencer):
def _init_pipeline(self, cfg: Config) -> Callable:
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
# Image loading is finished in `self.preprocess`.
test_pipeline_cfg = test_pipeline_cfg[1:]
from mmpretrain.datasets import remove_transform
# Image loading is finished in `self.preprocess`.
test_pipeline_cfg = remove_transform(test_pipeline_cfg,
'LoadImageFromFile')
test_pipeline = Compose(
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
return test_pipeline

View File

@ -70,9 +70,11 @@ class ImageCaptionInferencer(BaseInferencer):
def _init_pipeline(self, cfg: Config) -> Callable:
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
# Image loading is finished in `self.preprocess`.
test_pipeline_cfg = test_pipeline_cfg[1:]
from mmpretrain.datasets import remove_transform
# Image loading is finished in `self.preprocess`.
test_pipeline_cfg = remove_transform(test_pipeline_cfg,
'LoadImageFromFile')
test_pipeline = Compose(
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
return test_pipeline

View File

@ -110,9 +110,11 @@ class ImageClassificationInferencer(BaseInferencer):
def _init_pipeline(self, cfg: Config) -> Callable:
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
# Image loading is finished in `self.preprocess`.
test_pipeline_cfg = test_pipeline_cfg[1:]
from mmpretrain.datasets import remove_transform
# Image loading is finished in `self.preprocess`.
test_pipeline_cfg = remove_transform(test_pipeline_cfg,
'LoadImageFromFile')
test_pipeline = Compose(
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
return test_pipeline

View File

@ -172,9 +172,11 @@ class ImageRetrievalInferencer(BaseInferencer):
def _init_pipeline(self, cfg: Config) -> Callable:
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
# Image loading is finished in `self.preprocess`.
test_pipeline_cfg = test_pipeline_cfg[1:]
from mmpretrain.datasets import remove_transform
# Image loading is finished in `self.preprocess`.
test_pipeline_cfg = remove_transform(test_pipeline_cfg,
'LoadImageFromFile')
test_pipeline = Compose(
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
return test_pipeline

View File

@ -86,9 +86,11 @@ class VisualGroundingInferencer(BaseInferencer):
def _init_pipeline(self, cfg: Config) -> Callable:
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
# Image loading is finished in `self.preprocess`.
test_pipeline_cfg = test_pipeline_cfg[1:]
from mmpretrain.datasets import remove_transform
# Image loading is finished in `self.preprocess`.
test_pipeline_cfg = remove_transform(test_pipeline_cfg,
'LoadImageFromFile')
test_pipeline = Compose(
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
return test_pipeline

View File

@ -88,9 +88,11 @@ class VisualQuestionAnsweringInferencer(BaseInferencer):
def _init_pipeline(self, cfg: Config) -> Callable:
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
# Image loading is finished in `self.preprocess`.
test_pipeline_cfg = test_pipeline_cfg[1:]
from mmpretrain.datasets import remove_transform
# Image loading is finished in `self.preprocess`.
test_pipeline_cfg = remove_transform(test_pipeline_cfg,
'LoadImageFromFile')
test_pipeline = Compose(
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
return test_pipeline

View File

@ -16,6 +16,7 @@ from .processing import (Albumentations, BEiTMaskGenerator, CleanCaption,
RandomErasing, RandomResizedCrop,
RandomResizedCropAndInterpolationWithTwoPic,
RandomTranslatePad, ResizeEdge, SimMIMMaskGenerator)
from .utils import get_transform_idx, remove_transform
from .wrappers import ApplyToList, MultiView
for t in (CenterCrop, LoadImageFromFile, Normalize, RandomFlip,
@ -34,5 +35,6 @@ __all__ = [
'SimMIMMaskGenerator', 'CenterCrop', 'LoadImageFromFile', 'Normalize',
'RandomFlip', 'RandomGrayscale', 'RandomResize', 'Resize', 'MultiView',
'ApplyToList', 'CleanCaption', 'RandomTranslatePad',
'RandomResizedCropAndInterpolationWithTwoPic'
'RandomResizedCropAndInterpolationWithTwoPic', 'get_transform_idx',
'remove_transform'
]

View File

@ -0,0 +1,53 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import List, Union
from mmcv.transforms import BaseTransform
PIPELINE_TYPE = List[Union[dict, BaseTransform]]
def get_transform_idx(pipeline: PIPELINE_TYPE, target: str) -> int:
"""Returns the index of the transform in a pipeline.
Args:
pipeline (List[dict] | List[BaseTransform]): The transforms list.
target (str): The target transform class name.
Returns:
int: The transform index. Returns -1 if not found.
"""
for i, transform in enumerate(pipeline):
if isinstance(transform, dict):
if isinstance(transform['type'], type):
if transform['type'].__name__ == target:
return i
else:
if transform['type'] == target:
return i
else:
if transform.__class__.__name__ == target:
return i
return -1
def remove_transform(pipeline: PIPELINE_TYPE, target: str, inplace=False):
"""Remove the target transform type from the pipeline.
Args:
pipeline (List[dict] | List[BaseTransform]): The transforms list.
target (str): The target transform class name.
inplace (bool): Whether to modify the pipeline inplace.
Returns:
The modified transform.
"""
idx = get_transform_idx(pipeline, target)
if not inplace:
pipeline = copy.deepcopy(pipeline)
while idx >= 0:
pipeline.pop(idx)
idx = get_transform_idx(pipeline, target)
return pipeline

View File

@ -52,9 +52,12 @@ class DatasetValidator():
def __init__(self, dataset_cfg, log_file_path):
super(DatasetValidator, self).__init__()
# keep only LoadImageFromFile pipeline
assert dataset_cfg.pipeline[0]['type'] == 'LoadImageFromFile', (
'This tool is only for datasets needs to load image from files.')
self.pipeline = TRANSFORMS.build(dataset_cfg.pipeline[0])
from mmpretrain.datasets import get_transform_idx
load_idx = get_transform_idx(dataset_cfg.pipeline, 'LoadImageFromFile')
assert load_idx >= 0, \
'This tool is only for datasets needs to load image from files.'
self.pipeline = TRANSFORMS.build(dataset_cfg.pipeline[load_idx])
dataset_cfg.pipeline = []
dataset = build_dataset(dataset_cfg)