mmpretrain/tests/test_models/test_backbones/test_cspnet.py

148 lines
5.0 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from functools import partial
from unittest import TestCase
import torch
from mmcv.cnn import ConvModule
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmcls.models.backbones import CSPDarkNet, CSPResNet, CSPResNeXt
from mmcls.models.backbones.cspnet import (CSPNet, DarknetBottleneck,
ResNetBottleneck, ResNeXtBottleneck)
class TestCSPNet(TestCase):
def setUp(self):
self.arch = dict(
block_fn=(DarknetBottleneck, ResNetBottleneck, ResNeXtBottleneck),
in_channels=(32, 64, 128),
out_channels=(64, 128, 256),
num_blocks=(1, 2, 8),
expand_ratio=(2, 1, 1),
bottle_ratio=(3, 1, 1),
has_downsampler=True,
down_growth=True,
block_args=({}, {}, dict(base_channels=32)))
self.stem_fn = partial(torch.nn.Conv2d, out_channels=32, kernel_size=3)
def test_structure(self):
# Test with attribute arch_setting.
model = CSPNet(arch=self.arch, stem_fn=self.stem_fn, out_indices=[-1])
self.assertEqual(len(model.stages), 3)
self.assertEqual(type(model.stages[0].blocks[0]), DarknetBottleneck)
self.assertEqual(type(model.stages[1].blocks[0]), ResNetBottleneck)
self.assertEqual(type(model.stages[2].blocks[0]), ResNeXtBottleneck)
class TestCSPDarkNet(TestCase):
def setUp(self):
self.class_name = CSPDarkNet
self.cfg = dict(depth=53)
self.out_channels = [64, 128, 256, 512, 1024]
self.all_out_indices = [0, 1, 2, 3, 4]
self.frozen_stages = 2
self.stem_down = (1, 1)
self.num_stages = 5
def test_structure(self):
# Test invalid default depths
with self.assertRaisesRegex(AssertionError, 'depth must be one of'):
cfg = deepcopy(self.cfg)
cfg['depth'] = 'unknown'
self.class_name(**cfg)
# Test out_indices
cfg = deepcopy(self.cfg)
cfg['out_indices'] = {1: 1}
with self.assertRaisesRegex(AssertionError, "get <class 'dict'>"):
self.class_name(**cfg)
cfg['out_indices'] = [0, 13]
with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 13'):
self.class_name(**cfg)
# Test model structure
cfg = deepcopy(self.cfg)
model = self.class_name(**cfg)
self.assertEqual(len(model.stages), self.num_stages)
def test_forward(self):
imgs = torch.randn(3, 3, 224, 224)
# test without output_cls_token
cfg = deepcopy(self.cfg)
model = self.class_name(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
self.assertEqual(outs[-1].size(), (3, self.out_channels[-1], 7, 7))
# Test forward with multi out indices
cfg = deepcopy(self.cfg)
cfg['out_indices'] = self.all_out_indices
model = self.class_name(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), len(self.all_out_indices))
w, h = 224 / self.stem_down[0], 224 / self.stem_down[1]
for i, out in enumerate(outs):
self.assertEqual(
out.size(),
(3, self.out_channels[i], w // 2**(i + 1), h // 2**(i + 1)))
# Test frozen stages
cfg = deepcopy(self.cfg)
cfg['frozen_stages'] = self.frozen_stages
model = self.class_name(**cfg)
model.init_weights()
model.train()
assert model.stem.training is False
for param in model.stem.parameters():
assert param.requires_grad is False
for i in range(self.frozen_stages + 1):
stage = model.stages[i]
for mod in stage.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False, i
for param in stage.parameters():
assert param.requires_grad is False
class TestCSPResNet(TestCSPDarkNet):
def setUp(self):
self.class_name = CSPResNet
self.cfg = dict(depth=50)
self.out_channels = [128, 256, 512, 1024]
self.all_out_indices = [0, 1, 2, 3]
self.frozen_stages = 2
self.stem_down = (2, 2)
self.num_stages = 4
def test_deep_stem(self, ):
cfg = deepcopy(self.cfg)
cfg['deep_stem'] = True
model = self.class_name(**cfg)
self.assertEqual(len(model.stem), 3)
for i in range(3):
self.assertEqual(type(model.stem[i]), ConvModule)
class TestCSPResNeXt(TestCSPDarkNet):
def setUp(self):
self.class_name = CSPResNeXt
self.cfg = dict(depth=50)
self.out_channels = [256, 512, 1024, 2048]
self.all_out_indices = [0, 1, 2, 3]
self.frozen_stages = 2
self.stem_down = (2, 2)
self.num_stages = 4
if __name__ == '__main__':
import unittest
unittest.main()