diff --git a/mmengine/analysis/print_helper.py b/mmengine/analysis/print_helper.py index ebad6cba..7c03fb78 100644 --- a/mmengine/analysis/print_helper.py +++ b/mmengine/analysis/print_helper.py @@ -4,7 +4,7 @@ # https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/print_model_statistics.py from collections import defaultdict -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from rich import box @@ -675,8 +675,8 @@ def complexity_stats_table( def get_model_complexity_info( model: nn.Module, - input_shape: tuple = None, - inputs: Optional[torch.Tensor] = None, + input_shape: Optional[tuple] = None, + inputs: Union[torch.Tensor, Tuple[torch.Tensor, ...], None] = None, show_table: bool = True, show_arch: bool = True, ): @@ -684,10 +684,13 @@ def get_model_complexity_info( Args: model (nn.Module): The model to analyze. - input_shape (tuple): The input shape of the model. - inputs (torch.Tensor, optional): The input tensor of the model. - If not given the input tensor will be generated automatically - with the given input_shape. + input_shape (tuple, optional): The input shape of the model. + If inputs is not specified, the input_shape should be set. + Defaults to None. + inputs (torch.Tensor or tuple[torch.Tensor, ...], optional]): + The input tensor(s) of the model. If not given the input tensor + will be generated automatically with the given input_shape. + Defaults to None. show_table (bool): Whether to show the complexity table. Defaults to True. show_arch (bool): Whether to show the complexity arch.