95 lines
3.4 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Union
import torch
import torch.nn.functional as F
from mmseg.core import SegDataSample
def add_prefix(inputs, prefix):
"""Add prefix for dict.
Args:
inputs (dict): The input dict with str keys.
prefix (str): The prefix to add.
Returns:
dict: The dict with keys updated with ``prefix``.
"""
outputs = dict()
for name, value in inputs.items():
outputs[f'{prefix}.{name}'] = value
return outputs
def stack_batch(inputs: List[torch.Tensor],
batch_data_samples: List[SegDataSample],
size: tuple = None,
pad_value: Union[int, float] = 0,
seg_pad_val: Union[int, float] = 255,
padding_mode: str = 'constant') -> torch.Tensor:
"""Stack multiple inputs to form a batch and pad the images and gt_sem_segs
to the max shape use the right bottom padding mode.
Args:
inputs (List[Tensor]): The input multiple tensors. each is a
CHW 3D-tensor.
batch_data_samples (list[:obj:`SegDataSample`]): The Data
Samples. It usually includes information such as `gt_sem_seg`.
size (tuple): The img crop size.
pad_value (int, float): The padding value. Defaults to 0
seg_pad_val (int, float): The padding value. Defaults to 255
padding_mode (str): Type of padding. Default: constant.
- constant: pads with a constant value, this value is specified
with pad_val.
Returns:
Tensor: The 4D-tensor.
batch_data_samples (list[:obj:`SegDataSample`]): After the padding of
the gt_seg_map.
"""
assert isinstance(inputs, list), \
f'Expected input type to be list, but got {type(inputs)}'
assert len(set([tensor.ndim for tensor in inputs])) == 1, \
f'Expected the dimensions of all inputs must be the same, ' \
f'but got {[tensor.ndim for tensor in inputs]}'
assert inputs[0].ndim == 3, f'Expected tensor dimension to be 3, ' \
f'but got {inputs[0].ndim}'
assert len(set([tensor.shape[0] for tensor in inputs])) == 1, \
f'Expected the channels of all inputs must be the same, ' \
f'but got {[tensor.shape[0] for tensor in inputs]}'
padded_samples = []
for tensor, data_sample in zip(inputs, batch_data_samples):
if size is not None:
width = max(size[-1] - tensor.shape[-1], 0)
height = max(size[-2] - tensor.shape[-2], 0)
padding_size = (0, width, 0, height)
else:
padding_size = [0, 0, 0, 0]
if sum(padding_size) == 0:
padded_samples.append(tensor)
else:
# pad img
pad_img = F.pad(
tensor, padding_size, mode=padding_mode, value=pad_value)
padded_samples.append(pad_img)
# pad gt_sem_seg
gt_sem_seg = data_sample.gt_sem_seg.data
gt_width = max(pad_img.shape[-1] - gt_sem_seg.shape[-1], 0)
gt_height = max(pad_img.shape[-2] - gt_sem_seg.shape[-2], 0)
padding_gt_size = (0, gt_width, 0, gt_height)
del data_sample.gt_sem_seg.data
data_sample.gt_sem_seg.data = F.pad(
gt_sem_seg,
padding_gt_size,
mode=padding_mode,
value=seg_pad_val)
return torch.stack(padded_samples, dim=0), batch_data_samples