mmcv/tests/test_weight_init.py

78 lines
2.8 KiB
Python

# Copyright (c) Open-MMLab. All rights reserved.
import numpy as np
import pytest
import torch
from torch import nn
from mmcv.cnn import (bias_init_with_prob, caffe2_xavier_init, constant_init,
kaiming_init, normal_init, uniform_init, xavier_init)
def test_constant_init():
conv_module = nn.Conv2d(3, 16, 3)
constant_init(conv_module, 0.1)
assert conv_module.weight.allclose(
torch.full_like(conv_module.weight, 0.1))
assert conv_module.bias.allclose(torch.zeros_like(conv_module.bias))
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
constant_init(conv_module_no_bias, 0.1)
assert conv_module.weight.allclose(
torch.full_like(conv_module.weight, 0.1))
def test_xavier_init():
conv_module = nn.Conv2d(3, 16, 3)
xavier_init(conv_module, bias=0.1)
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
xavier_init(conv_module, distribution='uniform')
# TODO: sanity check of weight distribution, e.g. mean, std
with pytest.raises(AssertionError):
xavier_init(conv_module, distribution='student-t')
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
xavier_init(conv_module_no_bias)
def test_normal_init():
conv_module = nn.Conv2d(3, 16, 3)
normal_init(conv_module, bias=0.1)
# TODO: sanity check of weight distribution, e.g. mean, std
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
normal_init(conv_module_no_bias)
# TODO: sanity check distribution, e.g. mean, std
def test_uniform_init():
conv_module = nn.Conv2d(3, 16, 3)
uniform_init(conv_module, bias=0.1)
# TODO: sanity check of weight distribution, e.g. mean, std
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
uniform_init(conv_module_no_bias)
def test_kaiming_init():
conv_module = nn.Conv2d(3, 16, 3)
kaiming_init(conv_module, bias=0.1)
# TODO: sanity check of weight distribution, e.g. mean, std
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
kaiming_init(conv_module, distribution='uniform')
with pytest.raises(AssertionError):
kaiming_init(conv_module, distribution='student-t')
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
kaiming_init(conv_module_no_bias)
def test_caffe_xavier_init():
conv_module = nn.Conv2d(3, 16, 3)
caffe2_xavier_init(conv_module)
def test_bias_init_with_prob():
conv_module = nn.Conv2d(3, 16, 3)
prior_prob = 0.1
normal_init(conv_module, bias=bias_init_with_prob(0.1))
# TODO: sanity check of weight distribution, e.g. mean, std
bias = float(-np.log((1 - prior_prob) / prior_prob))
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, bias))