119 lines
4.5 KiB
Python
119 lines
4.5 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import List, Optional, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from .typing_utils import SampleList
|
|
|
|
|
|
def add_prefix(inputs, prefix):
|
|
"""Add prefix for dict.
|
|
|
|
Args:
|
|
inputs (dict): The input dict with str keys.
|
|
prefix (str): The prefix to add.
|
|
|
|
Returns:
|
|
|
|
dict: The dict with keys updated with ``prefix``.
|
|
"""
|
|
|
|
outputs = dict()
|
|
for name, value in inputs.items():
|
|
outputs[f'{prefix}.{name}'] = value
|
|
|
|
return outputs
|
|
|
|
|
|
def stack_batch(inputs: List[torch.Tensor],
|
|
data_samples: Optional[SampleList] = None,
|
|
size: Optional[tuple] = None,
|
|
size_divisor: Optional[int] = None,
|
|
pad_val: Union[int, float] = 0,
|
|
seg_pad_val: Union[int, float] = 255) -> torch.Tensor:
|
|
"""Stack multiple inputs to form a batch and pad the images and gt_sem_segs
|
|
to the max shape use the right bottom padding mode.
|
|
|
|
Args:
|
|
inputs (List[Tensor]): The input multiple tensors. each is a
|
|
CHW 3D-tensor.
|
|
data_samples (list[:obj:`SegDataSample`]): The list of data samples.
|
|
It usually includes information such as `gt_sem_seg`.
|
|
size (tuple, optional): Fixed padding size.
|
|
size_divisor (int, optional): The divisor of padded size.
|
|
pad_val (int, float): The padding value. Defaults to 0
|
|
seg_pad_val (int, float): The padding value. Defaults to 255
|
|
|
|
Returns:
|
|
Tensor: The 4D-tensor.
|
|
List[:obj:`SegDataSample`]: After the padding of the gt_seg_map.
|
|
"""
|
|
assert isinstance(inputs, list), \
|
|
f'Expected input type to be list, but got {type(inputs)}'
|
|
assert len({tensor.ndim for tensor in inputs}) == 1, \
|
|
f'Expected the dimensions of all inputs must be the same, ' \
|
|
f'but got {[tensor.ndim for tensor in inputs]}'
|
|
assert inputs[0].ndim == 3, f'Expected tensor dimension to be 3, ' \
|
|
f'but got {inputs[0].ndim}'
|
|
assert len({tensor.shape[0] for tensor in inputs}) == 1, \
|
|
f'Expected the channels of all inputs must be the same, ' \
|
|
f'but got {[tensor.shape[0] for tensor in inputs]}'
|
|
|
|
# only one of size and size_divisor should be valid
|
|
assert (size is not None) ^ (size_divisor is not None), \
|
|
'only one of size and size_divisor should be valid'
|
|
|
|
padded_inputs = []
|
|
padded_samples = []
|
|
inputs_sizes = [(img.shape[-2], img.shape[-1]) for img in inputs]
|
|
max_size = np.stack(inputs_sizes).max(0)
|
|
if size_divisor is not None and size_divisor > 1:
|
|
# the last two dims are H,W, both subject to divisibility requirement
|
|
max_size = (max_size +
|
|
(size_divisor - 1)) // size_divisor * size_divisor
|
|
|
|
for i in range(len(inputs)):
|
|
tensor = inputs[i]
|
|
if size is not None:
|
|
width = max(size[-1] - tensor.shape[-1], 0)
|
|
height = max(size[-2] - tensor.shape[-2], 0)
|
|
# (padding_left, padding_right, padding_top, padding_bottom)
|
|
padding_size = (0, width, 0, height)
|
|
elif size_divisor is not None:
|
|
width = max(max_size[-1] - tensor.shape[-1], 0)
|
|
height = max(max_size[-2] - tensor.shape[-2], 0)
|
|
padding_size = (0, width, 0, height)
|
|
else:
|
|
padding_size = [0, 0, 0, 0]
|
|
|
|
# pad img
|
|
pad_img = F.pad(tensor, padding_size, value=pad_val)
|
|
padded_inputs.append(pad_img)
|
|
# pad gt_sem_seg
|
|
if data_samples is not None:
|
|
data_sample = data_samples[i]
|
|
gt_sem_seg = data_sample.gt_sem_seg.data
|
|
del data_sample.gt_sem_seg.data
|
|
data_sample.gt_sem_seg.data = F.pad(
|
|
gt_sem_seg, padding_size, value=seg_pad_val)
|
|
if 'gt_edge_map' in data_sample:
|
|
gt_edge_map = data_sample.gt_edge_map.data
|
|
del data_sample.gt_edge_map.data
|
|
data_sample.gt_edge_map.data = F.pad(
|
|
gt_edge_map, padding_size, value=seg_pad_val)
|
|
data_sample.set_metainfo({
|
|
'img_shape': tensor.shape[-2:],
|
|
'pad_shape': data_sample.gt_sem_seg.shape,
|
|
'padding_size': padding_size
|
|
})
|
|
padded_samples.append(data_sample)
|
|
else:
|
|
padded_samples.append(
|
|
dict(
|
|
img_padding_size=padding_size,
|
|
pad_shape=pad_img.shape[-2:]))
|
|
|
|
return torch.stack(padded_inputs, dim=0), padded_samples
|