# Copyright (c) OpenMMLab. All rights reserved. import torch from mmengine.optim import OptimWrapper from mmengine.structures import PixelData from torch import nn from torch.optim import SGD from mmseg.models import SegDataPreProcessor from mmseg.models.decode_heads.cascade_decode_head import BaseCascadeDecodeHead from mmseg.models.decode_heads.decode_head import BaseDecodeHead from mmseg.registry import MODELS from mmseg.structures import SegDataSample def _demo_mm_inputs(input_shape=(1, 3, 8, 16), num_classes=10): """Create a superset of inputs needed to run test or train batches. Args: input_shape (tuple): input batch dimensions num_classes (int): number of semantic classes """ (N, C, H, W) = input_shape imgs = torch.randn(*input_shape) segs = torch.randint( low=0, high=num_classes - 1, size=(N, H, W), dtype=torch.long) img_metas = [{ 'img_shape': (H, W), 'ori_shape': (H, W), 'pad_shape': (H, W, C), 'filename': '.png', 'scale_factor': 1.0, 'flip': False, 'flip_direction': 'horizontal' } for _ in range(N)] data_samples = [ SegDataSample( gt_sem_seg=PixelData(data=segs[i]), metainfo=img_metas[i]) for i in range(N) ] mm_inputs = {'imgs': torch.FloatTensor(imgs), 'data_samples': data_samples} return mm_inputs @MODELS.register_module() class ExampleBackbone(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 3, 3) def init_weights(self, pretrained=None): pass def forward(self, x): return [self.conv(x)] @MODELS.register_module() class ExampleDecodeHead(BaseDecodeHead): def __init__(self, num_classes=19, out_channels=None): super().__init__( 3, 3, num_classes=num_classes, out_channels=out_channels) def forward(self, inputs): return self.cls_seg(inputs[0]) @MODELS.register_module() class ExampleCascadeDecodeHead(BaseCascadeDecodeHead): def __init__(self): super().__init__(3, 3, num_classes=19) def forward(self, inputs, prev_out): return self.cls_seg(inputs[0]) def _segmentor_forward_train_test(segmentor): if isinstance(segmentor.decode_head, nn.ModuleList): num_classes = segmentor.decode_head[-1].num_classes else: num_classes = segmentor.decode_head.num_classes # batch_size=2 for BatchNorm mm_inputs = _demo_mm_inputs(num_classes=num_classes) # convert to cuda Tensor if applicable if torch.cuda.is_available(): segmentor = segmentor.cuda() # check data preprocessor if not hasattr(segmentor, 'data_preprocessor') or segmentor.data_preprocessor is None: segmentor.data_preprocessor = SegDataPreProcessor() mm_inputs = segmentor.data_preprocessor(mm_inputs, True) imgs = mm_inputs.pop('imgs') data_samples = mm_inputs.pop('data_samples') # create optimizer wrapper optimizer = SGD(segmentor.parameters(), lr=0.1) optim_wrapper = OptimWrapper(optimizer) # Test forward train losses = segmentor.forward(imgs, data_samples, mode='loss') assert isinstance(losses, dict) # Test train_step data_batch = dict(inputs=imgs, data_samples=data_samples) outputs = segmentor.train_step(data_batch, optim_wrapper) assert isinstance(outputs, dict) assert 'loss' in outputs # Test val_step with torch.no_grad(): segmentor.eval() data_batch = dict(inputs=imgs, data_samples=data_samples) outputs = segmentor.val_step(data_batch) assert isinstance(outputs, list) # Test forward simple test with torch.no_grad(): segmentor.eval() data_batch = dict(inputs=imgs, data_samples=data_samples) results = segmentor.forward(imgs, data_samples, mode='tensor') assert isinstance(results, torch.Tensor)