mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Refactor] data flow (#1956)
* [WIP] Refactor data flow * model return * [WIP] Refactor data flow * support data_samples is optional * fix benchmark * fix base * minors * rebase * fix api * ut * fix api inference * comments * docstring * docstring * docstring * fix bug of slide inference * add assert c > 1
This commit is contained in:
parent
50546da85c
commit
8de0050f25
@ -1,5 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
@ -11,9 +12,9 @@ from mmengine.dataset import Compose
|
||||
from mmengine.runner import load_checkpoint
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
|
||||
from mmseg.data import SegDataSample
|
||||
from mmseg.models import BaseSegmentor
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import SampleList, dataset_aliases, get_classes, get_palette
|
||||
from mmseg.visualization import SegLocalVisualizer
|
||||
|
||||
@ -50,7 +51,6 @@ def init_model(config: Union[str, Path, Config],
|
||||
model = MODELS.build(config.model)
|
||||
if checkpoint is not None:
|
||||
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
||||
|
||||
dataset_meta = checkpoint['meta'].get('dataset_meta', None)
|
||||
# save the dataset_meta in the model for convenience
|
||||
if 'dataset_meta' in checkpoint.get('meta', {}):
|
||||
@ -108,14 +108,15 @@ def _preprare_data(imgs: ImageType, model: BaseSegmentor):
|
||||
# a pipeline for each inference
|
||||
pipeline = Compose(cfg.test_pipeline)
|
||||
|
||||
data = []
|
||||
data = defaultdict(list)
|
||||
for img in imgs:
|
||||
if isinstance(img, np.ndarray):
|
||||
data_ = dict(img=img)
|
||||
else:
|
||||
data_ = dict(img_path=img)
|
||||
data_ = pipeline(data_)
|
||||
data.append(data_)
|
||||
data['inputs'].append(data_['inputs'])
|
||||
data['data_samples'].append(data_['data_samples'])
|
||||
|
||||
return data, is_batch
|
||||
|
||||
@ -187,11 +188,12 @@ def show_result_pyplot(model: BaseSegmentor,
|
||||
save_dir=save_dir,
|
||||
alpha=opacity)
|
||||
visualizer.dataset_meta = dict(
|
||||
classes=model.CLASSES, palette=model.PALETTE)
|
||||
classes=model.dataset_meta['classes'],
|
||||
palette=model.dataset_meta['palette'])
|
||||
visualizer.add_datasample(
|
||||
name=title,
|
||||
image=image,
|
||||
pred_sample=result[0],
|
||||
data_sample=result[0],
|
||||
draw_gt=draw_gt,
|
||||
draw_pred=draw_pred,
|
||||
wait_time=wait_time,
|
||||
|
@ -78,7 +78,7 @@ class PackSegInputs(BaseTransform):
|
||||
if key in results:
|
||||
img_meta[key] = results[key]
|
||||
data_sample.set_metainfo(img_meta)
|
||||
packed_results['data_sample'] = data_sample
|
||||
packed_results['data_samples'] = data_sample
|
||||
|
||||
return packed_results
|
||||
|
||||
|
@ -66,7 +66,7 @@ class SegVisualizationHook(Hook):
|
||||
def _after_iter(self,
|
||||
runner: Runner,
|
||||
batch_idx: int,
|
||||
data_batch: Sequence[dict],
|
||||
data_batch: dict,
|
||||
outputs: Sequence[SegDataSample],
|
||||
mode: str = 'val') -> None:
|
||||
"""Run after every ``self.interval`` validation iterations.
|
||||
@ -74,7 +74,7 @@ class SegVisualizationHook(Hook):
|
||||
Args:
|
||||
runner (:obj:`Runner`): The runner of the validation process.
|
||||
batch_idx (int): The index of the current batch in the val loop.
|
||||
data_batch (Sequence[dict]): Data from dataloader.
|
||||
data_batch (dict): Data from dataloader.
|
||||
outputs (Sequence[:obj:`SegDataSample`]): Outputs from model.
|
||||
mode (str): mode (str): Current mode of runner. Defaults to 'val'.
|
||||
"""
|
||||
@ -85,18 +85,16 @@ class SegVisualizationHook(Hook):
|
||||
self.file_client = FileClient(**self.file_client_args)
|
||||
|
||||
if self.every_n_inner_iters(batch_idx, self.interval):
|
||||
for input_data, output in zip(data_batch, outputs):
|
||||
img_path = input_data['data_sample'].img_path
|
||||
for output in outputs:
|
||||
img_path = output.img_path
|
||||
img_bytes = self.file_client.get(img_path)
|
||||
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
|
||||
window_name = f'{mode}_{osp.basename(img_path)}'
|
||||
|
||||
gt_sample = input_data['data_sample']
|
||||
self._visualizer.add_datasample(
|
||||
window_name,
|
||||
img,
|
||||
gt_sample=gt_sample,
|
||||
pred_sample=output,
|
||||
data_sample=output,
|
||||
show=self.show,
|
||||
wait_time=self.wait_time,
|
||||
step=runner.iter)
|
||||
|
@ -49,25 +49,24 @@ class CitysMetric(BaseMetric):
|
||||
self.to_label_id = to_label_id
|
||||
self.suffix = suffix
|
||||
|
||||
def process(self, data_batch: Sequence[dict],
|
||||
predictions: Sequence[dict]) -> None:
|
||||
"""Process one batch of data and predictions.
|
||||
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
|
||||
"""Process one batch of data and data_samples.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to computed the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[dict]): A batch of data from the dataloader.
|
||||
predictions (Sequence[dict]): A batch of outputs from the model.
|
||||
data_batch (dict): A batch of data from the dataloader.
|
||||
data_samples (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
mkdir_or_exist(self.suffix)
|
||||
|
||||
for pred in predictions:
|
||||
pred_label = pred['pred_sem_seg']['data'][0].cpu().numpy()
|
||||
for data_sample in data_samples:
|
||||
pred_label = data_sample['pred_sem_seg']['data'][0].cpu().numpy()
|
||||
# results2img
|
||||
if self.to_label_id:
|
||||
pred_label = self._convert_to_label_id(pred_label)
|
||||
basename = osp.splitext(osp.basename(pred['img_path']))[0]
|
||||
basename = osp.splitext(osp.basename(data_sample['img_path']))[0]
|
||||
png_filename = osp.join(self.suffix, f'{basename}.png')
|
||||
output = Image.fromarray(pred_label.astype(np.uint8)).convert('P')
|
||||
import cityscapesscripts.helpers.labels as CSLabels
|
||||
|
@ -47,22 +47,20 @@ class IoUMetric(BaseMetric):
|
||||
self.nan_to_num = nan_to_num
|
||||
self.beta = beta
|
||||
|
||||
def process(self, data_batch: Sequence[dict],
|
||||
predictions: Sequence[dict]) -> None:
|
||||
"""Process one batch of data and predictions.
|
||||
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
|
||||
"""Process one batch of data and data_samples.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to computed the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[dict]): A batch of data from the dataloader.
|
||||
predictions (Sequence[dict]): A batch of outputs from the model.
|
||||
data_batch (dict): A batch of data from the dataloader.
|
||||
data_samples (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
num_classes = len(self.dataset_meta['classes'])
|
||||
for data, pred in zip(data_batch, predictions):
|
||||
pred_label = pred['pred_sem_seg']['data'].squeeze()
|
||||
label = data['data_sample']['gt_sem_seg']['data'].squeeze().to(
|
||||
pred_label)
|
||||
for data_sample in data_samples:
|
||||
pred_label = data_sample['pred_sem_seg']['data'].squeeze()
|
||||
label = data_sample['gt_sem_seg']['data'].squeeze().to(pred_label)
|
||||
self.results.append(
|
||||
self.intersect_and_union(pred_label, label, num_classes,
|
||||
self.ignore_index))
|
||||
|
@ -1,13 +1,12 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from numbers import Number
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from mmengine.model import BaseDataPreprocessor
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import OptSampleList, stack_batch
|
||||
from mmseg.utils import stack_batch
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
@ -87,22 +86,20 @@ class SegDataPreProcessor(BaseDataPreprocessor):
|
||||
# TODO: support batch augmentations.
|
||||
self.batch_augments = batch_augments
|
||||
|
||||
def forward(self,
|
||||
data: Sequence[dict],
|
||||
training: bool = False) -> Tuple[Tensor, OptSampleList]:
|
||||
def forward(self, data: dict, training: bool = False) -> Dict[str, Any]:
|
||||
"""Perform normalization、padding and bgr2rgb conversion based on
|
||||
``BaseDataPreprocessor``.
|
||||
|
||||
Args:
|
||||
data (Sequence[dict]): data sampled from dataloader.
|
||||
data (dict): data sampled from dataloader.
|
||||
training (bool): Whether to enable training time augmentation.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
|
||||
model input.
|
||||
Dict: Data in the same format as the model input.
|
||||
"""
|
||||
inputs, batch_data_samples = self.collate_data(data)
|
||||
|
||||
data = self.cast_data(data) # type: ignore
|
||||
inputs = data['inputs']
|
||||
data_samples = data.get('data_samples', None)
|
||||
# TODO: whether normalize should be after stack_batch
|
||||
if self.channel_conversion and inputs[0].size(0) == 3:
|
||||
inputs = [_input[[2, 1, 0], ...] for _input in inputs]
|
||||
@ -113,20 +110,23 @@ class SegDataPreProcessor(BaseDataPreprocessor):
|
||||
inputs = [_input.float() for _input in inputs]
|
||||
|
||||
if training:
|
||||
batch_inputs, batch_data_samples = stack_batch(
|
||||
assert data_samples is not None, ('During training, ',
|
||||
'`data_samples` must be define.')
|
||||
inputs, data_samples = stack_batch(
|
||||
inputs=inputs,
|
||||
batch_data_samples=batch_data_samples,
|
||||
data_samples=data_samples,
|
||||
size=self.size,
|
||||
size_divisor=self.size_divisor,
|
||||
pad_val=self.pad_val,
|
||||
seg_pad_val=self.seg_pad_val)
|
||||
|
||||
if self.batch_augments is not None:
|
||||
inputs, batch_data_samples = self.batch_augments(
|
||||
inputs, batch_data_samples)
|
||||
return batch_inputs, batch_data_samples
|
||||
inputs, data_samples = self.batch_augments(
|
||||
inputs, data_samples)
|
||||
return dict(inputs=inputs, data_samples=data_samples)
|
||||
else:
|
||||
assert len(inputs) == 1, (
|
||||
'Batch inference is not support currently, '
|
||||
'as the image size might be different in a batch')
|
||||
return torch.stack(inputs, dim=0), batch_data_samples
|
||||
return dict(
|
||||
inputs=torch.stack(inputs, dim=0), data_samples=data_samples)
|
||||
|
@ -47,20 +47,19 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
||||
return hasattr(self, 'decode_head') and self.decode_head is not None
|
||||
|
||||
@abstractmethod
|
||||
def extract_feat(self, batch_inputs: Tensor) -> bool:
|
||||
def extract_feat(self, inputs: Tensor) -> bool:
|
||||
"""Placeholder for extract features from images."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def encode_decode(self, batch_inputs: Tensor,
|
||||
batch_data_samples: SampleList):
|
||||
def encode_decode(self, inputs: Tensor, batch_data_samples: SampleList):
|
||||
"""Placeholder for encode images with backbone and decode into a
|
||||
semantic segmentation map of the same size as input."""
|
||||
pass
|
||||
|
||||
def forward(self,
|
||||
batch_inputs: Tensor,
|
||||
batch_data_samples: OptSampleList = None,
|
||||
inputs: Tensor,
|
||||
data_samples: OptSampleList = None,
|
||||
mode: str = 'tensor') -> ForwardResults:
|
||||
"""The unified entry for a forward process in both training and test.
|
||||
|
||||
@ -77,10 +76,11 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
||||
optimizer updating, which are done in the :meth:`train_step`.
|
||||
|
||||
Args:
|
||||
batch_inputs (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
batch_data_samples (list[:obj:`SegDataSample`], optional): The
|
||||
annotation data of every samples. Defaults to None.
|
||||
inputs (torch.Tensor): The input tensor with shape (N, C, ...) in
|
||||
general.
|
||||
data_samples (list[:obj:`SegDataSample`]): The seg data samples.
|
||||
It usually includes information such as `metainfo` and
|
||||
`gt_sem_seg`. Default to None.
|
||||
mode (str): Return what kind of value. Defaults to 'tensor'.
|
||||
|
||||
Returns:
|
||||
@ -91,33 +91,32 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
||||
- If ``mode="loss"``, return a dict of tensor.
|
||||
"""
|
||||
if mode == 'loss':
|
||||
return self.loss(batch_inputs, batch_data_samples)
|
||||
return self.loss(inputs, data_samples)
|
||||
elif mode == 'predict':
|
||||
return self.predict(batch_inputs, batch_data_samples)
|
||||
return self.predict(inputs, data_samples)
|
||||
elif mode == 'tensor':
|
||||
return self._forward(batch_inputs, batch_data_samples)
|
||||
return self._forward(inputs, data_samples)
|
||||
else:
|
||||
raise RuntimeError(f'Invalid mode "{mode}". '
|
||||
'Only supports loss, predict and tensor mode')
|
||||
|
||||
@abstractmethod
|
||||
def loss(self, batch_inputs: Tensor,
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
|
||||
"""Calculate losses from a batch of inputs and data samples."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, batch_inputs: Tensor,
|
||||
batch_data_samples: SampleList) -> SampleList:
|
||||
def predict(self,
|
||||
inputs: Tensor,
|
||||
data_samples: OptSampleList = None) -> SampleList:
|
||||
"""Predict results from a batch of inputs and data samples with post-
|
||||
processing."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _forward(
|
||||
self,
|
||||
batch_inputs: Tensor,
|
||||
batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
|
||||
def _forward(self,
|
||||
inputs: Tensor,
|
||||
data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
|
||||
"""Network forward process.
|
||||
|
||||
Usually includes backbone, neck and head forward without any post-
|
||||
@ -130,13 +129,16 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
||||
"""Placeholder for augmentation test."""
|
||||
pass
|
||||
|
||||
def postprocess_result(self, seg_logits_list: List[dict],
|
||||
batch_img_metas: List[dict]) -> list:
|
||||
def postprocess_result(self,
|
||||
seg_logits: Tensor,
|
||||
data_samples: OptSampleList = None) -> list:
|
||||
""" Convert results list to `SegDataSample`.
|
||||
Args:
|
||||
seg_logits_list (List[dict]): List of segmentation results,
|
||||
seg_logits from model of each input image.
|
||||
|
||||
seg_logits (Tensor): The segmentation results, seg_logits from
|
||||
model of each input image.
|
||||
data_samples (list[:obj:`SegDataSample`]): The seg data samples.
|
||||
It usually includes information such as `metainfo` and
|
||||
`gt_sem_seg`. Default to None.
|
||||
Returns:
|
||||
list[:obj:`SegDataSample`]: Segmentation results of the
|
||||
input images. Each SegDataSample usually contain:
|
||||
@ -145,22 +147,50 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
||||
- ``seg_logits``(PixelData): Predicted logits of semantic
|
||||
segmentation before normalization.
|
||||
"""
|
||||
predictions = []
|
||||
batch_size, C, H, W = seg_logits.shape
|
||||
assert C > 1, ('This post processes does not binary segmentation, and '
|
||||
f'channels `seg_logtis` must be > 1 but got {C}')
|
||||
|
||||
for i in range(len(seg_logits_list)):
|
||||
img_meta = batch_img_metas[i]
|
||||
seg_logits = resize(
|
||||
seg_logits_list[i][None],
|
||||
size=img_meta['ori_shape'],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners,
|
||||
warning=False).squeeze(0)
|
||||
# seg_logits shape is CHW
|
||||
seg_pred = seg_logits.argmax(dim=0, keepdim=True)
|
||||
prediction = SegDataSample(**{'metainfo': img_meta})
|
||||
prediction.set_data({
|
||||
'seg_logits': PixelData(**{'data': seg_logits}),
|
||||
'pred_sem_seg': PixelData(**{'data': seg_pred})
|
||||
})
|
||||
predictions.append(prediction)
|
||||
return predictions
|
||||
if data_samples is None:
|
||||
data_samples = []
|
||||
only_prediction = True
|
||||
else:
|
||||
only_prediction = False
|
||||
|
||||
for i in range(batch_size):
|
||||
if not only_prediction:
|
||||
img_meta = data_samples[i].metainfo
|
||||
# remove padding area
|
||||
padding_left, padding_right, padding_top, padding_bottom = \
|
||||
img_meta.get('padding_size', [0]*4)
|
||||
# i_seg_logits shape is 1, C, H, W after remove padding
|
||||
i_seg_logits = seg_logits[i:i + 1, :,
|
||||
padding_top:H - padding_bottom,
|
||||
padding_left:W - padding_right]
|
||||
# resize as original shape
|
||||
i_seg_logits = resize(
|
||||
i_seg_logits,
|
||||
size=img_meta['ori_shape'],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners,
|
||||
warning=False).squeeze(0)
|
||||
# i_seg_logits shape is C, H, W with original shape
|
||||
i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True)
|
||||
data_samples[i].set_data({
|
||||
'seg_logits':
|
||||
PixelData(**{'data': i_seg_logits}),
|
||||
'pred_sem_seg':
|
||||
PixelData(**{'data': i_seg_pred})
|
||||
})
|
||||
else:
|
||||
i_seg_logits = seg_logits[i]
|
||||
i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True)
|
||||
prediction = SegDataSample()
|
||||
prediction.set_data({
|
||||
'seg_logits':
|
||||
PixelData(**{'data': i_seg_logits}),
|
||||
'pred_sem_seg':
|
||||
PixelData(**{'data': i_seg_pred})
|
||||
})
|
||||
data_samples.append(prediction)
|
||||
return data_samples
|
||||
|
@ -69,11 +69,11 @@ class CascadeEncoderDecoder(EncoderDecoder):
|
||||
self.align_corners = self.decode_head[-1].align_corners
|
||||
self.num_classes = self.decode_head[-1].num_classes
|
||||
|
||||
def encode_decode(self, batch_inputs: Tensor,
|
||||
def encode_decode(self, inputs: Tensor,
|
||||
batch_img_metas: List[dict]) -> List[Tensor]:
|
||||
"""Encode images with backbone and decode into a semantic segmentation
|
||||
map of the same size as input."""
|
||||
x = self.extract_feat(batch_inputs)
|
||||
x = self.extract_feat(inputs)
|
||||
out = self.decode_head[0].forward(x)
|
||||
for i in range(1, self.num_stages - 1):
|
||||
out = self.decode_head[i].forward(x, out)
|
||||
@ -82,53 +82,52 @@ class CascadeEncoderDecoder(EncoderDecoder):
|
||||
|
||||
return seg_logits_list
|
||||
|
||||
def _decode_head_forward_train(self, batch_inputs: Tensor,
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
def _decode_head_forward_train(self, inputs: Tensor,
|
||||
data_samples: SampleList) -> dict:
|
||||
"""Run forward function and calculate loss for decode head in
|
||||
training."""
|
||||
losses = dict()
|
||||
|
||||
loss_decode = self.decode_head[0].loss(batch_inputs,
|
||||
batch_data_samples,
|
||||
loss_decode = self.decode_head[0].loss(inputs, data_samples,
|
||||
self.train_cfg)
|
||||
|
||||
losses.update(add_prefix(loss_decode, 'decode_0'))
|
||||
# get batch_img_metas
|
||||
batch_size = len(batch_data_samples)
|
||||
batch_size = len(data_samples)
|
||||
batch_img_metas = []
|
||||
for batch_index in range(batch_size):
|
||||
metainfo = batch_data_samples[batch_index].metainfo
|
||||
metainfo = data_samples[batch_index].metainfo
|
||||
batch_img_metas.append(metainfo)
|
||||
|
||||
for i in range(1, self.num_stages):
|
||||
# forward test again, maybe unnecessary for most methods.
|
||||
if i == 1:
|
||||
prev_outputs = self.decode_head[0].forward(batch_inputs)
|
||||
prev_outputs = self.decode_head[0].forward(inputs)
|
||||
else:
|
||||
prev_outputs = self.decode_head[i - 1].forward(
|
||||
batch_inputs, prev_outputs)
|
||||
loss_decode = self.decode_head[i].loss(batch_inputs, prev_outputs,
|
||||
batch_data_samples,
|
||||
inputs, prev_outputs)
|
||||
loss_decode = self.decode_head[i].loss(inputs, prev_outputs,
|
||||
data_samples,
|
||||
self.train_cfg)
|
||||
losses.update(add_prefix(loss_decode, f'decode_{i}'))
|
||||
|
||||
return losses
|
||||
|
||||
def _forward(self,
|
||||
batch_inputs: Tensor,
|
||||
inputs: Tensor,
|
||||
data_samples: OptSampleList = None) -> Tensor:
|
||||
"""Network forward process.
|
||||
|
||||
Args:
|
||||
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `img_metas` and `gt_semantic_seg`.
|
||||
inputs (Tensor): Inputs with shape (N, C, H, W).
|
||||
data_samples (List[:obj:`SegDataSample`]): The seg data samples.
|
||||
It usually includes information such as `metainfo` and
|
||||
`gt_semantic_seg`.
|
||||
|
||||
Returns:
|
||||
Tensor: Forward output of model without any post-processes.
|
||||
"""
|
||||
x = self.extract_feat(batch_inputs)
|
||||
x = self.extract_feat(inputs)
|
||||
|
||||
out = self.decode_head[0].forward(x)
|
||||
for i in range(1, self.num_stages):
|
||||
|
@ -1,7 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
@ -112,93 +111,87 @@ class EncoderDecoder(BaseSegmentor):
|
||||
else:
|
||||
self.auxiliary_head = MODELS.build(auxiliary_head)
|
||||
|
||||
def extract_feat(self, batch_inputs: Tensor) -> List[Tensor]:
|
||||
def extract_feat(self, inputs: Tensor) -> List[Tensor]:
|
||||
"""Extract features from images."""
|
||||
x = self.backbone(batch_inputs)
|
||||
x = self.backbone(inputs)
|
||||
if self.with_neck:
|
||||
x = self.neck(x)
|
||||
return x
|
||||
|
||||
def encode_decode(self, batch_inputs: Tensor,
|
||||
def encode_decode(self, inputs: Tensor,
|
||||
batch_img_metas: List[dict]) -> List[Tensor]:
|
||||
"""Encode images with backbone and decode into a semantic segmentation
|
||||
map of the same size as input."""
|
||||
x = self.extract_feat(batch_inputs)
|
||||
x = self.extract_feat(inputs)
|
||||
seg_logits = self.decode_head.predict(x, batch_img_metas,
|
||||
self.test_cfg)
|
||||
|
||||
return list(seg_logits)
|
||||
return seg_logits
|
||||
|
||||
def _decode_head_forward_train(self, batch_inputs: List[Tensor],
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
def _decode_head_forward_train(self, inputs: List[Tensor],
|
||||
data_samples: SampleList) -> dict:
|
||||
"""Run forward function and calculate loss for decode head in
|
||||
training."""
|
||||
losses = dict()
|
||||
loss_decode = self.decode_head.loss(batch_inputs, batch_data_samples,
|
||||
loss_decode = self.decode_head.loss(inputs, data_samples,
|
||||
self.train_cfg)
|
||||
|
||||
losses.update(add_prefix(loss_decode, 'decode'))
|
||||
return losses
|
||||
|
||||
def _auxiliary_head_forward_train(
|
||||
self,
|
||||
batch_inputs: List[Tensor],
|
||||
batch_data_samples: SampleList,
|
||||
) -> dict:
|
||||
def _auxiliary_head_forward_train(self, inputs: List[Tensor],
|
||||
data_samples: SampleList) -> dict:
|
||||
"""Run forward function and calculate loss for auxiliary head in
|
||||
training."""
|
||||
losses = dict()
|
||||
if isinstance(self.auxiliary_head, nn.ModuleList):
|
||||
for idx, aux_head in enumerate(self.auxiliary_head):
|
||||
loss_aux = aux_head.loss(batch_inputs, batch_data_samples,
|
||||
self.train_cfg)
|
||||
loss_aux = aux_head.loss(inputs, data_samples, self.train_cfg)
|
||||
losses.update(add_prefix(loss_aux, f'aux_{idx}'))
|
||||
else:
|
||||
loss_aux = self.auxiliary_head.loss(batch_inputs,
|
||||
batch_data_samples,
|
||||
loss_aux = self.auxiliary_head.loss(inputs, data_samples,
|
||||
self.train_cfg)
|
||||
losses.update(add_prefix(loss_aux, 'aux'))
|
||||
|
||||
return losses
|
||||
|
||||
def loss(self, batch_inputs: Tensor,
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
|
||||
"""Calculate losses from a batch of inputs and data samples.
|
||||
|
||||
Args:
|
||||
img (Tensor): Input images.
|
||||
batch_data_samples (list[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `metainfo` and `gt_sem_seg`.
|
||||
inputs (Tensor): Input images.
|
||||
data_samples (list[:obj:`SegDataSample`]): The seg data samples.
|
||||
It usually includes information such as `metainfo` and
|
||||
`gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
|
||||
x = self.extract_feat(batch_inputs)
|
||||
x = self.extract_feat(inputs)
|
||||
|
||||
losses = dict()
|
||||
|
||||
loss_decode = self._decode_head_forward_train(x, batch_data_samples)
|
||||
loss_decode = self._decode_head_forward_train(x, data_samples)
|
||||
losses.update(loss_decode)
|
||||
|
||||
if self.with_auxiliary_head:
|
||||
loss_aux = self._auxiliary_head_forward_train(
|
||||
x, batch_data_samples)
|
||||
loss_aux = self._auxiliary_head_forward_train(x, data_samples)
|
||||
losses.update(loss_aux)
|
||||
|
||||
return losses
|
||||
|
||||
def predict(self, batch_inputs: Tensor,
|
||||
batch_data_samples: SampleList) -> SampleList:
|
||||
def predict(self,
|
||||
inputs: Tensor,
|
||||
data_samples: OptSampleList = None) -> SampleList:
|
||||
"""Predict results from a batch of inputs and data samples with post-
|
||||
processing.
|
||||
|
||||
Args:
|
||||
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `metainfo` and `gt_sem_seg`.
|
||||
inputs (Tensor): Inputs with shape (N, C, H, W).
|
||||
data_samples (List[:obj:`SegDataSample`], optional): The seg data
|
||||
samples. It usually includes information such as `metainfo`
|
||||
and `gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
list[:obj:`SegDataSample`]: Segmentation results of the
|
||||
@ -208,40 +201,49 @@ class EncoderDecoder(BaseSegmentor):
|
||||
- ``seg_logits``(PixelData): Predicted logits of semantic
|
||||
segmentation before normalization.
|
||||
"""
|
||||
batch_img_metas = []
|
||||
for data_sample in batch_data_samples:
|
||||
batch_img_metas.append(data_sample.metainfo)
|
||||
if data_samples is not None:
|
||||
batch_img_metas = [
|
||||
data_sample.metainfo for data_sample in data_samples
|
||||
]
|
||||
else:
|
||||
batch_img_metas = [
|
||||
dict(
|
||||
ori_shape=inputs.shape[2:],
|
||||
img_shape=inputs.shape[2:],
|
||||
pad_shape=inputs.shape[2:],
|
||||
padding_size=[0, 0, 0, 0])
|
||||
] * inputs.shape[0]
|
||||
|
||||
seg_logit_list = self.inference(batch_inputs, batch_img_metas)
|
||||
seg_logits = self.inference(inputs, batch_img_metas)
|
||||
|
||||
return self.postprocess_result(seg_logit_list, batch_img_metas)
|
||||
return self.postprocess_result(seg_logits, data_samples)
|
||||
|
||||
def _forward(self,
|
||||
batch_inputs: Tensor,
|
||||
inputs: Tensor,
|
||||
data_samples: OptSampleList = None) -> Tensor:
|
||||
"""Network forward process.
|
||||
|
||||
Args:
|
||||
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The seg
|
||||
inputs (Tensor): Inputs with shape (N, C, H, W).
|
||||
data_samples (List[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `metainfo` and `gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
Tensor: Forward output of model without any post-processes.
|
||||
"""
|
||||
x = self.extract_feat(batch_inputs)
|
||||
x = self.extract_feat(inputs)
|
||||
return self.decode_head.forward(x)
|
||||
|
||||
def slide_inference(self, batch_inputs: Tensor,
|
||||
batch_img_metas: List[dict]) -> List[Tensor]:
|
||||
def slide_inference(self, inputs: Tensor,
|
||||
batch_img_metas: List[dict]) -> Tensor:
|
||||
"""Inference by sliding-window with overlap.
|
||||
|
||||
If h_crop > h_img or w_crop > w_img, the small patch will be used to
|
||||
decode without padding.
|
||||
|
||||
Args:
|
||||
batch_inputs (tensor): the tensor should have a shape NxCxHxW,
|
||||
inputs (tensor): the tensor should have a shape NxCxHxW,
|
||||
which contains all images in the batch.
|
||||
batch_img_metas (List[dict]): List of image metainfo where each may
|
||||
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
||||
@ -250,18 +252,18 @@ class EncoderDecoder(BaseSegmentor):
|
||||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
||||
|
||||
Returns:
|
||||
List[:obj:`Tensor`]: List of segmentation results, seg_logits from
|
||||
model of each input image.
|
||||
Tensor: The segmentation results, seg_logits from model of each
|
||||
input image.
|
||||
"""
|
||||
|
||||
h_stride, w_stride = self.test_cfg.stride
|
||||
h_crop, w_crop = self.test_cfg.crop_size
|
||||
batch_size, _, h_img, w_img = batch_inputs.size()
|
||||
batch_size, _, h_img, w_img = inputs.size()
|
||||
num_classes = self.num_classes
|
||||
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
|
||||
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
|
||||
preds = batch_inputs.new_zeros((batch_size, num_classes, h_img, w_img))
|
||||
count_mat = batch_inputs.new_zeros((batch_size, 1, h_img, w_img))
|
||||
preds = inputs.new_zeros((batch_size, num_classes, h_img, w_img))
|
||||
count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img))
|
||||
for h_idx in range(h_grids):
|
||||
for w_idx in range(w_grids):
|
||||
y1 = h_idx * h_stride
|
||||
@ -270,30 +272,29 @@ class EncoderDecoder(BaseSegmentor):
|
||||
x2 = min(x1 + w_crop, w_img)
|
||||
y1 = max(y2 - h_crop, 0)
|
||||
x1 = max(x2 - w_crop, 0)
|
||||
crop_img = batch_inputs[:, :, y1:y2, x1:x2]
|
||||
# change the img shape to patch shape
|
||||
crop_img = inputs[:, :, y1:y2, x1:x2]
|
||||
# change the image shape to patch shape
|
||||
batch_img_metas[0]['img_shape'] = crop_img.shape[2:]
|
||||
# the output of encode_decode is list of seg logits map
|
||||
# with shape [C, H, W]
|
||||
crop_seg_logit = torch.stack(
|
||||
self.encode_decode(crop_img, batch_img_metas), dim=0)
|
||||
# the output of encode_decode is seg logits tensor map
|
||||
# with shape [N, C, H, W]
|
||||
crop_seg_logit = self.encode_decode(crop_img, batch_img_metas)
|
||||
preds += F.pad(crop_seg_logit,
|
||||
(int(x1), int(preds.shape[3] - x2), int(y1),
|
||||
int(preds.shape[2] - y2)))
|
||||
|
||||
count_mat[:, :, y1:y2, x1:x2] += 1
|
||||
assert (count_mat == 0).sum() == 0
|
||||
seg_logits_list = list(preds / count_mat)
|
||||
seg_logits = preds / count_mat
|
||||
|
||||
return seg_logits_list
|
||||
return seg_logits
|
||||
|
||||
def whole_inference(self, batch_inputs: Tensor,
|
||||
batch_img_metas: List[dict]) -> List[Tensor]:
|
||||
def whole_inference(self, inputs: Tensor,
|
||||
batch_img_metas: List[dict]) -> Tensor:
|
||||
"""Inference with full image.
|
||||
|
||||
Args:
|
||||
batch_inputs (Tensor): The tensor should have a shape NxCxHxW,
|
||||
which contains all images in the batch.
|
||||
inputs (Tensor): The tensor should have a shape NxCxHxW, which
|
||||
contains all images in the batch.
|
||||
batch_img_metas (List[dict]): List of image metainfo where each may
|
||||
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
||||
'ori_shape', and 'pad_shape'.
|
||||
@ -301,44 +302,41 @@ class EncoderDecoder(BaseSegmentor):
|
||||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
||||
|
||||
Returns:
|
||||
List[:obj:`Tensor`]: List of segmentation results, seg_logits from
|
||||
model of each input image.
|
||||
Tensor: The segmentation results, seg_logits from model of each
|
||||
input image.
|
||||
"""
|
||||
|
||||
seg_logits_list = self.encode_decode(batch_inputs, batch_img_metas)
|
||||
seg_logits = self.encode_decode(inputs, batch_img_metas)
|
||||
|
||||
return seg_logits_list
|
||||
return seg_logits
|
||||
|
||||
def inference(self, batch_inputs: Tensor,
|
||||
batch_img_metas: List[dict]) -> List[Tensor]:
|
||||
def inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor:
|
||||
"""Inference with slide/whole style.
|
||||
|
||||
Args:
|
||||
batch_inputs (Tensor): The input image of shape (N, 3, H, W).
|
||||
inputs (Tensor): The input image of shape (N, 3, H, W).
|
||||
batch_img_metas (List[dict]): List of image metainfo where each may
|
||||
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
||||
'ori_shape', and 'pad_shape'.
|
||||
'ori_shape', 'pad_shape', and 'padding_size'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
||||
|
||||
Returns:
|
||||
List[:obj:`Tensor`]: List of segmentation results, seg_logits from
|
||||
model of each input image.
|
||||
Tensor: The segmentation results, seg_logits from model of each
|
||||
input image.
|
||||
"""
|
||||
|
||||
assert self.test_cfg.mode in ['slide', 'whole']
|
||||
ori_shape = batch_img_metas[0]['ori_shape']
|
||||
assert all(_['ori_shape'] == ori_shape for _ in batch_img_metas)
|
||||
if self.test_cfg.mode == 'slide':
|
||||
seg_logit_list = self.slide_inference(batch_inputs,
|
||||
batch_img_metas)
|
||||
seg_logit = self.slide_inference(inputs, batch_img_metas)
|
||||
else:
|
||||
seg_logit_list = self.whole_inference(batch_inputs,
|
||||
batch_img_metas)
|
||||
seg_logit = self.whole_inference(inputs, batch_img_metas)
|
||||
|
||||
return seg_logit_list
|
||||
return seg_logit
|
||||
|
||||
def aug_test(self, batch_inputs, batch_img_metas, rescale=True):
|
||||
def aug_test(self, inputs, batch_img_metas, rescale=True):
|
||||
"""Test with augmentations.
|
||||
|
||||
Only rescale=True is supported.
|
||||
@ -346,13 +344,12 @@ class EncoderDecoder(BaseSegmentor):
|
||||
# aug_test rescale all imgs back to ori_shape for now
|
||||
assert rescale
|
||||
# to save memory, we get augmented seg logit inplace
|
||||
seg_logit = self.inference(batch_inputs[0], batch_img_metas[0],
|
||||
rescale)
|
||||
for i in range(1, len(batch_inputs)):
|
||||
cur_seg_logit = self.inference(batch_inputs[i], batch_img_metas[i],
|
||||
seg_logit = self.inference(inputs[0], batch_img_metas[0], rescale)
|
||||
for i in range(1, len(inputs)):
|
||||
cur_seg_logit = self.inference(inputs[i], batch_img_metas[i],
|
||||
rescale)
|
||||
seg_logit += cur_seg_logit
|
||||
seg_logit /= len(batch_inputs)
|
||||
seg_logit /= len(inputs)
|
||||
seg_pred = seg_logit.argmax(dim=1)
|
||||
# unravel batch dim
|
||||
seg_pred = list(seg_pred)
|
||||
|
@ -28,7 +28,7 @@ def add_prefix(inputs, prefix):
|
||||
|
||||
|
||||
def stack_batch(inputs: List[torch.Tensor],
|
||||
batch_data_samples: Optional[SampleList] = None,
|
||||
data_samples: Optional[SampleList] = None,
|
||||
size: Optional[tuple] = None,
|
||||
size_divisor: Optional[int] = None,
|
||||
pad_val: Union[int, float] = 0,
|
||||
@ -39,8 +39,8 @@ def stack_batch(inputs: List[torch.Tensor],
|
||||
Args:
|
||||
inputs (List[Tensor]): The input multiple tensors. each is a
|
||||
CHW 3D-tensor.
|
||||
batch_data_samples (list[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as `gt_sem_seg`.
|
||||
data_samples (list[:obj:`SegDataSample`]): The list of data samples.
|
||||
It usually includes information such as `gt_sem_seg`.
|
||||
size (tuple, optional): Fixed padding size.
|
||||
size_divisor (int, optional): The divisor of padded size.
|
||||
pad_val (int, float): The padding value. Defaults to 0
|
||||
@ -48,8 +48,7 @@ def stack_batch(inputs: List[torch.Tensor],
|
||||
|
||||
Returns:
|
||||
Tensor: The 4D-tensor.
|
||||
batch_data_samples (list[:obj:`SegDataSample`]): After the padding of
|
||||
the gt_seg_map.
|
||||
List[:obj:`SegDataSample`]: After the padding of the gt_seg_map.
|
||||
"""
|
||||
assert isinstance(inputs, list), \
|
||||
f'Expected input type to be list, but got {type(inputs)}'
|
||||
@ -93,14 +92,17 @@ def stack_batch(inputs: List[torch.Tensor],
|
||||
pad_img = F.pad(tensor, padding_size, value=pad_val)
|
||||
padded_inputs.append(pad_img)
|
||||
# pad gt_sem_seg
|
||||
if batch_data_samples is not None:
|
||||
data_sample = batch_data_samples[i]
|
||||
if data_samples is not None:
|
||||
data_sample = data_samples[i]
|
||||
gt_sem_seg = data_sample.gt_sem_seg.data
|
||||
del data_sample.gt_sem_seg.data
|
||||
data_sample.gt_sem_seg.data = F.pad(
|
||||
gt_sem_seg, padding_size, value=seg_pad_val)
|
||||
data_sample.set_metainfo(
|
||||
{'pad_shape': data_sample.gt_sem_seg.shape})
|
||||
data_sample.set_metainfo({
|
||||
'img_shape': tensor.shape[-2:],
|
||||
'pad_shape': data_sample.gt_sem_seg.shape,
|
||||
'padding_size': padding_size
|
||||
})
|
||||
padded_samples.append(data_sample)
|
||||
else:
|
||||
padded_samples = None
|
||||
|
@ -102,8 +102,7 @@ class SegLocalVisualizer(Visualizer):
|
||||
def add_datasample(self,
|
||||
name: str,
|
||||
image: np.ndarray,
|
||||
gt_sample: Optional[SegDataSample] = None,
|
||||
pred_sample: Optional[SegDataSample] = None,
|
||||
data_sample: Optional[SegDataSample] = None,
|
||||
draw_gt: bool = True,
|
||||
draw_pred: bool = True,
|
||||
show: bool = False,
|
||||
@ -137,27 +136,26 @@ class SegLocalVisualizer(Visualizer):
|
||||
gt_img_data = None
|
||||
pred_img_data = None
|
||||
|
||||
if draw_gt and gt_sample is not None:
|
||||
if draw_gt and data_sample is not None and 'gt_sem_seg' in data_sample:
|
||||
gt_img_data = image
|
||||
if 'gt_sem_seg' in gt_sample:
|
||||
assert classes is not None, 'class information is ' \
|
||||
'not provided when ' \
|
||||
'visualizing semantic ' \
|
||||
'segmentation results.'
|
||||
gt_img_data = self._draw_sem_seg(gt_img_data,
|
||||
gt_sample.gt_sem_seg, classes,
|
||||
palette)
|
||||
assert classes is not None, 'class information is ' \
|
||||
'not provided when ' \
|
||||
'visualizing semantic ' \
|
||||
'segmentation results.'
|
||||
gt_img_data = self._draw_sem_seg(gt_img_data,
|
||||
data_sample.gt_sem_seg, classes,
|
||||
palette)
|
||||
|
||||
if draw_pred and pred_sample is not None:
|
||||
if (draw_pred and data_sample is not None
|
||||
and 'pred_sem_seg' in data_sample):
|
||||
pred_img_data = image
|
||||
if 'pred_sem_seg' in pred_sample:
|
||||
assert classes is not None, 'class information is ' \
|
||||
'not provided when ' \
|
||||
'visualizing semantic ' \
|
||||
'segmentation results.'
|
||||
pred_img_data = self._draw_sem_seg(pred_img_data,
|
||||
pred_sample.pred_sem_seg,
|
||||
classes, palette)
|
||||
assert classes is not None, 'class information is ' \
|
||||
'not provided when ' \
|
||||
'visualizing semantic ' \
|
||||
'segmentation results.'
|
||||
pred_img_data = self._draw_sem_seg(pred_img_data,
|
||||
data_sample.pred_sem_seg,
|
||||
classes, palette)
|
||||
|
||||
if gt_img_data is not None and pred_img_data is not None:
|
||||
drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1)
|
||||
|
@ -3,14 +3,13 @@ import glob
|
||||
import os
|
||||
from os.path import dirname, exists, isdir, join, relpath
|
||||
|
||||
# import numpy as np
|
||||
import numpy as np
|
||||
from mmengine import Config
|
||||
# from mmengine.dataset import Compose
|
||||
from mmengine.dataset import Compose
|
||||
from torch import nn
|
||||
|
||||
from mmseg.models import build_segmentor
|
||||
|
||||
# from mmseg.utils import register_all_modules
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
|
||||
def _get_config_directory():
|
||||
@ -64,70 +63,69 @@ def test_config_build_segmentor():
|
||||
_check_decode_head(head_config, segmentor.decode_head)
|
||||
|
||||
|
||||
# def test_config_data_pipeline():
|
||||
# """Test whether the data pipeline is valid and can process corner cases.
|
||||
def test_config_data_pipeline():
|
||||
"""Test whether the data pipeline is valid and can process corner cases.
|
||||
|
||||
# CommandLine:
|
||||
# xdoctest -m tests/test_config.py test_config_build_data_pipeline
|
||||
# """
|
||||
CommandLine:
|
||||
xdoctest -m tests/test_config.py test_config_build_data_pipeline
|
||||
"""
|
||||
|
||||
# register_all_modules()
|
||||
# config_dpath = _get_config_directory()
|
||||
# print('Found config_dpath = {!r}'.format(config_dpath))
|
||||
register_all_modules()
|
||||
config_dpath = _get_config_directory()
|
||||
print('Found config_dpath = {!r}'.format(config_dpath))
|
||||
|
||||
# import glob
|
||||
# config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
|
||||
# config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
|
||||
# config_names = [relpath(p, config_dpath) for p in config_fpaths]
|
||||
import glob
|
||||
config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
|
||||
config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
|
||||
config_names = [relpath(p, config_dpath) for p in config_fpaths]
|
||||
|
||||
# print('Using {} config files'.format(len(config_names)))
|
||||
print('Using {} config files'.format(len(config_names)))
|
||||
|
||||
# for config_fname in config_names:
|
||||
# config_fpath = join(config_dpath, config_fname)
|
||||
# print(
|
||||
# 'Building data pipeline, config_fpath = {!r}'.
|
||||
# format(config_fpath))
|
||||
# config_mod = Config.fromfile(config_fpath)
|
||||
for config_fname in config_names:
|
||||
config_fpath = join(config_dpath, config_fname)
|
||||
print(
|
||||
'Building data pipeline, config_fpath = {!r}'.format(config_fpath))
|
||||
config_mod = Config.fromfile(config_fpath)
|
||||
|
||||
# # remove loading pipeline
|
||||
# load_img_pipeline = config_mod.train_pipeline.pop(0)
|
||||
# to_float32 = load_img_pipeline.get('to_float32', False)
|
||||
# config_mod.train_pipeline.pop(0)
|
||||
# config_mod.test_pipeline.pop(0)
|
||||
# # remove loading annotation in test pipeline
|
||||
# config_mod.test_pipeline.pop(1)
|
||||
# remove loading pipeline
|
||||
load_img_pipeline = config_mod.train_pipeline.pop(0)
|
||||
to_float32 = load_img_pipeline.get('to_float32', False)
|
||||
config_mod.train_pipeline.pop(0)
|
||||
config_mod.test_pipeline.pop(0)
|
||||
# remove loading annotation in test pipeline
|
||||
config_mod.test_pipeline.pop(1)
|
||||
|
||||
# train_pipeline = Compose(config_mod.train_pipeline)
|
||||
# test_pipeline = Compose(config_mod.test_pipeline)
|
||||
train_pipeline = Compose(config_mod.train_pipeline)
|
||||
test_pipeline = Compose(config_mod.test_pipeline)
|
||||
|
||||
# img = np.random.randint(0, 255, size=(1024, 2048, 3), dtype=np.uint8)
|
||||
# if to_float32:
|
||||
# img = img.astype(np.float32)
|
||||
# seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8)
|
||||
img = np.random.randint(0, 255, size=(1024, 2048, 3), dtype=np.uint8)
|
||||
if to_float32:
|
||||
img = img.astype(np.float32)
|
||||
seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8)
|
||||
|
||||
# results = dict(
|
||||
# filename='test_img.png',
|
||||
# ori_filename='test_img.png',
|
||||
# img=img,
|
||||
# img_shape=img.shape,
|
||||
# ori_shape=img.shape,
|
||||
# gt_seg_map=seg)
|
||||
# results['seg_fields'] = ['gt_seg_map']
|
||||
results = dict(
|
||||
filename='test_img.png',
|
||||
ori_filename='test_img.png',
|
||||
img=img,
|
||||
img_shape=img.shape,
|
||||
ori_shape=img.shape,
|
||||
gt_seg_map=seg)
|
||||
results['seg_fields'] = ['gt_seg_map']
|
||||
|
||||
# print('Test training data pipeline: \n{!r}'.format(train_pipeline))
|
||||
# output_results = train_pipeline(results)
|
||||
# assert output_results is not None
|
||||
print('Test training data pipeline: \n{!r}'.format(train_pipeline))
|
||||
output_results = train_pipeline(results)
|
||||
assert output_results is not None
|
||||
|
||||
# results = dict(
|
||||
# filename='test_img.png',
|
||||
# ori_filename='test_img.png',
|
||||
# img=img,
|
||||
# img_shape=img.shape,
|
||||
# ori_shape=img.shape,
|
||||
# )
|
||||
# print('Test testing data pipeline: \n{!r}'.format(test_pipeline))
|
||||
# output_results = test_pipeline(results)
|
||||
# assert output_results is not None
|
||||
results = dict(
|
||||
filename='test_img.png',
|
||||
ori_filename='test_img.png',
|
||||
img=img,
|
||||
img_shape=img.shape,
|
||||
ori_shape=img.shape,
|
||||
)
|
||||
print('Test testing data pipeline: \n{!r}'.format(test_pipeline))
|
||||
output_results = test_pipeline(results)
|
||||
assert output_results is not None
|
||||
|
||||
|
||||
def _check_decode_head(decode_head_cfg, decode_head):
|
||||
|
@ -39,12 +39,12 @@ class TestPackSegInputs(unittest.TestCase):
|
||||
def test_transform(self):
|
||||
transform = PackSegInputs(meta_keys=self.meta_keys)
|
||||
results = transform(copy.deepcopy(self.results))
|
||||
self.assertIn('data_sample', results)
|
||||
self.assertIsInstance(results['data_sample'], SegDataSample)
|
||||
self.assertIsInstance(results['data_sample'].gt_sem_seg,
|
||||
self.assertIn('data_samples', results)
|
||||
self.assertIsInstance(results['data_samples'], SegDataSample)
|
||||
self.assertIsInstance(results['data_samples'].gt_sem_seg,
|
||||
BaseDataElement)
|
||||
self.assertEqual(results['data_sample'].ori_shape,
|
||||
results['data_sample'].gt_sem_seg.shape)
|
||||
self.assertEqual(results['data_samples'].ori_shape,
|
||||
results['data_samples'].gt_sem_seg.shape)
|
||||
|
||||
def test_repr(self):
|
||||
transform = PackSegInputs(meta_keys=self.meta_keys)
|
||||
|
@ -1,151 +1,151 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
# import os.path as osp
|
||||
|
||||
import mmcv
|
||||
import pytest
|
||||
# import mmcv
|
||||
# import pytest
|
||||
|
||||
from mmseg.datasets.transforms import * # noqa
|
||||
from mmseg.registry import TRANSFORMS
|
||||
# from mmseg.datasets.transforms import * # noqa
|
||||
# from mmseg.registry import TRANSFORMS
|
||||
|
||||
# TODO
|
||||
# def test_multi_scale_flip_aug():
|
||||
# # test assertion if scales=None, scale_factor=1 (not float).
|
||||
# with pytest.raises(AssertionError):
|
||||
# tta_transform = dict(
|
||||
# type='MultiScaleFlipAug',
|
||||
# scales=None,
|
||||
# scale_factor=1,
|
||||
# transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
# )
|
||||
# TRANSFORMS.build(tta_transform)
|
||||
|
||||
def test_multi_scale_flip_aug():
|
||||
# test assertion if scales=None, scale_factor=1 (not float).
|
||||
with pytest.raises(AssertionError):
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
scales=None,
|
||||
scale_factor=1,
|
||||
transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
)
|
||||
TRANSFORMS.build(tta_transform)
|
||||
# # test assertion if scales=None, scale_factor=None.
|
||||
# with pytest.raises(AssertionError):
|
||||
# tta_transform = dict(
|
||||
# type='MultiScaleFlipAug',
|
||||
# scales=None,
|
||||
# scale_factor=None,
|
||||
# transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
# )
|
||||
# TRANSFORMS.build(tta_transform)
|
||||
|
||||
# test assertion if scales=None, scale_factor=None.
|
||||
with pytest.raises(AssertionError):
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
scales=None,
|
||||
scale_factor=None,
|
||||
transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
)
|
||||
TRANSFORMS.build(tta_transform)
|
||||
# # test assertion if scales=(512, 512), scale_factor=1 (not float).
|
||||
# with pytest.raises(AssertionError):
|
||||
# tta_transform = dict(
|
||||
# type='MultiScaleFlipAug',
|
||||
# scales=(512, 512),
|
||||
# scale_factor=1,
|
||||
# transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
# )
|
||||
# TRANSFORMS.build(tta_transform)
|
||||
# meta_keys = ('img', 'ori_shape', 'ori_height', 'ori_width', 'pad_shape',
|
||||
# 'scale_factor', 'scale', 'flip')
|
||||
# tta_transform = dict(
|
||||
# type='MultiScaleFlipAug',
|
||||
# scales=[(256, 256), (512, 512), (1024, 1024)],
|
||||
# allow_flip=False,
|
||||
# resize_cfg=dict(type='Resize', keep_ratio=False),
|
||||
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
||||
# )
|
||||
# tta_module = TRANSFORMS.build(tta_transform)
|
||||
|
||||
# test assertion if scales=(512, 512), scale_factor=1 (not float).
|
||||
with pytest.raises(AssertionError):
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
scales=(512, 512),
|
||||
scale_factor=1,
|
||||
transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
)
|
||||
TRANSFORMS.build(tta_transform)
|
||||
meta_keys = ('img', 'ori_shape', 'ori_height', 'ori_width', 'pad_shape',
|
||||
'scale_factor', 'scale', 'flip')
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
scales=[(256, 256), (512, 512), (1024, 1024)],
|
||||
allow_flip=False,
|
||||
resize_cfg=dict(type='Resize', keep_ratio=False),
|
||||
transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
||||
)
|
||||
tta_module = TRANSFORMS.build(tta_transform)
|
||||
# results = dict()
|
||||
# # (288, 512, 3)
|
||||
# img = mmcv.imread(
|
||||
# osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
# results['img'] = img
|
||||
# results['ori_shape'] = img.shape
|
||||
# results['ori_height'] = img.shape[0]
|
||||
# results['ori_width'] = img.shape[1]
|
||||
# # Set initial values for default meta_keys
|
||||
# results['pad_shape'] = img.shape
|
||||
# results['scale_factor'] = 1.0
|
||||
|
||||
results = dict()
|
||||
# (288, 512, 3)
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
results['img'] = img
|
||||
results['ori_shape'] = img.shape
|
||||
results['ori_height'] = img.shape[0]
|
||||
results['ori_width'] = img.shape[1]
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
# tta_results = tta_module(results.copy())
|
||||
# assert [data_sample.scale
|
||||
# for data_sample in tta_results['data_sample']] == [(256, 256),
|
||||
# (512, 512),
|
||||
# (1024, 1024)]
|
||||
# assert [data_sample.flip for data_sample in tta_results['data_sample']
|
||||
# ] == [False, False, False]
|
||||
|
||||
tta_results = tta_module(results.copy())
|
||||
assert [data_sample.scale
|
||||
for data_sample in tta_results['data_sample']] == [(256, 256),
|
||||
(512, 512),
|
||||
(1024, 1024)]
|
||||
assert [data_sample.flip for data_sample in tta_results['data_sample']
|
||||
] == [False, False, False]
|
||||
# tta_transform = dict(
|
||||
# type='MultiScaleFlipAug',
|
||||
# scales=[(256, 256), (512, 512), (1024, 1024)],
|
||||
# allow_flip=True,
|
||||
# resize_cfg=dict(type='Resize', keep_ratio=False),
|
||||
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
||||
# )
|
||||
# tta_module = TRANSFORMS.build(tta_transform)
|
||||
# tta_results = tta_module(results.copy())
|
||||
# assert [data_sample.scale
|
||||
# for data_sample in tta_results['data_sample']] == [(256, 256),
|
||||
# (256, 256),
|
||||
# (512, 512),
|
||||
# (512, 512),
|
||||
# (1024, 1024),
|
||||
# (1024, 1024)]
|
||||
# assert [data_sample.flip for data_sample in tta_results['data_sample']
|
||||
# ] == [False, True, False, True, False, True]
|
||||
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
scales=[(256, 256), (512, 512), (1024, 1024)],
|
||||
allow_flip=True,
|
||||
resize_cfg=dict(type='Resize', keep_ratio=False),
|
||||
transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
||||
)
|
||||
tta_module = TRANSFORMS.build(tta_transform)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert [data_sample.scale
|
||||
for data_sample in tta_results['data_sample']] == [(256, 256),
|
||||
(256, 256),
|
||||
(512, 512),
|
||||
(512, 512),
|
||||
(1024, 1024),
|
||||
(1024, 1024)]
|
||||
assert [data_sample.flip for data_sample in tta_results['data_sample']
|
||||
] == [False, True, False, True, False, True]
|
||||
# tta_transform = dict(
|
||||
# type='MultiScaleFlipAug',
|
||||
# scales=[(512, 512)],
|
||||
# allow_flip=False,
|
||||
# resize_cfg=dict(type='Resize', keep_ratio=False),
|
||||
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
||||
# )
|
||||
# tta_module = TRANSFORMS.build(tta_transform)
|
||||
# tta_results = tta_module(results.copy())
|
||||
# assert [tta_results['data_sample'][0].scale] == [(512, 512)]
|
||||
# assert [tta_results['data_sample'][0].flip] == [False]
|
||||
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
scales=[(512, 512)],
|
||||
allow_flip=False,
|
||||
resize_cfg=dict(type='Resize', keep_ratio=False),
|
||||
transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
||||
)
|
||||
tta_module = TRANSFORMS.build(tta_transform)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert [tta_results['data_sample'][0].scale] == [(512, 512)]
|
||||
assert [tta_results['data_sample'][0].flip] == [False]
|
||||
# tta_transform = dict(
|
||||
# type='MultiScaleFlipAug',
|
||||
# scales=[(512, 512)],
|
||||
# allow_flip=True,
|
||||
# resize_cfg=dict(type='Resize', keep_ratio=False),
|
||||
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
||||
# )
|
||||
# tta_module = TRANSFORMS.build(tta_transform)
|
||||
# tta_results = tta_module(results.copy())
|
||||
# assert [data_sample.scale
|
||||
# for data_sample in tta_results['data_sample']] == [(512, 512),
|
||||
# (512, 512)]
|
||||
# assert [data_sample.flip
|
||||
# for data_sample in tta_results['data_sample']] == [False, True]
|
||||
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
scales=[(512, 512)],
|
||||
allow_flip=True,
|
||||
resize_cfg=dict(type='Resize', keep_ratio=False),
|
||||
transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
||||
)
|
||||
tta_module = TRANSFORMS.build(tta_transform)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert [data_sample.scale
|
||||
for data_sample in tta_results['data_sample']] == [(512, 512),
|
||||
(512, 512)]
|
||||
assert [data_sample.flip
|
||||
for data_sample in tta_results['data_sample']] == [False, True]
|
||||
# tta_transform = dict(
|
||||
# type='MultiScaleFlipAug',
|
||||
# scale_factor=[0.5, 1.0, 2.0],
|
||||
# allow_flip=False,
|
||||
# resize_cfg=dict(type='Resize', keep_ratio=False),
|
||||
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
||||
# )
|
||||
# tta_module = TRANSFORMS.build(tta_transform)
|
||||
# tta_results = tta_module(results.copy())
|
||||
# assert [data_sample.scale
|
||||
# for data_sample in tta_results['data_sample']] == [(256, 144),
|
||||
# (512, 288),
|
||||
# (1024, 576)]
|
||||
# assert [data_sample.flip for data_sample in tta_results['data_sample']
|
||||
# ] == [False, False, False]
|
||||
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
scale_factor=[0.5, 1.0, 2.0],
|
||||
allow_flip=False,
|
||||
resize_cfg=dict(type='Resize', keep_ratio=False),
|
||||
transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
||||
)
|
||||
tta_module = TRANSFORMS.build(tta_transform)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert [data_sample.scale
|
||||
for data_sample in tta_results['data_sample']] == [(256, 144),
|
||||
(512, 288),
|
||||
(1024, 576)]
|
||||
assert [data_sample.flip for data_sample in tta_results['data_sample']
|
||||
] == [False, False, False]
|
||||
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
scale_factor=[0.5, 1.0, 2.0],
|
||||
allow_flip=True,
|
||||
resize_cfg=dict(type='Resize', keep_ratio=False),
|
||||
transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
||||
)
|
||||
tta_module = TRANSFORMS.build(tta_transform)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert [data_sample.scale
|
||||
for data_sample in tta_results['data_sample']] == [(256, 144),
|
||||
(256, 144),
|
||||
(512, 288),
|
||||
(512, 288),
|
||||
(1024, 576),
|
||||
(1024, 576)]
|
||||
assert [data_sample.flip for data_sample in tta_results['data_sample']
|
||||
] == [False, True, False, True, False, True]
|
||||
# tta_transform = dict(
|
||||
# type='MultiScaleFlipAug',
|
||||
# scale_factor=[0.5, 1.0, 2.0],
|
||||
# allow_flip=True,
|
||||
# resize_cfg=dict(type='Resize', keep_ratio=False),
|
||||
# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)],
|
||||
# )
|
||||
# tta_module = TRANSFORMS.build(tta_transform)
|
||||
# tta_results = tta_module(results.copy())
|
||||
# assert [data_sample.scale
|
||||
# for data_sample in tta_results['data_sample']] == [(256, 144),
|
||||
# (256, 144),
|
||||
# (512, 288),
|
||||
# (512, 288),
|
||||
# (1024, 576),
|
||||
# (1024, 576)]
|
||||
# assert [data_sample.flip for data_sample in tta_results['data_sample']
|
||||
# ] == [False, True, False, True, False, True]
|
||||
|
@ -30,6 +30,7 @@ class TestVisualizationHook(TestCase):
|
||||
pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
|
||||
pred_sem_seg = PixelData(**pred_sem_seg_data)
|
||||
pred_seg_data_sample = SegDataSample()
|
||||
pred_seg_data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'})
|
||||
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
|
||||
self.outputs = [pred_seg_data_sample] * 2
|
||||
|
||||
|
@ -3,7 +3,7 @@ from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.structures import BaseDataElement, PixelData
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
from mmseg.evaluation import IoUMetric
|
||||
from mmseg.structures import SegDataSample
|
||||
@ -29,72 +29,48 @@ class TestIoUMetric(TestCase):
|
||||
else:
|
||||
image_shapes = [image_shapes] * batch_size
|
||||
|
||||
packed_inputs = []
|
||||
data_samples = []
|
||||
for idx in range(batch_size):
|
||||
image_shape = image_shapes[idx]
|
||||
_, h, w = image_shape
|
||||
|
||||
mm_inputs = dict()
|
||||
data_sample = SegDataSample()
|
||||
gt_semantic_seg = np.random.randint(
|
||||
0, num_classes, (1, h, w), dtype=np.uint8)
|
||||
gt_semantic_seg = torch.LongTensor(gt_semantic_seg)
|
||||
gt_sem_seg_data = dict(data=gt_semantic_seg)
|
||||
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||
mm_inputs['data_sample'] = data_sample.to_dict()
|
||||
packed_inputs.append(mm_inputs)
|
||||
|
||||
return packed_inputs
|
||||
data_samples.append(data_sample.to_dict())
|
||||
|
||||
return data_samples
|
||||
|
||||
def _demo_mm_model_output(self,
|
||||
data_samples,
|
||||
batch_size=2,
|
||||
image_shapes=(3, 64, 64),
|
||||
num_classes=5):
|
||||
"""Create a superset of inputs needed to run test or train batches.
|
||||
|
||||
Args:
|
||||
batch_size (int): batch size. Default to 2.
|
||||
image_shapes (List[tuple], Optional): image shape.
|
||||
Default to (3, 64, 64)
|
||||
num_classes (int): number of different classes.
|
||||
Default to 5.
|
||||
"""
|
||||
results_dict = dict()
|
||||
_, h, w = image_shapes
|
||||
seg_logit = torch.randn(batch_size, num_classes, h, w)
|
||||
results_dict['seg_logits'] = seg_logit
|
||||
seg_pred = np.random.randint(
|
||||
0, num_classes, (batch_size, h, w), dtype=np.uint8)
|
||||
seg_pred = torch.LongTensor(seg_pred)
|
||||
results_dict['pred_sem_seg'] = seg_pred
|
||||
|
||||
batch_datasampes = [
|
||||
SegDataSample()
|
||||
for _ in range(results_dict['pred_sem_seg'].shape[0])
|
||||
]
|
||||
for key, value in results_dict.items():
|
||||
for i in range(value.shape[0]):
|
||||
setattr(batch_datasampes[i], key, PixelData(data=value[i]))
|
||||
|
||||
_predictions = []
|
||||
for pred in batch_datasampes:
|
||||
if isinstance(pred, BaseDataElement):
|
||||
_predictions.append(pred.to_dict())
|
||||
else:
|
||||
_predictions.append(pred)
|
||||
return _predictions
|
||||
for data_sample in data_samples:
|
||||
data_sample['seg_logits'] = dict(
|
||||
data=torch.randn(num_classes, h, w))
|
||||
data_sample['pred_sem_seg'] = dict(
|
||||
data=torch.randint(0, num_classes, (1, h, w)))
|
||||
return data_samples
|
||||
|
||||
def test_evaluate(self):
|
||||
"""Test using the metric in the same way as Evalutor."""
|
||||
|
||||
data_batch = self._demo_mm_inputs()
|
||||
predictions = self._demo_mm_model_output()
|
||||
data_samples = self._demo_mm_inputs()
|
||||
data_samples = self._demo_mm_model_output(data_samples)
|
||||
|
||||
iou_metric = IoUMetric(iou_metrics=['mIoU'])
|
||||
iou_metric.dataset_meta = dict(
|
||||
classes=['wall', 'building', 'sky', 'floor', 'tree'],
|
||||
label_map=dict(),
|
||||
reduce_zero_label=False)
|
||||
iou_metric.process(data_batch, predictions)
|
||||
iou_metric.process([0] * len(data_samples), data_samples)
|
||||
res = iou_metric.evaluate(6)
|
||||
self.assertIsInstance(res, dict)
|
||||
|
@ -6,8 +6,11 @@ from mmcv.cnn import ConvModule
|
||||
from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule,
|
||||
InterpConv, UNet, UpConvBlock)
|
||||
from mmseg.models.utils import Upsample
|
||||
from mmseg.utils import register_all_modules
|
||||
from .utils import check_norm_state
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
def test_unet_basic_conv_block():
|
||||
with pytest.raises(AssertionError):
|
||||
|
@ -3,10 +3,10 @@ from unittest import TestCase
|
||||
|
||||
from mmseg.models import SegDataPreProcessor
|
||||
|
||||
# import torch
|
||||
# from mmengine.structures import PixelData
|
||||
import torch
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
# from mmseg.structures import SegDataSample
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
|
||||
class TestSegDataPreProcessor(TestCase):
|
||||
@ -31,16 +31,19 @@ class TestSegDataPreProcessor(TestCase):
|
||||
with self.assertRaises(AssertionError):
|
||||
SegDataPreProcessor(bgr_to_rgb=True, rgb_to_bgr=True)
|
||||
|
||||
# def test_forward(self):
|
||||
# data_sample = SegDataSample()
|
||||
# data_sample.gt_sem_seg = PixelData(
|
||||
# **{'data': torch.randint(0, 10, (1, 11, 10))})
|
||||
# processor = SegDataPreProcessor(
|
||||
# mean=[0, 0, 0], std=[1, 1, 1], size=(20, 20))
|
||||
# data = {
|
||||
# 'inputs': torch.randint(0, 256, (3, 11, 10)),
|
||||
# 'data_sample': data_sample
|
||||
# }
|
||||
# inputs, data_samples = processor([data, data], training=True)
|
||||
# self.assertEqual(inputs.shape, (2, 3, 20, 20))
|
||||
# self.assertEqual(len(data_samples), 2)
|
||||
def test_forward(self):
|
||||
data_sample = SegDataSample()
|
||||
data_sample.gt_sem_seg = PixelData(
|
||||
**{'data': torch.randint(0, 10, (1, 11, 10))})
|
||||
processor = SegDataPreProcessor(
|
||||
mean=[0, 0, 0], std=[1, 1, 1], size=(20, 20))
|
||||
data = {
|
||||
'inputs': [
|
||||
torch.randint(0, 256, (3, 11, 10)),
|
||||
torch.randint(0, 256, (3, 11, 10))
|
||||
],
|
||||
'data_samples': [data_sample, data_sample]
|
||||
}
|
||||
out = processor(data, training=True)
|
||||
self.assertEqual(out['inputs'].shape, (2, 3, 20, 20))
|
||||
self.assertEqual(len(out['data_samples']), 2)
|
||||
|
@ -34,14 +34,14 @@ def _demo_mm_inputs(batch_size=2, image_shapes=(3, 32, 32), num_classes=5):
|
||||
else:
|
||||
image_shapes = [image_shapes] * batch_size
|
||||
|
||||
packed_inputs = []
|
||||
inputs = []
|
||||
data_samples = []
|
||||
for idx in range(batch_size):
|
||||
image_shape = image_shapes[idx]
|
||||
c, h, w = image_shape
|
||||
image = np.random.randint(0, 255, size=image_shape, dtype=np.uint8)
|
||||
|
||||
mm_inputs = dict()
|
||||
mm_inputs['inputs'] = torch.from_numpy(image)
|
||||
mm_input = torch.from_numpy(image)
|
||||
|
||||
img_meta = {
|
||||
'img_id': idx,
|
||||
@ -62,10 +62,9 @@ def _demo_mm_inputs(batch_size=2, image_shapes=(3, 32, 32), num_classes=5):
|
||||
gt_semantic_seg = torch.LongTensor(gt_semantic_seg)
|
||||
gt_sem_seg_data = dict(data=gt_semantic_seg)
|
||||
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||
mm_inputs['data_sample'] = data_sample
|
||||
packed_inputs.append(mm_inputs)
|
||||
|
||||
return packed_inputs
|
||||
inputs.append(mm_input)
|
||||
data_samples.append(data_sample)
|
||||
return dict(inputs=inputs, data_samples=data_samples)
|
||||
|
||||
|
||||
def _get_config_directory():
|
||||
@ -226,27 +225,36 @@ def _test_encoder_decoder_forward(cfg_file):
|
||||
segmentor = revert_sync_batchnorm(segmentor)
|
||||
|
||||
# Test forward train
|
||||
batch_inputs, data_samples = segmentor.data_preprocessor(
|
||||
packed_inputs, True)
|
||||
losses = segmentor.forward(batch_inputs, data_samples, mode='loss')
|
||||
data = segmentor.data_preprocessor(packed_inputs, True)
|
||||
losses = segmentor.forward(**data, mode='loss')
|
||||
assert isinstance(losses, dict)
|
||||
|
||||
packed_inputs = _demo_mm_inputs(
|
||||
batch_size=1, image_shapes=(3, 32, 32), num_classes=num_classes)
|
||||
batch_inputs, data_samples = segmentor.data_preprocessor(
|
||||
packed_inputs, False)
|
||||
data = segmentor.data_preprocessor(packed_inputs, False)
|
||||
with torch.no_grad():
|
||||
segmentor.eval()
|
||||
# Test forward predict
|
||||
batch_results = segmentor.forward(
|
||||
batch_inputs, data_samples, mode='predict')
|
||||
batch_results = segmentor.forward(**data, mode='predict')
|
||||
assert len(batch_results) == 1
|
||||
assert is_list_of(batch_results, SegDataSample)
|
||||
assert batch_results[0].pred_sem_seg.shape == (32, 32)
|
||||
assert batch_results[0].seg_logits.data.shape == (num_classes, 32, 32)
|
||||
assert batch_results[0].gt_sem_seg.shape == (32, 32)
|
||||
|
||||
# Test forward tensor
|
||||
batch_results = segmentor.forward(
|
||||
batch_inputs, data_samples, mode='tensor')
|
||||
batch_results = segmentor.forward(**data, mode='tensor')
|
||||
assert isinstance(batch_results, Tensor) or is_tuple_of(
|
||||
batch_results, Tensor)
|
||||
|
||||
# Test forward predict without ground truth
|
||||
data.pop('data_samples')
|
||||
batch_results = segmentor.forward(**data, mode='predict')
|
||||
assert len(batch_results) == 1
|
||||
assert is_list_of(batch_results, SegDataSample)
|
||||
assert batch_results[0].pred_sem_seg.shape == (32, 32)
|
||||
|
||||
# Test forward tensor without ground truth
|
||||
batch_results = segmentor.forward(**data, mode='tensor')
|
||||
assert isinstance(batch_results, Tensor) or is_tuple_of(
|
||||
batch_results, Tensor)
|
||||
|
@ -29,8 +29,8 @@ class TestSegLocalVisualizer(TestCase):
|
||||
gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||
|
||||
def test_add_datasample_forward(gt_sem_seg):
|
||||
gt_seg_data_sample = SegDataSample()
|
||||
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
|
||||
data_sample = SegDataSample()
|
||||
data_sample.gt_sem_seg = gt_sem_seg
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
seg_local_visualizer = SegLocalVisualizer(
|
||||
@ -42,7 +42,7 @@ class TestSegLocalVisualizer(TestCase):
|
||||
|
||||
# test out_file
|
||||
seg_local_visualizer.add_datasample(out_file, image,
|
||||
gt_seg_data_sample)
|
||||
data_sample)
|
||||
|
||||
assert os.path.exists(
|
||||
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||
@ -57,22 +57,16 @@ class TestSegLocalVisualizer(TestCase):
|
||||
data=torch.randint(0, num_class, (1, h, w)))
|
||||
pred_sem_seg = PixelData(**pred_sem_seg_data)
|
||||
|
||||
pred_seg_data_sample = SegDataSample()
|
||||
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
|
||||
data_sample.pred_sem_seg = pred_sem_seg
|
||||
|
||||
seg_local_visualizer.add_datasample(out_file, image,
|
||||
gt_seg_data_sample,
|
||||
pred_seg_data_sample)
|
||||
data_sample)
|
||||
self._assert_image_and_shape(
|
||||
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||
out_file + '_0.png'), (h, w * 2, 3))
|
||||
|
||||
seg_local_visualizer.add_datasample(
|
||||
out_file,
|
||||
image,
|
||||
gt_seg_data_sample,
|
||||
pred_seg_data_sample,
|
||||
draw_gt=False)
|
||||
out_file, image, data_sample, draw_gt=False)
|
||||
self._assert_image_and_shape(
|
||||
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||
out_file + '_0.png'), (h, w, 3))
|
||||
@ -104,8 +98,8 @@ class TestSegLocalVisualizer(TestCase):
|
||||
gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||
|
||||
def test_cityscapes_add_datasample_forward(gt_sem_seg):
|
||||
gt_seg_data_sample = SegDataSample()
|
||||
gt_seg_data_sample.gt_sem_seg = gt_sem_seg
|
||||
data_sample = SegDataSample()
|
||||
data_sample.gt_sem_seg = gt_sem_seg
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
seg_local_visualizer = SegLocalVisualizer(
|
||||
@ -125,11 +119,11 @@ class TestSegLocalVisualizer(TestCase):
|
||||
[0, 60, 100], [0, 80, 100], [0, 0, 230],
|
||||
[119, 11, 32]])
|
||||
seg_local_visualizer.add_datasample(out_file, image,
|
||||
gt_seg_data_sample)
|
||||
data_sample)
|
||||
|
||||
# test out_file
|
||||
seg_local_visualizer.add_datasample(out_file, image,
|
||||
gt_seg_data_sample)
|
||||
data_sample)
|
||||
assert os.path.exists(
|
||||
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||
out_file + '_0.png'))
|
||||
@ -143,22 +137,16 @@ class TestSegLocalVisualizer(TestCase):
|
||||
data=torch.randint(0, num_class, (1, h, w)))
|
||||
pred_sem_seg = PixelData(**pred_sem_seg_data)
|
||||
|
||||
pred_seg_data_sample = SegDataSample()
|
||||
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
|
||||
data_sample.pred_sem_seg = pred_sem_seg
|
||||
|
||||
seg_local_visualizer.add_datasample(out_file, image,
|
||||
gt_seg_data_sample,
|
||||
pred_seg_data_sample)
|
||||
data_sample)
|
||||
self._assert_image_and_shape(
|
||||
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||
out_file + '_0.png'), (h, w * 2, 3))
|
||||
|
||||
seg_local_visualizer.add_datasample(
|
||||
out_file,
|
||||
image,
|
||||
gt_seg_data_sample,
|
||||
pred_seg_data_sample,
|
||||
draw_gt=False)
|
||||
out_file, image, data_sample, draw_gt=False)
|
||||
self._assert_image_and_shape(
|
||||
osp.join(tmp_dir, 'vis_data', 'vis_image',
|
||||
out_file + '_0.png'), (h, w, 3))
|
||||
|
@ -79,14 +79,15 @@ def main():
|
||||
|
||||
# benchmark with 200 batches and take the average
|
||||
for i, data in enumerate(data_loader):
|
||||
batch_inputs, data_samples = model.data_preprocessor(data, True)
|
||||
|
||||
data = model.data_preprocessor(data, True)
|
||||
inputs = data['inputs']
|
||||
data_samples = data['data_samples']
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
with torch.no_grad():
|
||||
model(batch_inputs, data_samples, mode='predict')
|
||||
model(inputs, data_samples, mode='predict')
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
Loading…
x
Reference in New Issue
Block a user