mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
95 lines
3.4 KiB
Python
95 lines
3.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import List, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from mmseg.core import SegDataSample
|
|
|
|
|
|
def add_prefix(inputs, prefix):
|
|
"""Add prefix for dict.
|
|
|
|
Args:
|
|
inputs (dict): The input dict with str keys.
|
|
prefix (str): The prefix to add.
|
|
|
|
Returns:
|
|
|
|
dict: The dict with keys updated with ``prefix``.
|
|
"""
|
|
|
|
outputs = dict()
|
|
for name, value in inputs.items():
|
|
outputs[f'{prefix}.{name}'] = value
|
|
|
|
return outputs
|
|
|
|
|
|
def stack_batch(inputs: List[torch.Tensor],
|
|
batch_data_samples: List[SegDataSample],
|
|
size: tuple = None,
|
|
pad_value: Union[int, float] = 0,
|
|
seg_pad_val: Union[int, float] = 255,
|
|
padding_mode: str = 'constant') -> torch.Tensor:
|
|
"""Stack multiple inputs to form a batch and pad the images and gt_sem_segs
|
|
to the max shape use the right bottom padding mode.
|
|
|
|
Args:
|
|
inputs (List[Tensor]): The input multiple tensors. each is a
|
|
CHW 3D-tensor.
|
|
batch_data_samples (list[:obj:`SegDataSample`]): The Data
|
|
Samples. It usually includes information such as `gt_sem_seg`.
|
|
size (tuple): The img crop size.
|
|
pad_value (int, float): The padding value. Defaults to 0
|
|
seg_pad_val (int, float): The padding value. Defaults to 255
|
|
padding_mode (str): Type of padding. Default: constant.
|
|
- constant: pads with a constant value, this value is specified
|
|
with pad_val.
|
|
|
|
Returns:
|
|
Tensor: The 4D-tensor.
|
|
batch_data_samples (list[:obj:`SegDataSample`]): After the padding of
|
|
the gt_seg_map.
|
|
"""
|
|
assert isinstance(inputs, list), \
|
|
f'Expected input type to be list, but got {type(inputs)}'
|
|
assert len(set([tensor.ndim for tensor in inputs])) == 1, \
|
|
f'Expected the dimensions of all inputs must be the same, ' \
|
|
f'but got {[tensor.ndim for tensor in inputs]}'
|
|
assert inputs[0].ndim == 3, f'Expected tensor dimension to be 3, ' \
|
|
f'but got {inputs[0].ndim}'
|
|
assert len(set([tensor.shape[0] for tensor in inputs])) == 1, \
|
|
f'Expected the channels of all inputs must be the same, ' \
|
|
f'but got {[tensor.shape[0] for tensor in inputs]}'
|
|
|
|
padded_samples = []
|
|
|
|
for tensor, data_sample in zip(inputs, batch_data_samples):
|
|
if size is not None:
|
|
width = max(size[-1] - tensor.shape[-1], 0)
|
|
height = max(size[-2] - tensor.shape[-2], 0)
|
|
padding_size = (0, width, 0, height)
|
|
else:
|
|
padding_size = [0, 0, 0, 0]
|
|
if sum(padding_size) == 0:
|
|
padded_samples.append(tensor)
|
|
else:
|
|
# pad img
|
|
pad_img = F.pad(
|
|
tensor, padding_size, mode=padding_mode, value=pad_value)
|
|
padded_samples.append(pad_img)
|
|
# pad gt_sem_seg
|
|
gt_sem_seg = data_sample.gt_sem_seg.data
|
|
gt_width = max(pad_img.shape[-1] - gt_sem_seg.shape[-1], 0)
|
|
gt_height = max(pad_img.shape[-2] - gt_sem_seg.shape[-2], 0)
|
|
padding_gt_size = (0, gt_width, 0, gt_height)
|
|
del data_sample.gt_sem_seg.data
|
|
data_sample.gt_sem_seg.data = F.pad(
|
|
gt_sem_seg,
|
|
padding_gt_size,
|
|
mode=padding_mode,
|
|
value=seg_pad_val)
|
|
|
|
return torch.stack(padded_samples, dim=0), batch_data_samples
|