# Copyright (c) OpenMMLab. All rights reserved. import torch from torch.nn.modules import GroupNorm from torch.nn.modules.batchnorm import _BatchNorm from mmseg.models.backbones.resnet import BasicBlock, Bottleneck from mmseg.models.backbones.resnext import Bottleneck as BottleneckX def is_block(modules): """Check if is ResNet building block.""" if isinstance(modules, (BasicBlock, Bottleneck, BottleneckX)): return True return False def is_norm(modules): """Check if is one of the norms.""" if isinstance(modules, (GroupNorm, _BatchNorm)): return True return False def all_zeros(modules): """Check if the weight(and bias) is all zero.""" weight_zero = torch.allclose(modules.weight.data, torch.zeros_like(modules.weight.data)) if hasattr(modules, 'bias'): bias_zero = torch.allclose(modules.bias.data, torch.zeros_like(modules.bias.data)) else: bias_zero = True return weight_zero and bias_zero def check_norm_state(modules, train_state): """Check if norm layer is in correct train state.""" for mod in modules: if isinstance(mod, _BatchNorm): if mod.training != train_state: return False return True