mmrazor/tests/test_models/test_mutables/test_diffop.py

201 lines
6.1 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import pytest
import torch
import torch.nn as nn
from mmrazor.models import * # noqa:F403,F401
from mmrazor.registry import MODELS
MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True)
MODELS.register_module(name='torchMaxPool2d', module=nn.MaxPool2d, force=True)
MODELS.register_module(name='torchAvgPool2d', module=nn.AvgPool2d, force=True)
class TestDiffOP(TestCase):
def test_forward_arch_param(self):
op_cfg = dict(
type='DiffMutableOP',
candidates=dict(
torch_conv2d_3x3=dict(
type='torchConv2d',
kernel_size=3,
padding=1,
),
torch_conv2d_5x5=dict(
type='torchConv2d',
kernel_size=5,
padding=2,
),
torch_conv2d_7x7=dict(
type='torchConv2d',
kernel_size=7,
padding=3,
),
),
module_kwargs=dict(in_channels=32, out_channels=32, stride=1))
op = MODELS.build(op_cfg)
input = torch.randn(4, 32, 64, 64)
arch_param = nn.Parameter(torch.randn(len(op_cfg['candidates'])))
output = op.forward_arch_param(input, arch_param=arch_param)
assert output is not None
# test when some element of arch_param is 0
arch_param = nn.Parameter(torch.ones(op.num_choices))
output = op.forward_arch_param(input, arch_param=arch_param)
assert output is not None
def test_forward_fixed(self):
op_cfg = dict(
type='DiffMutableOP',
candidates=dict(
torch_conv2d_3x3=dict(
type='torchConv2d',
kernel_size=3,
),
torch_conv2d_5x5=dict(
type='torchConv2d',
kernel_size=5,
),
torch_conv2d_7x7=dict(
type='torchConv2d',
kernel_size=7,
),
),
module_kwargs=dict(in_channels=32, out_channels=32, stride=1))
op = MODELS.build(op_cfg)
input = torch.randn(4, 32, 64, 64)
op.fix_chosen('torch_conv2d_7x7')
output = op.forward_fixed(input)
assert output is not None
assert op.is_fixed is True
def test_forward(self):
op_cfg = dict(
type='DiffMutableOP',
candidates=dict(
torch_conv2d_3x3=dict(
type='torchConv2d',
kernel_size=3,
padding=1,
),
torch_conv2d_5x5=dict(
type='torchConv2d',
kernel_size=5,
padding=2,
),
torch_conv2d_7x7=dict(
type='torchConv2d',
kernel_size=7,
padding=3,
),
),
module_kwargs=dict(in_channels=32, out_channels=32, stride=1))
op = MODELS.build(op_cfg)
input = torch.randn(4, 32, 64, 64)
# test set_forward_args
arch_param = nn.Parameter(torch.randn(len(op_cfg['candidates'])))
op.set_forward_args(arch_param=arch_param)
output = op.forward(input)
assert output is not None
# test dump_chosen
with pytest.raises(AssertionError):
op.dump_chosen()
# test forward when is_fixed is True
op.fix_chosen('torch_conv2d_7x7')
output = op.forward(input)
def test_property(self):
op_cfg = dict(
type='DiffMutableOP',
candidates=dict(
torch_conv2d_3x3=dict(
type='torchConv2d',
kernel_size=3,
padding=1,
),
torch_conv2d_5x5=dict(
type='torchConv2d',
kernel_size=5,
padding=2,
),
torch_conv2d_7x7=dict(
type='torchConv2d',
kernel_size=7,
padding=3,
),
),
module_kwargs=dict(in_channels=32, out_channels=32, stride=1))
op = MODELS.build(op_cfg)
assert len(op.choices) == 3
# test is_fixed propty
assert op.is_fixed is False
# test is_fixed setting
op.fix_chosen('torch_conv2d_5x5')
with pytest.raises(AttributeError):
op.is_fixed = True
# test fix choice when is_fixed is True
with pytest.raises(AttributeError):
op.fix_chosen('torch_conv2d_3x3')
def test_module_kwargs(self):
op_cfg = dict(
type='DiffMutableOP',
candidates=dict(
torch_conv2d_3x3=dict(
type='torchConv2d',
kernel_size=3,
in_channels=32,
out_channels=32,
stride=1,
),
torch_conv2d_5x5=dict(
type='torchConv2d',
kernel_size=5,
in_channels=32,
out_channels=32,
stride=1,
),
torch_conv2d_7x7=dict(
type='torchConv2d',
kernel_size=7,
in_channels=32,
out_channels=32,
stride=1,
),
torch_maxpool_3x3=dict(
type='torchMaxPool2d',
kernel_size=3,
stride=1,
),
torch_avgpool_3x3=dict(
type='torchAvgPool2d',
kernel_size=3,
stride=1,
),
),
)
op = MODELS.build(op_cfg)
input = torch.randn(4, 32, 64, 64)
op.fix_chosen('torch_avgpool_3x3')
output = op.forward(input)
assert output is not None