diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index a2d303ff..0ae0e1a1 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -28,7 +28,7 @@ test: - pip install pillow==6.2.2 - pip install -e . - python -c "import mmcls; print(mmcls.__version__)" - # - echo "Start testing..." - # - pip install pytest coverage - # - coverage run --source mmcls -m pytest tests/ - # - coverage report -m + - echo "Start testing..." + - pip install pytest coverage + - coverage run --source mmcls -m pytest tests/ + - coverage report -m diff --git a/mmcls/models/__init__.py b/mmcls/models/__init__.py index eab103a0..6dd17cda 100644 --- a/mmcls/models/__init__.py +++ b/mmcls/models/__init__.py @@ -1,4 +1,4 @@ -from .builder import build_model -from .registry import MODELS +from .backbones import * # noqa: F401,F403 +from .builder import BACKBONES, MODELS, build_backbone, build_model -__all__ = ['build_model', 'MODELS'] +__all__ = ['BACKBONES', 'MODELS', 'build_backbone', 'build_model'] diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py index e69de29b..5a34fb2a 100644 --- a/mmcls/models/backbones/__init__.py +++ b/mmcls/models/backbones/__init__.py @@ -0,0 +1,3 @@ +from .resnet import ResNet, ResNetV1d + +__all__ = ['ResNet', 'ResNetV1d'] diff --git a/mmcls/models/backbones/base_backbone.py b/mmcls/models/backbones/base_backbone.py new file mode 100644 index 00000000..1e163862 --- /dev/null +++ b/mmcls/models/backbones/base_backbone.py @@ -0,0 +1,55 @@ +import logging +from abc import ABCMeta, abstractmethod + +import torch.nn as nn +from mmcv.runner import load_checkpoint + + +class BaseBackbone(nn.Module, metaclass=ABCMeta): + """Base backbone. + + This class defines the basic functions of a backbone. + Any backbone that inherits this class should at least + define its own `forward` function. + + """ + + def __init__(self): + super(BaseBackbone, self).__init__() + + def init_weights(self, pretrained=None): + """Init backbone weights + + Args: + pretrained (str | None): If pretrained is a string, then it + initializes backbone weights by loading the pretrained + checkpoint. If pretrained is None, then it follows default + initializer or customized initializer in subclasses. + """ + if isinstance(pretrained, str): + logger = logging.getLogger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + # use default initializer or customized initializer in subclasses + pass + else: + raise TypeError('pretrained must be a str or None.' + f' But received {type(pretrained)}.') + + @abstractmethod + def forward(self, x): + """Forward computation + + Args: + x (tensor | tuple[tensor]): x could be a Torch.tensor or a tuple of + Torch.tensor, containing input data for forward computation. + """ + pass + + def train(self, mode=True): + """Set module status before forward computation + + Args: + mode (bool): Whether it is train_mode or test_mode + """ + super(BaseBackbone, self).train(mode) diff --git a/mmcls/models/backbones/resnet.py b/mmcls/models/backbones/resnet.py new file mode 100644 index 00000000..d9b62156 --- /dev/null +++ b/mmcls/models/backbones/resnet.py @@ -0,0 +1,532 @@ +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init, + kaiming_init) +from torch.nn.modules.batchnorm import _BatchNorm + +from ..builder import BACKBONES +from .base_backbone import BaseBackbone + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN')): + super(BasicBlock, self).__init__() + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, planes, planes, 3, padding=1, bias=False) + self.add_module(self.norm2_name, norm2) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.with_cp = with_cp + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + @property + def norm2(self): + return getattr(self, self.norm2_name) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN')): + """Bottleneck block for ResNet. + If style is "pytorch", the stride-two layer is the 3x3 conv layer, + if it is "caffe", the stride-two layer is the first 1x1 conv layer. + """ + super(Bottleneck, self).__init__() + assert style in ['pytorch', 'caffe'] + + self.inplanes = inplanes + self.planes = planes + self.stride = stride + self.dilation = dilation + self.style = style + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + + if self.style == 'pytorch': + self.conv1_stride = 1 + self.conv2_stride = stride + else: + self.conv1_stride = stride + self.conv2_stride = 1 + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + norm_cfg, planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + conv_cfg, + planes, + planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + @property + def norm2(self): + return getattr(self, self.norm2_name) + + @property + def norm3(self): + return getattr(self, self.norm3_name) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.norm3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class ResLayer(nn.Sequential): + """ResLayer to build ResNet style backbone. + + Args: + block (nn.Module): block used to build ResLayer. + inplanes (int): inplanes of block. + planes (int): planes of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + """ + + def __init__(self, + block, + inplanes, + planes, + num_blocks, + stride=1, + avg_down=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + **kwargs): + self.block = block + + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = [] + conv_stride = stride + if avg_down and stride != 1: + conv_stride = 1 + downsample.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False)) + downsample.extend([ + build_conv_layer( + conv_cfg, + inplanes, + planes * block.expansion, + kernel_size=1, + stride=conv_stride, + bias=False), + build_norm_layer(norm_cfg, planes * block.expansion)[1] + ]) + downsample = nn.Sequential(*downsample) + + layers = [] + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=stride, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + inplanes = planes * block.expansion + for i in range(1, num_blocks): + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + super(ResLayer, self).__init__(*layers) + + +@BACKBONES.register_module() +class ResNet(BaseBackbone): + """ResNet backbone. + + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Normally 3. + base_channels (int): Number of base channels of hidden layer. + num_stages (int): Resnet stages, normally 4. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + norm_cfg (dict): Dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. + + Example: + >>> from mmcls.models import ResNet + >>> import torch + >>> self = ResNet(depth=18) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 64, 8, 8) + (1, 128, 4, 4) + (1, 256, 2, 2) + (1, 512, 1, 1) + """ + + arch_settings = { + 18: (BasicBlock, (2, 2, 2, 2)), + 34: (BasicBlock, (3, 4, 6, 3)), + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + depth, + in_channels=3, + base_channels=64, + num_stages=4, + strides=(1, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(3, ), + style='pytorch', + deep_stem=False, + avg_down=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + with_cp=False, + zero_init_residual=True): + super(ResNet, self).__init__() + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for resnet') + self.depth = depth + self.base_channels = base_channels + self.num_stages = num_stages + assert num_stages >= 1 and num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + self.avg_down = avg_down + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.zero_init_residual = zero_init_residual + self.block, stage_blocks = self.arch_settings[depth] + self.stage_blocks = stage_blocks[:num_stages] + self.inplanes = base_channels + + self._make_stem_layer(in_channels, base_channels) + + self.res_layers = [] + for i, num_blocks in enumerate(self.stage_blocks): + stride = strides[i] + dilation = dilations[i] + planes = base_channels * 2**i + res_layer = self.make_res_layer( + block=self.block, + inplanes=self.inplanes, + planes=planes, + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg) + self.inplanes = planes * self.block.expansion + layer_name = f'layer{i + 1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = self.block.expansion * base_channels * 2**( + len(self.stage_blocks) - 1) + + def make_res_layer(self, **kwargs): + return ResLayer(**kwargs) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + def _make_stem_layer(self, in_channels, base_channels): + if self.deep_stem: + self.stem = nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + base_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, base_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + base_channels // 2, + base_channels // 2, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, base_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + base_channels // 2, + base_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, base_channels)[1], + nn.ReLU(inplace=True)) + else: + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + base_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, base_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + if self.deep_stem: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + else: + self.norm1.eval() + for m in [self.conv1, self.norm1]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = getattr(self, f'layer{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + super(ResNet, self).init_weights(pretrained) + if pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + + if self.zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + constant_init(m.norm3, 0) + elif isinstance(m, BasicBlock): + constant_init(m.norm2, 0) + + def forward(self, x): + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.maxpool(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + if len(outs) == 1: + return outs[0] + else: + return tuple(outs) + + def train(self, mode=True): + super(ResNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + +@BACKBONES.register_module() +class ResNetV1d(ResNet): + """ResNetV1d variant described in + `Bag of Tricks `_. + + Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv + in the input stem with three 3x3 convs. And in the downsampling block, + a 2x2 avg_pool with stride 2 is added before conv, whose stride is + changed to 1. + """ + + def __init__(self, **kwargs): + super(ResNetV1d, self).__init__( + deep_stem=True, avg_down=True, **kwargs) diff --git a/mmcls/models/builder.py b/mmcls/models/builder.py index 34955208..51e5db06 100644 --- a/mmcls/models/builder.py +++ b/mmcls/models/builder.py @@ -1,6 +1,7 @@ import torch.nn as nn from mmcv.utils import Registry, build_from_cfg +BACKBONES = Registry('backbone') MODELS = Registry('model') @@ -14,5 +15,9 @@ def build(cfg, registry, default_args=None): return build_from_cfg(cfg, registry, default_args) +def build_backbone(cfg): + return build(cfg, BACKBONES) + + def build_model(cfg, train_cfg=None, test_cfg=None): return build(cfg, MODELS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) diff --git a/tests/test_backbones/test_resnet.py b/tests/test_backbones/test_resnet.py new file mode 100644 index 00000000..b2d51001 --- /dev/null +++ b/tests/test_backbones/test_resnet.py @@ -0,0 +1,308 @@ +import pytest +import torch +from torch.nn.modules import AvgPool2d +from torch.nn.modules.batchnorm import _BatchNorm + +from mmcls.models.backbones import ResNet, ResNetV1d +from mmcls.models.backbones.resnet import BasicBlock, Bottleneck, ResLayer + + +def is_block(modules): + """Check if is ResNet building block.""" + if isinstance(modules, (BasicBlock, Bottleneck)): + return True + return False + + +def is_norm(modules): + """Check if is one of the norms.""" + if isinstance(modules, (_BatchNorm, )): + return True + return False + + +def all_zeros(modules): + """Check if the weight(and bias) is all zero.""" + weight_zero = torch.equal(modules.weight.data, + torch.zeros_like(modules.weight.data)) + if hasattr(modules, 'bias'): + bias_zero = torch.equal(modules.bias.data, + torch.zeros_like(modules.bias.data)) + else: + bias_zero = True + + return weight_zero and bias_zero + + +def check_norm_state(modules, train_state): + """Check if norm layer is in correct train state.""" + for mod in modules: + if isinstance(mod, _BatchNorm): + if mod.training != train_state: + return False + return True + + +def test_resnet_basic_block(): + # Test BasicBlock structure and forward + block = BasicBlock(64, 64) + assert block.conv1.in_channels == 64 + assert block.conv1.out_channels == 64 + assert block.conv1.kernel_size == (3, 3) + assert block.conv2.in_channels == 64 + assert block.conv2.out_channels == 64 + assert block.conv2.kernel_size == (3, 3) + x = torch.randn(1, 64, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size([1, 64, 56, 56]) + + # Test BasicBlock with checkpoint forward + block = BasicBlock(64, 64, with_cp=True) + assert block.with_cp + x = torch.randn(1, 64, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size([1, 64, 56, 56]) + + +def test_resnet_bottleneck(): + + with pytest.raises(AssertionError): + # Style must be in ['pytorch', 'caffe'] + Bottleneck(64, 64, style='tensorflow') + + # Test Bottleneck with checkpoint forward + block = Bottleneck(64, 16, with_cp=True) + assert block.with_cp + x = torch.randn(1, 64, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size([1, 64, 56, 56]) + + # Test Bottleneck style + block = Bottleneck(64, 64, stride=2, style='pytorch') + assert block.conv1.stride == (1, 1) + assert block.conv2.stride == (2, 2) + block = Bottleneck(64, 64, stride=2, style='caffe') + assert block.conv1.stride == (2, 2) + assert block.conv2.stride == (1, 1) + + # Test Bottleneck forward + block = Bottleneck(64, 16) + x = torch.randn(1, 64, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size([1, 64, 56, 56]) + + +def test_resnet_res_layer(): + # Test ResLayer of 3 Bottleneck w\o downsample + layer = ResLayer(Bottleneck, 64, 16, 3) + assert len(layer) == 3 + assert layer[0].conv1.in_channels == 64 + assert layer[0].conv1.out_channels == 16 + for i in range(1, len(layer)): + assert layer[i].conv1.in_channels == 64 + assert layer[i].conv1.out_channels == 16 + for i in range(len(layer)): + assert layer[i].downsample is None + x = torch.randn(1, 64, 56, 56) + x_out = layer(x) + assert x_out.shape == torch.Size([1, 64, 56, 56]) + + # Test ResLayer of 3 Bottleneck with downsample + layer = ResLayer(Bottleneck, 64, 64, 3) + assert layer[0].downsample[0].out_channels == 256 + for i in range(1, len(layer)): + assert layer[i].downsample is None + x = torch.randn(1, 64, 56, 56) + x_out = layer(x) + assert x_out.shape == torch.Size([1, 256, 56, 56]) + + # Test ResLayer of 3 Bottleneck with stride=2 + layer = ResLayer(Bottleneck, 64, 64, 3, stride=2) + assert layer[0].downsample[0].out_channels == 256 + assert layer[0].downsample[0].stride == (2, 2) + for i in range(1, len(layer)): + assert layer[i].downsample is None + x = torch.randn(1, 64, 56, 56) + x_out = layer(x) + assert x_out.shape == torch.Size([1, 256, 28, 28]) + + # Test ResLayer of 3 Bottleneck with stride=2 and average downsample + layer = ResLayer(Bottleneck, 64, 64, 3, stride=2, avg_down=True) + assert isinstance(layer[0].downsample[0], AvgPool2d) + assert layer[0].downsample[1].out_channels == 256 + assert layer[0].downsample[1].stride == (1, 1) + for i in range(1, len(layer)): + assert layer[i].downsample is None + x = torch.randn(1, 64, 56, 56) + x_out = layer(x) + assert x_out.shape == torch.Size([1, 256, 28, 28]) + + +def test_resnet_backbone(): + """Test resnet backbone""" + with pytest.raises(KeyError): + # ResNet depth should be in [18, 34, 50, 101, 152] + ResNet(20) + + with pytest.raises(AssertionError): + # In ResNet: 1 <= num_stages <= 4 + ResNet(50, num_stages=0) + + with pytest.raises(AssertionError): + # In ResNet: 1 <= num_stages <= 4 + ResNet(50, num_stages=5) + + with pytest.raises(AssertionError): + # len(strides) == len(dilations) == num_stages + ResNet(50, strides=(1, ), dilations=(1, 1), num_stages=3) + + with pytest.raises(TypeError): + # pretrained must be a string path + model = ResNet(50) + model.init_weights(pretrained=0) + + with pytest.raises(AssertionError): + # Style must be in ['pytorch', 'caffe'] + ResNet(50, style='tensorflow') + + # Test ResNet50 norm_eval=True + model = ResNet(50, norm_eval=True) + model.init_weights() + model.train() + assert check_norm_state(model.modules(), False) + + # Test ResNet50 with torchvision pretrained weight + model = ResNet(depth=50, norm_eval=True) + model.init_weights('torchvision://resnet50') + model.train() + assert check_norm_state(model.modules(), False) + + # Test ResNet50 with first stage frozen + frozen_stages = 1 + model = ResNet(50, frozen_stages=frozen_stages) + model.init_weights() + model.train() + assert model.norm1.training is False + for layer in [model.conv1, model.norm1]: + for param in layer.parameters(): + assert param.requires_grad is False + for i in range(1, frozen_stages + 1): + layer = getattr(model, f'layer{i}') + for mod in layer.modules(): + if isinstance(mod, _BatchNorm): + assert mod.training is False + for param in layer.parameters(): + assert param.requires_grad is False + + # Test ResNet50V1d with first stage frozen + model = ResNetV1d(depth=50, frozen_stages=frozen_stages) + assert len(model.stem) == 9 + model.init_weights() + model.train() + check_norm_state(model.stem, False) + for param in model.stem.parameters(): + assert param.requires_grad is False + for i in range(1, frozen_stages + 1): + layer = getattr(model, f'layer{i}') + for mod in layer.modules(): + if isinstance(mod, _BatchNorm): + assert mod.training is False + for param in layer.parameters(): + assert param.requires_grad is False + + # Test ResNet18 forward + model = ResNet(18, out_indices=(0, 1, 2, 3)) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == torch.Size([1, 64, 56, 56]) + assert feat[1].shape == torch.Size([1, 128, 28, 28]) + assert feat[2].shape == torch.Size([1, 256, 14, 14]) + assert feat[3].shape == torch.Size([1, 512, 7, 7]) + + # Test ResNet50 with BatchNorm forward + model = ResNet(50, out_indices=(0, 1, 2, 3)) + for m in model.modules(): + if is_norm(m): + assert isinstance(m, _BatchNorm) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == torch.Size([1, 256, 56, 56]) + assert feat[1].shape == torch.Size([1, 512, 28, 28]) + assert feat[2].shape == torch.Size([1, 1024, 14, 14]) + assert feat[3].shape == torch.Size([1, 2048, 7, 7]) + + # Test ResNet50 with layers 1, 2, 3 out forward + model = ResNet(50, out_indices=(0, 1, 2)) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 3 + assert feat[0].shape == torch.Size([1, 256, 56, 56]) + assert feat[1].shape == torch.Size([1, 512, 28, 28]) + assert feat[2].shape == torch.Size([1, 1024, 14, 14]) + + # Test ResNet50 with layers 3 (top feature maps) out forward + model = ResNet(50, out_indices=(3, )) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert feat.shape == torch.Size([1, 2048, 7, 7]) + + # Test ResNet50 with checkpoint forward + model = ResNet(50, out_indices=(0, 1, 2, 3), with_cp=True) + for m in model.modules(): + if is_block(m): + assert m.with_cp + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == torch.Size([1, 256, 56, 56]) + assert feat[1].shape == torch.Size([1, 512, 28, 28]) + assert feat[2].shape == torch.Size([1, 1024, 14, 14]) + assert feat[3].shape == torch.Size([1, 2048, 7, 7]) + + # Test ResNet50 zero initialization of residual + model = ResNet(50, out_indices=(0, 1, 2, 3), zero_init_residual=True) + model.init_weights() + for m in model.modules(): + if isinstance(m, Bottleneck): + assert all_zeros(m.norm3) + elif isinstance(m, BasicBlock): + assert all_zeros(m.norm2) + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == torch.Size([1, 256, 56, 56]) + assert feat[1].shape == torch.Size([1, 512, 28, 28]) + assert feat[2].shape == torch.Size([1, 1024, 14, 14]) + assert feat[3].shape == torch.Size([1, 2048, 7, 7]) + + # Test ResNetV1d forward + model = ResNetV1d(depth=50, out_indices=(0, 1, 2, 3)) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == torch.Size([1, 256, 56, 56]) + assert feat[1].shape == torch.Size([1, 512, 28, 28]) + assert feat[2].shape == torch.Size([1, 1024, 14, 14]) + assert feat[3].shape == torch.Size([1, 2048, 7, 7])