From 70daaaad59261a00b44d11ed8d130fd721556e7e Mon Sep 17 00:00:00 2001 From: xiexinch Date: Thu, 10 Nov 2022 14:21:05 +0800 Subject: [PATCH] support padding in test and fix remove gt padding at post_process --- mmseg/models/data_preprocessor.py | 66 ++++++++++++++++++++----------- mmseg/models/segmentors/base.py | 9 ++++- 2 files changed, 50 insertions(+), 25 deletions(-) diff --git a/mmseg/models/data_preprocessor.py b/mmseg/models/data_preprocessor.py index 34087d0c0..f3e0b02c4 100644 --- a/mmseg/models/data_preprocessor.py +++ b/mmseg/models/data_preprocessor.py @@ -48,18 +48,28 @@ class SegDataPreProcessor(BaseDataPreprocessor): rgb_to_bgr (bool): whether to convert image from RGB to RGB. Defaults to False. batch_augments (list[dict], optional): Batch-level augmentations + train_cfg (dict, optional): The padding size config in training, if + not specify, will use `size` and `size_divisor` params as default. + Defaults to None, only supports keys `size` or `size_divisor`. + test_cfg (dict, optional): The padding size config in testing, if not + specify, will use `size` and `size_divisor` params as default. + Defaults to None, only supports keys `size` or `size_divisor`. """ - def __init__(self, - mean: Sequence[Number] = None, - std: Sequence[Number] = None, - size: Optional[tuple] = None, - size_divisor: Optional[int] = None, - pad_val: Number = 0, - seg_pad_val: Number = 255, - bgr_to_rgb: bool = False, - rgb_to_bgr: bool = False, - batch_augments: Optional[List[dict]] = None): + def __init__( + self, + mean: Sequence[Number] = None, + std: Sequence[Number] = None, + size: Optional[tuple] = None, + size_divisor: Optional[int] = None, + pad_val: Number = 0, + seg_pad_val: Number = 255, + bgr_to_rgb: bool = False, + rgb_to_bgr: bool = False, + batch_augments: Optional[List[dict]] = None, + train_cfg: dict = None, + test_cfg: dict = None, + ): super().__init__() self.size = size self.size_divisor = size_divisor @@ -86,6 +96,11 @@ class SegDataPreProcessor(BaseDataPreprocessor): # TODO: support batch augmentations. self.batch_augments = batch_augments + # Support different padding methods in training and testing + default_size_cfg = dict(size=size, size_divisor=size_divisor) + self.train_cfg = train_cfg if train_cfg else default_size_cfg + self.test_cfg = test_cfg if test_cfg else default_size_cfg + def forward(self, data: dict, training: bool = False) -> Dict[str, Any]: """Perform normalization、padding and bgr2rgb conversion based on ``BaseDataPreprocessor``. @@ -111,21 +126,24 @@ class SegDataPreProcessor(BaseDataPreprocessor): if training: assert data_samples is not None, ('During training, ', '`data_samples` must be define.') - inputs, data_samples = stack_batch( - inputs=inputs, - data_samples=data_samples, - size=self.size, - size_divisor=self.size_divisor, - pad_val=self.pad_val, - seg_pad_val=self.seg_pad_val) - - if self.batch_augments is not None: - inputs, data_samples = self.batch_augments( - inputs, data_samples) - return dict(inputs=inputs, data_samples=data_samples) else: assert len(inputs) == 1, ( 'Batch inference is not support currently, ' 'as the image size might be different in a batch') - return dict( - inputs=torch.stack(inputs, dim=0), data_samples=data_samples) + + size_cfg = self.train_cfg if training else self.test_cfg + size = size_cfg.get('size', None) + size_divisor = size_cfg.get('size_divisor', None) + + inputs, data_samples = stack_batch( + inputs=inputs, + data_samples=data_samples, + size=size, + size_divisor=size_divisor, + pad_val=self.pad_val, + seg_pad_val=self.seg_pad_val) + + if self.batch_augments is not None: + inputs, data_samples = self.batch_augments(inputs, data_samples) + + return dict(inputs=inputs, data_samples=data_samples) diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index dfceddd99..66bfb1424 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -165,6 +165,11 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta): i_seg_logits = seg_logits[i:i + 1, :, padding_top:H - padding_bottom, padding_left:W - padding_right] + i_gt_sem_seg = data_samples[i].gt_sem_seg[:, padding_top:H - + padding_bottom, + padding_left:W - + padding_right] + # resize as original shape i_seg_logits = resize( i_seg_logits, @@ -184,7 +189,9 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta): 'seg_logits': PixelData(**{'data': i_seg_logits}), 'pred_sem_seg': - PixelData(**{'data': i_seg_pred}) + PixelData(**{'data': i_seg_pred}), + 'gt_sem_seg': + PixelData(**{'data': i_gt_sem_seg}) }) return data_samples