[Refactor] data flow (#1956)

* [WIP] Refactor data flow

* model return

* [WIP] Refactor data flow

* support data_samples is optional

* fix benchmark

* fix base

* minors

* rebase

* fix api

* ut

* fix api inference

* comments

* docstring

* docstring

* docstring

* fix bug of slide inference

* add assert c > 1
This commit is contained in:
Miao Zheng 2022-08-26 15:54:23 +08:00 committed by GitHub
parent 50546da85c
commit 8de0050f25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 522 additions and 521 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Optional, Sequence, Union
@ -11,9 +12,9 @@ from mmengine.dataset import Compose
from mmengine.runner import load_checkpoint
from mmengine.utils import mkdir_or_exist
from mmseg.data import SegDataSample
from mmseg.models import BaseSegmentor
from mmseg.registry import MODELS
from mmseg.structures import SegDataSample
from mmseg.utils import SampleList, dataset_aliases, get_classes, get_palette
from mmseg.visualization import SegLocalVisualizer
@ -50,7 +51,6 @@ def init_model(config: Union[str, Path, Config],
model = MODELS.build(config.model)
if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
dataset_meta = checkpoint['meta'].get('dataset_meta', None)
# save the dataset_meta in the model for convenience
if 'dataset_meta' in checkpoint.get('meta', {}):
@ -108,14 +108,15 @@ def _preprare_data(imgs: ImageType, model: BaseSegmentor):
# a pipeline for each inference
pipeline = Compose(cfg.test_pipeline)
data = []
data = defaultdict(list)
for img in imgs:
if isinstance(img, np.ndarray):
data_ = dict(img=img)
else:
data_ = dict(img_path=img)
data_ = pipeline(data_)
data.append(data_)
data['inputs'].append(data_['inputs'])
data['data_samples'].append(data_['data_samples'])
return data, is_batch
@ -187,11 +188,12 @@ def show_result_pyplot(model: BaseSegmentor,
save_dir=save_dir,
alpha=opacity)
visualizer.dataset_meta = dict(
classes=model.CLASSES, palette=model.PALETTE)
classes=model.dataset_meta['classes'],
palette=model.dataset_meta['palette'])
visualizer.add_datasample(
name=title,
image=image,
pred_sample=result[0],
data_sample=result[0],
draw_gt=draw_gt,
draw_pred=draw_pred,
wait_time=wait_time,

View File

@ -78,7 +78,7 @@ class PackSegInputs(BaseTransform):
if key in results:
img_meta[key] = results[key]
data_sample.set_metainfo(img_meta)
packed_results['data_sample'] = data_sample
packed_results['data_samples'] = data_sample
return packed_results

View File

@ -66,7 +66,7 @@ class SegVisualizationHook(Hook):
def _after_iter(self,
runner: Runner,
batch_idx: int,
data_batch: Sequence[dict],
data_batch: dict,
outputs: Sequence[SegDataSample],
mode: str = 'val') -> None:
"""Run after every ``self.interval`` validation iterations.
@ -74,7 +74,7 @@ class SegVisualizationHook(Hook):
Args:
runner (:obj:`Runner`): The runner of the validation process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (Sequence[dict]): Data from dataloader.
data_batch (dict): Data from dataloader.
outputs (Sequence[:obj:`SegDataSample`]): Outputs from model.
mode (str): mode (str): Current mode of runner. Defaults to 'val'.
"""
@ -85,18 +85,16 @@ class SegVisualizationHook(Hook):
self.file_client = FileClient(**self.file_client_args)
if self.every_n_inner_iters(batch_idx, self.interval):
for input_data, output in zip(data_batch, outputs):
img_path = input_data['data_sample'].img_path
for output in outputs:
img_path = output.img_path
img_bytes = self.file_client.get(img_path)
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
window_name = f'{mode}_{osp.basename(img_path)}'
gt_sample = input_data['data_sample']
self._visualizer.add_datasample(
window_name,
img,
gt_sample=gt_sample,
pred_sample=output,
data_sample=output,
show=self.show,
wait_time=self.wait_time,
step=runner.iter)

View File

@ -49,25 +49,24 @@ class CitysMetric(BaseMetric):
self.to_label_id = to_label_id
self.suffix = suffix
def process(self, data_batch: Sequence[dict],
predictions: Sequence[dict]) -> None:
"""Process one batch of data and predictions.
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
"""Process one batch of data and data_samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[dict]): A batch of outputs from the model.
data_batch (dict): A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
mkdir_or_exist(self.suffix)
for pred in predictions:
pred_label = pred['pred_sem_seg']['data'][0].cpu().numpy()
for data_sample in data_samples:
pred_label = data_sample['pred_sem_seg']['data'][0].cpu().numpy()
# results2img
if self.to_label_id:
pred_label = self._convert_to_label_id(pred_label)
basename = osp.splitext(osp.basename(pred['img_path']))[0]
basename = osp.splitext(osp.basename(data_sample['img_path']))[0]
png_filename = osp.join(self.suffix, f'{basename}.png')
output = Image.fromarray(pred_label.astype(np.uint8)).convert('P')
import cityscapesscripts.helpers.labels as CSLabels

View File

@ -47,22 +47,20 @@ class IoUMetric(BaseMetric):
self.nan_to_num = nan_to_num
self.beta = beta
def process(self, data_batch: Sequence[dict],
predictions: Sequence[dict]) -> None:
"""Process one batch of data and predictions.
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
"""Process one batch of data and data_samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[dict]): A batch of outputs from the model.
data_batch (dict): A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
num_classes = len(self.dataset_meta['classes'])
for data, pred in zip(data_batch, predictions):
pred_label = pred['pred_sem_seg']['data'].squeeze()
label = data['data_sample']['gt_sem_seg']['data'].squeeze().to(
pred_label)
for data_sample in data_samples:
pred_label = data_sample['pred_sem_seg']['data'].squeeze()
label = data_sample['gt_sem_seg']['data'].squeeze().to(pred_label)
self.results.append(
self.intersect_and_union(pred_label, label, num_classes,
self.ignore_index))

View File

@ -1,13 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from numbers import Number
from typing import List, Optional, Sequence, Tuple
from typing import Any, Dict, List, Optional, Sequence
import torch
from mmengine.model import BaseDataPreprocessor
from torch import Tensor
from mmseg.registry import MODELS
from mmseg.utils import OptSampleList, stack_batch
from mmseg.utils import stack_batch
@MODELS.register_module()
@ -87,22 +86,20 @@ class SegDataPreProcessor(BaseDataPreprocessor):
# TODO: support batch augmentations.
self.batch_augments = batch_augments
def forward(self,
data: Sequence[dict],
training: bool = False) -> Tuple[Tensor, OptSampleList]:
def forward(self, data: dict, training: bool = False) -> Dict[str, Any]:
"""Perform normalization、padding and bgr2rgb conversion based on
``BaseDataPreprocessor``.
Args:
data (Sequence[dict]): data sampled from dataloader.
data (dict): data sampled from dataloader.
training (bool): Whether to enable training time augmentation.
Returns:
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
model input.
Dict: Data in the same format as the model input.
"""
inputs, batch_data_samples = self.collate_data(data)
data = self.cast_data(data) # type: ignore
inputs = data['inputs']
data_samples = data.get('data_samples', None)
# TODO: whether normalize should be after stack_batch
if self.channel_conversion and inputs[0].size(0) == 3:
inputs = [_input[[2, 1, 0], ...] for _input in inputs]
@ -113,20 +110,23 @@ class SegDataPreProcessor(BaseDataPreprocessor):
inputs = [_input.float() for _input in inputs]
if training:
batch_inputs, batch_data_samples = stack_batch(
assert data_samples is not None, ('During training, ',
'`data_samples` must be define.')
inputs, data_samples = stack_batch(
inputs=inputs,
batch_data_samples=batch_data_samples,
data_samples=data_samples,
size=self.size,
size_divisor=self.size_divisor,
pad_val=self.pad_val,
seg_pad_val=self.seg_pad_val)
if self.batch_augments is not None:
inputs, batch_data_samples = self.batch_augments(
inputs, batch_data_samples)
return batch_inputs, batch_data_samples
inputs, data_samples = self.batch_augments(
inputs, data_samples)
return dict(inputs=inputs, data_samples=data_samples)
else:
assert len(inputs) == 1, (
'Batch inference is not support currently, '
'as the image size might be different in a batch')
return torch.stack(inputs, dim=0), batch_data_samples
return dict(
inputs=torch.stack(inputs, dim=0), data_samples=data_samples)

View File

@ -47,20 +47,19 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
return hasattr(self, 'decode_head') and self.decode_head is not None
@abstractmethod
def extract_feat(self, batch_inputs: Tensor) -> bool:
def extract_feat(self, inputs: Tensor) -> bool:
"""Placeholder for extract features from images."""
pass
@abstractmethod
def encode_decode(self, batch_inputs: Tensor,
batch_data_samples: SampleList):
def encode_decode(self, inputs: Tensor, batch_data_samples: SampleList):
"""Placeholder for encode images with backbone and decode into a
semantic segmentation map of the same size as input."""
pass
def forward(self,
batch_inputs: Tensor,
batch_data_samples: OptSampleList = None,
inputs: Tensor,
data_samples: OptSampleList = None,
mode: str = 'tensor') -> ForwardResults:
"""The unified entry for a forward process in both training and test.
@ -77,10 +76,11 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
optimizer updating, which are done in the :meth:`train_step`.
Args:
batch_inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
batch_data_samples (list[:obj:`SegDataSample`], optional): The
annotation data of every samples. Defaults to None.
inputs (torch.Tensor): The input tensor with shape (N, C, ...) in
general.
data_samples (list[:obj:`SegDataSample`]): The seg data samples.
It usually includes information such as `metainfo` and
`gt_sem_seg`. Default to None.
mode (str): Return what kind of value. Defaults to 'tensor'.
Returns:
@ -91,33 +91,32 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
- If ``mode="loss"``, return a dict of tensor.
"""
if mode == 'loss':
return self.loss(batch_inputs, batch_data_samples)
return self.loss(inputs, data_samples)
elif mode == 'predict':
return self.predict(batch_inputs, batch_data_samples)
return self.predict(inputs, data_samples)
elif mode == 'tensor':
return self._forward(batch_inputs, batch_data_samples)
return self._forward(inputs, data_samples)
else:
raise RuntimeError(f'Invalid mode "{mode}". '
'Only supports loss, predict and tensor mode')
@abstractmethod
def loss(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> dict:
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
"""Calculate losses from a batch of inputs and data samples."""
pass
@abstractmethod
def predict(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> SampleList:
def predict(self,
inputs: Tensor,
data_samples: OptSampleList = None) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing."""
pass
@abstractmethod
def _forward(
self,
batch_inputs: Tensor,
batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
def _forward(self,
inputs: Tensor,
data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
"""Network forward process.
Usually includes backbone, neck and head forward without any post-
@ -130,13 +129,16 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
"""Placeholder for augmentation test."""
pass
def postprocess_result(self, seg_logits_list: List[dict],
batch_img_metas: List[dict]) -> list:
def postprocess_result(self,
seg_logits: Tensor,
data_samples: OptSampleList = None) -> list:
""" Convert results list to `SegDataSample`.
Args:
seg_logits_list (List[dict]): List of segmentation results,
seg_logits from model of each input image.
seg_logits (Tensor): The segmentation results, seg_logits from
model of each input image.
data_samples (list[:obj:`SegDataSample`]): The seg data samples.
It usually includes information such as `metainfo` and
`gt_sem_seg`. Default to None.
Returns:
list[:obj:`SegDataSample`]: Segmentation results of the
input images. Each SegDataSample usually contain:
@ -145,22 +147,50 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
- ``seg_logits``(PixelData): Predicted logits of semantic
segmentation before normalization.
"""
predictions = []
batch_size, C, H, W = seg_logits.shape
assert C > 1, ('This post processes does not binary segmentation, and '
f'channels `seg_logtis` must be > 1 but got {C}')
for i in range(len(seg_logits_list)):
img_meta = batch_img_metas[i]
seg_logits = resize(
seg_logits_list[i][None],
size=img_meta['ori_shape'],
mode='bilinear',
align_corners=self.align_corners,
warning=False).squeeze(0)
# seg_logits shape is CHW
seg_pred = seg_logits.argmax(dim=0, keepdim=True)
prediction = SegDataSample(**{'metainfo': img_meta})
prediction.set_data({
'seg_logits': PixelData(**{'data': seg_logits}),
'pred_sem_seg': PixelData(**{'data': seg_pred})
})
predictions.append(prediction)
return predictions
if data_samples is None:
data_samples = []
only_prediction = True
else:
only_prediction = False
for i in range(batch_size):
if not only_prediction:
img_meta = data_samples[i].metainfo
# remove padding area
padding_left, padding_right, padding_top, padding_bottom = \
img_meta.get('padding_size', [0]*4)
# i_seg_logits shape is 1, C, H, W after remove padding
i_seg_logits = seg_logits[i:i + 1, :,
padding_top:H - padding_bottom,
padding_left:W - padding_right]
# resize as original shape
i_seg_logits = resize(
i_seg_logits,
size=img_meta['ori_shape'],
mode='bilinear',
align_corners=self.align_corners,
warning=False).squeeze(0)
# i_seg_logits shape is C, H, W with original shape
i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True)
data_samples[i].set_data({
'seg_logits':
PixelData(**{'data': i_seg_logits}),
'pred_sem_seg':
PixelData(**{'data': i_seg_pred})
})
else:
i_seg_logits = seg_logits[i]
i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True)
prediction = SegDataSample()
prediction.set_data({
'seg_logits':
PixelData(**{'data': i_seg_logits}),
'pred_sem_seg':
PixelData(**{'data': i_seg_pred})
})
data_samples.append(prediction)
return data_samples

View File

@ -69,11 +69,11 @@ class CascadeEncoderDecoder(EncoderDecoder):
self.align_corners = self.decode_head[-1].align_corners
self.num_classes = self.decode_head[-1].num_classes
def encode_decode(self, batch_inputs: Tensor,
def encode_decode(self, inputs: Tensor,
batch_img_metas: List[dict]) -> List[Tensor]:
"""Encode images with backbone and decode into a semantic segmentation
map of the same size as input."""
x = self.extract_feat(batch_inputs)
x = self.extract_feat(inputs)
out = self.decode_head[0].forward(x)
for i in range(1, self.num_stages - 1):
out = self.decode_head[i].forward(x, out)
@ -82,53 +82,52 @@ class CascadeEncoderDecoder(EncoderDecoder):
return seg_logits_list
def _decode_head_forward_train(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> dict:
def _decode_head_forward_train(self, inputs: Tensor,
data_samples: SampleList) -> dict:
"""Run forward function and calculate loss for decode head in
training."""
losses = dict()
loss_decode = self.decode_head[0].loss(batch_inputs,
batch_data_samples,
loss_decode = self.decode_head[0].loss(inputs, data_samples,
self.train_cfg)
losses.update(add_prefix(loss_decode, 'decode_0'))
# get batch_img_metas
batch_size = len(batch_data_samples)
batch_size = len(data_samples)
batch_img_metas = []
for batch_index in range(batch_size):
metainfo = batch_data_samples[batch_index].metainfo
metainfo = data_samples[batch_index].metainfo
batch_img_metas.append(metainfo)
for i in range(1, self.num_stages):
# forward test again, maybe unnecessary for most methods.
if i == 1:
prev_outputs = self.decode_head[0].forward(batch_inputs)
prev_outputs = self.decode_head[0].forward(inputs)
else:
prev_outputs = self.decode_head[i - 1].forward(
batch_inputs, prev_outputs)
loss_decode = self.decode_head[i].loss(batch_inputs, prev_outputs,
batch_data_samples,
inputs, prev_outputs)
loss_decode = self.decode_head[i].loss(inputs, prev_outputs,
data_samples,
self.train_cfg)
losses.update(add_prefix(loss_decode, f'decode_{i}'))
return losses
def _forward(self,
batch_inputs: Tensor,
inputs: Tensor,
data_samples: OptSampleList = None) -> Tensor:
"""Network forward process.
Args:
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
batch_data_samples (List[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `img_metas` and `gt_semantic_seg`.
inputs (Tensor): Inputs with shape (N, C, H, W).
data_samples (List[:obj:`SegDataSample`]): The seg data samples.
It usually includes information such as `metainfo` and
`gt_semantic_seg`.
Returns:
Tensor: Forward output of model without any post-processes.
"""
x = self.extract_feat(batch_inputs)
x = self.extract_feat(inputs)
out = self.decode_head[0].forward(x)
for i in range(1, self.num_stages):

View File

@ -1,7 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
@ -112,93 +111,87 @@ class EncoderDecoder(BaseSegmentor):
else:
self.auxiliary_head = MODELS.build(auxiliary_head)
def extract_feat(self, batch_inputs: Tensor) -> List[Tensor]:
def extract_feat(self, inputs: Tensor) -> List[Tensor]:
"""Extract features from images."""
x = self.backbone(batch_inputs)
x = self.backbone(inputs)
if self.with_neck:
x = self.neck(x)
return x
def encode_decode(self, batch_inputs: Tensor,
def encode_decode(self, inputs: Tensor,
batch_img_metas: List[dict]) -> List[Tensor]:
"""Encode images with backbone and decode into a semantic segmentation
map of the same size as input."""
x = self.extract_feat(batch_inputs)
x = self.extract_feat(inputs)
seg_logits = self.decode_head.predict(x, batch_img_metas,
self.test_cfg)
return list(seg_logits)
return seg_logits
def _decode_head_forward_train(self, batch_inputs: List[Tensor],
batch_data_samples: SampleList) -> dict:
def _decode_head_forward_train(self, inputs: List[Tensor],
data_samples: SampleList) -> dict:
"""Run forward function and calculate loss for decode head in
training."""
losses = dict()
loss_decode = self.decode_head.loss(batch_inputs, batch_data_samples,
loss_decode = self.decode_head.loss(inputs, data_samples,
self.train_cfg)
losses.update(add_prefix(loss_decode, 'decode'))
return losses
def _auxiliary_head_forward_train(
self,
batch_inputs: List[Tensor],
batch_data_samples: SampleList,
) -> dict:
def _auxiliary_head_forward_train(self, inputs: List[Tensor],
data_samples: SampleList) -> dict:
"""Run forward function and calculate loss for auxiliary head in
training."""
losses = dict()
if isinstance(self.auxiliary_head, nn.ModuleList):
for idx, aux_head in enumerate(self.auxiliary_head):
loss_aux = aux_head.loss(batch_inputs, batch_data_samples,
self.train_cfg)
loss_aux = aux_head.loss(inputs, data_samples, self.train_cfg)
losses.update(add_prefix(loss_aux, f'aux_{idx}'))
else:
loss_aux = self.auxiliary_head.loss(batch_inputs,
batch_data_samples,
loss_aux = self.auxiliary_head.loss(inputs, data_samples,
self.train_cfg)
losses.update(add_prefix(loss_aux, 'aux'))
return losses
def loss(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> dict:
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
img (Tensor): Input images.
batch_data_samples (list[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `metainfo` and `gt_sem_seg`.
inputs (Tensor): Input images.
data_samples (list[:obj:`SegDataSample`]): The seg data samples.
It usually includes information such as `metainfo` and
`gt_sem_seg`.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
x = self.extract_feat(batch_inputs)
x = self.extract_feat(inputs)
losses = dict()
loss_decode = self._decode_head_forward_train(x, batch_data_samples)
loss_decode = self._decode_head_forward_train(x, data_samples)
losses.update(loss_decode)
if self.with_auxiliary_head:
loss_aux = self._auxiliary_head_forward_train(
x, batch_data_samples)
loss_aux = self._auxiliary_head_forward_train(x, data_samples)
losses.update(loss_aux)
return losses
def predict(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> SampleList:
def predict(self,
inputs: Tensor,
data_samples: OptSampleList = None) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
batch_data_samples (List[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `metainfo` and `gt_sem_seg`.
inputs (Tensor): Inputs with shape (N, C, H, W).
data_samples (List[:obj:`SegDataSample`], optional): The seg data
samples. It usually includes information such as `metainfo`
and `gt_sem_seg`.
Returns:
list[:obj:`SegDataSample`]: Segmentation results of the
@ -208,40 +201,49 @@ class EncoderDecoder(BaseSegmentor):
- ``seg_logits``(PixelData): Predicted logits of semantic
segmentation before normalization.
"""
batch_img_metas = []
for data_sample in batch_data_samples:
batch_img_metas.append(data_sample.metainfo)
if data_samples is not None:
batch_img_metas = [
data_sample.metainfo for data_sample in data_samples
]
else:
batch_img_metas = [
dict(
ori_shape=inputs.shape[2:],
img_shape=inputs.shape[2:],
pad_shape=inputs.shape[2:],
padding_size=[0, 0, 0, 0])
] * inputs.shape[0]
seg_logit_list = self.inference(batch_inputs, batch_img_metas)
seg_logits = self.inference(inputs, batch_img_metas)
return self.postprocess_result(seg_logit_list, batch_img_metas)
return self.postprocess_result(seg_logits, data_samples)
def _forward(self,
batch_inputs: Tensor,
inputs: Tensor,
data_samples: OptSampleList = None) -> Tensor:
"""Network forward process.
Args:
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
batch_data_samples (List[:obj:`SegDataSample`]): The seg
inputs (Tensor): Inputs with shape (N, C, H, W).
data_samples (List[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `metainfo` and `gt_sem_seg`.
Returns:
Tensor: Forward output of model without any post-processes.
"""
x = self.extract_feat(batch_inputs)
x = self.extract_feat(inputs)
return self.decode_head.forward(x)
def slide_inference(self, batch_inputs: Tensor,
batch_img_metas: List[dict]) -> List[Tensor]:
def slide_inference(self, inputs: Tensor,
batch_img_metas: List[dict]) -> Tensor:
"""Inference by sliding-window with overlap.
If h_crop > h_img or w_crop > w_img, the small patch will be used to
decode without padding.
Args:
batch_inputs (tensor): the tensor should have a shape NxCxHxW,
inputs (tensor): the tensor should have a shape NxCxHxW,
which contains all images in the batch.
batch_img_metas (List[dict]): List of image metainfo where each may
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
@ -250,18 +252,18 @@ class EncoderDecoder(BaseSegmentor):
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
Returns:
List[:obj:`Tensor`]: List of segmentation results, seg_logits from
model of each input image.
Tensor: The segmentation results, seg_logits from model of each
input image.
"""
h_stride, w_stride = self.test_cfg.stride
h_crop, w_crop = self.test_cfg.crop_size
batch_size, _, h_img, w_img = batch_inputs.size()
batch_size, _, h_img, w_img = inputs.size()
num_classes = self.num_classes
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
preds = batch_inputs.new_zeros((batch_size, num_classes, h_img, w_img))
count_mat = batch_inputs.new_zeros((batch_size, 1, h_img, w_img))
preds = inputs.new_zeros((batch_size, num_classes, h_img, w_img))
count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img))
for h_idx in range(h_grids):
for w_idx in range(w_grids):
y1 = h_idx * h_stride
@ -270,30 +272,29 @@ class EncoderDecoder(BaseSegmentor):
x2 = min(x1 + w_crop, w_img)
y1 = max(y2 - h_crop, 0)
x1 = max(x2 - w_crop, 0)
crop_img = batch_inputs[:, :, y1:y2, x1:x2]
# change the img shape to patch shape
crop_img = inputs[:, :, y1:y2, x1:x2]
# change the image shape to patch shape
batch_img_metas[0]['img_shape'] = crop_img.shape[2:]
# the output of encode_decode is list of seg logits map
# with shape [C, H, W]
crop_seg_logit = torch.stack(
self.encode_decode(crop_img, batch_img_metas), dim=0)
# the output of encode_decode is seg logits tensor map
# with shape [N, C, H, W]
crop_seg_logit = self.encode_decode(crop_img, batch_img_metas)
preds += F.pad(crop_seg_logit,
(int(x1), int(preds.shape[3] - x2), int(y1),
int(preds.shape[2] - y2)))
count_mat[:, :, y1:y2, x1:x2] += 1
assert (count_mat == 0).sum() == 0
seg_logits_list = list(preds / count_mat)
seg_logits = preds / count_mat
return seg_logits_list
return seg_logits
def whole_inference(self, batch_inputs: Tensor,
batch_img_metas: List[dict]) -> List[Tensor]:
def whole_inference(self, inputs: Tensor,
batch_img_metas: List[dict]) -> Tensor:
"""Inference with full image.
Args:
batch_inputs (Tensor): The tensor should have a shape NxCxHxW,
which contains all images in the batch.
inputs (Tensor): The tensor should have a shape NxCxHxW, which
contains all images in the batch.
batch_img_metas (List[dict]): List of image metainfo where each may
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
'ori_shape', and 'pad_shape'.
@ -301,44 +302,41 @@ class EncoderDecoder(BaseSegmentor):
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
Returns:
List[:obj:`Tensor`]: List of segmentation results, seg_logits from
model of each input image.
Tensor: The segmentation results, seg_logits from model of each
input image.
"""
seg_logits_list = self.encode_decode(batch_inputs, batch_img_metas)
seg_logits = self.encode_decode(inputs, batch_img_metas)
return seg_logits_list
return seg_logits
def inference(self, batch_inputs: Tensor,
batch_img_metas: List[dict]) -> List[Tensor]:
def inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor:
"""Inference with slide/whole style.
Args:
batch_inputs (Tensor): The input image of shape (N, 3, H, W).
inputs (Tensor): The input image of shape (N, 3, H, W).
batch_img_metas (List[dict]): List of image metainfo where each may
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
'ori_shape', and 'pad_shape'.
'ori_shape', 'pad_shape', and 'padding_size'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
Returns:
List[:obj:`Tensor`]: List of segmentation results, seg_logits from
model of each input image.
Tensor: The segmentation results, seg_logits from model of each
input image.
"""
assert self.test_cfg.mode in ['slide', 'whole']
ori_shape = batch_img_metas[0]['ori_shape']
assert all(_['ori_shape'] == ori_shape for _ in batch_img_metas)
if self.test_cfg.mode == 'slide':
seg_logit_list = self.slide_inference(batch_inputs,
batch_img_metas)
seg_logit = self.slide_inference(inputs, batch_img_metas)
else:
seg_logit_list = self.whole_inference(batch_inputs,
batch_img_metas)
seg_logit = self.whole_inference(inputs, batch_img_metas)
return seg_logit_list
return seg_logit
def aug_test(self, batch_inputs, batch_img_metas, rescale=True):
def aug_test(self, inputs, batch_img_metas, rescale=True):
"""Test with augmentations.
Only rescale=True is supported.
@ -346,13 +344,12 @@ class EncoderDecoder(BaseSegmentor):
# aug_test rescale all imgs back to ori_shape for now
assert rescale
# to save memory, we get augmented seg logit inplace
seg_logit = self.inference(batch_inputs[0], batch_img_metas[0],
rescale)
for i in range(1, len(batch_inputs)):
cur_seg_logit = self.inference(batch_inputs[i], batch_img_metas[i],
seg_logit = self.inference(inputs[0], batch_img_metas[0], rescale)
for i in range(1, len(inputs)):
cur_seg_logit = self.inference(inputs[i], batch_img_metas[i],
rescale)
seg_logit += cur_seg_logit
seg_logit /= len(batch_inputs)
seg_logit /= len(inputs)
seg_pred = seg_logit.argmax(dim=1)
# unravel batch dim
seg_pred = list(seg_pred)

View File

@ -28,7 +28,7 @@ def add_prefix(inputs, prefix):
def stack_batch(inputs: List[torch.Tensor],
batch_data_samples: Optional[SampleList] = None,
data_samples: Optional[SampleList] = None,
size: Optional[tuple] = None,
size_divisor: Optional[int] = None,
pad_val: Union[int, float] = 0,
@ -39,8 +39,8 @@ def stack_batch(inputs: List[torch.Tensor],
Args:
inputs (List[Tensor]): The input multiple tensors. each is a
CHW 3D-tensor.
batch_data_samples (list[:obj:`SegDataSample`]): The Data
Samples. It usually includes information such as `gt_sem_seg`.
data_samples (list[:obj:`SegDataSample`]): The list of data samples.
It usually includes information such as `gt_sem_seg`.
size (tuple, optional): Fixed padding size.
size_divisor (int, optional): The divisor of padded size.
pad_val (int, float): The padding value. Defaults to 0
@ -48,8 +48,7 @@ def stack_batch(inputs: List[torch.Tensor],
Returns:
Tensor: The 4D-tensor.
batch_data_samples (list[:obj:`SegDataSample`]): After the padding of
the gt_seg_map.
List[:obj:`SegDataSample`]: After the padding of the gt_seg_map.
"""
assert isinstance(inputs, list), \
f'Expected input type to be list, but got {type(inputs)}'
@ -93,14 +92,17 @@ def stack_batch(inputs: List[torch.Tensor],
pad_img = F.pad(tensor, padding_size, value=pad_val)
padded_inputs.append(pad_img)
# pad gt_sem_seg
if batch_data_samples is not None:
data_sample = batch_data_samples[i]
if data_samples is not None:
data_sample = data_samples[i]
gt_sem_seg = data_sample.gt_sem_seg.data
del data_sample.gt_sem_seg.data
data_sample.gt_sem_seg.data = F.pad(
gt_sem_seg, padding_size, value=seg_pad_val)
data_sample.set_metainfo(
{'pad_shape': data_sample.gt_sem_seg.shape})
data_sample.set_metainfo({
'img_shape': tensor.shape[-2:],
'pad_shape': data_sample.gt_sem_seg.shape,
'padding_size': padding_size
})
padded_samples.append(data_sample)
else:
padded_samples = None

View File

@ -102,8 +102,7 @@ class SegLocalVisualizer(Visualizer):
def add_datasample(self,
name: str,
image: np.ndarray,
gt_sample: Optional[SegDataSample] = None,
pred_sample: Optional[SegDataSample] = None,
data_sample: Optional[SegDataSample] = None,
draw_gt: bool = True,
draw_pred: bool = True,
show: bool = False,
@ -137,27 +136,26 @@ class SegLocalVisualizer(Visualizer):
gt_img_data = None
pred_img_data = None
if draw_gt and gt_sample is not None:
if draw_gt and data_sample is not None and 'gt_sem_seg' in data_sample:
gt_img_data = image
if 'gt_sem_seg' in gt_sample:
assert classes is not None, 'class information is ' \
'not provided when ' \
'visualizing semantic ' \
'segmentation results.'
gt_img_data = self._draw_sem_seg(gt_img_data,
gt_sample.gt_sem_seg, classes,
palette)
assert classes is not None, 'class information is ' \
'not provided when ' \
'visualizing semantic ' \
'segmentation results.'
gt_img_data = self._draw_sem_seg(gt_img_data,
data_sample.gt_sem_seg, classes,
palette)
if draw_pred and pred_sample is not None:
if (draw_pred and data_sample is not None
and 'pred_sem_seg' in data_sample):
pred_img_data = image
if 'pred_sem_seg' in pred_sample:
assert classes is not None, 'class information is ' \
'not provided when ' \
'visualizing semantic ' \
'segmentation results.'
pred_img_data = self._draw_sem_seg(pred_img_data,
pred_sample.pred_sem_seg,
classes, palette)
assert classes is not None, 'class information is ' \
'not provided when ' \
'visualizing semantic ' \
'segmentation results.'
pred_img_data = self._draw_sem_seg(pred_img_data,
data_sample.pred_sem_seg,
classes, palette)
if gt_img_data is not None and pred_img_data is not None:
drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1)

View File

@ -3,14 +3,13 @@ import glob
import os
from os.path import dirname, exists, isdir, join, relpath
# import numpy as np
import numpy as np
from mmengine import Config
# from mmengine.dataset import Compose
from mmengine.dataset import Compose
from torch import nn
from mmseg.models import build_segmentor
# from mmseg.utils import register_all_modules
from mmseg.utils import register_all_modules
def _get_config_directory():
@ -64,70 +63,69 @@ def test_config_build_segmentor():
_check_decode_head(head_config, segmentor.decode_head)
# def test_config_data_pipeline():
# """Test whether the data pipeline is valid and can process corner cases.
def test_config_data_pipeline():
"""Test whether the data pipeline is valid and can process corner cases.
# CommandLine:
# xdoctest -m tests/test_config.py test_config_build_data_pipeline
# """
CommandLine:
xdoctest -m tests/test_config.py test_config_build_data_pipeline
"""
# register_all_modules()
# config_dpath = _get_config_directory()
# print('Found config_dpath = {!r}'.format(config_dpath))
register_all_modules()
config_dpath = _get_config_directory()
print('Found config_dpath = {!r}'.format(config_dpath))
# import glob
# config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
# config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
# config_names = [relpath(p, config_dpath) for p in config_fpaths]
import glob
config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
config_names = [relpath(p, config_dpath) for p in config_fpaths]
# print('Using {} config files'.format(len(config_names)))
print('Using {} config files'.format(len(config_names)))
# for config_fname in config_names:
# config_fpath = join(config_dpath, config_fname)
# print(
# 'Building data pipeline, config_fpath = {!r}'.
# format(config_fpath))
# config_mod = Config.fromfile(config_fpath)
for config_fname in config_names:
config_fpath = join(config_dpath, config_fname)
print(
'Building data pipeline, config_fpath = {!r}'.format(config_fpath))
config_mod = Config.fromfile(config_fpath)
# # remove loading pipeline
# load_img_pipeline = config_mod.train_pipeline.pop(0)
# to_float32 = load_img_pipeline.get('to_float32', False)
# config_mod.train_pipeline.pop(0)
# config_mod.test_pipeline.pop(0)
# # remove loading annotation in test pipeline
# config_mod.test_pipeline.pop(1)
# remove loading pipeline
load_img_pipeline = config_mod.train_pipeline.pop(0)
to_float32 = load_img_pipeline.get('to_float32', False)
config_mod.train_pipeline.pop(0)
config_mod.test_pipeline.pop(0)
# remove loading annotation in test pipeline
config_mod.test_pipeline.pop(1)
# train_pipeline = Compose(config_mod.train_pipeline)
# test_pipeline = Compose(config_mod.test_pipeline)
train_pipeline = Compose(config_mod.train_pipeline)
test_pipeline = Compose(config_mod.test_pipeline)
# img = np.random.randint(0, 255, size=(1024, 2048, 3), dtype=np.uint8)
# if to_float32:
# img = img.astype(np.float32)
# seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8)
img = np.random.randint(0, 255, size=(1024, 2048, 3), dtype=np.uint8)
if to_float32:
img = img.astype(np.float32)
seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8)
# results = dict(
# filename='test_img.png',
# ori_filename='test_img.png',
# img=img,
# img_shape=img.shape,
# ori_shape=img.shape,
# gt_seg_map=seg)
# results['seg_fields'] = ['gt_seg_map']
results = dict(
filename='test_img.png',
ori_filename='test_img.png',
img=img,
img_shape=img.shape,
ori_shape=img.shape,
gt_seg_map=seg)
results['seg_fields'] = ['gt_seg_map']
# print('Test training data pipeline: \n{!r}'.format(train_pipeline))
# output_results = train_pipeline(results)
# assert output_results is not None
print('Test training data pipeline: \n{!r}'.format(train_pipeline))
output_results = train_pipeline(results)
assert output_results is not None
# results = dict(
# filename='test_img.png',
# ori_filename='test_img.png',
# img=img,
# img_shape=img.shape,
# ori_shape=img.shape,
# )
# print('Test testing data pipeline: \n{!r}'.format(test_pipeline))
# output_results = test_pipeline(results)
# assert output_results is not None
results = dict(
filename='test_img.png',
ori_filename='test_img.png',
img=img,
img_shape=img.shape,
ori_shape=img.shape,
)
print('Test testing data pipeline: \n{!r}'.format(test_pipeline))
output_results = test_pipeline(results)
assert output_results is not None
def _check_decode_head(decode_head_cfg, decode_head):

View File

@ -39,12 +39,12 @@ class TestPackSegInputs(unittest.TestCase):
def test_transform(self):
transform = PackSegInputs(meta_keys=self.meta_keys)
results = transform(copy.deepcopy(self.results))
self.assertIn('data_sample', results)
self.assertIsInstance(results['data_sample'], SegDataSample)
self.assertIsInstance(results['data_sample'].gt_sem_seg,
self.assertIn('data_samples', results)
self.assertIsInstance(results['data_samples'], SegDataSample)
self.assertIsInstance(results['data_samples'].gt_sem_seg,
BaseDataElement)
self.assertEqual(results['data_sample'].ori_shape,
results['data_sample'].gt_sem_seg.shape)
self.assertEqual(results['data_samples'].ori_shape,
results['data_samples'].gt_sem_seg.shape)
def test_repr(self):
transform = PackSegInputs(meta_keys=self.meta_keys)

View File

@ -1,151 +1,151 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
# import os.path as osp
import mmcv
import pytest
# import mmcv
# import pytest
from mmseg.datasets.transforms import * # noqa
from mmseg.registry import TRANSFORMS
# from mmseg.datasets.transforms import * # noqa
# from mmseg.registry import TRANSFORMS
# TODO
# def test_multi_scale_flip_aug():
# # test assertion if scales=None, scale_factor=1 (not float).
# with pytest.raises(AssertionError):
# tta_transform = dict(
# type='MultiScaleFlipAug',
# scales=None,
# scale_factor=1,
# transforms=[dict(type='Resize', keep_ratio=False)],
# )
# TRANSFORMS.build(tta_transform)
def test_multi_scale_flip_aug():
# test assertion if scales=None, scale_factor=1 (not float).
with pytest.raises(AssertionError):
tta_transform = dict(
type='MultiScaleFlipAug',
scales=None,
scale_factor=1,
transforms=[dict(type='Resize', keep_ratio=False)],
)
TRANSFORMS.build(tta_transform)
# # test assertion if scales=None, scale_factor=None.
# with pytest.raises(AssertionError):
# tta_transform = dict(
# type='MultiScaleFlipAug',
# scales=None,
# scale_factor=None,
# transforms=[dict(type='Resize', keep_ratio=False)],
# )
# TRANSFORMS.build(tta_transform)
# test assertion if scales=None, scale_factor=None.
with pytest.raises(AssertionError):
tta_transform = dict(
type='MultiScaleFlipAug',
scales=None,
scale_factor=None,
transforms=[dict(type='Resize', keep_ratio=False)],
)
TRANSFORMS.build(tta_transform)
# # test assertion if scales=(512, 512), scale_factor=1 (not float).
# with pytest.raises(AssertionError):
# tta_transform = dict(
# type='MultiScaleFlipAug',
# scales=(512, 512),
# scale_factor=1,
# transforms=[dict(type='Resize', keep_ratio=False)],
# )
# TRANSFORMS.build(tta_transform)
# meta_keys = ('img', 'ori_shape', 'ori_height', 'ori_width', 'pad_shape',
# 'scale_factor', 'scale', 'flip')
# tta_transform = dict(
# type='MultiScaleFlipAug',
# scales=[(256, 256), (512, 512), (1024, 1024)],
# allow_flip=False,
# resize_cfg=dict(type='Resize', keep_ratio=False),
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
# )
# tta_module = TRANSFORMS.build(tta_transform)
# test assertion if scales=(512, 512), scale_factor=1 (not float).
with pytest.raises(AssertionError):
tta_transform = dict(
type='MultiScaleFlipAug',
scales=(512, 512),
scale_factor=1,
transforms=[dict(type='Resize', keep_ratio=False)],
)
TRANSFORMS.build(tta_transform)
meta_keys = ('img', 'ori_shape', 'ori_height', 'ori_width', 'pad_shape',
'scale_factor', 'scale', 'flip')
tta_transform = dict(
type='MultiScaleFlipAug',
scales=[(256, 256), (512, 512), (1024, 1024)],
allow_flip=False,
resize_cfg=dict(type='Resize', keep_ratio=False),
transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
)
tta_module = TRANSFORMS.build(tta_transform)
# results = dict()
# # (288, 512, 3)
# img = mmcv.imread(
# osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
# results['img'] = img
# results['ori_shape'] = img.shape
# results['ori_height'] = img.shape[0]
# results['ori_width'] = img.shape[1]
# # Set initial values for default meta_keys
# results['pad_shape'] = img.shape
# results['scale_factor'] = 1.0
results = dict()
# (288, 512, 3)
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
results['img'] = img
results['ori_shape'] = img.shape
results['ori_height'] = img.shape[0]
results['ori_width'] = img.shape[1]
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
# tta_results = tta_module(results.copy())
# assert [data_sample.scale
# for data_sample in tta_results['data_sample']] == [(256, 256),
# (512, 512),
# (1024, 1024)]
# assert [data_sample.flip for data_sample in tta_results['data_sample']
# ] == [False, False, False]
tta_results = tta_module(results.copy())
assert [data_sample.scale
for data_sample in tta_results['data_sample']] == [(256, 256),
(512, 512),
(1024, 1024)]
assert [data_sample.flip for data_sample in tta_results['data_sample']
] == [False, False, False]
# tta_transform = dict(
# type='MultiScaleFlipAug',
# scales=[(256, 256), (512, 512), (1024, 1024)],
# allow_flip=True,
# resize_cfg=dict(type='Resize', keep_ratio=False),
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
# )
# tta_module = TRANSFORMS.build(tta_transform)
# tta_results = tta_module(results.copy())
# assert [data_sample.scale
# for data_sample in tta_results['data_sample']] == [(256, 256),
# (256, 256),
# (512, 512),
# (512, 512),
# (1024, 1024),
# (1024, 1024)]
# assert [data_sample.flip for data_sample in tta_results['data_sample']
# ] == [False, True, False, True, False, True]
tta_transform = dict(
type='MultiScaleFlipAug',
scales=[(256, 256), (512, 512), (1024, 1024)],
allow_flip=True,
resize_cfg=dict(type='Resize', keep_ratio=False),
transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
)
tta_module = TRANSFORMS.build(tta_transform)
tta_results = tta_module(results.copy())
assert [data_sample.scale
for data_sample in tta_results['data_sample']] == [(256, 256),
(256, 256),
(512, 512),
(512, 512),
(1024, 1024),
(1024, 1024)]
assert [data_sample.flip for data_sample in tta_results['data_sample']
] == [False, True, False, True, False, True]
# tta_transform = dict(
# type='MultiScaleFlipAug',
# scales=[(512, 512)],
# allow_flip=False,
# resize_cfg=dict(type='Resize', keep_ratio=False),
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
# )
# tta_module = TRANSFORMS.build(tta_transform)
# tta_results = tta_module(results.copy())
# assert [tta_results['data_sample'][0].scale] == [(512, 512)]
# assert [tta_results['data_sample'][0].flip] == [False]
tta_transform = dict(
type='MultiScaleFlipAug',
scales=[(512, 512)],
allow_flip=False,
resize_cfg=dict(type='Resize', keep_ratio=False),
transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
)
tta_module = TRANSFORMS.build(tta_transform)
tta_results = tta_module(results.copy())
assert [tta_results['data_sample'][0].scale] == [(512, 512)]
assert [tta_results['data_sample'][0].flip] == [False]
# tta_transform = dict(
# type='MultiScaleFlipAug',
# scales=[(512, 512)],
# allow_flip=True,
# resize_cfg=dict(type='Resize', keep_ratio=False),
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
# )
# tta_module = TRANSFORMS.build(tta_transform)
# tta_results = tta_module(results.copy())
# assert [data_sample.scale
# for data_sample in tta_results['data_sample']] == [(512, 512),
# (512, 512)]
# assert [data_sample.flip
# for data_sample in tta_results['data_sample']] == [False, True]
tta_transform = dict(
type='MultiScaleFlipAug',
scales=[(512, 512)],
allow_flip=True,
resize_cfg=dict(type='Resize', keep_ratio=False),
transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
)
tta_module = TRANSFORMS.build(tta_transform)
tta_results = tta_module(results.copy())
assert [data_sample.scale
for data_sample in tta_results['data_sample']] == [(512, 512),
(512, 512)]
assert [data_sample.flip
for data_sample in tta_results['data_sample']] == [False, True]
# tta_transform = dict(
# type='MultiScaleFlipAug',
# scale_factor=[0.5, 1.0, 2.0],
# allow_flip=False,
# resize_cfg=dict(type='Resize', keep_ratio=False),
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
# )
# tta_module = TRANSFORMS.build(tta_transform)
# tta_results = tta_module(results.copy())
# assert [data_sample.scale
# for data_sample in tta_results['data_sample']] == [(256, 144),
# (512, 288),
# (1024, 576)]
# assert [data_sample.flip for data_sample in tta_results['data_sample']
# ] == [False, False, False]
tta_transform = dict(
type='MultiScaleFlipAug',
scale_factor=[0.5, 1.0, 2.0],
allow_flip=False,
resize_cfg=dict(type='Resize', keep_ratio=False),
transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
)
tta_module = TRANSFORMS.build(tta_transform)
tta_results = tta_module(results.copy())
assert [data_sample.scale
for data_sample in tta_results['data_sample']] == [(256, 144),
(512, 288),
(1024, 576)]
assert [data_sample.flip for data_sample in tta_results['data_sample']
] == [False, False, False]
tta_transform = dict(
type='MultiScaleFlipAug',
scale_factor=[0.5, 1.0, 2.0],
allow_flip=True,
resize_cfg=dict(type='Resize', keep_ratio=False),
transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
)
tta_module = TRANSFORMS.build(tta_transform)
tta_results = tta_module(results.copy())
assert [data_sample.scale
for data_sample in tta_results['data_sample']] == [(256, 144),
(256, 144),
(512, 288),
(512, 288),
(1024, 576),
(1024, 576)]
assert [data_sample.flip for data_sample in tta_results['data_sample']
] == [False, True, False, True, False, True]
# tta_transform = dict(
# type='MultiScaleFlipAug',
# scale_factor=[0.5, 1.0, 2.0],
# allow_flip=True,
# resize_cfg=dict(type='Resize', keep_ratio=False),
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
# )
# tta_module = TRANSFORMS.build(tta_transform)
# tta_results = tta_module(results.copy())
# assert [data_sample.scale
# for data_sample in tta_results['data_sample']] == [(256, 144),
# (256, 144),
# (512, 288),
# (512, 288),
# (1024, 576),
# (1024, 576)]
# assert [data_sample.flip for data_sample in tta_results['data_sample']
# ] == [False, True, False, True, False, True]

View File

@ -30,6 +30,7 @@ class TestVisualizationHook(TestCase):
pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
pred_sem_seg = PixelData(**pred_sem_seg_data)
pred_seg_data_sample = SegDataSample()
pred_seg_data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'})
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
self.outputs = [pred_seg_data_sample] * 2

View File

@ -3,7 +3,7 @@ from unittest import TestCase
import numpy as np
import torch
from mmengine.structures import BaseDataElement, PixelData
from mmengine.structures import PixelData
from mmseg.evaluation import IoUMetric
from mmseg.structures import SegDataSample
@ -29,72 +29,48 @@ class TestIoUMetric(TestCase):
else:
image_shapes = [image_shapes] * batch_size
packed_inputs = []
data_samples = []
for idx in range(batch_size):
image_shape = image_shapes[idx]
_, h, w = image_shape
mm_inputs = dict()
data_sample = SegDataSample()
gt_semantic_seg = np.random.randint(
0, num_classes, (1, h, w), dtype=np.uint8)
gt_semantic_seg = torch.LongTensor(gt_semantic_seg)
gt_sem_seg_data = dict(data=gt_semantic_seg)
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
mm_inputs['data_sample'] = data_sample.to_dict()
packed_inputs.append(mm_inputs)
return packed_inputs
data_samples.append(data_sample.to_dict())
return data_samples
def _demo_mm_model_output(self,
data_samples,
batch_size=2,
image_shapes=(3, 64, 64),
num_classes=5):
"""Create a superset of inputs needed to run test or train batches.
Args:
batch_size (int): batch size. Default to 2.
image_shapes (List[tuple], Optional): image shape.
Default to (3, 64, 64)
num_classes (int): number of different classes.
Default to 5.
"""
results_dict = dict()
_, h, w = image_shapes
seg_logit = torch.randn(batch_size, num_classes, h, w)
results_dict['seg_logits'] = seg_logit
seg_pred = np.random.randint(
0, num_classes, (batch_size, h, w), dtype=np.uint8)
seg_pred = torch.LongTensor(seg_pred)
results_dict['pred_sem_seg'] = seg_pred
batch_datasampes = [
SegDataSample()
for _ in range(results_dict['pred_sem_seg'].shape[0])
]
for key, value in results_dict.items():
for i in range(value.shape[0]):
setattr(batch_datasampes[i], key, PixelData(data=value[i]))
_predictions = []
for pred in batch_datasampes:
if isinstance(pred, BaseDataElement):
_predictions.append(pred.to_dict())
else:
_predictions.append(pred)
return _predictions
for data_sample in data_samples:
data_sample['seg_logits'] = dict(
data=torch.randn(num_classes, h, w))
data_sample['pred_sem_seg'] = dict(
data=torch.randint(0, num_classes, (1, h, w)))
return data_samples
def test_evaluate(self):
"""Test using the metric in the same way as Evalutor."""
data_batch = self._demo_mm_inputs()
predictions = self._demo_mm_model_output()
data_samples = self._demo_mm_inputs()
data_samples = self._demo_mm_model_output(data_samples)
iou_metric = IoUMetric(iou_metrics=['mIoU'])
iou_metric.dataset_meta = dict(
classes=['wall', 'building', 'sky', 'floor', 'tree'],
label_map=dict(),
reduce_zero_label=False)
iou_metric.process(data_batch, predictions)
iou_metric.process([0] * len(data_samples), data_samples)
res = iou_metric.evaluate(6)
self.assertIsInstance(res, dict)

View File

@ -6,8 +6,11 @@ from mmcv.cnn import ConvModule
from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule,
InterpConv, UNet, UpConvBlock)
from mmseg.models.utils import Upsample
from mmseg.utils import register_all_modules
from .utils import check_norm_state
register_all_modules()
def test_unet_basic_conv_block():
with pytest.raises(AssertionError):

View File

@ -3,10 +3,10 @@ from unittest import TestCase
from mmseg.models import SegDataPreProcessor
# import torch
# from mmengine.structures import PixelData
import torch
from mmengine.structures import PixelData
# from mmseg.structures import SegDataSample
from mmseg.structures import SegDataSample
class TestSegDataPreProcessor(TestCase):
@ -31,16 +31,19 @@ class TestSegDataPreProcessor(TestCase):
with self.assertRaises(AssertionError):
SegDataPreProcessor(bgr_to_rgb=True, rgb_to_bgr=True)
# def test_forward(self):
# data_sample = SegDataSample()
# data_sample.gt_sem_seg = PixelData(
# **{'data': torch.randint(0, 10, (1, 11, 10))})
# processor = SegDataPreProcessor(
# mean=[0, 0, 0], std=[1, 1, 1], size=(20, 20))
# data = {
# 'inputs': torch.randint(0, 256, (3, 11, 10)),
# 'data_sample': data_sample
# }
# inputs, data_samples = processor([data, data], training=True)
# self.assertEqual(inputs.shape, (2, 3, 20, 20))
# self.assertEqual(len(data_samples), 2)
def test_forward(self):
data_sample = SegDataSample()
data_sample.gt_sem_seg = PixelData(
**{'data': torch.randint(0, 10, (1, 11, 10))})
processor = SegDataPreProcessor(
mean=[0, 0, 0], std=[1, 1, 1], size=(20, 20))
data = {
'inputs': [
torch.randint(0, 256, (3, 11, 10)),
torch.randint(0, 256, (3, 11, 10))
],
'data_samples': [data_sample, data_sample]
}
out = processor(data, training=True)
self.assertEqual(out['inputs'].shape, (2, 3, 20, 20))
self.assertEqual(len(out['data_samples']), 2)

View File

@ -34,14 +34,14 @@ def _demo_mm_inputs(batch_size=2, image_shapes=(3, 32, 32), num_classes=5):
else:
image_shapes = [image_shapes] * batch_size
packed_inputs = []
inputs = []
data_samples = []
for idx in range(batch_size):
image_shape = image_shapes[idx]
c, h, w = image_shape
image = np.random.randint(0, 255, size=image_shape, dtype=np.uint8)
mm_inputs = dict()
mm_inputs['inputs'] = torch.from_numpy(image)
mm_input = torch.from_numpy(image)
img_meta = {
'img_id': idx,
@ -62,10 +62,9 @@ def _demo_mm_inputs(batch_size=2, image_shapes=(3, 32, 32), num_classes=5):
gt_semantic_seg = torch.LongTensor(gt_semantic_seg)
gt_sem_seg_data = dict(data=gt_semantic_seg)
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
mm_inputs['data_sample'] = data_sample
packed_inputs.append(mm_inputs)
return packed_inputs
inputs.append(mm_input)
data_samples.append(data_sample)
return dict(inputs=inputs, data_samples=data_samples)
def _get_config_directory():
@ -226,27 +225,36 @@ def _test_encoder_decoder_forward(cfg_file):
segmentor = revert_sync_batchnorm(segmentor)
# Test forward train
batch_inputs, data_samples = segmentor.data_preprocessor(
packed_inputs, True)
losses = segmentor.forward(batch_inputs, data_samples, mode='loss')
data = segmentor.data_preprocessor(packed_inputs, True)
losses = segmentor.forward(**data, mode='loss')
assert isinstance(losses, dict)
packed_inputs = _demo_mm_inputs(
batch_size=1, image_shapes=(3, 32, 32), num_classes=num_classes)
batch_inputs, data_samples = segmentor.data_preprocessor(
packed_inputs, False)
data = segmentor.data_preprocessor(packed_inputs, False)
with torch.no_grad():
segmentor.eval()
# Test forward predict
batch_results = segmentor.forward(
batch_inputs, data_samples, mode='predict')
batch_results = segmentor.forward(**data, mode='predict')
assert len(batch_results) == 1
assert is_list_of(batch_results, SegDataSample)
assert batch_results[0].pred_sem_seg.shape == (32, 32)
assert batch_results[0].seg_logits.data.shape == (num_classes, 32, 32)
assert batch_results[0].gt_sem_seg.shape == (32, 32)
# Test forward tensor
batch_results = segmentor.forward(
batch_inputs, data_samples, mode='tensor')
batch_results = segmentor.forward(**data, mode='tensor')
assert isinstance(batch_results, Tensor) or is_tuple_of(
batch_results, Tensor)
# Test forward predict without ground truth
data.pop('data_samples')
batch_results = segmentor.forward(**data, mode='predict')
assert len(batch_results) == 1
assert is_list_of(batch_results, SegDataSample)
assert batch_results[0].pred_sem_seg.shape == (32, 32)
# Test forward tensor without ground truth
batch_results = segmentor.forward(**data, mode='tensor')
assert isinstance(batch_results, Tensor) or is_tuple_of(
batch_results, Tensor)

View File

@ -29,8 +29,8 @@ class TestSegLocalVisualizer(TestCase):
gt_sem_seg = PixelData(**gt_sem_seg_data)
def test_add_datasample_forward(gt_sem_seg):
gt_seg_data_sample = SegDataSample()
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
data_sample = SegDataSample()
data_sample.gt_sem_seg = gt_sem_seg
with tempfile.TemporaryDirectory() as tmp_dir:
seg_local_visualizer = SegLocalVisualizer(
@ -42,7 +42,7 @@ class TestSegLocalVisualizer(TestCase):
# test out_file
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample)
data_sample)
assert os.path.exists(
osp.join(tmp_dir, 'vis_data', 'vis_image',
@ -57,22 +57,16 @@ class TestSegLocalVisualizer(TestCase):
data=torch.randint(0, num_class, (1, h, w)))
pred_sem_seg = PixelData(**pred_sem_seg_data)
pred_seg_data_sample = SegDataSample()
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
data_sample.pred_sem_seg = pred_sem_seg
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample,
pred_seg_data_sample)
data_sample)
self._assert_image_and_shape(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'), (h, w * 2, 3))
seg_local_visualizer.add_datasample(
out_file,
image,
gt_seg_data_sample,
pred_seg_data_sample,
draw_gt=False)
out_file, image, data_sample, draw_gt=False)
self._assert_image_and_shape(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'), (h, w, 3))
@ -104,8 +98,8 @@ class TestSegLocalVisualizer(TestCase):
gt_sem_seg = PixelData(**gt_sem_seg_data)
def test_cityscapes_add_datasample_forward(gt_sem_seg):
gt_seg_data_sample = SegDataSample()
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
data_sample = SegDataSample()
data_sample.gt_sem_seg = gt_sem_seg
with tempfile.TemporaryDirectory() as tmp_dir:
seg_local_visualizer = SegLocalVisualizer(
@ -125,11 +119,11 @@ class TestSegLocalVisualizer(TestCase):
[0, 60, 100], [0, 80, 100], [0, 0, 230],
[119, 11, 32]])
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample)
data_sample)
# test out_file
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample)
data_sample)
assert os.path.exists(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'))
@ -143,22 +137,16 @@ class TestSegLocalVisualizer(TestCase):
data=torch.randint(0, num_class, (1, h, w)))
pred_sem_seg = PixelData(**pred_sem_seg_data)
pred_seg_data_sample = SegDataSample()
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
data_sample.pred_sem_seg = pred_sem_seg
seg_local_visualizer.add_datasample(out_file, image,
gt_seg_data_sample,
pred_seg_data_sample)
data_sample)
self._assert_image_and_shape(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'), (h, w * 2, 3))
seg_local_visualizer.add_datasample(
out_file,
image,
gt_seg_data_sample,
pred_seg_data_sample,
draw_gt=False)
out_file, image, data_sample, draw_gt=False)
self._assert_image_and_shape(
osp.join(tmp_dir, 'vis_data', 'vis_image',
out_file + '_0.png'), (h, w, 3))

View File

@ -79,14 +79,15 @@ def main():
# benchmark with 200 batches and take the average
for i, data in enumerate(data_loader):
batch_inputs, data_samples = model.data_preprocessor(data, True)
data = model.data_preprocessor(data, True)
inputs = data['inputs']
data_samples = data['data_samples']
if torch.cuda.is_available():
torch.cuda.synchronize()
start_time = time.perf_counter()
with torch.no_grad():
model(batch_inputs, data_samples, mode='predict')
model(inputs, data_samples, mode='predict')
if torch.cuda.is_available():
torch.cuda.synchronize()