[Enhance] Complement type hint of get_model_complexity_info() (#1064)

* Complement type hint of get_model_complexity_info()

The type of `inputs` should be one of `torch.Tensor`,
`tuple[torch.Tensor, ...]` and `None`.

Signed-off-by: Shengjiang QUAN <qsj287068067@126.com>

* Update print_helper.py

---------

Signed-off-by: Shengjiang QUAN <qsj287068067@126.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
pull/912/head
sjiang95 2023-04-10 20:57:54 +09:00 committed by GitHub
parent b2ad2210b5
commit 5da24ed102
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 7 deletions

View File

@ -4,7 +4,7 @@
# https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/print_model_statistics.py
from collections import defaultdict
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
from rich import box
@ -675,8 +675,8 @@ def complexity_stats_table(
def get_model_complexity_info(
model: nn.Module,
input_shape: tuple = None,
inputs: Optional[torch.Tensor] = None,
input_shape: Optional[tuple] = None,
inputs: Union[torch.Tensor, Tuple[torch.Tensor, ...], None] = None,
show_table: bool = True,
show_arch: bool = True,
):
@ -684,10 +684,13 @@ def get_model_complexity_info(
Args:
model (nn.Module): The model to analyze.
input_shape (tuple): The input shape of the model.
inputs (torch.Tensor, optional): The input tensor of the model.
If not given the input tensor will be generated automatically
with the given input_shape.
input_shape (tuple, optional): The input shape of the model.
If inputs is not specified, the input_shape should be set.
Defaults to None.
inputs (torch.Tensor or tuple[torch.Tensor, ...], optional]):
The input tensor(s) of the model. If not given the input tensor
will be generated automatically with the given input_shape.
Defaults to None.
show_table (bool): Whether to show the complexity table.
Defaults to True.
show_arch (bool): Whether to show the complexity arch.