# Copyright (c) OpenMMLab. All rights reserved. import pytest import torch from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm from mmseg.models.backbones.hrnet import HRModule, HRNet from mmseg.models.backbones.resnet import BasicBlock, Bottleneck @pytest.mark.parametrize('block', [BasicBlock, Bottleneck]) def test_hrmodule(block): # Test multiscale forward num_channles = (32, 64) in_channels = [c * block.expansion for c in num_channles] hrmodule = HRModule( num_branches=2, blocks=block, in_channels=in_channels, num_blocks=(4, 4), num_channels=num_channles, ) feats = [ torch.randn(1, in_channels[0], 64, 64), torch.randn(1, in_channels[1], 32, 32) ] feats = hrmodule(feats) assert len(feats) == 2 assert feats[0].shape == torch.Size([1, in_channels[0], 64, 64]) assert feats[1].shape == torch.Size([1, in_channels[1], 32, 32]) # Test single scale forward num_channles = (32, 64) in_channels = [c * block.expansion for c in num_channles] hrmodule = HRModule( num_branches=2, blocks=block, in_channels=in_channels, num_blocks=(4, 4), num_channels=num_channles, multiscale_output=False, ) feats = [ torch.randn(1, in_channels[0], 64, 64), torch.randn(1, in_channels[1], 32, 32) ] feats = hrmodule(feats) assert len(feats) == 1 assert feats[0].shape == torch.Size([1, in_channels[0], 64, 64]) def test_hrnet_backbone(): # only have 3 stages extra = dict( stage1=dict( num_modules=1, num_branches=1, block='BOTTLENECK', num_blocks=(4, ), num_channels=(64, )), stage2=dict( num_modules=1, num_branches=2, block='BASIC', num_blocks=(4, 4), num_channels=(32, 64)), stage3=dict( num_modules=4, num_branches=3, block='BASIC', num_blocks=(4, 4, 4), num_channels=(32, 64, 128))) with pytest.raises(AssertionError): # HRNet now only support 4 stages HRNet(extra=extra) extra['stage4'] = dict( num_modules=3, num_branches=3, # should be 4 block='BASIC', num_blocks=(4, 4, 4, 4), num_channels=(32, 64, 128, 256)) with pytest.raises(AssertionError): # len(num_blocks) should equal num_branches HRNet(extra=extra) extra['stage4']['num_branches'] = 4 # Test hrnetv2p_w32 model = HRNet(extra=extra) model.init_weights() model.train() imgs = torch.randn(1, 3, 64, 64) feats = model(imgs) assert len(feats) == 4 assert feats[0].shape == torch.Size([1, 32, 16, 16]) assert feats[3].shape == torch.Size([1, 256, 2, 2]) # Test single scale output model = HRNet(extra=extra, multiscale_output=False) model.init_weights() model.train() imgs = torch.randn(1, 3, 64, 64) feats = model(imgs) assert len(feats) == 1 assert feats[0].shape == torch.Size([1, 32, 16, 16]) # Test HRNET with two stage frozen frozen_stages = 2 model = HRNet(extra, frozen_stages=frozen_stages) model.init_weights() model.train() assert model.norm1.training is False for layer in [model.conv1, model.norm1]: for param in layer.parameters(): assert param.requires_grad is False for i in range(1, frozen_stages + 1): if i == 1: layer = getattr(model, f'layer{i}') transition = getattr(model, f'transition{i}') elif i == 4: layer = getattr(model, f'stage{i}') else: layer = getattr(model, f'stage{i}') transition = getattr(model, f'transition{i}') for mod in layer.modules(): if isinstance(mod, _BatchNorm): assert mod.training is False for param in layer.parameters(): assert param.requires_grad is False for mod in transition.modules(): if isinstance(mod, _BatchNorm): assert mod.training is False for param in transition.parameters(): assert param.requires_grad is False