[Fix] Fix preprocess_model_config for CIFAR dataset (#1659)
* fix cifar10 for mmcls * remove unnecessary codepull/1681/head
parent
15ad234a7a
commit
cbddf5a458
mmdeploy/codebase/mmcls/deploy
|
@ -63,13 +63,15 @@ def process_model_config(model_cfg: Config,
|
|||
cfg.test_pipeline.pop(0)
|
||||
# check whether input_shape is valid
|
||||
if input_shape is not None:
|
||||
if 'crop_size' in cfg.test_pipeline[2]:
|
||||
crop_size = cfg.test_pipeline[2]['crop_size']
|
||||
if tuple(input_shape) != (crop_size, crop_size):
|
||||
logger = get_root_logger()
|
||||
logger.warning(
|
||||
f'`input shape` should be equal to `crop_size`: {crop_size},\
|
||||
but given: {input_shape}')
|
||||
for pipeline_component in cfg.test_pipeline:
|
||||
if 'Crop' in pipeline_component['type']:
|
||||
if 'crop_size' in pipeline_component:
|
||||
crop_size = pipeline_component['crop_size']
|
||||
if tuple(input_shape) != (crop_size, crop_size):
|
||||
logger = get_root_logger()
|
||||
logger.warning(
|
||||
f'`input shape` should be equal to `crop_size`: {crop_size},\
|
||||
but given: {input_shape}')
|
||||
return cfg
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue