mirror of https://github.com/open-mmlab/mmcv.git
326 lines
12 KiB
Python
326 lines
12 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from copy import deepcopy
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from mmcv.cnn.rfsearch.operator import Conv2dRFSearchOp
|
|
|
|
global_config = dict(
|
|
step=0,
|
|
max_step=12,
|
|
search_interval=1,
|
|
exp_rate=0.5,
|
|
init_alphas=0.01,
|
|
mmin=1,
|
|
mmax=24,
|
|
num_branches=2,
|
|
skip_layer=['stem', 'layer1'])
|
|
|
|
|
|
# test with 3x3 conv
|
|
def test_rfsearch_operator_3x3():
|
|
conv = nn.Conv2d(
|
|
in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
|
|
operator = Conv2dRFSearchOp(conv, global_config)
|
|
x = torch.randn(1, 3, 32, 32)
|
|
|
|
# set no_grad to perform in-place operator
|
|
with torch.no_grad():
|
|
# After expand: (1, 1) (2, 2)
|
|
assert len(operator.dilation_rates) == 2
|
|
assert operator.dilation_rates[0] == (1, 1)
|
|
assert operator.dilation_rates[1] == (2, 2)
|
|
assert torch.all(operator.branch_weights.data ==
|
|
global_config['init_alphas']).item()
|
|
# test forward
|
|
assert operator(x).shape == (1, 3, 32, 32)
|
|
|
|
# After estimate: (2, 2) with branch_weights of [0.5 0.5]
|
|
operator.estimate_rates()
|
|
assert len(operator.dilation_rates) == 1
|
|
assert operator.dilation_rates[0] == (2, 2)
|
|
assert operator.op_layer.dilation == (2, 2)
|
|
assert operator.op_layer.padding == (2, 2)
|
|
# test forward
|
|
assert operator(x).shape == (1, 3, 32, 32)
|
|
|
|
# After expand: (1, 1) (3, 3)
|
|
operator.expand_rates()
|
|
assert len(operator.dilation_rates) == 2
|
|
assert operator.dilation_rates[0] == (1, 1)
|
|
assert operator.dilation_rates[1] == (3, 3)
|
|
assert torch.all(operator.branch_weights.data ==
|
|
global_config['init_alphas']).item()
|
|
# test forward
|
|
assert operator(x).shape == (1, 3, 32, 32)
|
|
|
|
operator.branch_weights[0] = 0.1
|
|
operator.branch_weights[1] = 0.4
|
|
# After estimate: (3, 3) with branch_weights of [0.2 0.8]
|
|
operator.estimate_rates()
|
|
assert len(operator.dilation_rates) == 1
|
|
assert operator.dilation_rates[0] == (3, 3)
|
|
assert operator.op_layer.dilation == (3, 3)
|
|
assert operator.op_layer.padding == (3, 3)
|
|
# test forward
|
|
assert operator(x).shape == (1, 3, 32, 32)
|
|
|
|
|
|
# test with 5x5 conv
|
|
def test_rfsearch_operator_5x5():
|
|
conv = nn.Conv2d(
|
|
in_channels=3, out_channels=3, kernel_size=5, stride=1, padding=2)
|
|
operator = Conv2dRFSearchOp(conv, global_config)
|
|
x = torch.randn(1, 3, 32, 32)
|
|
|
|
with torch.no_grad():
|
|
# After expand: (1, 1) (2, 2)
|
|
assert len(operator.dilation_rates) == 2
|
|
assert operator.dilation_rates[0] == (1, 1)
|
|
assert operator.dilation_rates[1] == (2, 2)
|
|
assert torch.all(operator.branch_weights.data ==
|
|
global_config['init_alphas']).item()
|
|
# test forward
|
|
assert operator(x).shape == (1, 3, 32, 32)
|
|
|
|
# After estimate: (2, 2) with branch_weights of [0.5 0.5]
|
|
operator.estimate_rates()
|
|
assert len(operator.dilation_rates) == 1
|
|
assert operator.dilation_rates[0] == (2, 2)
|
|
assert operator.op_layer.dilation == (2, 2)
|
|
assert operator.op_layer.padding == (4, 4)
|
|
# test forward
|
|
assert operator(x).shape == (1, 3, 32, 32)
|
|
|
|
# After expand: (1, 1) (3, 3)
|
|
operator.expand_rates()
|
|
assert len(operator.dilation_rates) == 2
|
|
assert operator.dilation_rates[0] == (1, 1)
|
|
assert operator.dilation_rates[1] == (3, 3)
|
|
assert torch.all(operator.branch_weights.data ==
|
|
global_config['init_alphas']).item()
|
|
# test forward
|
|
assert operator(x).shape == (1, 3, 32, 32)
|
|
|
|
operator.branch_weights[0] = 0.1
|
|
operator.branch_weights[1] = 0.4
|
|
# After estimate: (3, 3) with branch_weights of [0.2 0.8]
|
|
operator.estimate_rates()
|
|
assert len(operator.dilation_rates) == 1
|
|
assert operator.dilation_rates[0] == (3, 3)
|
|
assert operator.op_layer.dilation == (3, 3)
|
|
assert operator.op_layer.padding == (6, 6)
|
|
# test forward
|
|
assert operator(x).shape == (1, 3, 32, 32)
|
|
|
|
|
|
# test with 5x5 conv num_branches=3
|
|
def test_rfsearch_operator_5x5_branch3():
|
|
conv = nn.Conv2d(
|
|
in_channels=3, out_channels=3, kernel_size=5, stride=1, padding=2)
|
|
config = deepcopy(global_config)
|
|
config['num_branches'] = 3
|
|
operator = Conv2dRFSearchOp(conv, config)
|
|
x = torch.randn(1, 3, 32, 32)
|
|
|
|
with torch.no_grad():
|
|
# After expand: (1, 1) (2, 2)
|
|
assert len(operator.dilation_rates) == 2
|
|
assert operator.dilation_rates[0] == (1, 1)
|
|
assert operator.dilation_rates[1] == (2, 2)
|
|
assert torch.all(operator.branch_weights.data ==
|
|
global_config['init_alphas']).item()
|
|
# test forward
|
|
assert operator(x).shape == (1, 3, 32, 32)
|
|
|
|
# After estimate: (2, 2) with branch_weights of [0.5 0.5]
|
|
operator.estimate_rates()
|
|
assert len(operator.dilation_rates) == 1
|
|
assert operator.dilation_rates[0] == (2, 2)
|
|
assert operator.op_layer.dilation == (2, 2)
|
|
assert operator.op_layer.padding == (4, 4)
|
|
# test forward
|
|
assert operator(x).shape == (1, 3, 32, 32)
|
|
|
|
# After expand: (1, 1) (2, 2) (3, 3)
|
|
operator.expand_rates()
|
|
assert len(operator.dilation_rates) == 3
|
|
assert operator.dilation_rates[0] == (1, 1)
|
|
assert operator.dilation_rates[1] == (2, 2)
|
|
assert operator.dilation_rates[2] == (3, 3)
|
|
assert torch.all(operator.branch_weights.data ==
|
|
global_config['init_alphas']).item()
|
|
# test forward
|
|
assert operator(x).shape == (1, 3, 32, 32)
|
|
|
|
operator.branch_weights[0] = 0.1
|
|
operator.branch_weights[1] = 0.3
|
|
operator.branch_weights[2] = 0.6
|
|
# After estimate: (3, 3) with branch_weights of [0.1 0.3 0.6]
|
|
operator.estimate_rates()
|
|
assert len(operator.dilation_rates) == 1
|
|
assert operator.dilation_rates[0] == (3, 3)
|
|
assert operator.op_layer.dilation == (3, 3)
|
|
assert operator.op_layer.padding == (6, 6)
|
|
# test forward
|
|
assert operator(x).shape == (1, 3, 32, 32)
|
|
|
|
|
|
# test with 1x5 conv
|
|
def test_rfsearch_operator_1x5():
|
|
conv = nn.Conv2d(
|
|
in_channels=3,
|
|
out_channels=3,
|
|
kernel_size=(1, 5),
|
|
stride=1,
|
|
padding=(0, 2))
|
|
operator = Conv2dRFSearchOp(conv, global_config)
|
|
x = torch.randn(1, 3, 32, 32)
|
|
|
|
# After expand: (1, 1) (1, 2)
|
|
assert len(operator.dilation_rates) == 2
|
|
assert operator.dilation_rates[0] == (1, 1)
|
|
assert operator.dilation_rates[1] == (1, 2)
|
|
assert torch.all(
|
|
operator.branch_weights.data == global_config['init_alphas']).item()
|
|
# test forward
|
|
assert operator(x).shape == (1, 3, 32, 32)
|
|
|
|
with torch.no_grad():
|
|
# After estimate: (1, 2) with branch_weights of [0.5 0.5]
|
|
operator.estimate_rates()
|
|
assert len(operator.dilation_rates) == 1
|
|
assert operator.dilation_rates[0] == (1, 2)
|
|
assert operator.op_layer.dilation == (1, 2)
|
|
assert operator.op_layer.padding == (0, 4)
|
|
# test forward
|
|
assert operator(x).shape == (1, 3, 32, 32)
|
|
|
|
# After expand: (1, 1) (1, 3)
|
|
operator.expand_rates()
|
|
assert len(operator.dilation_rates) == 2
|
|
assert operator.dilation_rates[0] == (1, 1)
|
|
assert operator.dilation_rates[1] == (1, 3)
|
|
assert torch.all(operator.branch_weights.data ==
|
|
global_config['init_alphas']).item()
|
|
# test forward
|
|
assert operator(x).shape == (1, 3, 32, 32)
|
|
|
|
operator.branch_weights[0] = 0.2
|
|
operator.branch_weights[1] = 0.8
|
|
# After estimate: (3, 3) with branch_weights of [0.2 0.8]
|
|
operator.estimate_rates()
|
|
assert len(operator.dilation_rates) == 1
|
|
assert operator.dilation_rates[0] == (1, 3)
|
|
assert operator.op_layer.dilation == (1, 3)
|
|
assert operator.op_layer.padding == (0, 6)
|
|
# test forward
|
|
assert operator(x).shape == (1, 3, 32, 32)
|
|
|
|
|
|
# test with 5x5 conv initial_dilation=(2, 2)
|
|
def test_rfsearch_operator_5x5_d2x2():
|
|
conv = nn.Conv2d(
|
|
in_channels=3,
|
|
out_channels=3,
|
|
kernel_size=5,
|
|
stride=1,
|
|
padding=4,
|
|
dilation=(2, 2))
|
|
operator = Conv2dRFSearchOp(conv, global_config)
|
|
x = torch.randn(1, 3, 32, 32)
|
|
|
|
with torch.no_grad():
|
|
# After expand: (1, 1) (3, 3)
|
|
assert len(operator.dilation_rates) == 2
|
|
assert operator.dilation_rates[0] == (1, 1)
|
|
assert operator.dilation_rates[1] == (3, 3)
|
|
assert torch.all(operator.branch_weights.data ==
|
|
global_config['init_alphas']).item()
|
|
# test forward
|
|
assert operator(x).shape == (1, 3, 32, 32)
|
|
|
|
# After estimate: (2, 2) with branch_weights of [0.5 0.5]
|
|
operator.estimate_rates()
|
|
assert len(operator.dilation_rates) == 1
|
|
assert operator.dilation_rates[0] == (2, 2)
|
|
assert operator.op_layer.dilation == (2, 2)
|
|
assert operator.op_layer.padding == (4, 4)
|
|
# test forward
|
|
assert operator(x).shape == (1, 3, 32, 32)
|
|
|
|
# After expand: (1, 1) (3, 3)
|
|
operator.expand_rates()
|
|
assert len(operator.dilation_rates) == 2
|
|
assert operator.dilation_rates[0] == (1, 1)
|
|
assert operator.dilation_rates[1] == (3, 3)
|
|
assert torch.all(operator.branch_weights.data ==
|
|
global_config['init_alphas']).item()
|
|
# test forward
|
|
assert operator(x).shape == (1, 3, 32, 32)
|
|
|
|
operator.branch_weights[0] = 0.8
|
|
operator.branch_weights[1] = 0.2
|
|
# After estimate: (3, 3) with branch_weights of [0.8 0.2]
|
|
operator.estimate_rates()
|
|
assert len(operator.dilation_rates) == 1
|
|
assert operator.dilation_rates[0] == (1, 1)
|
|
assert operator.op_layer.dilation == (1, 1)
|
|
assert operator.op_layer.padding == (2, 2)
|
|
# test forward
|
|
assert operator(x).shape == (1, 3, 32, 32)
|
|
|
|
|
|
# test with 5x5 conv initial_dilation=(1, 2)
|
|
def test_rfsearch_operator_5x5_d1x2():
|
|
conv = nn.Conv2d(
|
|
in_channels=3,
|
|
out_channels=3,
|
|
kernel_size=5,
|
|
stride=1,
|
|
padding=(2, 4),
|
|
dilation=(1, 2))
|
|
operator = Conv2dRFSearchOp(conv, global_config)
|
|
x = torch.randn(1, 3, 32, 32)
|
|
|
|
with torch.no_grad():
|
|
# After expand: (1, 1) (2, 3)
|
|
assert len(operator.dilation_rates) == 2
|
|
assert operator.dilation_rates[0] == (1, 1)
|
|
assert operator.dilation_rates[1] == (2, 3)
|
|
assert torch.all(operator.branch_weights.data ==
|
|
global_config['init_alphas']).item()
|
|
# test forward
|
|
assert operator(x).shape == (1, 3, 32, 32)
|
|
|
|
# After estimate: (2, 2) with branch_weights of [0.5 0.5]
|
|
operator.estimate_rates()
|
|
assert len(operator.dilation_rates) == 1
|
|
assert operator.dilation_rates[0] == (2, 2)
|
|
assert operator.op_layer.dilation == (2, 2)
|
|
assert operator.op_layer.padding == (4, 4)
|
|
# test forward
|
|
assert operator(x).shape == (1, 3, 32, 32)
|
|
|
|
# After expand: (1, 1) (3, 3)
|
|
operator.expand_rates()
|
|
assert len(operator.dilation_rates) == 2
|
|
assert operator.dilation_rates[0] == (1, 1)
|
|
assert operator.dilation_rates[1] == (3, 3)
|
|
assert torch.all(operator.branch_weights.data ==
|
|
global_config['init_alphas']).item()
|
|
# test forward
|
|
assert operator(x).shape == (1, 3, 32, 32)
|
|
|
|
operator.branch_weights[0] = 0.1
|
|
operator.branch_weights[1] = 0.8
|
|
# After estimate: (3, 3) with branch_weights of [0.1 0.8]
|
|
operator.estimate_rates()
|
|
assert len(operator.dilation_rates) == 1
|
|
assert operator.dilation_rates[0] == (3, 3)
|
|
assert operator.op_layer.dilation == (3, 3)
|
|
assert operator.op_layer.padding == (6, 6)
|
|
# test forward
|
|
assert operator(x).shape == (1, 3, 32, 32)
|