mmpretrain/tests/test_models/test_backbones/test_inception_v3.py

57 lines
2.0 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from types import MethodType
from unittest import TestCase
import torch
from mmcls.models import InceptionV3
from mmcls.models.backbones.inception_v3 import InceptionAux
class TestInceptionV3(TestCase):
DEFAULT_ARGS = dict(num_classes=10, aux_logits=False, dropout=0.)
def test_structure(self):
# Test without auxiliary branch.
model = InceptionV3(**self.DEFAULT_ARGS)
self.assertIsNone(model.AuxLogits)
# Test with auxiliary branch.
cfg = {**self.DEFAULT_ARGS, 'aux_logits': True}
model = InceptionV3(**cfg)
self.assertIsInstance(model.AuxLogits, InceptionAux)
def test_init_weights(self):
cfg = {**self.DEFAULT_ARGS, 'aux_logits': True}
model = InceptionV3(**cfg)
init_info = {}
def get_init_info(self, *args):
for name, param in self.named_parameters():
init_info[name] = ''.join(
self._params_init_info[param]['init_info'])
model._dump_init_info = MethodType(get_init_info, model)
model.init_weights()
self.assertIn('TruncNormalInit: a=-2, b=2, mean=0, std=0.1, bias=0',
init_info['Conv2d_1a_3x3.conv.weight'])
self.assertIn('TruncNormalInit: a=-2, b=2, mean=0, std=0.01, bias=0',
init_info['AuxLogits.conv0.conv.weight'])
self.assertIn('TruncNormalInit: a=-2, b=2, mean=0, std=0.001, bias=0',
init_info['AuxLogits.fc.weight'])
def test_forward(self):
inputs = torch.rand(2, 3, 299, 299)
model = InceptionV3(**self.DEFAULT_ARGS)
aux_out, out = model(inputs)
self.assertIsNone(aux_out)
self.assertEqual(out.shape, (2, 10))
cfg = {**self.DEFAULT_ARGS, 'aux_logits': True}
model = InceptionV3(**cfg)
aux_out, out = model(inputs)
self.assertEqual(aux_out.shape, (2, 10))
self.assertEqual(out.shape, (2, 10))