mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] Complement type hint of get_model_complexity_info() (#1064)
* Complement type hint of get_model_complexity_info() The type of `inputs` should be one of `torch.Tensor`, `tuple[torch.Tensor, ...]` and `None`. Signed-off-by: Shengjiang QUAN <qsj287068067@126.com> * Update print_helper.py --------- Signed-off-by: Shengjiang QUAN <qsj287068067@126.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
parent
b2ad2210b5
commit
5da24ed102
@ -4,7 +4,7 @@
|
|||||||
# https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/print_model_statistics.py
|
# https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/print_model_statistics.py
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
|
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from rich import box
|
from rich import box
|
||||||
@ -675,8 +675,8 @@ def complexity_stats_table(
|
|||||||
|
|
||||||
def get_model_complexity_info(
|
def get_model_complexity_info(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
input_shape: tuple = None,
|
input_shape: Optional[tuple] = None,
|
||||||
inputs: Optional[torch.Tensor] = None,
|
inputs: Union[torch.Tensor, Tuple[torch.Tensor, ...], None] = None,
|
||||||
show_table: bool = True,
|
show_table: bool = True,
|
||||||
show_arch: bool = True,
|
show_arch: bool = True,
|
||||||
):
|
):
|
||||||
@ -684,10 +684,13 @@ def get_model_complexity_info(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): The model to analyze.
|
model (nn.Module): The model to analyze.
|
||||||
input_shape (tuple): The input shape of the model.
|
input_shape (tuple, optional): The input shape of the model.
|
||||||
inputs (torch.Tensor, optional): The input tensor of the model.
|
If inputs is not specified, the input_shape should be set.
|
||||||
If not given the input tensor will be generated automatically
|
Defaults to None.
|
||||||
with the given input_shape.
|
inputs (torch.Tensor or tuple[torch.Tensor, ...], optional]):
|
||||||
|
The input tensor(s) of the model. If not given the input tensor
|
||||||
|
will be generated automatically with the given input_shape.
|
||||||
|
Defaults to None.
|
||||||
show_table (bool): Whether to show the complexity table.
|
show_table (bool): Whether to show the complexity table.
|
||||||
Defaults to True.
|
Defaults to True.
|
||||||
show_arch (bool): Whether to show the complexity arch.
|
show_arch (bool): Whether to show the complexity arch.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user