mmcv/tests/test_cnn/test_rfsearch/test_operator.py

326 lines
12 KiB
Python
Raw Normal View History

[Feature] Support receptive field search of CNN models (#2056) * support rfsearch * add labs for rfsearch * format * format * add docstring and type hints * clean code Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * rm unused func * update code * update code * update code * update details * fix details * support asymmetric kernel * support asymmetric kernel * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Apply suggestions from code review * add unit tests for rfsearch * set device for Conv2dRFSearchOp * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * remove unused function search_estimate_only * move unit tests * Update tests/test_cnn/test_rfsearch/test_operator.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmcv/cnn/rfsearch/operator.py Co-authored-by: Yue Zhou <592267829@qq.com> * change logger * Update mmcv/cnn/rfsearch/operator.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: lzyhha <819814373@qq.com> Co-authored-by: Zhongyu Li <44114862+lzyhha@users.noreply.github.com> Co-authored-by: Yue Zhou <592267829@qq.com> [Fix] Fix skip_layer for RF-Next (#2489) * judge skip_layer by fullname * lint * skip_layer first * update unit test
2022-11-22 19:15:55 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import torch
import torch.nn as nn
from mmcv.cnn.rfsearch.operator import Conv2dRFSearchOp
global_config = dict(
step=0,
max_step=12,
search_interval=1,
exp_rate=0.5,
init_alphas=0.01,
mmin=1,
mmax=24,
num_branches=2,
skip_layer=['stem', 'layer1'])
# test with 3x3 conv
def test_rfsearch_operator_3x3():
conv = nn.Conv2d(
in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
operator = Conv2dRFSearchOp(conv, global_config)
x = torch.randn(1, 3, 32, 32)
# set no_grad to perform in-place operator
with torch.no_grad():
# After expand: (1, 1) (2, 2)
assert len(operator.dilation_rates) == 2
assert operator.dilation_rates[0] == (1, 1)
assert operator.dilation_rates[1] == (2, 2)
assert torch.all(operator.branch_weights.data ==
global_config['init_alphas']).item()
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# After estimate: (2, 2) with branch_weights of [0.5 0.5]
operator.estimate_rates()
assert len(operator.dilation_rates) == 1
assert operator.dilation_rates[0] == (2, 2)
assert operator.op_layer.dilation == (2, 2)
assert operator.op_layer.padding == (2, 2)
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# After expand: (1, 1) (3, 3)
operator.expand_rates()
assert len(operator.dilation_rates) == 2
assert operator.dilation_rates[0] == (1, 1)
assert operator.dilation_rates[1] == (3, 3)
assert torch.all(operator.branch_weights.data ==
global_config['init_alphas']).item()
# test forward
assert operator(x).shape == (1, 3, 32, 32)
operator.branch_weights[0] = 0.1
operator.branch_weights[1] = 0.4
# After estimate: (3, 3) with branch_weights of [0.2 0.8]
operator.estimate_rates()
assert len(operator.dilation_rates) == 1
assert operator.dilation_rates[0] == (3, 3)
assert operator.op_layer.dilation == (3, 3)
assert operator.op_layer.padding == (3, 3)
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# test with 5x5 conv
def test_rfsearch_operator_5x5():
conv = nn.Conv2d(
in_channels=3, out_channels=3, kernel_size=5, stride=1, padding=2)
operator = Conv2dRFSearchOp(conv, global_config)
x = torch.randn(1, 3, 32, 32)
with torch.no_grad():
# After expand: (1, 1) (2, 2)
assert len(operator.dilation_rates) == 2
assert operator.dilation_rates[0] == (1, 1)
assert operator.dilation_rates[1] == (2, 2)
assert torch.all(operator.branch_weights.data ==
global_config['init_alphas']).item()
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# After estimate: (2, 2) with branch_weights of [0.5 0.5]
operator.estimate_rates()
assert len(operator.dilation_rates) == 1
assert operator.dilation_rates[0] == (2, 2)
assert operator.op_layer.dilation == (2, 2)
assert operator.op_layer.padding == (4, 4)
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# After expand: (1, 1) (3, 3)
operator.expand_rates()
assert len(operator.dilation_rates) == 2
assert operator.dilation_rates[0] == (1, 1)
assert operator.dilation_rates[1] == (3, 3)
assert torch.all(operator.branch_weights.data ==
global_config['init_alphas']).item()
# test forward
assert operator(x).shape == (1, 3, 32, 32)
operator.branch_weights[0] = 0.1
operator.branch_weights[1] = 0.4
# After estimate: (3, 3) with branch_weights of [0.2 0.8]
operator.estimate_rates()
assert len(operator.dilation_rates) == 1
assert operator.dilation_rates[0] == (3, 3)
assert operator.op_layer.dilation == (3, 3)
assert operator.op_layer.padding == (6, 6)
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# test with 5x5 conv num_branches=3
def test_rfsearch_operator_5x5_branch3():
conv = nn.Conv2d(
in_channels=3, out_channels=3, kernel_size=5, stride=1, padding=2)
config = deepcopy(global_config)
config['num_branches'] = 3
operator = Conv2dRFSearchOp(conv, config)
x = torch.randn(1, 3, 32, 32)
with torch.no_grad():
# After expand: (1, 1) (2, 2)
assert len(operator.dilation_rates) == 2
assert operator.dilation_rates[0] == (1, 1)
assert operator.dilation_rates[1] == (2, 2)
assert torch.all(operator.branch_weights.data ==
global_config['init_alphas']).item()
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# After estimate: (2, 2) with branch_weights of [0.5 0.5]
operator.estimate_rates()
assert len(operator.dilation_rates) == 1
assert operator.dilation_rates[0] == (2, 2)
assert operator.op_layer.dilation == (2, 2)
assert operator.op_layer.padding == (4, 4)
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# After expand: (1, 1) (2, 2) (3, 3)
operator.expand_rates()
assert len(operator.dilation_rates) == 3
assert operator.dilation_rates[0] == (1, 1)
assert operator.dilation_rates[1] == (2, 2)
assert operator.dilation_rates[2] == (3, 3)
assert torch.all(operator.branch_weights.data ==
global_config['init_alphas']).item()
# test forward
assert operator(x).shape == (1, 3, 32, 32)
operator.branch_weights[0] = 0.1
operator.branch_weights[1] = 0.3
operator.branch_weights[2] = 0.6
# After estimate: (3, 3) with branch_weights of [0.1 0.3 0.6]
operator.estimate_rates()
assert len(operator.dilation_rates) == 1
assert operator.dilation_rates[0] == (3, 3)
assert operator.op_layer.dilation == (3, 3)
assert operator.op_layer.padding == (6, 6)
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# test with 1x5 conv
def test_rfsearch_operator_1x5():
conv = nn.Conv2d(
in_channels=3,
out_channels=3,
kernel_size=(1, 5),
stride=1,
padding=(0, 2))
operator = Conv2dRFSearchOp(conv, global_config)
x = torch.randn(1, 3, 32, 32)
# After expand: (1, 1) (1, 2)
assert len(operator.dilation_rates) == 2
assert operator.dilation_rates[0] == (1, 1)
assert operator.dilation_rates[1] == (1, 2)
assert torch.all(
operator.branch_weights.data == global_config['init_alphas']).item()
# test forward
assert operator(x).shape == (1, 3, 32, 32)
with torch.no_grad():
# After estimate: (1, 2) with branch_weights of [0.5 0.5]
operator.estimate_rates()
assert len(operator.dilation_rates) == 1
assert operator.dilation_rates[0] == (1, 2)
assert operator.op_layer.dilation == (1, 2)
assert operator.op_layer.padding == (0, 4)
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# After expand: (1, 1) (1, 3)
operator.expand_rates()
assert len(operator.dilation_rates) == 2
assert operator.dilation_rates[0] == (1, 1)
assert operator.dilation_rates[1] == (1, 3)
assert torch.all(operator.branch_weights.data ==
global_config['init_alphas']).item()
# test forward
assert operator(x).shape == (1, 3, 32, 32)
operator.branch_weights[0] = 0.2
operator.branch_weights[1] = 0.8
# After estimate: (3, 3) with branch_weights of [0.2 0.8]
operator.estimate_rates()
assert len(operator.dilation_rates) == 1
assert operator.dilation_rates[0] == (1, 3)
assert operator.op_layer.dilation == (1, 3)
assert operator.op_layer.padding == (0, 6)
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# test with 5x5 conv initial_dilation=(2, 2)
def test_rfsearch_operator_5x5_d2x2():
conv = nn.Conv2d(
in_channels=3,
out_channels=3,
kernel_size=5,
stride=1,
padding=4,
dilation=(2, 2))
operator = Conv2dRFSearchOp(conv, global_config)
x = torch.randn(1, 3, 32, 32)
with torch.no_grad():
# After expand: (1, 1) (3, 3)
assert len(operator.dilation_rates) == 2
assert operator.dilation_rates[0] == (1, 1)
assert operator.dilation_rates[1] == (3, 3)
assert torch.all(operator.branch_weights.data ==
global_config['init_alphas']).item()
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# After estimate: (2, 2) with branch_weights of [0.5 0.5]
operator.estimate_rates()
assert len(operator.dilation_rates) == 1
assert operator.dilation_rates[0] == (2, 2)
assert operator.op_layer.dilation == (2, 2)
assert operator.op_layer.padding == (4, 4)
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# After expand: (1, 1) (3, 3)
operator.expand_rates()
assert len(operator.dilation_rates) == 2
assert operator.dilation_rates[0] == (1, 1)
assert operator.dilation_rates[1] == (3, 3)
assert torch.all(operator.branch_weights.data ==
global_config['init_alphas']).item()
# test forward
assert operator(x).shape == (1, 3, 32, 32)
operator.branch_weights[0] = 0.8
operator.branch_weights[1] = 0.2
# After estimate: (3, 3) with branch_weights of [0.8 0.2]
operator.estimate_rates()
assert len(operator.dilation_rates) == 1
assert operator.dilation_rates[0] == (1, 1)
assert operator.op_layer.dilation == (1, 1)
assert operator.op_layer.padding == (2, 2)
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# test with 5x5 conv initial_dilation=(1, 2)
def test_rfsearch_operator_5x5_d1x2():
conv = nn.Conv2d(
in_channels=3,
out_channels=3,
kernel_size=5,
stride=1,
padding=(2, 4),
dilation=(1, 2))
operator = Conv2dRFSearchOp(conv, global_config)
x = torch.randn(1, 3, 32, 32)
with torch.no_grad():
# After expand: (1, 1) (2, 3)
assert len(operator.dilation_rates) == 2
assert operator.dilation_rates[0] == (1, 1)
assert operator.dilation_rates[1] == (2, 3)
assert torch.all(operator.branch_weights.data ==
global_config['init_alphas']).item()
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# After estimate: (2, 2) with branch_weights of [0.5 0.5]
operator.estimate_rates()
assert len(operator.dilation_rates) == 1
assert operator.dilation_rates[0] == (2, 2)
assert operator.op_layer.dilation == (2, 2)
assert operator.op_layer.padding == (4, 4)
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# After expand: (1, 1) (3, 3)
operator.expand_rates()
assert len(operator.dilation_rates) == 2
assert operator.dilation_rates[0] == (1, 1)
assert operator.dilation_rates[1] == (3, 3)
assert torch.all(operator.branch_weights.data ==
global_config['init_alphas']).item()
# test forward
assert operator(x).shape == (1, 3, 32, 32)
operator.branch_weights[0] = 0.1
operator.branch_weights[1] = 0.8
# After estimate: (3, 3) with branch_weights of [0.1 0.8]
operator.estimate_rates()
assert len(operator.dilation_rates) == 1
assert operator.dilation_rates[0] == (3, 3)
assert operator.op_layer.dilation == (3, 3)
assert operator.op_layer.padding == (6, 6)
# test forward
assert operator(x).shape == (1, 3, 32, 32)