mirror of https://github.com/open-mmlab/mmcv.git
78 lines
2.8 KiB
Python
78 lines
2.8 KiB
Python
# Copyright (c) Open-MMLab. All rights reserved.
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
from torch import nn
|
|
|
|
from mmcv.cnn import (bias_init_with_prob, caffe2_xavier_init, constant_init,
|
|
kaiming_init, normal_init, uniform_init, xavier_init)
|
|
|
|
|
|
def test_constant_init():
|
|
conv_module = nn.Conv2d(3, 16, 3)
|
|
constant_init(conv_module, 0.1)
|
|
assert conv_module.weight.allclose(
|
|
torch.full_like(conv_module.weight, 0.1))
|
|
assert conv_module.bias.allclose(torch.zeros_like(conv_module.bias))
|
|
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
|
|
constant_init(conv_module_no_bias, 0.1)
|
|
assert conv_module.weight.allclose(
|
|
torch.full_like(conv_module.weight, 0.1))
|
|
|
|
|
|
def test_xavier_init():
|
|
conv_module = nn.Conv2d(3, 16, 3)
|
|
xavier_init(conv_module, bias=0.1)
|
|
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
|
|
xavier_init(conv_module, distribution='uniform')
|
|
# TODO: sanity check of weight distribution, e.g. mean, std
|
|
with pytest.raises(AssertionError):
|
|
xavier_init(conv_module, distribution='student-t')
|
|
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
|
|
xavier_init(conv_module_no_bias)
|
|
|
|
|
|
def test_normal_init():
|
|
conv_module = nn.Conv2d(3, 16, 3)
|
|
normal_init(conv_module, bias=0.1)
|
|
# TODO: sanity check of weight distribution, e.g. mean, std
|
|
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
|
|
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
|
|
normal_init(conv_module_no_bias)
|
|
# TODO: sanity check distribution, e.g. mean, std
|
|
|
|
|
|
def test_uniform_init():
|
|
conv_module = nn.Conv2d(3, 16, 3)
|
|
uniform_init(conv_module, bias=0.1)
|
|
# TODO: sanity check of weight distribution, e.g. mean, std
|
|
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
|
|
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
|
|
uniform_init(conv_module_no_bias)
|
|
|
|
|
|
def test_kaiming_init():
|
|
conv_module = nn.Conv2d(3, 16, 3)
|
|
kaiming_init(conv_module, bias=0.1)
|
|
# TODO: sanity check of weight distribution, e.g. mean, std
|
|
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
|
|
kaiming_init(conv_module, distribution='uniform')
|
|
with pytest.raises(AssertionError):
|
|
kaiming_init(conv_module, distribution='student-t')
|
|
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
|
|
kaiming_init(conv_module_no_bias)
|
|
|
|
|
|
def test_caffe_xavier_init():
|
|
conv_module = nn.Conv2d(3, 16, 3)
|
|
caffe2_xavier_init(conv_module)
|
|
|
|
|
|
def test_bias_init_with_prob():
|
|
conv_module = nn.Conv2d(3, 16, 3)
|
|
prior_prob = 0.1
|
|
normal_init(conv_module, bias=bias_init_with_prob(0.1))
|
|
# TODO: sanity check of weight distribution, e.g. mean, std
|
|
bias = float(-np.log((1 - prior_prob) / prior_prob))
|
|
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, bias))
|