diff --git a/docs/en/advanced_tutorials/model_analysis.md b/docs/en/advanced_tutorials/model_analysis.md index 0817c59e..124ad3f6 100644 --- a/docs/en/advanced_tutorials/model_analysis.md +++ b/docs/en/advanced_tutorials/model_analysis.md @@ -70,24 +70,24 @@ The return outputs is dict, which contains the following keys: - `out_table`: print related information by table ``` -┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━┓ -┃ module ┃ #parameters or shape ┃ #flops ┃ #activations ┃ -┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━┩ -│ model │ 0.44K │ 0.4K │ 40 │ -│ fc1 │ 0.11K │ 100 │ 10 │ -│ fc1.weight │ (10, 10) │ │ │ -│ fc1.bias │ (10,) │ │ │ -│ fc2 │ 0.11K │ 100 │ 10 │ -│ fc2.weight │ (10, 10) │ │ │ -│ fc2.bias │ (10,) │ │ │ -│ inner │ 0.22K │ 0.2K │ 20 │ -│ inner.fc1 │ 0.11K │ 100 │ 10 │ -│ inner.fc1.weight │ (10, 10) │ │ │ -│ inner.fc1.bias │ (10,) │ │ │ -│ inner.fc2 │ 0.11K │ 100 │ 10 │ -│ inner.fc2.weight │ (10, 10) │ │ │ -│ inner.fc2.bias │ (10,) │ │ │ -└─────────────────────┴──────────────────────┴────────┴──────────────┘ ++---------------------+----------------------+--------+--------------+ +| module | #parameters or shape | #flops | #activations | ++---------------------+----------------------+--------+--------------+ +| model | 0.44K | 0.4K | 40 | +| fc1 | 0.11K | 100 | 10 | +| fc1.weight | (10, 10) | | | +| fc1.bias | (10,) | | | +| fc2 | 0.11K | 100 | 10 | +| fc2.weight | (10, 10) | | | +| fc2.bias | (10,) | | | +| inner | 0.22K | 0.2K | 20 | +| inner.fc1 | 0.11K | 100 | 10 | +| inner.fc1.weight | (10, 10) | | | +| inner.fc1.bias | (10,) | | | +| inner.fc2 | 0.11K | 100 | 10 | +| inner.fc2.weight | (10, 10) | | | +| inner.fc2.bias | (10,) | | | ++---------------------+----------------------+--------+--------------+ ``` - `out_arch`: print related information by network layers diff --git a/mmengine/analysis/complexity_analysis.py b/mmengine/analysis/complexity_analysis.py index 0c701be3..435e5fe5 100644 --- a/mmengine/analysis/complexity_analysis.py +++ b/mmengine/analysis/complexity_analysis.py @@ -5,6 +5,7 @@ from collections import defaultdict from typing import Any, Counter, DefaultDict, Dict, Optional, Tuple, Union import torch.nn as nn +from rich import box from rich.console import Console from rich.table import Table from torch import Tensor @@ -341,7 +342,8 @@ def parameter_count_table(model: nn.Module, max_depth: int = 3) -> str: rows.append(('model', format_size(count.pop('')))) fill(0, '') - table = Table(title=f'parameter count of {model.__class__.__name__}') + table = Table( + title=f'parameter count of {model.__class__.__name__}', box=box.ASCII2) table.add_column('name') table.add_column('#elements or shape') diff --git a/mmengine/analysis/print_helper.py b/mmengine/analysis/print_helper.py index eeb0854c..c2ede9de 100644 --- a/mmengine/analysis/print_helper.py +++ b/mmengine/analysis/print_helper.py @@ -7,6 +7,7 @@ from collections import defaultdict from typing import Any, Dict, Iterable, List, Optional, Set, Tuple import torch +from rich import box from rich.console import Console from rich.table import Table from torch import nn @@ -421,7 +422,7 @@ def _stats_table_format( root_prefix = root_name + ('.' if root_name else '') fill(indent_lvl=1, prefix=root_prefix) - table = Table() + table = Table(box=box.ASCII2) for header in headers: table.add_column(header)