2022-09-18 10:47:44 +08:00
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
from mmdet.models.backbones.res2net import Bottle2neck
|
|
|
|
from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
|
|
|
|
from mmdet.models.backbones.resnext import Bottleneck as BottleneckX
|
|
|
|
from mmdet.models.layers import SimplifiedBasicBlock
|
2022-09-20 10:57:33 +08:00
|
|
|
from torch.nn.modules import GroupNorm
|
|
|
|
from torch.nn.modules.batchnorm import _BatchNorm
|
2022-09-18 10:47:44 +08:00
|
|
|
|
|
|
|
|
|
|
|
def is_block(modules):
|
|
|
|
"""Check if is ResNet building block."""
|
|
|
|
if isinstance(modules, (BasicBlock, Bottleneck, BottleneckX, Bottle2neck,
|
|
|
|
SimplifiedBasicBlock)):
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def is_norm(modules):
|
|
|
|
"""Check if is one of the norms."""
|
|
|
|
if isinstance(modules, (GroupNorm, _BatchNorm)):
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def check_norm_state(modules, train_state):
|
|
|
|
"""Check if norm layer is in correct train state."""
|
|
|
|
for mod in modules:
|
|
|
|
if isinstance(mod, _BatchNorm):
|
|
|
|
if mod.training != train_state:
|
|
|
|
return False
|
|
|
|
return True
|