mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
support padding in test and fix remove gt padding at post_process
This commit is contained in:
parent
7927591a22
commit
70daaaad59
@ -48,18 +48,28 @@ class SegDataPreProcessor(BaseDataPreprocessor):
|
||||
rgb_to_bgr (bool): whether to convert image from RGB to RGB.
|
||||
Defaults to False.
|
||||
batch_augments (list[dict], optional): Batch-level augmentations
|
||||
train_cfg (dict, optional): The padding size config in training, if
|
||||
not specify, will use `size` and `size_divisor` params as default.
|
||||
Defaults to None, only supports keys `size` or `size_divisor`.
|
||||
test_cfg (dict, optional): The padding size config in testing, if not
|
||||
specify, will use `size` and `size_divisor` params as default.
|
||||
Defaults to None, only supports keys `size` or `size_divisor`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mean: Sequence[Number] = None,
|
||||
std: Sequence[Number] = None,
|
||||
size: Optional[tuple] = None,
|
||||
size_divisor: Optional[int] = None,
|
||||
pad_val: Number = 0,
|
||||
seg_pad_val: Number = 255,
|
||||
bgr_to_rgb: bool = False,
|
||||
rgb_to_bgr: bool = False,
|
||||
batch_augments: Optional[List[dict]] = None):
|
||||
def __init__(
|
||||
self,
|
||||
mean: Sequence[Number] = None,
|
||||
std: Sequence[Number] = None,
|
||||
size: Optional[tuple] = None,
|
||||
size_divisor: Optional[int] = None,
|
||||
pad_val: Number = 0,
|
||||
seg_pad_val: Number = 255,
|
||||
bgr_to_rgb: bool = False,
|
||||
rgb_to_bgr: bool = False,
|
||||
batch_augments: Optional[List[dict]] = None,
|
||||
train_cfg: dict = None,
|
||||
test_cfg: dict = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.size = size
|
||||
self.size_divisor = size_divisor
|
||||
@ -86,6 +96,11 @@ class SegDataPreProcessor(BaseDataPreprocessor):
|
||||
# TODO: support batch augmentations.
|
||||
self.batch_augments = batch_augments
|
||||
|
||||
# Support different padding methods in training and testing
|
||||
default_size_cfg = dict(size=size, size_divisor=size_divisor)
|
||||
self.train_cfg = train_cfg if train_cfg else default_size_cfg
|
||||
self.test_cfg = test_cfg if test_cfg else default_size_cfg
|
||||
|
||||
def forward(self, data: dict, training: bool = False) -> Dict[str, Any]:
|
||||
"""Perform normalization、padding and bgr2rgb conversion based on
|
||||
``BaseDataPreprocessor``.
|
||||
@ -111,21 +126,24 @@ class SegDataPreProcessor(BaseDataPreprocessor):
|
||||
if training:
|
||||
assert data_samples is not None, ('During training, ',
|
||||
'`data_samples` must be define.')
|
||||
inputs, data_samples = stack_batch(
|
||||
inputs=inputs,
|
||||
data_samples=data_samples,
|
||||
size=self.size,
|
||||
size_divisor=self.size_divisor,
|
||||
pad_val=self.pad_val,
|
||||
seg_pad_val=self.seg_pad_val)
|
||||
|
||||
if self.batch_augments is not None:
|
||||
inputs, data_samples = self.batch_augments(
|
||||
inputs, data_samples)
|
||||
return dict(inputs=inputs, data_samples=data_samples)
|
||||
else:
|
||||
assert len(inputs) == 1, (
|
||||
'Batch inference is not support currently, '
|
||||
'as the image size might be different in a batch')
|
||||
return dict(
|
||||
inputs=torch.stack(inputs, dim=0), data_samples=data_samples)
|
||||
|
||||
size_cfg = self.train_cfg if training else self.test_cfg
|
||||
size = size_cfg.get('size', None)
|
||||
size_divisor = size_cfg.get('size_divisor', None)
|
||||
|
||||
inputs, data_samples = stack_batch(
|
||||
inputs=inputs,
|
||||
data_samples=data_samples,
|
||||
size=size,
|
||||
size_divisor=size_divisor,
|
||||
pad_val=self.pad_val,
|
||||
seg_pad_val=self.seg_pad_val)
|
||||
|
||||
if self.batch_augments is not None:
|
||||
inputs, data_samples = self.batch_augments(inputs, data_samples)
|
||||
|
||||
return dict(inputs=inputs, data_samples=data_samples)
|
||||
|
@ -165,6 +165,11 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
||||
i_seg_logits = seg_logits[i:i + 1, :,
|
||||
padding_top:H - padding_bottom,
|
||||
padding_left:W - padding_right]
|
||||
i_gt_sem_seg = data_samples[i].gt_sem_seg[:, padding_top:H -
|
||||
padding_bottom,
|
||||
padding_left:W -
|
||||
padding_right]
|
||||
|
||||
# resize as original shape
|
||||
i_seg_logits = resize(
|
||||
i_seg_logits,
|
||||
@ -184,7 +189,9 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
||||
'seg_logits':
|
||||
PixelData(**{'data': i_seg_logits}),
|
||||
'pred_sem_seg':
|
||||
PixelData(**{'data': i_seg_pred})
|
||||
PixelData(**{'data': i_seg_pred}),
|
||||
'gt_sem_seg':
|
||||
PixelData(**{'data': i_gt_sem_seg})
|
||||
})
|
||||
|
||||
return data_samples
|
||||
|
Loading…
x
Reference in New Issue
Block a user