diff --git a/mmdeploy/codebase/mmcls/deploy/classification.py b/mmdeploy/codebase/mmcls/deploy/classification.py index 8e91c2f4c..a51bea1ca 100644 --- a/mmdeploy/codebase/mmcls/deploy/classification.py +++ b/mmdeploy/codebase/mmcls/deploy/classification.py @@ -233,8 +233,6 @@ class Classification(BaseTask): log_file (str | None): The file to write the evaluation results. Defaults to `None` and the results will only print on stdout. """ - import warnings - from mmcv.utils import get_logger logger = get_logger('test', log_file=log_file, log_level=logging.INFO) @@ -243,7 +241,7 @@ class Classification(BaseTask): for k, v in results.items(): logger.info(f'{k} : {v:.2f}') else: - warnings.warn('Evaluation metrics are not specified.') + logger.warning('Evaluation metrics are not specified.') scores = np.vstack(outputs) pred_score = np.max(scores, axis=1) pred_label = np.argmax(scores, axis=1) @@ -281,8 +279,14 @@ class Classification(BaseTask): dict: Composed of the postprocess information. """ postprocess = self.model_cfg.model.head - assert 'topk' in postprocess, 'model config lack topk' - postprocess.topk = max(postprocess.topk) + if 'topk' not in postprocess: + topk = (1, ) + logger = get_root_logger() + logger.warning('no topk in postprocess config, using default \ + topk value.') + else: + topk = postprocess.topk + postprocess.topk = max(topk) return postprocess def get_model_name(self) -> str: diff --git a/requirements/optional.txt b/requirements/optional.txt index 5f5251130..9a077ff56 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -1,4 +1,4 @@ -mmcls>=0.21.0,<=0.22.1 +mmcls>=0.21.0,<=0.23.0 mmdet>=2.19.0,<=2.20.0 mmedit mmocr>=0.3.0,<=0.4.1 diff --git a/tests/regression/mmcls.yml b/tests/regression/mmcls.yml index c204383d0..78943b63d 100644 --- a/tests/regression/mmcls.yml +++ b/tests/regression/mmcls.yml @@ -144,6 +144,18 @@ models: - *pipeline_pplnn_dynamic_fp32 - *pipeline_openvino_dynamic_fp32 + - name: DenseNet + metafile: configs/densenet/metafile.yml + model_configs: + - configs/densenet/densenet121_4xb256_in1k.py + pipelines: + - *pipeline_ts_fp32 + - *pipeline_ort_dynamic_fp32 + - *pipeline_trt_dynamic_fp16 + - *pipeline_ncnn_static_fp32 + - *pipeline_pplnn_dynamic_fp32 + - *pipeline_openvino_dynamic_fp32 + - name: SE-ResNet metafile: configs/seresnet/metafile.yml model_configs: