[Fix][New_config] Fix demo bug (#1647)
* Fix demo * Update implement --------- Co-authored-by: mzr1996 <mzr1996@163.com>pull/1653/head
parent
6d7fe91a98
commit
d6056af2b8
|
@ -83,9 +83,11 @@ class FeatureExtractor(BaseInferencer):
|
||||||
|
|
||||||
def _init_pipeline(self, cfg: Config) -> Callable:
|
def _init_pipeline(self, cfg: Config) -> Callable:
|
||||||
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||||
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
|
from mmpretrain.datasets import remove_transform
|
||||||
# Image loading is finished in `self.preprocess`.
|
|
||||||
test_pipeline_cfg = test_pipeline_cfg[1:]
|
# Image loading is finished in `self.preprocess`.
|
||||||
|
test_pipeline_cfg = remove_transform(test_pipeline_cfg,
|
||||||
|
'LoadImageFromFile')
|
||||||
test_pipeline = Compose(
|
test_pipeline = Compose(
|
||||||
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
|
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
|
||||||
return test_pipeline
|
return test_pipeline
|
||||||
|
|
|
@ -70,9 +70,11 @@ class ImageCaptionInferencer(BaseInferencer):
|
||||||
|
|
||||||
def _init_pipeline(self, cfg: Config) -> Callable:
|
def _init_pipeline(self, cfg: Config) -> Callable:
|
||||||
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||||
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
|
from mmpretrain.datasets import remove_transform
|
||||||
# Image loading is finished in `self.preprocess`.
|
|
||||||
test_pipeline_cfg = test_pipeline_cfg[1:]
|
# Image loading is finished in `self.preprocess`.
|
||||||
|
test_pipeline_cfg = remove_transform(test_pipeline_cfg,
|
||||||
|
'LoadImageFromFile')
|
||||||
test_pipeline = Compose(
|
test_pipeline = Compose(
|
||||||
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
|
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
|
||||||
return test_pipeline
|
return test_pipeline
|
||||||
|
|
|
@ -110,9 +110,11 @@ class ImageClassificationInferencer(BaseInferencer):
|
||||||
|
|
||||||
def _init_pipeline(self, cfg: Config) -> Callable:
|
def _init_pipeline(self, cfg: Config) -> Callable:
|
||||||
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||||
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
|
from mmpretrain.datasets import remove_transform
|
||||||
# Image loading is finished in `self.preprocess`.
|
|
||||||
test_pipeline_cfg = test_pipeline_cfg[1:]
|
# Image loading is finished in `self.preprocess`.
|
||||||
|
test_pipeline_cfg = remove_transform(test_pipeline_cfg,
|
||||||
|
'LoadImageFromFile')
|
||||||
test_pipeline = Compose(
|
test_pipeline = Compose(
|
||||||
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
|
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
|
||||||
return test_pipeline
|
return test_pipeline
|
||||||
|
|
|
@ -172,9 +172,11 @@ class ImageRetrievalInferencer(BaseInferencer):
|
||||||
|
|
||||||
def _init_pipeline(self, cfg: Config) -> Callable:
|
def _init_pipeline(self, cfg: Config) -> Callable:
|
||||||
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||||
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
|
from mmpretrain.datasets import remove_transform
|
||||||
# Image loading is finished in `self.preprocess`.
|
|
||||||
test_pipeline_cfg = test_pipeline_cfg[1:]
|
# Image loading is finished in `self.preprocess`.
|
||||||
|
test_pipeline_cfg = remove_transform(test_pipeline_cfg,
|
||||||
|
'LoadImageFromFile')
|
||||||
test_pipeline = Compose(
|
test_pipeline = Compose(
|
||||||
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
|
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
|
||||||
return test_pipeline
|
return test_pipeline
|
||||||
|
|
|
@ -86,9 +86,11 @@ class VisualGroundingInferencer(BaseInferencer):
|
||||||
|
|
||||||
def _init_pipeline(self, cfg: Config) -> Callable:
|
def _init_pipeline(self, cfg: Config) -> Callable:
|
||||||
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||||
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
|
from mmpretrain.datasets import remove_transform
|
||||||
# Image loading is finished in `self.preprocess`.
|
|
||||||
test_pipeline_cfg = test_pipeline_cfg[1:]
|
# Image loading is finished in `self.preprocess`.
|
||||||
|
test_pipeline_cfg = remove_transform(test_pipeline_cfg,
|
||||||
|
'LoadImageFromFile')
|
||||||
test_pipeline = Compose(
|
test_pipeline = Compose(
|
||||||
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
|
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
|
||||||
return test_pipeline
|
return test_pipeline
|
||||||
|
|
|
@ -88,9 +88,11 @@ class VisualQuestionAnsweringInferencer(BaseInferencer):
|
||||||
|
|
||||||
def _init_pipeline(self, cfg: Config) -> Callable:
|
def _init_pipeline(self, cfg: Config) -> Callable:
|
||||||
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||||
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
|
from mmpretrain.datasets import remove_transform
|
||||||
# Image loading is finished in `self.preprocess`.
|
|
||||||
test_pipeline_cfg = test_pipeline_cfg[1:]
|
# Image loading is finished in `self.preprocess`.
|
||||||
|
test_pipeline_cfg = remove_transform(test_pipeline_cfg,
|
||||||
|
'LoadImageFromFile')
|
||||||
test_pipeline = Compose(
|
test_pipeline = Compose(
|
||||||
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
|
[TRANSFORMS.build(t) for t in test_pipeline_cfg])
|
||||||
return test_pipeline
|
return test_pipeline
|
||||||
|
|
|
@ -16,6 +16,7 @@ from .processing import (Albumentations, BEiTMaskGenerator, CleanCaption,
|
||||||
RandomErasing, RandomResizedCrop,
|
RandomErasing, RandomResizedCrop,
|
||||||
RandomResizedCropAndInterpolationWithTwoPic,
|
RandomResizedCropAndInterpolationWithTwoPic,
|
||||||
RandomTranslatePad, ResizeEdge, SimMIMMaskGenerator)
|
RandomTranslatePad, ResizeEdge, SimMIMMaskGenerator)
|
||||||
|
from .utils import get_transform_idx, remove_transform
|
||||||
from .wrappers import ApplyToList, MultiView
|
from .wrappers import ApplyToList, MultiView
|
||||||
|
|
||||||
for t in (CenterCrop, LoadImageFromFile, Normalize, RandomFlip,
|
for t in (CenterCrop, LoadImageFromFile, Normalize, RandomFlip,
|
||||||
|
@ -34,5 +35,6 @@ __all__ = [
|
||||||
'SimMIMMaskGenerator', 'CenterCrop', 'LoadImageFromFile', 'Normalize',
|
'SimMIMMaskGenerator', 'CenterCrop', 'LoadImageFromFile', 'Normalize',
|
||||||
'RandomFlip', 'RandomGrayscale', 'RandomResize', 'Resize', 'MultiView',
|
'RandomFlip', 'RandomGrayscale', 'RandomResize', 'Resize', 'MultiView',
|
||||||
'ApplyToList', 'CleanCaption', 'RandomTranslatePad',
|
'ApplyToList', 'CleanCaption', 'RandomTranslatePad',
|
||||||
'RandomResizedCropAndInterpolationWithTwoPic'
|
'RandomResizedCropAndInterpolationWithTwoPic', 'get_transform_idx',
|
||||||
|
'remove_transform'
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import copy
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from mmcv.transforms import BaseTransform
|
||||||
|
|
||||||
|
PIPELINE_TYPE = List[Union[dict, BaseTransform]]
|
||||||
|
|
||||||
|
|
||||||
|
def get_transform_idx(pipeline: PIPELINE_TYPE, target: str) -> int:
|
||||||
|
"""Returns the index of the transform in a pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pipeline (List[dict] | List[BaseTransform]): The transforms list.
|
||||||
|
target (str): The target transform class name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The transform index. Returns -1 if not found.
|
||||||
|
"""
|
||||||
|
for i, transform in enumerate(pipeline):
|
||||||
|
if isinstance(transform, dict):
|
||||||
|
if isinstance(transform['type'], type):
|
||||||
|
if transform['type'].__name__ == target:
|
||||||
|
return i
|
||||||
|
else:
|
||||||
|
if transform['type'] == target:
|
||||||
|
return i
|
||||||
|
else:
|
||||||
|
if transform.__class__.__name__ == target:
|
||||||
|
return i
|
||||||
|
|
||||||
|
return -1
|
||||||
|
|
||||||
|
|
||||||
|
def remove_transform(pipeline: PIPELINE_TYPE, target: str, inplace=False):
|
||||||
|
"""Remove the target transform type from the pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pipeline (List[dict] | List[BaseTransform]): The transforms list.
|
||||||
|
target (str): The target transform class name.
|
||||||
|
inplace (bool): Whether to modify the pipeline inplace.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The modified transform.
|
||||||
|
"""
|
||||||
|
idx = get_transform_idx(pipeline, target)
|
||||||
|
if not inplace:
|
||||||
|
pipeline = copy.deepcopy(pipeline)
|
||||||
|
while idx >= 0:
|
||||||
|
pipeline.pop(idx)
|
||||||
|
idx = get_transform_idx(pipeline, target)
|
||||||
|
|
||||||
|
return pipeline
|
|
@ -52,9 +52,12 @@ class DatasetValidator():
|
||||||
def __init__(self, dataset_cfg, log_file_path):
|
def __init__(self, dataset_cfg, log_file_path):
|
||||||
super(DatasetValidator, self).__init__()
|
super(DatasetValidator, self).__init__()
|
||||||
# keep only LoadImageFromFile pipeline
|
# keep only LoadImageFromFile pipeline
|
||||||
assert dataset_cfg.pipeline[0]['type'] == 'LoadImageFromFile', (
|
from mmpretrain.datasets import get_transform_idx
|
||||||
'This tool is only for datasets needs to load image from files.')
|
|
||||||
self.pipeline = TRANSFORMS.build(dataset_cfg.pipeline[0])
|
load_idx = get_transform_idx(dataset_cfg.pipeline, 'LoadImageFromFile')
|
||||||
|
assert load_idx >= 0, \
|
||||||
|
'This tool is only for datasets needs to load image from files.'
|
||||||
|
self.pipeline = TRANSFORMS.build(dataset_cfg.pipeline[load_idx])
|
||||||
dataset_cfg.pipeline = []
|
dataset_cfg.pipeline = []
|
||||||
dataset = build_dataset(dataset_cfg)
|
dataset = build_dataset(dataset_cfg)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue