# Copyright (c) Open-MMLab. All rights reserved. import numpy as np import pytest import torch from torch import nn from mmcv.cnn import (bias_init_with_prob, caffe2_xavier_init, constant_init, kaiming_init, normal_init, uniform_init, xavier_init) def test_constant_init(): conv_module = nn.Conv2d(3, 16, 3) constant_init(conv_module, 0.1) assert conv_module.weight.allclose( torch.full_like(conv_module.weight, 0.1)) assert conv_module.bias.allclose(torch.zeros_like(conv_module.bias)) conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False) constant_init(conv_module_no_bias, 0.1) assert conv_module.weight.allclose( torch.full_like(conv_module.weight, 0.1)) def test_xavier_init(): conv_module = nn.Conv2d(3, 16, 3) xavier_init(conv_module, bias=0.1) assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1)) xavier_init(conv_module, distribution='uniform') # TODO: sanity check of weight distribution, e.g. mean, std with pytest.raises(AssertionError): xavier_init(conv_module, distribution='student-t') conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False) xavier_init(conv_module_no_bias) def test_normal_init(): conv_module = nn.Conv2d(3, 16, 3) normal_init(conv_module, bias=0.1) # TODO: sanity check of weight distribution, e.g. mean, std assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1)) conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False) normal_init(conv_module_no_bias) # TODO: sanity check distribution, e.g. mean, std def test_uniform_init(): conv_module = nn.Conv2d(3, 16, 3) uniform_init(conv_module, bias=0.1) # TODO: sanity check of weight distribution, e.g. mean, std assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1)) conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False) uniform_init(conv_module_no_bias) def test_kaiming_init(): conv_module = nn.Conv2d(3, 16, 3) kaiming_init(conv_module, bias=0.1) # TODO: sanity check of weight distribution, e.g. mean, std assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1)) kaiming_init(conv_module, distribution='uniform') with pytest.raises(AssertionError): kaiming_init(conv_module, distribution='student-t') conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False) kaiming_init(conv_module_no_bias) def test_caffe_xavier_init(): conv_module = nn.Conv2d(3, 16, 3) caffe2_xavier_init(conv_module) def test_bias_init_with_prob(): conv_module = nn.Conv2d(3, 16, 3) prior_prob = 0.1 normal_init(conv_module, bias=bias_init_with_prob(0.1)) # TODO: sanity check of weight distribution, e.g. mean, std bias = float(-np.log((1 - prior_prob) / prior_prob)) assert conv_module.bias.allclose(torch.full_like(conv_module.bias, bias))