From 5da24ed1021c7ebd17a5d7c89999089f15a5d4cf Mon Sep 17 00:00:00 2001 From: sjiang95 <51251025+sjiang95@users.noreply.github.com> Date: Mon, 10 Apr 2023 20:57:54 +0900 Subject: [PATCH] [Enhance] Complement type hint of get_model_complexity_info() (#1064) * Complement type hint of get_model_complexity_info() The type of `inputs` should be one of `torch.Tensor`, `tuple[torch.Tensor, ...]` and `None`. Signed-off-by: Shengjiang QUAN * Update print_helper.py --------- Signed-off-by: Shengjiang QUAN Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmengine/analysis/print_helper.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/mmengine/analysis/print_helper.py b/mmengine/analysis/print_helper.py index ebad6cba..7c03fb78 100644 --- a/mmengine/analysis/print_helper.py +++ b/mmengine/analysis/print_helper.py @@ -4,7 +4,7 @@ # https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/print_model_statistics.py from collections import defaultdict -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from rich import box @@ -675,8 +675,8 @@ def complexity_stats_table( def get_model_complexity_info( model: nn.Module, - input_shape: tuple = None, - inputs: Optional[torch.Tensor] = None, + input_shape: Optional[tuple] = None, + inputs: Union[torch.Tensor, Tuple[torch.Tensor, ...], None] = None, show_table: bool = True, show_arch: bool = True, ): @@ -684,10 +684,13 @@ def get_model_complexity_info( Args: model (nn.Module): The model to analyze. - input_shape (tuple): The input shape of the model. - inputs (torch.Tensor, optional): The input tensor of the model. - If not given the input tensor will be generated automatically - with the given input_shape. + input_shape (tuple, optional): The input shape of the model. + If inputs is not specified, the input_shape should be set. + Defaults to None. + inputs (torch.Tensor or tuple[torch.Tensor, ...], optional]): + The input tensor(s) of the model. If not given the input tensor + will be generated automatically with the given input_shape. + Defaults to None. show_table (bool): Whether to show the complexity table. Defaults to True. show_arch (bool): Whether to show the complexity arch.