Compare commits

...

17 Commits

Author SHA1 Message Date
Mashiro 390ba2fbb2
Bump version to 0.10.7 2025-03-04 20:20:27 +08:00
Epiphany d620552c2c
[fix] fix ci () 2025-02-27 15:13:38 +08:00
Mashiro 41fa84a9a9
[Fix] remove torch dependencies in `build_function.py` () 2025-02-18 16:38:51 +08:00
Mashiro 698782f920
[Fix] Fix deploy ci ()
* [Enhance] Support trigger ci manually

* [Fix] Fix deploy CI
2025-01-15 18:11:05 +08:00
Mashiro e60ab1dde3
[Enhance] Support trigger ci manually () 2025-01-15 18:07:34 +08:00
Mashiro 8ec837814e
[Enhance] Support trigger ci manually () 2025-01-15 17:58:23 +08:00
Qian Zhao a4475f5eea
Update deploy.yml () 2025-01-15 17:34:07 +08:00
Mashiro a8c74c346d
Bump version to v0.10.6 () 2025-01-13 19:20:26 +08:00
Epiphany 9124ebf7a2
[Enhance] ensure type in cfg ()
* ensure type in cfg

* change import level
2024-11-06 20:26:47 +08:00
Epiphany 2e0ab7a922
[Fix] fix Adafactor optim on torch2.5 and fix compatibility ()
* fix Adafactor opptim on torch2.5 and fix compatibility

* fix runtest error
2024-11-05 21:22:46 +08:00
Epiphany fc59364d64
[Fix] fix error when pytest>=8.2 () 2024-11-05 20:43:17 +08:00
Tibor Reiss 4183cf0829
Fix return in finally () 2024-11-04 14:39:25 +08:00
Mashiro cc3b74b5e8
[Fix] Fix lint ()
* [Fix] Fix lint

* [Fix] Fix lint

* Update mmengine/dist/utils.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

---------

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
2024-11-02 22:23:51 +08:00
Mashiro c9b59962d6
Yehc/bump version to v0.10.5 ()
* fix check builtin module

* bump version to v0.10.5

* bump version to v0.10.5
2024-09-20 15:42:06 +08:00
Mashiro 5e736b143b
fix check builtin module () 2024-09-11 18:45:24 +08:00
Chris Jiang 85c83ba616
Update is_mlu_available ()
* Update is_mlu_available

to adapt torch_mlu main, the torch.is_mlu_available method is removed

* Update utils.py

* Update utils.py
2024-05-30 10:07:45 +08:00
fanqiNO1 d1f1aabf81
[Feature] Support calculating loss during validation () 2024-05-17 15:27:53 +08:00
59 changed files with 393 additions and 290 deletions

View File

@ -1,6 +1,8 @@
name: deploy
on: push
on:
- push
- workflow_dispatch
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
@ -9,13 +11,14 @@ concurrency:
jobs:
build-n-publish:
runs-on: ubuntu-latest
if: startsWith(github.event.ref, 'refs/tags')
if: |
startsWith(github.event.ref, 'refs/tags') || github.event_name == 'workflow_dispatch'
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
uses: actions/setup-python@v2
- uses: actions/checkout@v4
- name: Set up Python 3.10.13
uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.10.13
- name: Install wheel
run: pip install wheel
- name: Build MMEngine
@ -27,13 +30,14 @@ jobs:
build-n-publish-lite:
runs-on: ubuntu-latest
if: startsWith(github.event.ref, 'refs/tags')
if: |
startsWith(github.event.ref, 'refs/tags') || github.event_name == 'workflow_dispatch'
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
uses: actions/setup-python@v2
- uses: actions/checkout@v4
- name: Set up Python 3.10.13
uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.10.13
- name: Install wheel
run: pip install wheel
- name: Build MMEngine-lite

View File

@ -11,10 +11,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
- name: Set up Python 3.10.15
uses: actions/setup-python@v2
with:
python-version: 3.7
python-version: '3.10.15'
- name: Install pre-commit hook
run: |
pip install pre-commit

View File

@ -1,5 +1,9 @@
name: pr_stage_test
env:
ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true
on:
pull_request:
paths-ignore:
@ -21,152 +25,114 @@ concurrency:
jobs:
build_cpu:
runs-on: ubuntu-22.04
defaults:
run:
shell: bash -l {0}
strategy:
matrix:
python-version: [3.7]
include:
- torch: 1.8.1
torchvision: 0.9.1
python-version: ['3.9']
torch: ['2.0.0']
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
- name: Check out repo
uses: actions/checkout@v3
- name: Setup conda env
uses: conda-incubator/setup-miniconda@v2
with:
auto-update-conda: true
miniconda-version: "latest"
use-only-tar-bz2: true
activate-environment: test
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
run: python -m pip install pip --upgrade
- name: Upgrade wheel
run: python -m pip install wheel --upgrade
- name: Install PyTorch
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
- name: Build MMEngine from source
run: pip install -e . -v
- name: Install unit tests dependencies
- name: Update pip
run: |
pip install -r requirements/tests.txt
pip install openmim
mim install mmcv
- name: Run unittests and generate coverage report
python -m pip install --upgrade pip wheel
- name: Install dependencies
run: |
coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist
coverage xml
coverage report -m
# Upload coverage report for python3.7 && pytorch1.8.1 cpu
- name: Upload coverage to Codecov
apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6
python -m pip install torch==${{matrix.torch}}
python -m pip install -e . -v
python -m pip install -r requirements/tests.txt
python -m pip install openmim
mim install mmcv coverage
- name: Run unit tests with coverage
run: coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist
- name: Upload Coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
env_vars: OS,PYTHON
name: codecov-umbrella
fail_ci_if_error: false
build_cu102:
build_gpu:
runs-on: ubuntu-22.04
container:
image: pytorch/pytorch:1.8.1-cuda10.2-cudnn7-devel
defaults:
run:
shell: bash -l {0}
env:
MKL_THREADING_LAYER: GNU
strategy:
matrix:
python-version: [3.7]
python-version: ['3.9','3.10']
torch: ['2.0.0','2.3.1','2.5.1']
cuda: ['cu118']
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
- name: Check out repo
uses: actions/checkout@v3
- name: Setup conda env
uses: conda-incubator/setup-miniconda@v2
with:
auto-update-conda: true
miniconda-version: "latest"
use-only-tar-bz2: true
activate-environment: test
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
run: pip install pip --upgrade
- name: Fetch GPG keys
- name: Update pip
run: |
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
- name: Install system dependencies
run: apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6
- name: Build MMEngine from source
run: pip install -e . -v
- name: Install unit tests dependencies
python -m pip install --upgrade pip wheel
- name: Install dependencies
run: |
pip install -r requirements/tests.txt
pip install openmim
mim install mmcv
- name: Run unittests and generate coverage report
run: |
coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist
coverage xml
coverage report -m
apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6
python -m pip install torch==${{matrix.torch}} --index-url https://download.pytorch.org/whl/${{matrix.cuda}}
python -m pip install -e . -v
python -m pip install -r requirements/tests.txt
python -m pip install openmim
mim install mmcv coverage
- name: Run unit tests with coverage
run: coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist
- name: Upload Coverage to Codecov
uses: codecov/codecov-action@v3
build_cu117:
runs-on: ubuntu-22.04
container:
image: pytorch/pytorch:2.0.0-cuda11.7-cudnn8-devel
strategy:
matrix:
python-version: [3.9]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
run: pip install pip --upgrade
- name: Fetch GPG keys
run: |
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
- name: Install system dependencies
run: apt-get update && apt-get install -y git ffmpeg libturbojpeg
- name: Build MMEngine from source
run: pip install -e . -v
- name: Install unit tests dependencies
run: |
pip install -r requirements/tests.txt
pip install openmim
mim install mmcv
# Distributed related unit test may randomly error in PyTorch 1.13.0
- name: Run unittests and generate coverage report
run: |
coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist/
coverage xml
coverage report -m
build_windows:
runs-on: windows-2022
strategy:
matrix:
python-version: [3.7]
platform: [cpu, cu111]
torch: [1.8.1]
torchvision: [0.9.1]
include:
- python-version: 3.8
platform: cu118
torch: 2.1.0
torchvision: 0.16.0
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
# Windows CI could fail If we call `pip install pip --upgrade` directly.
run: python -m pip install pip --upgrade
- name: Install PyTorch
run: pip install torch==${{matrix.torch}}+${{matrix.platform}} torchvision==${{matrix.torchvision}}+${{matrix.platform}} -f https://download.pytorch.org/whl/${{matrix.platform}}/torch_stable.html
- name: Build MMEngine from source
run: pip install -e . -v
- name: Install unit tests dependencies
run: |
pip install -r requirements/tests.txt
pip install openmim
mim install mmcv
- name: Run CPU unittests
run: pytest tests/ --ignore tests/test_dist
if: ${{ matrix.platform == 'cpu' }}
- name: Run GPU unittests
# Skip testing distributed related unit tests since the memory of windows CI is limited
run: pytest tests/ --ignore tests/test_dist --ignore tests/test_optim/test_optimizer/test_optimizer_wrapper.py --ignore tests/test_model/test_wrappers/test_model_wrapper.py --ignore tests/test_hooks/test_sync_buffers_hook.py
if: ${{ matrix.platform == 'cu111' }} || ${{ matrix.platform == 'cu118' }}
# build_windows:
# runs-on: windows-2022
# strategy:
# matrix:
# python-version: [3.9]
# platform: [cpu, cu111]
# torch: [1.8.1]
# torchvision: [0.9.1]
# include:
# - python-version: 3.8
# platform: cu118
# torch: 2.1.0
# torchvision: 0.16.0
# steps:
# - uses: actions/checkout@v3
# - name: Set up Python ${{ matrix.python-version }}
# uses: actions/setup-python@v4
# with:
# python-version: ${{ matrix.python-version }}
# - name: Upgrade pip
# # Windows CI could fail If we call `pip install pip --upgrade` directly.
# run: python -m pip install pip wheel --upgrade
# - name: Install PyTorch
# run: pip install torch==${{matrix.torch}}+${{matrix.platform}} torchvision==${{matrix.torchvision}}+${{matrix.platform}} -f https://download.pytorch.org/whl/${{matrix.platform}}/torch_stable.html
# - name: Build MMEngine from source
# run: pip install -e . -v
# - name: Install unit tests dependencies
# run: |
# pip install -r requirements/tests.txt
# pip install openmim
# mim install mmcv
# - name: Run CPU unittests
# run: pytest tests/ --ignore tests/test_dist
# if: ${{ matrix.platform == 'cpu' }}
# - name: Run GPU unittests
# # Skip testing distributed related unit tests since the memory of windows CI is limited
# run: pytest tests/ --ignore tests/test_dist --ignore tests/test_optim/test_optimizer/test_optimizer_wrapper.py --ignore tests/test_model/test_wrappers/test_model_wrapper.py --ignore tests/test_hooks/test_sync_buffers_hook.py
# if: ${{ matrix.platform == 'cu111' }} || ${{ matrix.platform == 'cu118' }}

View File

@ -1,7 +1,11 @@
exclude: ^tests/data/
repos:
- repo: https://gitee.com/openmmlab/mirrors-flake8
rev: 5.0.4
- repo: https://github.com/pre-commit/pre-commit
rev: v4.0.0
hooks:
- id: validate_manifest
- repo: https://github.com/PyCQA/flake8
rev: 7.1.1
hooks:
- id: flake8
- repo: https://gitee.com/openmmlab/mirrors-isort
@ -13,7 +17,7 @@ repos:
hooks:
- id: yapf
- repo: https://gitee.com/openmmlab/mirrors-pre-commit-hooks
rev: v4.3.0
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: check-yaml
@ -55,7 +59,7 @@ repos:
args: ["mmengine", "tests"]
- id: remove-improper-eol-in-cn-docs
- repo: https://gitee.com/openmmlab/mirrors-mypy
rev: v0.812
rev: v1.2.0
hooks:
- id: mypy
exclude: |-
@ -63,3 +67,4 @@ repos:
^examples
| ^docs
)
additional_dependencies: ["types-setuptools", "types-requests", "types-PyYAML"]

View File

@ -1,7 +1,11 @@
exclude: ^tests/data/
repos:
- repo: https://github.com/pre-commit/pre-commit
rev: v4.0.0
hooks:
- id: validate_manifest
- repo: https://github.com/PyCQA/flake8
rev: 5.0.4
rev: 7.1.1
hooks:
- id: flake8
- repo: https://github.com/PyCQA/isort
@ -13,7 +17,7 @@ repos:
hooks:
- id: yapf
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: check-yaml
@ -34,12 +38,8 @@ repos:
- mdformat-openmmlab
- mdformat_frontmatter
- linkify-it-py
- repo: https://github.com/codespell-project/codespell
rev: v2.2.1
hooks:
- id: codespell
- repo: https://github.com/myint/docformatter
rev: v1.3.1
rev: 06907d0
hooks:
- id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"]
@ -55,7 +55,7 @@ repos:
args: ["mmengine", "tests"]
- id: remove-improper-eol-in-cn-docs
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.812
rev: v1.2.0
hooks:
- id: mypy
exclude: |-
@ -63,3 +63,4 @@ repos:
^examples
| ^docs
)
additional_dependencies: ["types-setuptools", "types-requests", "types-PyYAML"]

View File

@ -59,7 +59,7 @@ English | [简体中文](README_zh-CN.md)
## What's New
v0.10.4 was released on 2024-4-23.
v0.10.6 was released on 2025-01-13.
Highlights:

View File

@ -59,7 +59,7 @@
## 最近进展
最新版本 v0.10.4 在 2024.4.23 发布。
最新版本 v0.10.5 在 2024.9.11 发布。
版本亮点:

View File

@ -1,5 +1,9 @@
# Changelog of v0.x
## v0.10.5 (11/9/2024)
- Fix `_is_builtin_module`. by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/1571
## v0.10.4 (23/4/2024)
### New Features & Enhancements

View File

@ -499,7 +499,7 @@ class BaseStrategy(metaclass=ABCMeta):
'"type" and "constructor" are not in '
f'optimizer, but got {name}={optim}')
optim_wrappers[name] = optim
return OptimWrapperDict(**optim_wrappers)
return OptimWrapperDict(**optim_wrappers) # type: ignore
else:
raise TypeError('optimizer wrapper should be an OptimWrapper '
f'object or dict, but got {optim_wrapper}')

View File

@ -361,7 +361,7 @@ class ColossalAIStrategy(BaseStrategy):
map_location: Union[str, Callable] = 'default',
callback: Optional[Callable] = None,
) -> dict:
"""override this method since colossalai resume optimizer from filename
"""Override this method since colossalai resume optimizer from filename
directly."""
self.logger.info(f'Resume checkpoint from {filename}')

View File

@ -53,7 +53,7 @@ class DDPStrategy(SingleDeviceStrategy):
init_dist(launcher, backend, **kwargs)
def convert_model(self, model: nn.Module) -> nn.Module:
"""convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm``
"""Convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm``
(SyncBN) or ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers.
Args:

View File

@ -393,7 +393,7 @@ class Config:
def __init__(
self,
cfg_dict: dict = None,
cfg_dict: Optional[dict] = None,
cfg_text: Optional[str] = None,
filename: Optional[Union[str, Path]] = None,
env_variables: Optional[dict] = None,
@ -1227,7 +1227,8 @@ class Config:
if base_code is not None:
base_code = ast.Expression( # type: ignore
body=base_code.value) # type: ignore
base_files = eval(compile(base_code, '', mode='eval'))
base_files = eval(compile(base_code, '',
mode='eval')) # type: ignore
else:
base_files = []
elif file_format in ('.yml', '.yaml', '.json'):
@ -1288,7 +1289,7 @@ class Config:
def _merge_a_into_b(a: dict,
b: dict,
allow_list_keys: bool = False) -> dict:
"""merge dict ``a`` into dict ``b`` (non-inplace).
"""Merge dict ``a`` into dict ``b`` (non-inplace).
Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
in-place modifications.
@ -1358,22 +1359,22 @@ class Config:
@property
def filename(self) -> str:
"""get file name of config."""
"""Get file name of config."""
return self._filename
@property
def text(self) -> str:
"""get config text."""
"""Get config text."""
return self._text
@property
def env_variables(self) -> dict:
"""get used environment variables."""
"""Get used environment variables."""
return self._env_variables
@property
def pretty_text(self) -> str:
"""get formatted python config text."""
"""Get formatted python config text."""
indent = 4
@ -1727,17 +1728,17 @@ class Config:
class DictAction(Action):
"""
argparse action to split an argument into KEY=VALUE form
on the first = and append to a dictionary. List options can
be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
"""Argparse action to split an argument into KEY=VALUE form on the first =
and append to a dictionary.
List options can be passed as comma separated values, i.e 'KEY=V1,V2,V3',
or with explicit brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested
brackets to build list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
"""
@staticmethod
def _parse_int_float_bool(val: str) -> Union[int, float, bool, Any]:
"""parse int/float/bool value in the string."""
"""Parse int/float/bool value in the string."""
try:
return int(val)
except ValueError:
@ -1822,7 +1823,7 @@ class DictAction(Action):
parser: ArgumentParser,
namespace: Namespace,
values: Union[str, Sequence[Any], None],
option_string: str = None):
option_string: str = None): # type: ignore
"""Parse Variables in string and add them into argparser.
Args:

View File

@ -12,6 +12,7 @@ from mmengine.fileio import load
from mmengine.utils import check_file_exist
PYTHON_ROOT_DIR = osp.dirname(osp.dirname(sys.executable))
SYSTEM_PYTHON_PREFIX = '/usr/lib/python'
MODULE2PACKAGE = {
'mmcls': 'mmcls',
@ -176,7 +177,8 @@ def _is_builtin_module(module_name: str) -> bool:
return False
origin_path = osp.abspath(origin_path)
if ('site-package' in origin_path or 'dist-package' in origin_path
or not origin_path.startswith(PYTHON_ROOT_DIR)):
or not origin_path.startswith(
(PYTHON_ROOT_DIR, SYSTEM_PYTHON_PREFIX))):
return False
else:
return True

View File

@ -16,6 +16,12 @@ try:
except Exception:
IS_NPU_AVAILABLE = False
try:
import torch_mlu # noqa: F401
IS_MLU_AVAILABLE = hasattr(torch, 'mlu') and torch.mlu.is_available()
except Exception:
IS_MLU_AVAILABLE = False
try:
import torch_dipu # noqa: F401
IS_DIPU_AVAILABLE = True
@ -64,7 +70,7 @@ def is_npu_available() -> bool:
def is_mlu_available() -> bool:
"""Returns True if Cambricon PyTorch and mlu devices exist."""
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
return IS_MLU_AVAILABLE
def is_mps_available() -> bool:

View File

@ -563,7 +563,7 @@ def cast_data_device(
Tensor or list or dict: ``data`` was casted to ``device``.
"""
if out is not None:
if type(data) != type(out):
if type(data) is not type(out):
raise TypeError(
'out should be the same type with data, but got data is '
f'{type(data)} and out is {type(data)}')

View File

@ -175,11 +175,11 @@ class DumpResults(BaseMetric):
self.out_file_path = out_file_path
def process(self, data_batch: Any, predictions: Sequence[dict]) -> None:
"""transfer tensors in predictions to CPU."""
"""Transfer tensors in predictions to CPU."""
self.results.extend(_to_cpu(predictions))
def compute_metrics(self, results: list) -> dict:
"""dump the prediction results to a pickle file."""
"""Dump the prediction results to a pickle file."""
dump(results, self.out_file_path)
print_log(
f'Results has been saved to {self.out_file_path}.',
@ -188,7 +188,7 @@ class DumpResults(BaseMetric):
def _to_cpu(data: Any) -> Any:
"""transfer all tensors and BaseDataElement to cpu."""
"""Transfer all tensors and BaseDataElement to cpu."""
if isinstance(data, (Tensor, BaseDataElement)):
return data.to('cpu')
elif isinstance(data, list):

View File

@ -233,7 +233,7 @@ class ProfilerHook(Hook):
self._export_chrome_trace(runner)
def after_train_iter(self, runner, batch_idx, data_batch, outputs):
"""profiler will call `step` method if it is not closed."""
"""Profiler will call `step` method if it is not closed."""
if not self._closed:
self.profiler.step()
if runner.iter == self.profile_times - 1 and not self.by_epoch:

View File

@ -58,7 +58,7 @@ class HistoryBuffer:
self._statistics_methods.setdefault('mean', HistoryBuffer.mean)
def update(self, log_val: Union[int, float], count: int = 1) -> None:
"""update the log history.
"""Update the log history.
If the length of the buffer exceeds ``self._max_length``, the oldest
element will be removed from the buffer.

View File

@ -443,8 +443,7 @@ def _get_host_info() -> str:
host = f'{getuser()}@{gethostname()}'
except Exception as e:
warnings.warn(f'Host or user not found: {str(e)}')
finally:
return host
return host
def _get_logging_file_handlers() -> Dict:

View File

@ -253,17 +253,17 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
dict or list: Data in the same format as the model input.
"""
data = self.cast_data(data) # type: ignore
_batch_inputs = data['inputs']
_batch_inputs = data['inputs'] # type: ignore
# Process data with `pseudo_collate`.
if is_seq_of(_batch_inputs, torch.Tensor):
batch_inputs = []
for _batch_input in _batch_inputs:
# channel transform
if self._channel_conversion:
_batch_input = _batch_input[[2, 1, 0], ...]
_batch_input = _batch_input[[2, 1, 0], ...] # type: ignore
# Convert to float after channel conversion to ensure
# efficiency
_batch_input = _batch_input.float()
_batch_input = _batch_input.float() # type: ignore
# Normalization.
if self._enable_normalize:
if self.mean.shape[0] == 3:
@ -302,7 +302,7 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
else:
raise TypeError('Output of `cast_data` should be a dict of '
'list/tuple with inputs and data_samples, '
f'but got {type(data)}: {data}')
data['inputs'] = batch_inputs
data.setdefault('data_samples', None)
return data
f'but got {type(data)}: {data}') # type: ignore
data['inputs'] = batch_inputs # type: ignore
data.setdefault('data_samples', None) # type: ignore
return data # type: ignore

View File

@ -119,7 +119,7 @@ def caffe2_xavier_init(module, bias=0):
def bias_init_with_prob(prior_prob):
"""initialize conv/fc bias value according to a given probability value."""
"""Initialize conv/fc bias value according to a given probability value."""
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
return bias_init
@ -662,12 +662,12 @@ def trunc_normal_(tensor: Tensor,
std: float = 1.,
a: float = -2.,
b: float = 2.) -> Tensor:
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
r"""Fills the input Tensor with values drawn from a truncated normal
distribution. The values are effectively drawn from the normal distribution
:math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside
:math:`[a, b]` redrawn until they are within the bounds. The method used
for generating the random values works best when :math:`a \leq \text{mean}
\leq b`.
Modified from
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py

View File

@ -127,7 +127,8 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
auto_wrap_policy: Union[str, Callable, None] = None,
backward_prefetch: Union[str, BackwardPrefetch, None] = None,
mixed_precision: Union[dict, MixedPrecision, None] = None,
param_init_fn: Union[str, Callable[[nn.Module], None]] = None,
param_init_fn: Union[str, Callable[
[nn.Module], None]] = None, # type: ignore # noqa: E501
use_orig_params: bool = True,
**kwargs,
):
@ -362,7 +363,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
optim: torch.optim.Optimizer,
group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
"""copied from pytorch 2.0.1 which has fixed some bugs."""
"""Copied from pytorch 2.0.1 which has fixed some bugs."""
state_dict_settings = FullyShardedDataParallel.get_state_dict_type(
model)
return FullyShardedDataParallel._optim_state_dict_impl(
@ -384,7 +385,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
state_dict_config: Optional[StateDictConfig] = None,
optim_state_dict_config: Optional[OptimStateDictConfig] = None,
) -> StateDictSettings:
"""copied from pytorch 2.0.1 which has fixed some bugs."""
"""Copied from pytorch 2.0.1 which has fixed some bugs."""
import torch.distributed.fsdp._traversal_utils as traversal_utils
_state_dict_type_to_config = {
StateDictType.FULL_STATE_DICT: FullStateDictConfig,

View File

@ -123,8 +123,7 @@ class ApexOptimWrapper(OptimWrapper):
self._inner_count += 1
def state_dict(self) -> dict:
"""Get the state dictionary of :attr:`optimizer` and
:attr:`apex_amp`.
"""Get the state dictionary of :attr:`optimizer` and :attr:`apex_amp`.
Based on the state dictionary of the optimizer, the returned state
dictionary will add a key named "apex_amp".

View File

@ -25,7 +25,11 @@ def register_torch_optimizers() -> List[str]:
_optim = getattr(torch.optim, module_name)
if inspect.isclass(_optim) and issubclass(_optim,
torch.optim.Optimizer):
OPTIMIZERS.register_module(module=_optim)
if module_name == 'Adafactor':
OPTIMIZERS.register_module(
name='TorchAdafactor', module=_optim)
else:
OPTIMIZERS.register_module(module=_optim)
torch_optimizers.append(module_name)
return torch_optimizers

View File

@ -131,7 +131,7 @@ class DefaultOptimWrapperConstructor:
self._validate_cfg()
def _validate_cfg(self) -> None:
"""verify the correctness of the config."""
"""Verify the correctness of the config."""
if not isinstance(self.paramwise_cfg, dict):
raise TypeError('paramwise_cfg should be None or a dict, '
f'but got {type(self.paramwise_cfg)}')
@ -155,7 +155,7 @@ class DefaultOptimWrapperConstructor:
raise ValueError('base_wd should not be None')
def _is_in(self, param_group: dict, param_group_list: list) -> bool:
"""check whether the `param_group` is in the`param_group_list`"""
"""Check whether the `param_group` is in the`param_group_list`"""
assert is_list_of(param_group_list, dict)
param = set(param_group['params'])
param_set = set()

View File

@ -161,8 +161,7 @@ class OptimWrapperDict(OptimWrapper):
self.optim_wrappers[name].load_state_dict(_state_dict)
def items(self) -> Iterator[Tuple[str, OptimWrapper]]:
"""A generator to get the name and corresponding
:obj:`OptimWrapper`"""
"""A generator to get the name and corresponding :obj:`OptimWrapper`"""
yield from self.optim_wrappers.items()
def values(self) -> Iterator[OptimWrapper]:

View File

@ -223,13 +223,13 @@ class PolyLR(LRSchedulerMixin, PolyParamScheduler):
@PARAM_SCHEDULERS.register_module()
class OneCycleLR(LRSchedulerMixin, OneCycleParamScheduler):
r"""Sets the learning rate of each parameter group according to the
1cycle learning rate policy. The 1cycle policy anneals the learning
rate from an initial learning rate to some maximum learning rate and then
from that maximum learning rate to some minimum learning rate much lower
than the initial learning rate.
This policy was initially described in the paper `Super-Convergence:
Very Fast Training of Neural Networks Using Large Learning Rates`_.
r"""Sets the learning rate of each parameter group according to the 1cycle
learning rate policy. The 1cycle policy anneals the learning rate from an
initial learning rate to some maximum learning rate and then from that
maximum learning rate to some minimum learning rate much lower than the
initial learning rate. This policy was initially described in the paper
`Super-Convergence: Very Fast Training of Neural Networks Using Large
Learning Rates`_.
The 1cycle learning rate policy changes the learning rate after every
batch. `step` should be called after a batch has been used for training.

View File

@ -565,9 +565,9 @@ class ExponentialParamScheduler(_ParamScheduler):
@PARAM_SCHEDULERS.register_module()
class CosineAnnealingParamScheduler(_ParamScheduler):
r"""Set the parameter value of each parameter group using a cosine
annealing schedule, where :math:`\eta_{max}` is set to the initial value
and :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
r"""Set the parameter value of each parameter group using a cosine annealing
schedule, where :math:`\eta_{max}` is set to the initial value and
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
.. math::
\begin{aligned}
@ -617,7 +617,7 @@ class CosineAnnealingParamScheduler(_ParamScheduler):
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
"""
""" # noqa: E501
def __init__(self,
optimizer: Union[Optimizer, BaseOptimWrapper],
@ -890,13 +890,13 @@ class PolyParamScheduler(_ParamScheduler):
@PARAM_SCHEDULERS.register_module()
class OneCycleParamScheduler(_ParamScheduler):
r"""Sets the parameters of each parameter group according to the
1cycle learning rate policy. The 1cycle policy anneals the learning
rate from an initial learning rate to some maximum learning rate and then
from that maximum learning rate to some minimum learning rate much lower
than the initial learning rate.
This policy was initially described in the paper `Super-Convergence:
Very Fast Training of Neural Networks Using Large Learning Rates`_.
r"""Sets the parameters of each parameter group according to the 1cycle
learning rate policy. The 1cycle policy anneals the learning rate from an
initial learning rate to some maximum learning rate and then from that
maximum learning rate to some minimum learning rate much lower than the
initial learning rate. This policy was initially described in the paper
`Super-Convergence: Very Fast Training of Neural Networks Using Large
Learning Rates`_.
The 1cycle learning rate policy changes the learning rate after every
batch. `step` should be called after a batch has been used for training.

View File

@ -4,7 +4,7 @@ import logging
from typing import TYPE_CHECKING, Any, Optional, Union
from mmengine.config import Config, ConfigDict
from mmengine.utils import ManagerMixin
from mmengine.utils import ManagerMixin, digit_version
from .registry import Registry
if TYPE_CHECKING:
@ -232,6 +232,21 @@ def build_model_from_cfg(
return build_from_cfg(cfg, registry, default_args)
def build_optimizer_from_cfg(
cfg: Union[dict, ConfigDict, Config],
registry: Registry,
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any:
import torch
from ..logging import print_log
if 'type' in cfg \
and 'Adafactor' == cfg['type'] \
and digit_version(torch.__version__) >= digit_version('2.5.0'):
print_log(
'the torch version of Adafactor is registered as TorchAdafactor')
return build_from_cfg(cfg, registry, default_args)
def build_scheduler_from_cfg(
cfg: Union[dict, ConfigDict, Config],
registry: Registry,

View File

@ -81,7 +81,7 @@ class DefaultScope(ManagerMixin):
@classmethod
@contextmanager
def overwrite_default_scope(cls, scope_name: Optional[str]) -> Generator:
"""overwrite the current default scope with `scope_name`"""
"""Overwrite the current default scope with `scope_name`"""
if scope_name is None:
yield
else:

View File

@ -332,7 +332,7 @@ class Registry:
return root
def import_from_location(self) -> None:
"""import modules from the pre-defined locations in self._location."""
"""Import modules from the pre-defined locations in self._location."""
if not self._imported:
# Avoid circular import
from ..logging import print_log

View File

@ -6,8 +6,8 @@ More datails can be found at
https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html.
"""
from .build_functions import (build_model_from_cfg, build_runner_from_cfg,
build_scheduler_from_cfg)
from .build_functions import (build_model_from_cfg, build_optimizer_from_cfg,
build_runner_from_cfg, build_scheduler_from_cfg)
from .registry import Registry
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
@ -35,7 +35,7 @@ MODEL_WRAPPERS = Registry('model_wrapper')
WEIGHT_INITIALIZERS = Registry('weight initializer')
# mangage all kinds of optimizers like `SGD` and `Adam`
OPTIMIZERS = Registry('optimizer')
OPTIMIZERS = Registry('optimizer', build_func=build_optimizer_from_cfg)
# manage optimizer wrapper
OPTIM_WRAPPERS = Registry('optim_wrapper')
# manage constructors that customize the optimization hyperparameters.

View File

@ -109,7 +109,7 @@ def init_default_scope(scope: str) -> None:
if current_scope.scope_name != scope: # type: ignore
print_log(
'The current default scope ' # type: ignore
f'"{current_scope.scope_name}" is not "{scope}", '
f'"{current_scope.scope_name}" is not "{scope}", ' # type: ignore
'`init_default_scope` will force set the current'
f'default scope to "{scope}".',
logger='current',

View File

@ -540,7 +540,7 @@ class FlexibleRunner:
@property
def hooks(self):
"""list[:obj:`Hook`]: A list of registered hooks."""
"""List[:obj:`Hook`]: A list of registered hooks."""
return self._hooks
@property
@ -1117,7 +1117,7 @@ class FlexibleRunner:
return '\n'.join(stage_hook_infos)
def load_or_resume(self):
"""load or resume checkpoint."""
"""Load or resume checkpoint."""
if self._has_loaded:
return None
@ -1539,7 +1539,7 @@ class FlexibleRunner:
file_client_args: Optional[dict] = None,
save_optimizer: bool = True,
save_param_scheduler: bool = True,
meta: dict = None,
meta: Optional[dict] = None,
by_epoch: bool = True,
backend_args: Optional[dict] = None,
):

View File

@ -309,7 +309,7 @@ class CheckpointLoader:
@classmethod
def load_checkpoint(cls, filename, map_location=None, logger='current'):
"""load checkpoint through URL scheme path.
"""Load checkpoint through URL scheme path.
Args:
filename (str): checkpoint file name with given prefix
@ -332,7 +332,7 @@ class CheckpointLoader:
@CheckpointLoader.register_scheme(prefixes='')
def load_from_local(filename, map_location):
"""load checkpoint by local file path.
"""Load checkpoint by local file path.
Args:
filename (str): local checkpoint file path
@ -353,7 +353,7 @@ def load_from_http(filename,
map_location=None,
model_dir=None,
progress=os.isatty(0)):
"""load checkpoint through HTTP or HTTPS scheme path. In distributed
"""Load checkpoint through HTTP or HTTPS scheme path. In distributed
setting, this function only download checkpoint at local rank 0.
Args:
@ -386,7 +386,7 @@ def load_from_http(filename,
@CheckpointLoader.register_scheme(prefixes='pavi://')
def load_from_pavi(filename, map_location=None):
"""load checkpoint through the file path prefixed with pavi. In distributed
"""Load checkpoint through the file path prefixed with pavi. In distributed
setting, this function download ckpt at all ranks to different temporary
directories.
@ -419,7 +419,7 @@ def load_from_pavi(filename, map_location=None):
@CheckpointLoader.register_scheme(
prefixes=[r'(\S+\:)?s3://', r'(\S+\:)?petrel://'])
def load_from_ceph(filename, map_location=None, backend='petrel'):
"""load checkpoint through the file path prefixed with s3. In distributed
"""Load checkpoint through the file path prefixed with s3. In distributed
setting, this function download ckpt at all ranks to different temporary
directories.
@ -441,7 +441,7 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
def load_from_torchvision(filename, map_location=None):
"""load checkpoint through the file path prefixed with modelzoo or
"""Load checkpoint through the file path prefixed with modelzoo or
torchvision.
Args:
@ -467,7 +467,7 @@ def load_from_torchvision(filename, map_location=None):
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
def load_from_openmmlab(filename, map_location=None):
"""load checkpoint through the file path prefixed with open-mmlab or
"""Load checkpoint through the file path prefixed with open-mmlab or
openmmlab.
Args:
@ -510,7 +510,7 @@ def load_from_openmmlab(filename, map_location=None):
@CheckpointLoader.register_scheme(prefixes='mmcls://')
def load_from_mmcls(filename, map_location=None):
"""load checkpoint through the file path prefixed with mmcls.
"""Load checkpoint through the file path prefixed with mmcls.
Args:
filename (str): checkpoint file path with mmcls prefix

View File

@ -8,8 +8,10 @@ import torch
from torch.utils.data import DataLoader
from mmengine.evaluator import Evaluator
from mmengine.logging import print_log
from mmengine.logging import HistoryBuffer, print_log
from mmengine.registry import LOOPS
from mmengine.structures import BaseDataElement
from mmengine.utils import is_list_of
from .amp import autocast
from .base_loop import BaseLoop
from .utils import calc_dynamic_intervals
@ -363,17 +365,26 @@ class ValLoop(BaseLoop):
logger='current',
level=logging.WARNING)
self.fp16 = fp16
self.val_loss: Dict[str, HistoryBuffer] = dict()
def run(self) -> dict:
"""Launch validation."""
self.runner.call_hook('before_val')
self.runner.call_hook('before_val_epoch')
self.runner.model.eval()
# clear val loss
self.val_loss.clear()
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)
# compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
if self.val_loss:
loss_dict = _parse_losses(self.val_loss, 'val')
metrics.update(loss_dict)
self.runner.call_hook('after_val_epoch', metrics=metrics)
self.runner.call_hook('after_val')
return metrics
@ -391,6 +402,9 @@ class ValLoop(BaseLoop):
# outputs should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
outputs = self.runner.model.val_step(data_batch)
outputs, self.val_loss = _update_losses(outputs, self.val_loss)
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
'after_val_iter',
@ -435,17 +449,26 @@ class TestLoop(BaseLoop):
logger='current',
level=logging.WARNING)
self.fp16 = fp16
self.test_loss: Dict[str, HistoryBuffer] = dict()
def run(self) -> dict:
"""Launch test."""
self.runner.call_hook('before_test')
self.runner.call_hook('before_test_epoch')
self.runner.model.eval()
# clear test loss
self.test_loss.clear()
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)
# compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
if self.test_loss:
loss_dict = _parse_losses(self.test_loss, 'test')
metrics.update(loss_dict)
self.runner.call_hook('after_test_epoch', metrics=metrics)
self.runner.call_hook('after_test')
return metrics
@ -462,9 +485,66 @@ class TestLoop(BaseLoop):
# predictions should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
outputs = self.runner.model.test_step(data_batch)
outputs, self.test_loss = _update_losses(outputs, self.test_loss)
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
'after_test_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=outputs)
def _parse_losses(losses: Dict[str, HistoryBuffer],
stage: str) -> Dict[str, float]:
"""Parses the raw losses of the network.
Args:
losses (dict): raw losses of the network.
stage (str): The stage of loss, e.g., 'val' or 'test'.
Returns:
dict[str, float]: The key is the loss name, and the value is the
average loss.
"""
all_loss = 0
loss_dict: Dict[str, float] = dict()
for loss_name, loss_value in losses.items():
avg_loss = loss_value.mean()
loss_dict[loss_name] = avg_loss
if 'loss' in loss_name:
all_loss += avg_loss
loss_dict[f'{stage}_loss'] = all_loss
return loss_dict
def _update_losses(outputs: list, losses: dict) -> Tuple[list, dict]:
"""Update and record the losses of the network.
Args:
outputs (list): The outputs of the network.
losses (dict): The losses of the network.
Returns:
list: The updated outputs of the network.
dict: The updated losses of the network.
"""
if isinstance(outputs[-1],
BaseDataElement) and outputs[-1].keys() == ['loss']:
loss = outputs[-1].loss # type: ignore
outputs = outputs[:-1]
else:
loss = dict()
for loss_name, loss_value in loss.items():
if loss_name not in losses:
losses[loss_name] = HistoryBuffer()
if isinstance(loss_value, torch.Tensor):
losses[loss_name].update(loss_value.item())
elif is_list_of(loss_value, torch.Tensor):
for loss_value_i in loss_value:
losses[loss_name].update(loss_value_i.item())
return outputs, losses

View File

@ -579,7 +579,7 @@ class Runner:
@property
def hooks(self):
"""list[:obj:`Hook`]: A list of registered hooks."""
"""List[:obj:`Hook`]: A list of registered hooks."""
return self._hooks
@property
@ -720,7 +720,7 @@ class Runner:
def build_logger(self,
log_level: Union[int, str] = 'INFO',
log_file: str = None,
log_file: Optional[str] = None,
**kwargs) -> MMLogger:
"""Build a global asscessable MMLogger.
@ -1677,7 +1677,7 @@ class Runner:
return '\n'.join(stage_hook_infos)
def load_or_resume(self) -> None:
"""load or resume checkpoint."""
"""Load or resume checkpoint."""
if self._has_loaded:
return None

View File

@ -387,7 +387,7 @@ class BaseDataElement:
return dict(self.metainfo_items())
def __setattr__(self, name: str, value: Any):
"""setattr is only used to set data."""
"""Setattr is only used to set data."""
if name in ('_metainfo_fields', '_data_fields'):
if not hasattr(self, name):
super().__setattr__(name, value)

View File

@ -135,7 +135,7 @@ class InstanceData(BaseDataElement):
"""
def __setattr__(self, name: str, value: Sized):
"""setattr is only used to set data.
"""Setattr is only used to set data.
The value must have the attribute of `__len__` and have the same length
of `InstanceData`.

View File

@ -92,10 +92,25 @@ class MultiProcessTestCase(TestCase):
# Constructor patches current instance test method to
# assume the role of the main process and join its subprocesses,
# or run the underlying test function.
def __init__(self, method_name: str = 'runTest') -> None:
def __init__(self,
method_name: str = 'runTest',
methodName: str = 'runTest') -> None:
# methodName is the correct naming in unittest
# and testslide uses keyword arguments.
# So we need to use both to 1) not break BC and, 2) support testslide.
if methodName != 'runTest':
method_name = methodName
super().__init__(method_name)
fn = getattr(self, method_name)
setattr(self, method_name, self.join_or_run(fn))
try:
fn = getattr(self, method_name)
setattr(self, method_name, self.join_or_run(fn))
except AttributeError as e:
if methodName != 'runTest':
# we allow instantiation with no explicit method name
# but not an *incorrect* or missing method name
raise ValueError(
f'no such test method in {self.__class__}: {methodName}'
) from e
def setUp(self) -> None:
super().setUp()

View File

@ -57,6 +57,7 @@ if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version(
check_hash=False,
file_name=None):
r"""Loads the Torch serialized object at the given URL.
If downloaded file is a zip file, it will be automatically decompressed
If the object is already present in `model_dir`, it's deserialized and
returned.

View File

@ -67,7 +67,7 @@ class TimeCounter:
instance.log_interval = log_interval
instance.warmup_interval = warmup_interval
instance.with_sync = with_sync
instance.with_sync = with_sync # type: ignore
instance.tag = tag
instance.logger = logger
@ -127,7 +127,7 @@ class TimeCounter:
self.print_time(elapsed)
def print_time(self, elapsed: Union[int, float]) -> None:
"""print times per count."""
"""Print times per count."""
if self.__count >= self.warmup_interval:
self.__pure_inf_time += elapsed

View File

@ -131,7 +131,7 @@ def tuple_cast(inputs, dst_type):
def is_seq_of(seq: Any,
expected_type: Union[Type, tuple],
seq_type: Type = None) -> bool:
seq_type: Optional[Type] = None) -> bool:
"""Check whether it is a sequence of some type.
Args:

View File

@ -69,11 +69,11 @@ def get_installed_path(package: str) -> str:
else:
raise e
possible_path = osp.join(pkg.location, package)
possible_path = osp.join(pkg.location, package) # type: ignore
if osp.exists(possible_path):
return possible_path
else:
return osp.join(pkg.location, package2module(package))
return osp.join(pkg.location, package2module(package)) # type: ignore
def package2module(package: str):

View File

@ -3,7 +3,7 @@ import sys
from collections.abc import Iterable
from multiprocessing import Pool
from shutil import get_terminal_size
from typing import Callable, Sequence
from typing import Callable, Optional, Sequence
from .timer import Timer
@ -54,7 +54,7 @@ class ProgressBar:
self.timer = Timer()
def update(self, num_tasks: int = 1):
"""update progressbar.
"""Update progressbar.
Args:
num_tasks (int): Update step size.
@ -142,8 +142,8 @@ def init_pool(process_num, initializer=None, initargs=None):
def track_parallel_progress(func: Callable,
tasks: Sequence,
nproc: int,
initializer: Callable = None,
initargs: tuple = None,
initializer: Optional[Callable] = None,
initargs: Optional[tuple] = None,
bar_width: int = 50,
chunksize: int = 1,
skip_first: bool = False,

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from multiprocessing import Pool
from typing import Callable, Iterable, Sized
from typing import Callable, Iterable, Optional, Sized
from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task,
TaskProgressColumn, TextColumn, TimeRemainingColumn)
@ -47,7 +47,7 @@ def _tasks_with_index(tasks):
def track_progress_rich(func: Callable,
tasks: Iterable = tuple(),
task_num: int = None,
task_num: Optional[int] = None,
nproc: int = 1,
chunksize: int = 1,
description: str = 'Processing',

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
__version__ = '0.10.4'
__version__ = '0.10.7'
def parse_version_info(version_str):

View File

@ -161,7 +161,7 @@ class BaseVisBackend(metaclass=ABCMeta):
pass
def close(self) -> None:
"""close an opened object."""
"""Close an opened object."""
pass
@ -314,7 +314,7 @@ class LocalVisBackend(BaseVisBackend):
def _dump(self, value_dict: dict, file_path: str,
file_format: str) -> None:
"""dump dict to file.
"""Dump dict to file.
Args:
value_dict (dict) : The dict data to saved.
@ -505,7 +505,7 @@ class WandbVisBackend(BaseVisBackend):
self._wandb.log(scalar_dict, commit=self._commit)
def close(self) -> None:
"""close an opened wandb object."""
"""Close an opened wandb object."""
if hasattr(self, '_wandb'):
self._wandb.join()
@ -629,7 +629,7 @@ class TensorboardVisBackend(BaseVisBackend):
self.add_scalar(key, value, step)
def close(self):
"""close an opened tensorboard object."""
"""Close an opened tensorboard object."""
if hasattr(self, '_tensorboard'):
self._tensorboard.close()
@ -1135,7 +1135,7 @@ class NeptuneVisBackend(BaseVisBackend):
self._neptune[k].append(v, step=step)
def close(self) -> None:
"""close an opened object."""
"""Close an opened object."""
if hasattr(self, '_neptune'):
self._neptune.stop()
@ -1282,7 +1282,7 @@ class DVCLiveVisBackend(BaseVisBackend):
self.add_scalar(key, value, step, **kwargs)
def close(self) -> None:
"""close an opened dvclive object."""
"""Close an opened dvclive object."""
if not hasattr(self, '_dvclive'):
return

View File

@ -356,7 +356,7 @@ class Visualizer(ManagerMixin):
@master_only
def get_backend(self, name) -> 'BaseVisBackend':
"""get vis backend by name.
"""Get vis backend by name.
Args:
name (str): The name of vis backend
@ -1145,7 +1145,7 @@ class Visualizer(ManagerMixin):
pass
def close(self) -> None:
"""close an opened object."""
"""Close an opened object."""
for vis_backend in self._vis_backends.values():
vis_backend.close()

View File

@ -843,8 +843,8 @@ class TestConfig:
assert cfg_dict['item4'] == 'test'
assert '_delete_' not in cfg_dict['item1']
assert type(cfg_dict['item1']) == ConfigDict
assert type(cfg_dict['item2']) == ConfigDict
assert type(cfg_dict['item1']) is ConfigDict
assert type(cfg_dict['item2']) is ConfigDict
def _merge_intermediate_variable(self):

View File

@ -300,8 +300,8 @@ except ImportError:
get_inputs.append(filepath)
with build_temporary_directory() as tmp_dir, \
patch.object(backend, 'put', side_effect=put),\
patch.object(backend, 'get', side_effect=get),\
patch.object(backend, 'put', side_effect=put), \
patch.object(backend, 'get', side_effect=get), \
patch.object(backend, 'exists', return_value=False):
tmp_dir = tmp_dir.replace('\\', '/')
dst = f'{tmp_dir}/dir'
@ -351,7 +351,7 @@ except ImportError:
with build_temporary_directory() as tmp_dir, \
patch.object(backend, 'copyfile_from_local',
side_effect=copyfile_from_local),\
side_effect=copyfile_from_local), \
patch.object(backend, 'exists', return_value=False):
backend.copytree_from_local(tmp_dir, self.petrel_dir)
@ -427,7 +427,7 @@ except ImportError:
def remove(filepath):
inputs.append(filepath)
with build_temporary_directory() as tmp_dir,\
with build_temporary_directory() as tmp_dir, \
patch.object(backend, 'remove', side_effect=remove):
backend.rmtree(tmp_dir)

View File

@ -13,7 +13,8 @@ from mmengine.testing import assert_allclose
class TestAveragedModel(TestCase):
"""Test the AveragedModel class.
Some test cases are referenced from https://github.com/pytorch/pytorch/blob/master/test/test_optim.py
Some test cases are referenced from
https://github.com/pytorch/pytorch/blob/master/test/test_optim.py
""" # noqa: E501
def _test_swa_model(self, net_device, avg_device):

View File

@ -102,7 +102,7 @@ class TestLRScheduler(TestCase):
rtol=0)
def test_scheduler_before_optim_warning(self):
"""warns if scheduler is used before optimizer."""
"""Warns if scheduler is used before optimizer."""
def call_sch_before_optim():
scheduler = StepLR(self.optimizer, gamma=0.1, step_size=3)

View File

@ -120,7 +120,7 @@ class TestMomentumScheduler(TestCase):
rtol=0)
def test_scheduler_before_optim_warning(self):
"""warns if scheduler is used before optimizer."""
"""Warns if scheduler is used before optimizer."""
def call_sch_before_optim():
scheduler = StepMomentum(self.optimizer, gamma=0.1, step_size=3)

View File

@ -127,7 +127,7 @@ class TestParameterScheduler(TestCase):
rtol=0)
def test_scheduler_before_optim_warning(self):
"""warns if scheduler is used before optimizer."""
"""Warns if scheduler is used before optimizer."""
def call_sch_before_optim():
scheduler = StepParamScheduler(

View File

@ -251,7 +251,7 @@ def test_load_checkpoint_metadata():
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
*args, **kwargs):
"""load checkpoints."""
"""Load checkpoints."""
# Names of some parameters in has been changed.
version = local_metadata.get('version', None)

View File

@ -2226,7 +2226,7 @@ class TestRunner(TestCase):
@HOOKS.register_module(force=True)
class TestWarmupHook(Hook):
"""test custom train loop."""
"""Test custom train loop."""
def before_warmup_iter(self, runner, data_batch=None):
before_warmup_iter_results.append('before')

View File

@ -64,7 +64,7 @@ class TestBaseDataElement(TestCase):
return metainfo, data
def is_equal(self, x, y):
assert type(x) == type(y)
assert type(x) is type(y)
if isinstance(
x, (int, float, str, list, tuple, dict, set, BaseDataElement)):
return x == y
@ -141,7 +141,7 @@ class TestBaseDataElement(TestCase):
# test new() with no arguments
new_instances = instances.new()
assert type(new_instances) == type(instances)
assert type(new_instances) is type(instances)
# After deepcopy, the address of new data'element will be same as
# origin, but when change new data' element will not effect the origin
# element and will have new address
@ -154,7 +154,7 @@ class TestBaseDataElement(TestCase):
# test new() with arguments
metainfo, data = self.setup_data()
new_instances = instances.new(metainfo=metainfo, **data)
assert type(new_instances) == type(instances)
assert type(new_instances) is type(instances)
assert id(new_instances.gt_instances) != id(instances.gt_instances)
_, new_data = self.setup_data()
new_instances.set_data(new_data)
@ -168,7 +168,7 @@ class TestBaseDataElement(TestCase):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, **data)
new_instances = instances.clone()
assert type(new_instances) == type(instances)
assert type(new_instances) is type(instances)
def test_set_metainfo(self):
metainfo, _ = self.setup_data()

View File

@ -45,7 +45,7 @@ class MockVisBackend:
self._add_scalars = True
def close(self) -> None:
"""close an opened object."""
"""Close an opened object."""
self._close = True