diff --git a/mmdeploy/codebase/mmcls/deploy/classification.py b/mmdeploy/codebase/mmcls/deploy/classification.py index 6cb17b298..dad8ba3ff 100644 --- a/mmdeploy/codebase/mmcls/deploy/classification.py +++ b/mmdeploy/codebase/mmcls/deploy/classification.py @@ -63,13 +63,15 @@ def process_model_config(model_cfg: Config, cfg.test_pipeline.pop(0) # check whether input_shape is valid if input_shape is not None: - if 'crop_size' in cfg.test_pipeline[2]: - crop_size = cfg.test_pipeline[2]['crop_size'] - if tuple(input_shape) != (crop_size, crop_size): - logger = get_root_logger() - logger.warning( - f'`input shape` should be equal to `crop_size`: {crop_size},\ - but given: {input_shape}') + for pipeline_component in cfg.test_pipeline: + if 'Crop' in pipeline_component['type']: + if 'crop_size' in pipeline_component: + crop_size = pipeline_component['crop_size'] + if tuple(input_shape) != (crop_size, crop_size): + logger = get_root_logger() + logger.warning( + f'`input shape` should be equal to `crop_size`: {crop_size},\ + but given: {input_shape}') return cfg