mirror of
https://github.com/open-mmlab/mmrazor.git
synced 2025-06-03 15:02:54 +08:00
201 lines
6.0 KiB
Python
201 lines
6.0 KiB
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||
|
from unittest import TestCase
|
||
|
|
||
|
import pytest
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
from mmrazor.models import * # noqa:F403,F401
|
||
|
from mmrazor.registry import MODELS
|
||
|
|
||
|
MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True)
|
||
|
MODELS.register_module(name='torchMaxPool2d', module=nn.MaxPool2d, force=True)
|
||
|
MODELS.register_module(name='torchAvgPool2d', module=nn.AvgPool2d, force=True)
|
||
|
|
||
|
|
||
|
class TestDiffOP(TestCase):
|
||
|
|
||
|
def test_forward_arch_param(self):
|
||
|
op_cfg = dict(
|
||
|
type='DiffOP',
|
||
|
candidate_ops=dict(
|
||
|
torch_conv2d_3x3=dict(
|
||
|
type='torchConv2d',
|
||
|
kernel_size=3,
|
||
|
padding=1,
|
||
|
),
|
||
|
torch_conv2d_5x5=dict(
|
||
|
type='torchConv2d',
|
||
|
kernel_size=5,
|
||
|
padding=2,
|
||
|
),
|
||
|
torch_conv2d_7x7=dict(
|
||
|
type='torchConv2d',
|
||
|
kernel_size=7,
|
||
|
padding=3,
|
||
|
),
|
||
|
),
|
||
|
module_kwargs=dict(in_channels=32, out_channels=32, stride=1))
|
||
|
|
||
|
op = MODELS.build(op_cfg)
|
||
|
input = torch.randn(4, 32, 64, 64)
|
||
|
|
||
|
arch_param = op.build_arch_param()
|
||
|
output = op.forward_arch_param(input, arch_param=arch_param)
|
||
|
assert output is not None
|
||
|
|
||
|
output = op.forward_arch_param(input, arch_param=None)
|
||
|
assert output is not None
|
||
|
|
||
|
# test when some element of arch_param is 0
|
||
|
arch_param = op.build_arch_param()
|
||
|
arch_param = nn.Parameter(torch.ones(op.num_choices))
|
||
|
output = op.forward_arch_param(input, arch_param=arch_param)
|
||
|
assert output is not None
|
||
|
|
||
|
def test_forward_fixed(self):
|
||
|
op_cfg = dict(
|
||
|
type='DiffOP',
|
||
|
candidate_ops=dict(
|
||
|
torch_conv2d_3x3=dict(
|
||
|
type='torchConv2d',
|
||
|
kernel_size=3,
|
||
|
),
|
||
|
torch_conv2d_5x5=dict(
|
||
|
type='torchConv2d',
|
||
|
kernel_size=5,
|
||
|
),
|
||
|
torch_conv2d_7x7=dict(
|
||
|
type='torchConv2d',
|
||
|
kernel_size=7,
|
||
|
),
|
||
|
),
|
||
|
module_kwargs=dict(in_channels=32, out_channels=32, stride=1))
|
||
|
|
||
|
op = MODELS.build(op_cfg)
|
||
|
input = torch.randn(4, 32, 64, 64)
|
||
|
|
||
|
op.fix_chosen('torch_conv2d_7x7')
|
||
|
output = op.forward_fixed(input)
|
||
|
|
||
|
assert output is not None
|
||
|
assert op.is_fixed is True
|
||
|
|
||
|
def test_forward(self):
|
||
|
op_cfg = dict(
|
||
|
type='DiffOP',
|
||
|
candidate_ops=dict(
|
||
|
torch_conv2d_3x3=dict(
|
||
|
type='torchConv2d',
|
||
|
kernel_size=3,
|
||
|
padding=1,
|
||
|
),
|
||
|
torch_conv2d_5x5=dict(
|
||
|
type='torchConv2d',
|
||
|
kernel_size=5,
|
||
|
padding=2,
|
||
|
),
|
||
|
torch_conv2d_7x7=dict(
|
||
|
type='torchConv2d',
|
||
|
kernel_size=7,
|
||
|
padding=3,
|
||
|
),
|
||
|
),
|
||
|
module_kwargs=dict(in_channels=32, out_channels=32, stride=1))
|
||
|
|
||
|
op = MODELS.build(op_cfg)
|
||
|
input = torch.randn(4, 32, 64, 64)
|
||
|
|
||
|
# test set_forward_args
|
||
|
arch_param = op.build_arch_param()
|
||
|
op.set_forward_args(arch_param=arch_param)
|
||
|
output = op.forward(input)
|
||
|
assert output is not None
|
||
|
|
||
|
# test forward when is_fixed is True
|
||
|
op.fix_chosen('torch_conv2d_7x7')
|
||
|
output = op.forward(input)
|
||
|
|
||
|
def test_property(self):
|
||
|
op_cfg = dict(
|
||
|
type='DiffOP',
|
||
|
candidate_ops=dict(
|
||
|
torch_conv2d_3x3=dict(
|
||
|
type='torchConv2d',
|
||
|
kernel_size=3,
|
||
|
padding=1,
|
||
|
),
|
||
|
torch_conv2d_5x5=dict(
|
||
|
type='torchConv2d',
|
||
|
kernel_size=5,
|
||
|
padding=2,
|
||
|
),
|
||
|
torch_conv2d_7x7=dict(
|
||
|
type='torchConv2d',
|
||
|
kernel_size=7,
|
||
|
padding=3,
|
||
|
),
|
||
|
),
|
||
|
module_kwargs=dict(in_channels=32, out_channels=32, stride=1))
|
||
|
|
||
|
op = MODELS.build(op_cfg)
|
||
|
|
||
|
assert len(op.choices) == 3
|
||
|
|
||
|
# test is_fixed propty
|
||
|
assert op.is_fixed is False
|
||
|
|
||
|
# test is_fixed setting
|
||
|
op.fix_chosen('torch_conv2d_5x5')
|
||
|
|
||
|
with pytest.raises(AttributeError):
|
||
|
op.is_fixed = True
|
||
|
|
||
|
# test fix choice when is_fixed is True
|
||
|
with pytest.raises(AttributeError):
|
||
|
op.fix_chosen('torch_conv2d_3x3')
|
||
|
|
||
|
def test_module_kwargs(self):
|
||
|
op_cfg = dict(
|
||
|
type='DiffOP',
|
||
|
candidate_ops=dict(
|
||
|
torch_conv2d_3x3=dict(
|
||
|
type='torchConv2d',
|
||
|
kernel_size=3,
|
||
|
in_channels=32,
|
||
|
out_channels=32,
|
||
|
stride=1,
|
||
|
),
|
||
|
torch_conv2d_5x5=dict(
|
||
|
type='torchConv2d',
|
||
|
kernel_size=5,
|
||
|
in_channels=32,
|
||
|
out_channels=32,
|
||
|
stride=1,
|
||
|
),
|
||
|
torch_conv2d_7x7=dict(
|
||
|
type='torchConv2d',
|
||
|
kernel_size=7,
|
||
|
in_channels=32,
|
||
|
out_channels=32,
|
||
|
stride=1,
|
||
|
),
|
||
|
torch_maxpool_3x3=dict(
|
||
|
type='torchMaxPool2d',
|
||
|
kernel_size=3,
|
||
|
stride=1,
|
||
|
),
|
||
|
torch_avgpool_3x3=dict(
|
||
|
type='torchAvgPool2d',
|
||
|
kernel_size=3,
|
||
|
stride=1,
|
||
|
),
|
||
|
),
|
||
|
)
|
||
|
op = MODELS.build(op_cfg)
|
||
|
input = torch.randn(4, 32, 64, 64)
|
||
|
|
||
|
op.fix_chosen('torch_avgpool_3x3')
|
||
|
output = op.forward(input)
|
||
|
assert output is not None
|