201 lines
6.1 KiB
Python
201 lines
6.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest import TestCase
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from mmrazor.models import * # noqa:F403,F401
|
|
from mmrazor.registry import MODELS
|
|
|
|
MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True)
|
|
MODELS.register_module(name='torchMaxPool2d', module=nn.MaxPool2d, force=True)
|
|
MODELS.register_module(name='torchAvgPool2d', module=nn.AvgPool2d, force=True)
|
|
|
|
|
|
class TestDiffOP(TestCase):
|
|
|
|
def test_forward_arch_param(self):
|
|
op_cfg = dict(
|
|
type='DiffMutableOP',
|
|
candidates=dict(
|
|
torch_conv2d_3x3=dict(
|
|
type='torchConv2d',
|
|
kernel_size=3,
|
|
padding=1,
|
|
),
|
|
torch_conv2d_5x5=dict(
|
|
type='torchConv2d',
|
|
kernel_size=5,
|
|
padding=2,
|
|
),
|
|
torch_conv2d_7x7=dict(
|
|
type='torchConv2d',
|
|
kernel_size=7,
|
|
padding=3,
|
|
),
|
|
),
|
|
module_kwargs=dict(in_channels=32, out_channels=32, stride=1))
|
|
|
|
op = MODELS.build(op_cfg)
|
|
input = torch.randn(4, 32, 64, 64)
|
|
|
|
arch_param = nn.Parameter(torch.randn(len(op_cfg['candidates'])))
|
|
output = op.forward_arch_param(input, arch_param=arch_param)
|
|
assert output is not None
|
|
|
|
# test when some element of arch_param is 0
|
|
arch_param = nn.Parameter(torch.ones(op.num_choices))
|
|
output = op.forward_arch_param(input, arch_param=arch_param)
|
|
assert output is not None
|
|
|
|
def test_forward_fixed(self):
|
|
op_cfg = dict(
|
|
type='DiffMutableOP',
|
|
candidates=dict(
|
|
torch_conv2d_3x3=dict(
|
|
type='torchConv2d',
|
|
kernel_size=3,
|
|
),
|
|
torch_conv2d_5x5=dict(
|
|
type='torchConv2d',
|
|
kernel_size=5,
|
|
),
|
|
torch_conv2d_7x7=dict(
|
|
type='torchConv2d',
|
|
kernel_size=7,
|
|
),
|
|
),
|
|
module_kwargs=dict(in_channels=32, out_channels=32, stride=1))
|
|
|
|
op = MODELS.build(op_cfg)
|
|
input = torch.randn(4, 32, 64, 64)
|
|
|
|
op.fix_chosen('torch_conv2d_7x7')
|
|
output = op.forward_fixed(input)
|
|
|
|
assert output is not None
|
|
assert op.is_fixed is True
|
|
|
|
def test_forward(self):
|
|
op_cfg = dict(
|
|
type='DiffMutableOP',
|
|
candidates=dict(
|
|
torch_conv2d_3x3=dict(
|
|
type='torchConv2d',
|
|
kernel_size=3,
|
|
padding=1,
|
|
),
|
|
torch_conv2d_5x5=dict(
|
|
type='torchConv2d',
|
|
kernel_size=5,
|
|
padding=2,
|
|
),
|
|
torch_conv2d_7x7=dict(
|
|
type='torchConv2d',
|
|
kernel_size=7,
|
|
padding=3,
|
|
),
|
|
),
|
|
module_kwargs=dict(in_channels=32, out_channels=32, stride=1))
|
|
|
|
op = MODELS.build(op_cfg)
|
|
input = torch.randn(4, 32, 64, 64)
|
|
|
|
# test set_forward_args
|
|
arch_param = nn.Parameter(torch.randn(len(op_cfg['candidates'])))
|
|
op.set_forward_args(arch_param=arch_param)
|
|
output = op.forward(input)
|
|
assert output is not None
|
|
|
|
# test dump_chosen
|
|
with pytest.raises(AssertionError):
|
|
op.dump_chosen()
|
|
|
|
# test forward when is_fixed is True
|
|
op.fix_chosen('torch_conv2d_7x7')
|
|
output = op.forward(input)
|
|
|
|
def test_property(self):
|
|
op_cfg = dict(
|
|
type='DiffMutableOP',
|
|
candidates=dict(
|
|
torch_conv2d_3x3=dict(
|
|
type='torchConv2d',
|
|
kernel_size=3,
|
|
padding=1,
|
|
),
|
|
torch_conv2d_5x5=dict(
|
|
type='torchConv2d',
|
|
kernel_size=5,
|
|
padding=2,
|
|
),
|
|
torch_conv2d_7x7=dict(
|
|
type='torchConv2d',
|
|
kernel_size=7,
|
|
padding=3,
|
|
),
|
|
),
|
|
module_kwargs=dict(in_channels=32, out_channels=32, stride=1))
|
|
|
|
op = MODELS.build(op_cfg)
|
|
|
|
assert len(op.choices) == 3
|
|
|
|
# test is_fixed propty
|
|
assert op.is_fixed is False
|
|
|
|
# test is_fixed setting
|
|
op.fix_chosen('torch_conv2d_5x5')
|
|
|
|
with pytest.raises(AttributeError):
|
|
op.is_fixed = True
|
|
|
|
# test fix choice when is_fixed is True
|
|
with pytest.raises(AttributeError):
|
|
op.fix_chosen('torch_conv2d_3x3')
|
|
|
|
def test_module_kwargs(self):
|
|
op_cfg = dict(
|
|
type='DiffMutableOP',
|
|
candidates=dict(
|
|
torch_conv2d_3x3=dict(
|
|
type='torchConv2d',
|
|
kernel_size=3,
|
|
in_channels=32,
|
|
out_channels=32,
|
|
stride=1,
|
|
),
|
|
torch_conv2d_5x5=dict(
|
|
type='torchConv2d',
|
|
kernel_size=5,
|
|
in_channels=32,
|
|
out_channels=32,
|
|
stride=1,
|
|
),
|
|
torch_conv2d_7x7=dict(
|
|
type='torchConv2d',
|
|
kernel_size=7,
|
|
in_channels=32,
|
|
out_channels=32,
|
|
stride=1,
|
|
),
|
|
torch_maxpool_3x3=dict(
|
|
type='torchMaxPool2d',
|
|
kernel_size=3,
|
|
stride=1,
|
|
),
|
|
torch_avgpool_3x3=dict(
|
|
type='torchAvgPool2d',
|
|
kernel_size=3,
|
|
stride=1,
|
|
),
|
|
),
|
|
)
|
|
op = MODELS.build(op_cfg)
|
|
input = torch.randn(4, 32, 64, 64)
|
|
|
|
op.fix_chosen('torch_avgpool_3x3')
|
|
output = op.forward(input)
|
|
assert output is not None
|