[Enhance] Complement type hint of get_model_complexity_info() (#1064)

* Complement type hint of get_model_complexity_info()

The type of `inputs` should be one of `torch.Tensor`,
`tuple[torch.Tensor, ...]` and `None`.

Signed-off-by: Shengjiang QUAN <qsj287068067@126.com>

* Update print_helper.py

---------

Signed-off-by: Shengjiang QUAN <qsj287068067@126.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
sjiang95 2023-04-10 20:57:54 +09:00 committed by GitHub
parent b2ad2210b5
commit 5da24ed102
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,7 +4,7 @@
# https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/print_model_statistics.py # https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/print_model_statistics.py
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from rich import box from rich import box
@ -675,8 +675,8 @@ def complexity_stats_table(
def get_model_complexity_info( def get_model_complexity_info(
model: nn.Module, model: nn.Module,
input_shape: tuple = None, input_shape: Optional[tuple] = None,
inputs: Optional[torch.Tensor] = None, inputs: Union[torch.Tensor, Tuple[torch.Tensor, ...], None] = None,
show_table: bool = True, show_table: bool = True,
show_arch: bool = True, show_arch: bool = True,
): ):
@ -684,10 +684,13 @@ def get_model_complexity_info(
Args: Args:
model (nn.Module): The model to analyze. model (nn.Module): The model to analyze.
input_shape (tuple): The input shape of the model. input_shape (tuple, optional): The input shape of the model.
inputs (torch.Tensor, optional): The input tensor of the model. If inputs is not specified, the input_shape should be set.
If not given the input tensor will be generated automatically Defaults to None.
with the given input_shape. inputs (torch.Tensor or tuple[torch.Tensor, ...], optional]):
The input tensor(s) of the model. If not given the input tensor
will be generated automatically with the given input_shape.
Defaults to None.
show_table (bool): Whether to show the complexity table. show_table (bool): Whether to show the complexity table.
Defaults to True. Defaults to True.
show_arch (bool): Whether to show the complexity arch. show_arch (bool): Whether to show the complexity arch.