diff --git a/mmdeploy/codebase/mmseg/deploy/segmentation.py b/mmdeploy/codebase/mmseg/deploy/segmentation.py index 3b8adec1e..3dc471696 100644 --- a/mmdeploy/codebase/mmseg/deploy/segmentation.py +++ b/mmdeploy/codebase/mmseg/deploy/segmentation.py @@ -303,10 +303,10 @@ class Segmentation(BaseTask): params = self.model_cfg.model.decode_head if isinstance(params, list): params = params[-1] - postprocess = dict(params=params, type='ResizeMask') with_argmax = get_codebase_config(self.deploy_cfg).get( 'with_argmax', True) - postprocess['with_argmax'] = with_argmax + params['with_argmax'] = with_argmax + postprocess = dict(params=params, type='ResizeMask') return postprocess def get_model_name(self, *args, **kwargs) -> str: