mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] get_model_complexity_info() supports multiple inputs (#1065)
This commit is contained in:
parent
43165160e6
commit
fafb476e58
@ -12,6 +12,7 @@ from rich.console import Console
|
|||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from mmengine.utils import is_tuple_of
|
||||||
from .complexity_analysis import (ActivationAnalyzer, FlopAnalyzer,
|
from .complexity_analysis import (ActivationAnalyzer, FlopAnalyzer,
|
||||||
parameter_count)
|
parameter_count)
|
||||||
|
|
||||||
@ -675,19 +676,38 @@ def complexity_stats_table(
|
|||||||
|
|
||||||
def get_model_complexity_info(
|
def get_model_complexity_info(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
input_shape: Optional[tuple] = None,
|
input_shape: Union[Tuple[int, ...], Tuple[Tuple[int, ...], ...],
|
||||||
inputs: Union[torch.Tensor, Tuple[torch.Tensor, ...], None] = None,
|
None] = None,
|
||||||
|
inputs: Union[torch.Tensor, Tuple[torch.Tensor, ...], Tuple[Any, ...],
|
||||||
|
None] = None,
|
||||||
show_table: bool = True,
|
show_table: bool = True,
|
||||||
show_arch: bool = True,
|
show_arch: bool = True,
|
||||||
):
|
):
|
||||||
"""Interface to get the complexity of a model.
|
"""Interface to get the complexity of a model.
|
||||||
|
|
||||||
|
The parameter `inputs` are fed to the forward method of model.
|
||||||
|
If `inputs` is not specified, the `input_shape` is required and
|
||||||
|
it will be used to construct the dummy input fed to model.
|
||||||
|
If the forward of model requires two or more inputs, the `inputs`
|
||||||
|
should be a tuple of tensor or the `input_shape` should be a tuple
|
||||||
|
of tuple which each element will be constructed into a dumpy input.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # the forward of model accepts only one input
|
||||||
|
>>> input_shape = (3, 224, 224)
|
||||||
|
>>> get_model_complexity_info(model, input_shape=input_shape)
|
||||||
|
>>> # the forward of model accepts two or more inputs
|
||||||
|
>>> input_shape = ((3, 224, 224), (3, 10))
|
||||||
|
>>> get_model_complexity_info(model, input_shape=input_shape)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): The model to analyze.
|
model (nn.Module): The model to analyze.
|
||||||
input_shape (tuple, optional): The input shape of the model.
|
input_shape (Union[Tuple[int, ...], Tuple[Tuple[int, ...]], None]):
|
||||||
If inputs is not specified, the input_shape should be set.
|
The input shape of the model.
|
||||||
|
If "inputs" is not specified, the "input_shape" should be set.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
inputs (torch.Tensor or tuple[torch.Tensor, ...], optional]):
|
inputs (torch.Tensor, tuple[torch.Tensor, ...] or Tuple[Any, ...],\
|
||||||
|
optional]):
|
||||||
The input tensor(s) of the model. If not given the input tensor
|
The input tensor(s) of the model. If not given the input tensor
|
||||||
will be generated automatically with the given input_shape.
|
will be generated automatically with the given input_shape.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
@ -705,7 +725,21 @@ def get_model_complexity_info(
|
|||||||
raise ValueError('"input_shape" and "inputs" cannot be both set.')
|
raise ValueError('"input_shape" and "inputs" cannot be both set.')
|
||||||
|
|
||||||
if inputs is None:
|
if inputs is None:
|
||||||
inputs = (torch.randn(1, *input_shape), )
|
if is_tuple_of(input_shape, int): # tuple of int, construct one tensor
|
||||||
|
inputs = (torch.randn(1, *input_shape), )
|
||||||
|
elif is_tuple_of(input_shape, tuple) and all([
|
||||||
|
is_tuple_of(one_input_shape, int)
|
||||||
|
for one_input_shape in input_shape # type: ignore
|
||||||
|
]): # tuple of tuple of int, construct multiple tensors
|
||||||
|
inputs = tuple([
|
||||||
|
torch.randn(1, *one_input_shape)
|
||||||
|
for one_input_shape in input_shape # type: ignore
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'"input_shape" should be either a `tuple of int` (to construct'
|
||||||
|
'one input tensor) or a `tuple of tuple of int` (to construct'
|
||||||
|
'multiple input tensors).')
|
||||||
|
|
||||||
flop_handler = FlopAnalyzer(model, inputs)
|
flop_handler = FlopAnalyzer(model, inputs)
|
||||||
activation_handler = ActivationAnalyzer(model, inputs)
|
activation_handler = ActivationAnalyzer(model, inputs)
|
||||||
|
108
tests/test_analysis/test_print_helper.py
Normal file
108
tests/test_analysis/test_print_helper.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from mmengine.analysis.complexity_analysis import FlopAnalyzer, parameter_count
|
||||||
|
from mmengine.analysis.print_helper import get_model_complexity_info
|
||||||
|
from mmengine.utils import digit_version
|
||||||
|
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||||
|
|
||||||
|
|
||||||
|
class NetAcceptOneTensor(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.l1 = nn.Linear(in_features=5, out_features=6)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
out = self.l1(x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class NetAcceptTwoTensors(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.l1 = nn.Linear(in_features=5, out_features=6)
|
||||||
|
self.l2 = nn.Linear(in_features=7, out_features=6)
|
||||||
|
|
||||||
|
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
||||||
|
out = self.l1(x1) + self.l2(x2)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class NetAcceptOneTensorAndOneScalar(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.l1 = nn.Linear(in_features=5, out_features=6)
|
||||||
|
self.l2 = nn.Linear(in_features=5, out_features=6)
|
||||||
|
|
||||||
|
def forward(self, x1: torch.Tensor, r) -> torch.Tensor:
|
||||||
|
out = r * self.l1(x1) + (1 - r) * self.l2(x1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_model_complexity_info():
|
||||||
|
input1 = torch.randn(1, 9, 5)
|
||||||
|
input_shape1 = (9, 5)
|
||||||
|
input2 = torch.randn(1, 9, 7)
|
||||||
|
input_shape2 = (9, 7)
|
||||||
|
scalar = 0.3
|
||||||
|
|
||||||
|
# test a network that accepts one tensor as input
|
||||||
|
model = NetAcceptOneTensor()
|
||||||
|
complexity_info = get_model_complexity_info(model=model, inputs=input1)
|
||||||
|
flops = FlopAnalyzer(model=model, inputs=input1).total()
|
||||||
|
params = parameter_count(model=model)['']
|
||||||
|
assert complexity_info['flops'] == flops
|
||||||
|
assert complexity_info['params'] == params
|
||||||
|
|
||||||
|
complexity_info = get_model_complexity_info(
|
||||||
|
model=model, input_shape=input_shape1)
|
||||||
|
flops = FlopAnalyzer(
|
||||||
|
model=model, inputs=(torch.randn(1, *input_shape1), )).total()
|
||||||
|
assert complexity_info['flops'] == flops
|
||||||
|
|
||||||
|
# test a network that accepts two tensors as input
|
||||||
|
model = NetAcceptTwoTensors()
|
||||||
|
complexity_info = get_model_complexity_info(
|
||||||
|
model=model, inputs=(input1, input2))
|
||||||
|
flops = FlopAnalyzer(model=model, inputs=(input1, input2)).total()
|
||||||
|
params = parameter_count(model=model)['']
|
||||||
|
assert complexity_info['flops'] == flops
|
||||||
|
assert complexity_info['params'] == params
|
||||||
|
|
||||||
|
complexity_info = get_model_complexity_info(
|
||||||
|
model=model, input_shape=(input_shape1, input_shape2))
|
||||||
|
inputs = (torch.randn(1, *input_shape1), torch.randn(1, *input_shape2))
|
||||||
|
flops = FlopAnalyzer(model=model, inputs=inputs).total()
|
||||||
|
assert complexity_info['flops'] == flops
|
||||||
|
|
||||||
|
# test a network that accepts one tensor and one scalar as input
|
||||||
|
model = NetAcceptOneTensorAndOneScalar()
|
||||||
|
# For pytorch<1.9, a scalar input is not acceptable for torch.jit,
|
||||||
|
# wrap it to `torch.tensor`. See https://github.com/pytorch/pytorch/blob/cd9dd653e98534b5d3a9f2576df2feda40916f1d/torch/csrc/jit/python/python_arg_flatten.cpp#L90. # noqa: E501
|
||||||
|
scalar = torch.tensor([
|
||||||
|
scalar
|
||||||
|
]) if digit_version(TORCH_VERSION) < digit_version('1.9.0') else scalar
|
||||||
|
complexity_info = get_model_complexity_info(
|
||||||
|
model=model, inputs=(input1, scalar))
|
||||||
|
flops = FlopAnalyzer(model=model, inputs=(input1, scalar)).total()
|
||||||
|
params = parameter_count(model=model)['']
|
||||||
|
assert complexity_info['flops'] == flops
|
||||||
|
assert complexity_info['params'] == params
|
||||||
|
|
||||||
|
# `get_model_complexity_info()` should throw `ValueError`
|
||||||
|
# when neithor `inputs` nor `input_shape` is specified
|
||||||
|
with pytest.raises(ValueError, match='should be set'):
|
||||||
|
get_model_complexity_info(model)
|
||||||
|
|
||||||
|
# `get_model_complexity_info()` should throw `ValueError`
|
||||||
|
# when both `inputs` and `input_shape` are specified
|
||||||
|
model = NetAcceptOneTensor()
|
||||||
|
with pytest.raises(ValueError, match='cannot be both set'):
|
||||||
|
get_model_complexity_info(
|
||||||
|
model, inputs=input1, input_shape=input_shape1)
|
Loading…
x
Reference in New Issue
Block a user