[Enhance] Complement type hint of get_model_complexity_info() (#1064)
* Complement type hint of get_model_complexity_info() The type of `inputs` should be one of `torch.Tensor`, `tuple[torch.Tensor, ...]` and `None`. Signed-off-by: Shengjiang QUAN <qsj287068067@126.com> * Update print_helper.py --------- Signed-off-by: Shengjiang QUAN <qsj287068067@126.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>pull/912/head
parent
b2ad2210b5
commit
5da24ed102
|
@ -4,7 +4,7 @@
|
|||
# https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/print_model_statistics.py
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from rich import box
|
||||
|
@ -675,8 +675,8 @@ def complexity_stats_table(
|
|||
|
||||
def get_model_complexity_info(
|
||||
model: nn.Module,
|
||||
input_shape: tuple = None,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
input_shape: Optional[tuple] = None,
|
||||
inputs: Union[torch.Tensor, Tuple[torch.Tensor, ...], None] = None,
|
||||
show_table: bool = True,
|
||||
show_arch: bool = True,
|
||||
):
|
||||
|
@ -684,10 +684,13 @@ def get_model_complexity_info(
|
|||
|
||||
Args:
|
||||
model (nn.Module): The model to analyze.
|
||||
input_shape (tuple): The input shape of the model.
|
||||
inputs (torch.Tensor, optional): The input tensor of the model.
|
||||
If not given the input tensor will be generated automatically
|
||||
with the given input_shape.
|
||||
input_shape (tuple, optional): The input shape of the model.
|
||||
If inputs is not specified, the input_shape should be set.
|
||||
Defaults to None.
|
||||
inputs (torch.Tensor or tuple[torch.Tensor, ...], optional]):
|
||||
The input tensor(s) of the model. If not given the input tensor
|
||||
will be generated automatically with the given input_shape.
|
||||
Defaults to None.
|
||||
show_table (bool): Whether to show the complexity table.
|
||||
Defaults to True.
|
||||
show_arch (bool): Whether to show the complexity arch.
|
||||
|
|
Loading…
Reference in New Issue