57 lines
2.0 KiB
Python
57 lines
2.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from types import MethodType
|
|
from unittest import TestCase
|
|
|
|
import torch
|
|
|
|
from mmcls.models import InceptionV3
|
|
from mmcls.models.backbones.inception_v3 import InceptionAux
|
|
|
|
|
|
class TestInceptionV3(TestCase):
|
|
DEFAULT_ARGS = dict(num_classes=10, aux_logits=False, dropout=0.)
|
|
|
|
def test_structure(self):
|
|
# Test without auxiliary branch.
|
|
model = InceptionV3(**self.DEFAULT_ARGS)
|
|
self.assertIsNone(model.AuxLogits)
|
|
|
|
# Test with auxiliary branch.
|
|
cfg = {**self.DEFAULT_ARGS, 'aux_logits': True}
|
|
model = InceptionV3(**cfg)
|
|
self.assertIsInstance(model.AuxLogits, InceptionAux)
|
|
|
|
def test_init_weights(self):
|
|
cfg = {**self.DEFAULT_ARGS, 'aux_logits': True}
|
|
model = InceptionV3(**cfg)
|
|
|
|
init_info = {}
|
|
|
|
def get_init_info(self, *args):
|
|
for name, param in self.named_parameters():
|
|
init_info[name] = ''.join(
|
|
self._params_init_info[param]['init_info'])
|
|
|
|
model._dump_init_info = MethodType(get_init_info, model)
|
|
model.init_weights()
|
|
self.assertIn('TruncNormalInit: a=-2, b=2, mean=0, std=0.1, bias=0',
|
|
init_info['Conv2d_1a_3x3.conv.weight'])
|
|
self.assertIn('TruncNormalInit: a=-2, b=2, mean=0, std=0.01, bias=0',
|
|
init_info['AuxLogits.conv0.conv.weight'])
|
|
self.assertIn('TruncNormalInit: a=-2, b=2, mean=0, std=0.001, bias=0',
|
|
init_info['AuxLogits.fc.weight'])
|
|
|
|
def test_forward(self):
|
|
inputs = torch.rand(2, 3, 299, 299)
|
|
|
|
model = InceptionV3(**self.DEFAULT_ARGS)
|
|
aux_out, out = model(inputs)
|
|
self.assertIsNone(aux_out)
|
|
self.assertEqual(out.shape, (2, 10))
|
|
|
|
cfg = {**self.DEFAULT_ARGS, 'aux_logits': True}
|
|
model = InceptionV3(**cfg)
|
|
aux_out, out = model(inputs)
|
|
self.assertEqual(aux_out.shape, (2, 10))
|
|
self.assertEqual(out.shape, (2, 10))
|