Compare commits

...

17 Commits

Author SHA1 Message Date
Mashiro
390ba2fbb2
Bump version to 0.10.7 2025-03-04 20:20:27 +08:00
Epiphany
d620552c2c
[fix] fix ci (#1636) 2025-02-27 15:13:38 +08:00
Mashiro
41fa84a9a9
[Fix] remove torch dependencies in build_function.py (#1632) 2025-02-18 16:38:51 +08:00
Mashiro
698782f920
[Fix] Fix deploy ci (#1628)
* [Enhance] Support trigger ci manually

* [Fix] Fix deploy CI
2025-01-15 18:11:05 +08:00
Mashiro
e60ab1dde3
[Enhance] Support trigger ci manually (#1627) 2025-01-15 18:07:34 +08:00
Mashiro
8ec837814e
[Enhance] Support trigger ci manually (#1626) 2025-01-15 17:58:23 +08:00
Qian Zhao
a4475f5eea
Update deploy.yml (#1625) 2025-01-15 17:34:07 +08:00
Mashiro
a8c74c346d
Bump version to v0.10.6 (#1623) 2025-01-13 19:20:26 +08:00
Epiphany
9124ebf7a2
[Enhance] ensure type in cfg (#1602)
* ensure type in cfg

* change import level
2024-11-06 20:26:47 +08:00
Epiphany
2e0ab7a922
[Fix] fix Adafactor optim on torch2.5 and fix compatibility (#1600)
* fix Adafactor opptim on torch2.5 and fix compatibility

* fix runtest error
2024-11-05 21:22:46 +08:00
Epiphany
fc59364d64
[Fix] fix error when pytest>=8.2 (#1601) 2024-11-05 20:43:17 +08:00
Tibor Reiss
4183cf0829
Fix return in finally (#1596) 2024-11-04 14:39:25 +08:00
Mashiro
cc3b74b5e8
[Fix] Fix lint (#1598)
* [Fix] Fix lint

* [Fix] Fix lint

* Update mmengine/dist/utils.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

---------

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
2024-11-02 22:23:51 +08:00
Mashiro
c9b59962d6
Yehc/bump version to v0.10.5 (#1572)
* fix check builtin module

* bump version to v0.10.5

* bump version to v0.10.5
2024-09-20 15:42:06 +08:00
Mashiro
5e736b143b
fix check builtin module (#1571) 2024-09-11 18:45:24 +08:00
Chris Jiang
85c83ba616
Update is_mlu_available (#1537)
* Update is_mlu_available

to adapt torch_mlu main, the torch.is_mlu_available method is removed

* Update utils.py

* Update utils.py
2024-05-30 10:07:45 +08:00
fanqiNO1
d1f1aabf81
[Feature] Support calculating loss during validation (#1503) 2024-05-17 15:27:53 +08:00
59 changed files with 393 additions and 290 deletions

View File

@ -1,6 +1,8 @@
name: deploy name: deploy
on: push on:
- push
- workflow_dispatch
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref }}
@ -9,13 +11,14 @@ concurrency:
jobs: jobs:
build-n-publish: build-n-publish:
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: startsWith(github.event.ref, 'refs/tags') if: |
startsWith(github.event.ref, 'refs/tags') || github.event_name == 'workflow_dispatch'
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- name: Set up Python 3.7 - name: Set up Python 3.10.13
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: 3.7 python-version: 3.10.13
- name: Install wheel - name: Install wheel
run: pip install wheel run: pip install wheel
- name: Build MMEngine - name: Build MMEngine
@ -27,13 +30,14 @@ jobs:
build-n-publish-lite: build-n-publish-lite:
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: startsWith(github.event.ref, 'refs/tags') if: |
startsWith(github.event.ref, 'refs/tags') || github.event_name == 'workflow_dispatch'
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- name: Set up Python 3.7 - name: Set up Python 3.10.13
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: 3.7 python-version: 3.10.13
- name: Install wheel - name: Install wheel
run: pip install wheel run: pip install wheel
- name: Build MMEngine-lite - name: Build MMEngine-lite

View File

@ -11,10 +11,10 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Set up Python 3.7 - name: Set up Python 3.10.15
uses: actions/setup-python@v2 uses: actions/setup-python@v2
with: with:
python-version: 3.7 python-version: '3.10.15'
- name: Install pre-commit hook - name: Install pre-commit hook
run: | run: |
pip install pre-commit pip install pre-commit

View File

@ -1,5 +1,9 @@
name: pr_stage_test name: pr_stage_test
env:
ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true
on: on:
pull_request: pull_request:
paths-ignore: paths-ignore:
@ -21,152 +25,114 @@ concurrency:
jobs: jobs:
build_cpu: build_cpu:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
defaults:
run:
shell: bash -l {0}
strategy: strategy:
matrix: matrix:
python-version: [3.7] python-version: ['3.9']
include: torch: ['2.0.0']
- torch: 1.8.1
torchvision: 0.9.1
steps: steps:
- uses: actions/checkout@v3 - name: Check out repo
- name: Set up Python ${{ matrix.python-version }} uses: actions/checkout@v3
uses: actions/setup-python@v4 - name: Setup conda env
uses: conda-incubator/setup-miniconda@v2
with: with:
auto-update-conda: true
miniconda-version: "latest"
use-only-tar-bz2: true
activate-environment: test
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Upgrade pip - name: Update pip
run: python -m pip install pip --upgrade
- name: Upgrade wheel
run: python -m pip install wheel --upgrade
- name: Install PyTorch
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
- name: Build MMEngine from source
run: pip install -e . -v
- name: Install unit tests dependencies
run: | run: |
pip install -r requirements/tests.txt python -m pip install --upgrade pip wheel
pip install openmim - name: Install dependencies
mim install mmcv
- name: Run unittests and generate coverage report
run: | run: |
coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6
coverage xml python -m pip install torch==${{matrix.torch}}
coverage report -m python -m pip install -e . -v
# Upload coverage report for python3.7 && pytorch1.8.1 cpu python -m pip install -r requirements/tests.txt
- name: Upload coverage to Codecov python -m pip install openmim
mim install mmcv coverage
- name: Run unit tests with coverage
run: coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist
- name: Upload Coverage to Codecov
uses: codecov/codecov-action@v3 uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
env_vars: OS,PYTHON
name: codecov-umbrella
fail_ci_if_error: false
build_cu102: build_gpu:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
container: defaults:
image: pytorch/pytorch:1.8.1-cuda10.2-cudnn7-devel run:
shell: bash -l {0}
env: env:
MKL_THREADING_LAYER: GNU MKL_THREADING_LAYER: GNU
strategy: strategy:
matrix: matrix:
python-version: [3.7] python-version: ['3.9','3.10']
torch: ['2.0.0','2.3.1','2.5.1']
cuda: ['cu118']
steps: steps:
- uses: actions/checkout@v3 - name: Check out repo
- name: Set up Python ${{ matrix.python-version }} uses: actions/checkout@v3
uses: actions/setup-python@v4 - name: Setup conda env
uses: conda-incubator/setup-miniconda@v2
with: with:
auto-update-conda: true
miniconda-version: "latest"
use-only-tar-bz2: true
activate-environment: test
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Upgrade pip - name: Update pip
run: pip install pip --upgrade
- name: Fetch GPG keys
run: | run: |
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub python -m pip install --upgrade pip wheel
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub - name: Install dependencies
- name: Install system dependencies
run: apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6
- name: Build MMEngine from source
run: pip install -e . -v
- name: Install unit tests dependencies
run: | run: |
pip install -r requirements/tests.txt apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6
pip install openmim python -m pip install torch==${{matrix.torch}} --index-url https://download.pytorch.org/whl/${{matrix.cuda}}
mim install mmcv python -m pip install -e . -v
- name: Run unittests and generate coverage report python -m pip install -r requirements/tests.txt
run: | python -m pip install openmim
coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist mim install mmcv coverage
coverage xml - name: Run unit tests with coverage
coverage report -m run: coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist
- name: Upload Coverage to Codecov
uses: codecov/codecov-action@v3
build_cu117: # build_windows:
runs-on: ubuntu-22.04 # runs-on: windows-2022
container: # strategy:
image: pytorch/pytorch:2.0.0-cuda11.7-cudnn8-devel # matrix:
strategy: # python-version: [3.9]
matrix: # platform: [cpu, cu111]
python-version: [3.9] # torch: [1.8.1]
steps: # torchvision: [0.9.1]
- uses: actions/checkout@v3 # include:
- name: Set up Python ${{ matrix.python-version }} # - python-version: 3.8
uses: actions/setup-python@v4 # platform: cu118
with: # torch: 2.1.0
python-version: ${{ matrix.python-version }} # torchvision: 0.16.0
- name: Upgrade pip # steps:
run: pip install pip --upgrade # - uses: actions/checkout@v3
- name: Fetch GPG keys # - name: Set up Python ${{ matrix.python-version }}
run: | # uses: actions/setup-python@v4
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub # with:
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub # python-version: ${{ matrix.python-version }}
- name: Install system dependencies # - name: Upgrade pip
run: apt-get update && apt-get install -y git ffmpeg libturbojpeg # # Windows CI could fail If we call `pip install pip --upgrade` directly.
- name: Build MMEngine from source # run: python -m pip install pip wheel --upgrade
run: pip install -e . -v # - name: Install PyTorch
- name: Install unit tests dependencies # run: pip install torch==${{matrix.torch}}+${{matrix.platform}} torchvision==${{matrix.torchvision}}+${{matrix.platform}} -f https://download.pytorch.org/whl/${{matrix.platform}}/torch_stable.html
run: | # - name: Build MMEngine from source
pip install -r requirements/tests.txt # run: pip install -e . -v
pip install openmim # - name: Install unit tests dependencies
mim install mmcv # run: |
# Distributed related unit test may randomly error in PyTorch 1.13.0 # pip install -r requirements/tests.txt
- name: Run unittests and generate coverage report # pip install openmim
run: | # mim install mmcv
coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist/ # - name: Run CPU unittests
coverage xml # run: pytest tests/ --ignore tests/test_dist
coverage report -m # if: ${{ matrix.platform == 'cpu' }}
# - name: Run GPU unittests
build_windows: # # Skip testing distributed related unit tests since the memory of windows CI is limited
runs-on: windows-2022 # run: pytest tests/ --ignore tests/test_dist --ignore tests/test_optim/test_optimizer/test_optimizer_wrapper.py --ignore tests/test_model/test_wrappers/test_model_wrapper.py --ignore tests/test_hooks/test_sync_buffers_hook.py
strategy: # if: ${{ matrix.platform == 'cu111' }} || ${{ matrix.platform == 'cu118' }}
matrix:
python-version: [3.7]
platform: [cpu, cu111]
torch: [1.8.1]
torchvision: [0.9.1]
include:
- python-version: 3.8
platform: cu118
torch: 2.1.0
torchvision: 0.16.0
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
# Windows CI could fail If we call `pip install pip --upgrade` directly.
run: python -m pip install pip --upgrade
- name: Install PyTorch
run: pip install torch==${{matrix.torch}}+${{matrix.platform}} torchvision==${{matrix.torchvision}}+${{matrix.platform}} -f https://download.pytorch.org/whl/${{matrix.platform}}/torch_stable.html
- name: Build MMEngine from source
run: pip install -e . -v
- name: Install unit tests dependencies
run: |
pip install -r requirements/tests.txt
pip install openmim
mim install mmcv
- name: Run CPU unittests
run: pytest tests/ --ignore tests/test_dist
if: ${{ matrix.platform == 'cpu' }}
- name: Run GPU unittests
# Skip testing distributed related unit tests since the memory of windows CI is limited
run: pytest tests/ --ignore tests/test_dist --ignore tests/test_optim/test_optimizer/test_optimizer_wrapper.py --ignore tests/test_model/test_wrappers/test_model_wrapper.py --ignore tests/test_hooks/test_sync_buffers_hook.py
if: ${{ matrix.platform == 'cu111' }} || ${{ matrix.platform == 'cu118' }}

View File

@ -1,7 +1,11 @@
exclude: ^tests/data/ exclude: ^tests/data/
repos: repos:
- repo: https://gitee.com/openmmlab/mirrors-flake8 - repo: https://github.com/pre-commit/pre-commit
rev: 5.0.4 rev: v4.0.0
hooks:
- id: validate_manifest
- repo: https://github.com/PyCQA/flake8
rev: 7.1.1
hooks: hooks:
- id: flake8 - id: flake8
- repo: https://gitee.com/openmmlab/mirrors-isort - repo: https://gitee.com/openmmlab/mirrors-isort
@ -13,7 +17,7 @@ repos:
hooks: hooks:
- id: yapf - id: yapf
- repo: https://gitee.com/openmmlab/mirrors-pre-commit-hooks - repo: https://gitee.com/openmmlab/mirrors-pre-commit-hooks
rev: v4.3.0 rev: v5.0.0
hooks: hooks:
- id: trailing-whitespace - id: trailing-whitespace
- id: check-yaml - id: check-yaml
@ -55,7 +59,7 @@ repos:
args: ["mmengine", "tests"] args: ["mmengine", "tests"]
- id: remove-improper-eol-in-cn-docs - id: remove-improper-eol-in-cn-docs
- repo: https://gitee.com/openmmlab/mirrors-mypy - repo: https://gitee.com/openmmlab/mirrors-mypy
rev: v0.812 rev: v1.2.0
hooks: hooks:
- id: mypy - id: mypy
exclude: |- exclude: |-
@ -63,3 +67,4 @@ repos:
^examples ^examples
| ^docs | ^docs
) )
additional_dependencies: ["types-setuptools", "types-requests", "types-PyYAML"]

View File

@ -1,7 +1,11 @@
exclude: ^tests/data/ exclude: ^tests/data/
repos: repos:
- repo: https://github.com/pre-commit/pre-commit
rev: v4.0.0
hooks:
- id: validate_manifest
- repo: https://github.com/PyCQA/flake8 - repo: https://github.com/PyCQA/flake8
rev: 5.0.4 rev: 7.1.1
hooks: hooks:
- id: flake8 - id: flake8
- repo: https://github.com/PyCQA/isort - repo: https://github.com/PyCQA/isort
@ -13,7 +17,7 @@ repos:
hooks: hooks:
- id: yapf - id: yapf
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0 rev: v5.0.0
hooks: hooks:
- id: trailing-whitespace - id: trailing-whitespace
- id: check-yaml - id: check-yaml
@ -34,12 +38,8 @@ repos:
- mdformat-openmmlab - mdformat-openmmlab
- mdformat_frontmatter - mdformat_frontmatter
- linkify-it-py - linkify-it-py
- repo: https://github.com/codespell-project/codespell
rev: v2.2.1
hooks:
- id: codespell
- repo: https://github.com/myint/docformatter - repo: https://github.com/myint/docformatter
rev: v1.3.1 rev: 06907d0
hooks: hooks:
- id: docformatter - id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"] args: ["--in-place", "--wrap-descriptions", "79"]
@ -55,7 +55,7 @@ repos:
args: ["mmengine", "tests"] args: ["mmengine", "tests"]
- id: remove-improper-eol-in-cn-docs - id: remove-improper-eol-in-cn-docs
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.812 rev: v1.2.0
hooks: hooks:
- id: mypy - id: mypy
exclude: |- exclude: |-
@ -63,3 +63,4 @@ repos:
^examples ^examples
| ^docs | ^docs
) )
additional_dependencies: ["types-setuptools", "types-requests", "types-PyYAML"]

View File

@ -59,7 +59,7 @@ English | [简体中文](README_zh-CN.md)
## What's New ## What's New
v0.10.4 was released on 2024-4-23. v0.10.6 was released on 2025-01-13.
Highlights: Highlights:

View File

@ -59,7 +59,7 @@
## 最近进展 ## 最近进展
最新版本 v0.10.4 在 2024.4.23 发布。 最新版本 v0.10.5 在 2024.9.11 发布。
版本亮点: 版本亮点:

View File

@ -1,5 +1,9 @@
# Changelog of v0.x # Changelog of v0.x
## v0.10.5 (11/9/2024)
- Fix `_is_builtin_module`. by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/1571
## v0.10.4 (23/4/2024) ## v0.10.4 (23/4/2024)
### New Features & Enhancements ### New Features & Enhancements

View File

@ -499,7 +499,7 @@ class BaseStrategy(metaclass=ABCMeta):
'"type" and "constructor" are not in ' '"type" and "constructor" are not in '
f'optimizer, but got {name}={optim}') f'optimizer, but got {name}={optim}')
optim_wrappers[name] = optim optim_wrappers[name] = optim
return OptimWrapperDict(**optim_wrappers) return OptimWrapperDict(**optim_wrappers) # type: ignore
else: else:
raise TypeError('optimizer wrapper should be an OptimWrapper ' raise TypeError('optimizer wrapper should be an OptimWrapper '
f'object or dict, but got {optim_wrapper}') f'object or dict, but got {optim_wrapper}')

View File

@ -361,7 +361,7 @@ class ColossalAIStrategy(BaseStrategy):
map_location: Union[str, Callable] = 'default', map_location: Union[str, Callable] = 'default',
callback: Optional[Callable] = None, callback: Optional[Callable] = None,
) -> dict: ) -> dict:
"""override this method since colossalai resume optimizer from filename """Override this method since colossalai resume optimizer from filename
directly.""" directly."""
self.logger.info(f'Resume checkpoint from {filename}') self.logger.info(f'Resume checkpoint from {filename}')

View File

@ -53,7 +53,7 @@ class DDPStrategy(SingleDeviceStrategy):
init_dist(launcher, backend, **kwargs) init_dist(launcher, backend, **kwargs)
def convert_model(self, model: nn.Module) -> nn.Module: def convert_model(self, model: nn.Module) -> nn.Module:
"""convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm`` """Convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm``
(SyncBN) or ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers. (SyncBN) or ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers.
Args: Args:

View File

@ -393,7 +393,7 @@ class Config:
def __init__( def __init__(
self, self,
cfg_dict: dict = None, cfg_dict: Optional[dict] = None,
cfg_text: Optional[str] = None, cfg_text: Optional[str] = None,
filename: Optional[Union[str, Path]] = None, filename: Optional[Union[str, Path]] = None,
env_variables: Optional[dict] = None, env_variables: Optional[dict] = None,
@ -1227,7 +1227,8 @@ class Config:
if base_code is not None: if base_code is not None:
base_code = ast.Expression( # type: ignore base_code = ast.Expression( # type: ignore
body=base_code.value) # type: ignore body=base_code.value) # type: ignore
base_files = eval(compile(base_code, '', mode='eval')) base_files = eval(compile(base_code, '',
mode='eval')) # type: ignore
else: else:
base_files = [] base_files = []
elif file_format in ('.yml', '.yaml', '.json'): elif file_format in ('.yml', '.yaml', '.json'):
@ -1288,7 +1289,7 @@ class Config:
def _merge_a_into_b(a: dict, def _merge_a_into_b(a: dict,
b: dict, b: dict,
allow_list_keys: bool = False) -> dict: allow_list_keys: bool = False) -> dict:
"""merge dict ``a`` into dict ``b`` (non-inplace). """Merge dict ``a`` into dict ``b`` (non-inplace).
Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
in-place modifications. in-place modifications.
@ -1358,22 +1359,22 @@ class Config:
@property @property
def filename(self) -> str: def filename(self) -> str:
"""get file name of config.""" """Get file name of config."""
return self._filename return self._filename
@property @property
def text(self) -> str: def text(self) -> str:
"""get config text.""" """Get config text."""
return self._text return self._text
@property @property
def env_variables(self) -> dict: def env_variables(self) -> dict:
"""get used environment variables.""" """Get used environment variables."""
return self._env_variables return self._env_variables
@property @property
def pretty_text(self) -> str: def pretty_text(self) -> str:
"""get formatted python config text.""" """Get formatted python config text."""
indent = 4 indent = 4
@ -1727,17 +1728,17 @@ class Config:
class DictAction(Action): class DictAction(Action):
""" """Argparse action to split an argument into KEY=VALUE form on the first =
argparse action to split an argument into KEY=VALUE form and append to a dictionary.
on the first = and append to a dictionary. List options can
be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit List options can be passed as comma separated values, i.e 'KEY=V1,V2,V3',
brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build or with explicit brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested
list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]' brackets to build list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
""" """
@staticmethod @staticmethod
def _parse_int_float_bool(val: str) -> Union[int, float, bool, Any]: def _parse_int_float_bool(val: str) -> Union[int, float, bool, Any]:
"""parse int/float/bool value in the string.""" """Parse int/float/bool value in the string."""
try: try:
return int(val) return int(val)
except ValueError: except ValueError:
@ -1822,7 +1823,7 @@ class DictAction(Action):
parser: ArgumentParser, parser: ArgumentParser,
namespace: Namespace, namespace: Namespace,
values: Union[str, Sequence[Any], None], values: Union[str, Sequence[Any], None],
option_string: str = None): option_string: str = None): # type: ignore
"""Parse Variables in string and add them into argparser. """Parse Variables in string and add them into argparser.
Args: Args:

View File

@ -12,6 +12,7 @@ from mmengine.fileio import load
from mmengine.utils import check_file_exist from mmengine.utils import check_file_exist
PYTHON_ROOT_DIR = osp.dirname(osp.dirname(sys.executable)) PYTHON_ROOT_DIR = osp.dirname(osp.dirname(sys.executable))
SYSTEM_PYTHON_PREFIX = '/usr/lib/python'
MODULE2PACKAGE = { MODULE2PACKAGE = {
'mmcls': 'mmcls', 'mmcls': 'mmcls',
@ -176,7 +177,8 @@ def _is_builtin_module(module_name: str) -> bool:
return False return False
origin_path = osp.abspath(origin_path) origin_path = osp.abspath(origin_path)
if ('site-package' in origin_path or 'dist-package' in origin_path if ('site-package' in origin_path or 'dist-package' in origin_path
or not origin_path.startswith(PYTHON_ROOT_DIR)): or not origin_path.startswith(
(PYTHON_ROOT_DIR, SYSTEM_PYTHON_PREFIX))):
return False return False
else: else:
return True return True

View File

@ -16,6 +16,12 @@ try:
except Exception: except Exception:
IS_NPU_AVAILABLE = False IS_NPU_AVAILABLE = False
try:
import torch_mlu # noqa: F401
IS_MLU_AVAILABLE = hasattr(torch, 'mlu') and torch.mlu.is_available()
except Exception:
IS_MLU_AVAILABLE = False
try: try:
import torch_dipu # noqa: F401 import torch_dipu # noqa: F401
IS_DIPU_AVAILABLE = True IS_DIPU_AVAILABLE = True
@ -64,7 +70,7 @@ def is_npu_available() -> bool:
def is_mlu_available() -> bool: def is_mlu_available() -> bool:
"""Returns True if Cambricon PyTorch and mlu devices exist.""" """Returns True if Cambricon PyTorch and mlu devices exist."""
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available() return IS_MLU_AVAILABLE
def is_mps_available() -> bool: def is_mps_available() -> bool:

View File

@ -563,7 +563,7 @@ def cast_data_device(
Tensor or list or dict: ``data`` was casted to ``device``. Tensor or list or dict: ``data`` was casted to ``device``.
""" """
if out is not None: if out is not None:
if type(data) != type(out): if type(data) is not type(out):
raise TypeError( raise TypeError(
'out should be the same type with data, but got data is ' 'out should be the same type with data, but got data is '
f'{type(data)} and out is {type(data)}') f'{type(data)} and out is {type(data)}')

View File

@ -175,11 +175,11 @@ class DumpResults(BaseMetric):
self.out_file_path = out_file_path self.out_file_path = out_file_path
def process(self, data_batch: Any, predictions: Sequence[dict]) -> None: def process(self, data_batch: Any, predictions: Sequence[dict]) -> None:
"""transfer tensors in predictions to CPU.""" """Transfer tensors in predictions to CPU."""
self.results.extend(_to_cpu(predictions)) self.results.extend(_to_cpu(predictions))
def compute_metrics(self, results: list) -> dict: def compute_metrics(self, results: list) -> dict:
"""dump the prediction results to a pickle file.""" """Dump the prediction results to a pickle file."""
dump(results, self.out_file_path) dump(results, self.out_file_path)
print_log( print_log(
f'Results has been saved to {self.out_file_path}.', f'Results has been saved to {self.out_file_path}.',
@ -188,7 +188,7 @@ class DumpResults(BaseMetric):
def _to_cpu(data: Any) -> Any: def _to_cpu(data: Any) -> Any:
"""transfer all tensors and BaseDataElement to cpu.""" """Transfer all tensors and BaseDataElement to cpu."""
if isinstance(data, (Tensor, BaseDataElement)): if isinstance(data, (Tensor, BaseDataElement)):
return data.to('cpu') return data.to('cpu')
elif isinstance(data, list): elif isinstance(data, list):

View File

@ -233,7 +233,7 @@ class ProfilerHook(Hook):
self._export_chrome_trace(runner) self._export_chrome_trace(runner)
def after_train_iter(self, runner, batch_idx, data_batch, outputs): def after_train_iter(self, runner, batch_idx, data_batch, outputs):
"""profiler will call `step` method if it is not closed.""" """Profiler will call `step` method if it is not closed."""
if not self._closed: if not self._closed:
self.profiler.step() self.profiler.step()
if runner.iter == self.profile_times - 1 and not self.by_epoch: if runner.iter == self.profile_times - 1 and not self.by_epoch:

View File

@ -58,7 +58,7 @@ class HistoryBuffer:
self._statistics_methods.setdefault('mean', HistoryBuffer.mean) self._statistics_methods.setdefault('mean', HistoryBuffer.mean)
def update(self, log_val: Union[int, float], count: int = 1) -> None: def update(self, log_val: Union[int, float], count: int = 1) -> None:
"""update the log history. """Update the log history.
If the length of the buffer exceeds ``self._max_length``, the oldest If the length of the buffer exceeds ``self._max_length``, the oldest
element will be removed from the buffer. element will be removed from the buffer.

View File

@ -443,7 +443,6 @@ def _get_host_info() -> str:
host = f'{getuser()}@{gethostname()}' host = f'{getuser()}@{gethostname()}'
except Exception as e: except Exception as e:
warnings.warn(f'Host or user not found: {str(e)}') warnings.warn(f'Host or user not found: {str(e)}')
finally:
return host return host

View File

@ -253,17 +253,17 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
dict or list: Data in the same format as the model input. dict or list: Data in the same format as the model input.
""" """
data = self.cast_data(data) # type: ignore data = self.cast_data(data) # type: ignore
_batch_inputs = data['inputs'] _batch_inputs = data['inputs'] # type: ignore
# Process data with `pseudo_collate`. # Process data with `pseudo_collate`.
if is_seq_of(_batch_inputs, torch.Tensor): if is_seq_of(_batch_inputs, torch.Tensor):
batch_inputs = [] batch_inputs = []
for _batch_input in _batch_inputs: for _batch_input in _batch_inputs:
# channel transform # channel transform
if self._channel_conversion: if self._channel_conversion:
_batch_input = _batch_input[[2, 1, 0], ...] _batch_input = _batch_input[[2, 1, 0], ...] # type: ignore
# Convert to float after channel conversion to ensure # Convert to float after channel conversion to ensure
# efficiency # efficiency
_batch_input = _batch_input.float() _batch_input = _batch_input.float() # type: ignore
# Normalization. # Normalization.
if self._enable_normalize: if self._enable_normalize:
if self.mean.shape[0] == 3: if self.mean.shape[0] == 3:
@ -302,7 +302,7 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
else: else:
raise TypeError('Output of `cast_data` should be a dict of ' raise TypeError('Output of `cast_data` should be a dict of '
'list/tuple with inputs and data_samples, ' 'list/tuple with inputs and data_samples, '
f'but got {type(data)}: {data}') f'but got {type(data)}: {data}') # type: ignore
data['inputs'] = batch_inputs data['inputs'] = batch_inputs # type: ignore
data.setdefault('data_samples', None) data.setdefault('data_samples', None) # type: ignore
return data return data # type: ignore

View File

@ -119,7 +119,7 @@ def caffe2_xavier_init(module, bias=0):
def bias_init_with_prob(prior_prob): def bias_init_with_prob(prior_prob):
"""initialize conv/fc bias value according to a given probability value.""" """Initialize conv/fc bias value according to a given probability value."""
bias_init = float(-np.log((1 - prior_prob) / prior_prob)) bias_init = float(-np.log((1 - prior_prob) / prior_prob))
return bias_init return bias_init
@ -662,12 +662,12 @@ def trunc_normal_(tensor: Tensor,
std: float = 1., std: float = 1.,
a: float = -2., a: float = -2.,
b: float = 2.) -> Tensor: b: float = 2.) -> Tensor:
r"""Fills the input Tensor with values drawn from a truncated r"""Fills the input Tensor with values drawn from a truncated normal
normal distribution. The values are effectively drawn from the distribution. The values are effectively drawn from the normal distribution
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside
with values outside :math:`[a, b]` redrawn until they are within :math:`[a, b]` redrawn until they are within the bounds. The method used
the bounds. The method used for generating the random values works for generating the random values works best when :math:`a \leq \text{mean}
best when :math:`a \leq \text{mean} \leq b`. \leq b`.
Modified from Modified from
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py

View File

@ -127,7 +127,8 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
auto_wrap_policy: Union[str, Callable, None] = None, auto_wrap_policy: Union[str, Callable, None] = None,
backward_prefetch: Union[str, BackwardPrefetch, None] = None, backward_prefetch: Union[str, BackwardPrefetch, None] = None,
mixed_precision: Union[dict, MixedPrecision, None] = None, mixed_precision: Union[dict, MixedPrecision, None] = None,
param_init_fn: Union[str, Callable[[nn.Module], None]] = None, param_init_fn: Union[str, Callable[
[nn.Module], None]] = None, # type: ignore # noqa: E501
use_orig_params: bool = True, use_orig_params: bool = True,
**kwargs, **kwargs,
): ):
@ -362,7 +363,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
group: Optional[dist.ProcessGroup] = None, group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""copied from pytorch 2.0.1 which has fixed some bugs.""" """Copied from pytorch 2.0.1 which has fixed some bugs."""
state_dict_settings = FullyShardedDataParallel.get_state_dict_type( state_dict_settings = FullyShardedDataParallel.get_state_dict_type(
model) model)
return FullyShardedDataParallel._optim_state_dict_impl( return FullyShardedDataParallel._optim_state_dict_impl(
@ -384,7 +385,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
state_dict_config: Optional[StateDictConfig] = None, state_dict_config: Optional[StateDictConfig] = None,
optim_state_dict_config: Optional[OptimStateDictConfig] = None, optim_state_dict_config: Optional[OptimStateDictConfig] = None,
) -> StateDictSettings: ) -> StateDictSettings:
"""copied from pytorch 2.0.1 which has fixed some bugs.""" """Copied from pytorch 2.0.1 which has fixed some bugs."""
import torch.distributed.fsdp._traversal_utils as traversal_utils import torch.distributed.fsdp._traversal_utils as traversal_utils
_state_dict_type_to_config = { _state_dict_type_to_config = {
StateDictType.FULL_STATE_DICT: FullStateDictConfig, StateDictType.FULL_STATE_DICT: FullStateDictConfig,

View File

@ -123,8 +123,7 @@ class ApexOptimWrapper(OptimWrapper):
self._inner_count += 1 self._inner_count += 1
def state_dict(self) -> dict: def state_dict(self) -> dict:
"""Get the state dictionary of :attr:`optimizer` and """Get the state dictionary of :attr:`optimizer` and :attr:`apex_amp`.
:attr:`apex_amp`.
Based on the state dictionary of the optimizer, the returned state Based on the state dictionary of the optimizer, the returned state
dictionary will add a key named "apex_amp". dictionary will add a key named "apex_amp".

View File

@ -25,6 +25,10 @@ def register_torch_optimizers() -> List[str]:
_optim = getattr(torch.optim, module_name) _optim = getattr(torch.optim, module_name)
if inspect.isclass(_optim) and issubclass(_optim, if inspect.isclass(_optim) and issubclass(_optim,
torch.optim.Optimizer): torch.optim.Optimizer):
if module_name == 'Adafactor':
OPTIMIZERS.register_module(
name='TorchAdafactor', module=_optim)
else:
OPTIMIZERS.register_module(module=_optim) OPTIMIZERS.register_module(module=_optim)
torch_optimizers.append(module_name) torch_optimizers.append(module_name)
return torch_optimizers return torch_optimizers

View File

@ -131,7 +131,7 @@ class DefaultOptimWrapperConstructor:
self._validate_cfg() self._validate_cfg()
def _validate_cfg(self) -> None: def _validate_cfg(self) -> None:
"""verify the correctness of the config.""" """Verify the correctness of the config."""
if not isinstance(self.paramwise_cfg, dict): if not isinstance(self.paramwise_cfg, dict):
raise TypeError('paramwise_cfg should be None or a dict, ' raise TypeError('paramwise_cfg should be None or a dict, '
f'but got {type(self.paramwise_cfg)}') f'but got {type(self.paramwise_cfg)}')
@ -155,7 +155,7 @@ class DefaultOptimWrapperConstructor:
raise ValueError('base_wd should not be None') raise ValueError('base_wd should not be None')
def _is_in(self, param_group: dict, param_group_list: list) -> bool: def _is_in(self, param_group: dict, param_group_list: list) -> bool:
"""check whether the `param_group` is in the`param_group_list`""" """Check whether the `param_group` is in the`param_group_list`"""
assert is_list_of(param_group_list, dict) assert is_list_of(param_group_list, dict)
param = set(param_group['params']) param = set(param_group['params'])
param_set = set() param_set = set()

View File

@ -161,8 +161,7 @@ class OptimWrapperDict(OptimWrapper):
self.optim_wrappers[name].load_state_dict(_state_dict) self.optim_wrappers[name].load_state_dict(_state_dict)
def items(self) -> Iterator[Tuple[str, OptimWrapper]]: def items(self) -> Iterator[Tuple[str, OptimWrapper]]:
"""A generator to get the name and corresponding """A generator to get the name and corresponding :obj:`OptimWrapper`"""
:obj:`OptimWrapper`"""
yield from self.optim_wrappers.items() yield from self.optim_wrappers.items()
def values(self) -> Iterator[OptimWrapper]: def values(self) -> Iterator[OptimWrapper]:

View File

@ -223,13 +223,13 @@ class PolyLR(LRSchedulerMixin, PolyParamScheduler):
@PARAM_SCHEDULERS.register_module() @PARAM_SCHEDULERS.register_module()
class OneCycleLR(LRSchedulerMixin, OneCycleParamScheduler): class OneCycleLR(LRSchedulerMixin, OneCycleParamScheduler):
r"""Sets the learning rate of each parameter group according to the r"""Sets the learning rate of each parameter group according to the 1cycle
1cycle learning rate policy. The 1cycle policy anneals the learning learning rate policy. The 1cycle policy anneals the learning rate from an
rate from an initial learning rate to some maximum learning rate and then initial learning rate to some maximum learning rate and then from that
from that maximum learning rate to some minimum learning rate much lower maximum learning rate to some minimum learning rate much lower than the
than the initial learning rate. initial learning rate. This policy was initially described in the paper
This policy was initially described in the paper `Super-Convergence: `Super-Convergence: Very Fast Training of Neural Networks Using Large
Very Fast Training of Neural Networks Using Large Learning Rates`_. Learning Rates`_.
The 1cycle learning rate policy changes the learning rate after every The 1cycle learning rate policy changes the learning rate after every
batch. `step` should be called after a batch has been used for training. batch. `step` should be called after a batch has been used for training.

View File

@ -565,9 +565,9 @@ class ExponentialParamScheduler(_ParamScheduler):
@PARAM_SCHEDULERS.register_module() @PARAM_SCHEDULERS.register_module()
class CosineAnnealingParamScheduler(_ParamScheduler): class CosineAnnealingParamScheduler(_ParamScheduler):
r"""Set the parameter value of each parameter group using a cosine r"""Set the parameter value of each parameter group using a cosine annealing
annealing schedule, where :math:`\eta_{max}` is set to the initial value schedule, where :math:`\eta_{max}` is set to the initial value and
and :math:`T_{cur}` is the number of epochs since the last restart in SGDR: :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
.. math:: .. math::
\begin{aligned} \begin{aligned}
@ -617,7 +617,7 @@ class CosineAnnealingParamScheduler(_ParamScheduler):
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts: .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983 https://arxiv.org/abs/1608.03983
""" """ # noqa: E501
def __init__(self, def __init__(self,
optimizer: Union[Optimizer, BaseOptimWrapper], optimizer: Union[Optimizer, BaseOptimWrapper],
@ -890,13 +890,13 @@ class PolyParamScheduler(_ParamScheduler):
@PARAM_SCHEDULERS.register_module() @PARAM_SCHEDULERS.register_module()
class OneCycleParamScheduler(_ParamScheduler): class OneCycleParamScheduler(_ParamScheduler):
r"""Sets the parameters of each parameter group according to the r"""Sets the parameters of each parameter group according to the 1cycle
1cycle learning rate policy. The 1cycle policy anneals the learning learning rate policy. The 1cycle policy anneals the learning rate from an
rate from an initial learning rate to some maximum learning rate and then initial learning rate to some maximum learning rate and then from that
from that maximum learning rate to some minimum learning rate much lower maximum learning rate to some minimum learning rate much lower than the
than the initial learning rate. initial learning rate. This policy was initially described in the paper
This policy was initially described in the paper `Super-Convergence: `Super-Convergence: Very Fast Training of Neural Networks Using Large
Very Fast Training of Neural Networks Using Large Learning Rates`_. Learning Rates`_.
The 1cycle learning rate policy changes the learning rate after every The 1cycle learning rate policy changes the learning rate after every
batch. `step` should be called after a batch has been used for training. batch. `step` should be called after a batch has been used for training.

View File

@ -4,7 +4,7 @@ import logging
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
from mmengine.config import Config, ConfigDict from mmengine.config import Config, ConfigDict
from mmengine.utils import ManagerMixin from mmengine.utils import ManagerMixin, digit_version
from .registry import Registry from .registry import Registry
if TYPE_CHECKING: if TYPE_CHECKING:
@ -232,6 +232,21 @@ def build_model_from_cfg(
return build_from_cfg(cfg, registry, default_args) return build_from_cfg(cfg, registry, default_args)
def build_optimizer_from_cfg(
cfg: Union[dict, ConfigDict, Config],
registry: Registry,
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any:
import torch
from ..logging import print_log
if 'type' in cfg \
and 'Adafactor' == cfg['type'] \
and digit_version(torch.__version__) >= digit_version('2.5.0'):
print_log(
'the torch version of Adafactor is registered as TorchAdafactor')
return build_from_cfg(cfg, registry, default_args)
def build_scheduler_from_cfg( def build_scheduler_from_cfg(
cfg: Union[dict, ConfigDict, Config], cfg: Union[dict, ConfigDict, Config],
registry: Registry, registry: Registry,

View File

@ -81,7 +81,7 @@ class DefaultScope(ManagerMixin):
@classmethod @classmethod
@contextmanager @contextmanager
def overwrite_default_scope(cls, scope_name: Optional[str]) -> Generator: def overwrite_default_scope(cls, scope_name: Optional[str]) -> Generator:
"""overwrite the current default scope with `scope_name`""" """Overwrite the current default scope with `scope_name`"""
if scope_name is None: if scope_name is None:
yield yield
else: else:

View File

@ -332,7 +332,7 @@ class Registry:
return root return root
def import_from_location(self) -> None: def import_from_location(self) -> None:
"""import modules from the pre-defined locations in self._location.""" """Import modules from the pre-defined locations in self._location."""
if not self._imported: if not self._imported:
# Avoid circular import # Avoid circular import
from ..logging import print_log from ..logging import print_log

View File

@ -6,8 +6,8 @@ More datails can be found at
https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html. https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html.
""" """
from .build_functions import (build_model_from_cfg, build_runner_from_cfg, from .build_functions import (build_model_from_cfg, build_optimizer_from_cfg,
build_scheduler_from_cfg) build_runner_from_cfg, build_scheduler_from_cfg)
from .registry import Registry from .registry import Registry
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner` # manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
@ -35,7 +35,7 @@ MODEL_WRAPPERS = Registry('model_wrapper')
WEIGHT_INITIALIZERS = Registry('weight initializer') WEIGHT_INITIALIZERS = Registry('weight initializer')
# mangage all kinds of optimizers like `SGD` and `Adam` # mangage all kinds of optimizers like `SGD` and `Adam`
OPTIMIZERS = Registry('optimizer') OPTIMIZERS = Registry('optimizer', build_func=build_optimizer_from_cfg)
# manage optimizer wrapper # manage optimizer wrapper
OPTIM_WRAPPERS = Registry('optim_wrapper') OPTIM_WRAPPERS = Registry('optim_wrapper')
# manage constructors that customize the optimization hyperparameters. # manage constructors that customize the optimization hyperparameters.

View File

@ -109,7 +109,7 @@ def init_default_scope(scope: str) -> None:
if current_scope.scope_name != scope: # type: ignore if current_scope.scope_name != scope: # type: ignore
print_log( print_log(
'The current default scope ' # type: ignore 'The current default scope ' # type: ignore
f'"{current_scope.scope_name}" is not "{scope}", ' f'"{current_scope.scope_name}" is not "{scope}", ' # type: ignore
'`init_default_scope` will force set the current' '`init_default_scope` will force set the current'
f'default scope to "{scope}".', f'default scope to "{scope}".',
logger='current', logger='current',

View File

@ -540,7 +540,7 @@ class FlexibleRunner:
@property @property
def hooks(self): def hooks(self):
"""list[:obj:`Hook`]: A list of registered hooks.""" """List[:obj:`Hook`]: A list of registered hooks."""
return self._hooks return self._hooks
@property @property
@ -1117,7 +1117,7 @@ class FlexibleRunner:
return '\n'.join(stage_hook_infos) return '\n'.join(stage_hook_infos)
def load_or_resume(self): def load_or_resume(self):
"""load or resume checkpoint.""" """Load or resume checkpoint."""
if self._has_loaded: if self._has_loaded:
return None return None
@ -1539,7 +1539,7 @@ class FlexibleRunner:
file_client_args: Optional[dict] = None, file_client_args: Optional[dict] = None,
save_optimizer: bool = True, save_optimizer: bool = True,
save_param_scheduler: bool = True, save_param_scheduler: bool = True,
meta: dict = None, meta: Optional[dict] = None,
by_epoch: bool = True, by_epoch: bool = True,
backend_args: Optional[dict] = None, backend_args: Optional[dict] = None,
): ):

View File

@ -309,7 +309,7 @@ class CheckpointLoader:
@classmethod @classmethod
def load_checkpoint(cls, filename, map_location=None, logger='current'): def load_checkpoint(cls, filename, map_location=None, logger='current'):
"""load checkpoint through URL scheme path. """Load checkpoint through URL scheme path.
Args: Args:
filename (str): checkpoint file name with given prefix filename (str): checkpoint file name with given prefix
@ -332,7 +332,7 @@ class CheckpointLoader:
@CheckpointLoader.register_scheme(prefixes='') @CheckpointLoader.register_scheme(prefixes='')
def load_from_local(filename, map_location): def load_from_local(filename, map_location):
"""load checkpoint by local file path. """Load checkpoint by local file path.
Args: Args:
filename (str): local checkpoint file path filename (str): local checkpoint file path
@ -353,7 +353,7 @@ def load_from_http(filename,
map_location=None, map_location=None,
model_dir=None, model_dir=None,
progress=os.isatty(0)): progress=os.isatty(0)):
"""load checkpoint through HTTP or HTTPS scheme path. In distributed """Load checkpoint through HTTP or HTTPS scheme path. In distributed
setting, this function only download checkpoint at local rank 0. setting, this function only download checkpoint at local rank 0.
Args: Args:
@ -386,7 +386,7 @@ def load_from_http(filename,
@CheckpointLoader.register_scheme(prefixes='pavi://') @CheckpointLoader.register_scheme(prefixes='pavi://')
def load_from_pavi(filename, map_location=None): def load_from_pavi(filename, map_location=None):
"""load checkpoint through the file path prefixed with pavi. In distributed """Load checkpoint through the file path prefixed with pavi. In distributed
setting, this function download ckpt at all ranks to different temporary setting, this function download ckpt at all ranks to different temporary
directories. directories.
@ -419,7 +419,7 @@ def load_from_pavi(filename, map_location=None):
@CheckpointLoader.register_scheme( @CheckpointLoader.register_scheme(
prefixes=[r'(\S+\:)?s3://', r'(\S+\:)?petrel://']) prefixes=[r'(\S+\:)?s3://', r'(\S+\:)?petrel://'])
def load_from_ceph(filename, map_location=None, backend='petrel'): def load_from_ceph(filename, map_location=None, backend='petrel'):
"""load checkpoint through the file path prefixed with s3. In distributed """Load checkpoint through the file path prefixed with s3. In distributed
setting, this function download ckpt at all ranks to different temporary setting, this function download ckpt at all ranks to different temporary
directories. directories.
@ -441,7 +441,7 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://')) @CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
def load_from_torchvision(filename, map_location=None): def load_from_torchvision(filename, map_location=None):
"""load checkpoint through the file path prefixed with modelzoo or """Load checkpoint through the file path prefixed with modelzoo or
torchvision. torchvision.
Args: Args:
@ -467,7 +467,7 @@ def load_from_torchvision(filename, map_location=None):
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://')) @CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
def load_from_openmmlab(filename, map_location=None): def load_from_openmmlab(filename, map_location=None):
"""load checkpoint through the file path prefixed with open-mmlab or """Load checkpoint through the file path prefixed with open-mmlab or
openmmlab. openmmlab.
Args: Args:
@ -510,7 +510,7 @@ def load_from_openmmlab(filename, map_location=None):
@CheckpointLoader.register_scheme(prefixes='mmcls://') @CheckpointLoader.register_scheme(prefixes='mmcls://')
def load_from_mmcls(filename, map_location=None): def load_from_mmcls(filename, map_location=None):
"""load checkpoint through the file path prefixed with mmcls. """Load checkpoint through the file path prefixed with mmcls.
Args: Args:
filename (str): checkpoint file path with mmcls prefix filename (str): checkpoint file path with mmcls prefix

View File

@ -8,8 +8,10 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from mmengine.evaluator import Evaluator from mmengine.evaluator import Evaluator
from mmengine.logging import print_log from mmengine.logging import HistoryBuffer, print_log
from mmengine.registry import LOOPS from mmengine.registry import LOOPS
from mmengine.structures import BaseDataElement
from mmengine.utils import is_list_of
from .amp import autocast from .amp import autocast
from .base_loop import BaseLoop from .base_loop import BaseLoop
from .utils import calc_dynamic_intervals from .utils import calc_dynamic_intervals
@ -363,17 +365,26 @@ class ValLoop(BaseLoop):
logger='current', logger='current',
level=logging.WARNING) level=logging.WARNING)
self.fp16 = fp16 self.fp16 = fp16
self.val_loss: Dict[str, HistoryBuffer] = dict()
def run(self) -> dict: def run(self) -> dict:
"""Launch validation.""" """Launch validation."""
self.runner.call_hook('before_val') self.runner.call_hook('before_val')
self.runner.call_hook('before_val_epoch') self.runner.call_hook('before_val_epoch')
self.runner.model.eval() self.runner.model.eval()
# clear val loss
self.val_loss.clear()
for idx, data_batch in enumerate(self.dataloader): for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch) self.run_iter(idx, data_batch)
# compute metrics # compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
if self.val_loss:
loss_dict = _parse_losses(self.val_loss, 'val')
metrics.update(loss_dict)
self.runner.call_hook('after_val_epoch', metrics=metrics) self.runner.call_hook('after_val_epoch', metrics=metrics)
self.runner.call_hook('after_val') self.runner.call_hook('after_val')
return metrics return metrics
@ -391,6 +402,9 @@ class ValLoop(BaseLoop):
# outputs should be sequence of BaseDataElement # outputs should be sequence of BaseDataElement
with autocast(enabled=self.fp16): with autocast(enabled=self.fp16):
outputs = self.runner.model.val_step(data_batch) outputs = self.runner.model.val_step(data_batch)
outputs, self.val_loss = _update_losses(outputs, self.val_loss)
self.evaluator.process(data_samples=outputs, data_batch=data_batch) self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook( self.runner.call_hook(
'after_val_iter', 'after_val_iter',
@ -435,17 +449,26 @@ class TestLoop(BaseLoop):
logger='current', logger='current',
level=logging.WARNING) level=logging.WARNING)
self.fp16 = fp16 self.fp16 = fp16
self.test_loss: Dict[str, HistoryBuffer] = dict()
def run(self) -> dict: def run(self) -> dict:
"""Launch test.""" """Launch test."""
self.runner.call_hook('before_test') self.runner.call_hook('before_test')
self.runner.call_hook('before_test_epoch') self.runner.call_hook('before_test_epoch')
self.runner.model.eval() self.runner.model.eval()
# clear test loss
self.test_loss.clear()
for idx, data_batch in enumerate(self.dataloader): for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch) self.run_iter(idx, data_batch)
# compute metrics # compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
if self.test_loss:
loss_dict = _parse_losses(self.test_loss, 'test')
metrics.update(loss_dict)
self.runner.call_hook('after_test_epoch', metrics=metrics) self.runner.call_hook('after_test_epoch', metrics=metrics)
self.runner.call_hook('after_test') self.runner.call_hook('after_test')
return metrics return metrics
@ -462,9 +485,66 @@ class TestLoop(BaseLoop):
# predictions should be sequence of BaseDataElement # predictions should be sequence of BaseDataElement
with autocast(enabled=self.fp16): with autocast(enabled=self.fp16):
outputs = self.runner.model.test_step(data_batch) outputs = self.runner.model.test_step(data_batch)
outputs, self.test_loss = _update_losses(outputs, self.test_loss)
self.evaluator.process(data_samples=outputs, data_batch=data_batch) self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook( self.runner.call_hook(
'after_test_iter', 'after_test_iter',
batch_idx=idx, batch_idx=idx,
data_batch=data_batch, data_batch=data_batch,
outputs=outputs) outputs=outputs)
def _parse_losses(losses: Dict[str, HistoryBuffer],
stage: str) -> Dict[str, float]:
"""Parses the raw losses of the network.
Args:
losses (dict): raw losses of the network.
stage (str): The stage of loss, e.g., 'val' or 'test'.
Returns:
dict[str, float]: The key is the loss name, and the value is the
average loss.
"""
all_loss = 0
loss_dict: Dict[str, float] = dict()
for loss_name, loss_value in losses.items():
avg_loss = loss_value.mean()
loss_dict[loss_name] = avg_loss
if 'loss' in loss_name:
all_loss += avg_loss
loss_dict[f'{stage}_loss'] = all_loss
return loss_dict
def _update_losses(outputs: list, losses: dict) -> Tuple[list, dict]:
"""Update and record the losses of the network.
Args:
outputs (list): The outputs of the network.
losses (dict): The losses of the network.
Returns:
list: The updated outputs of the network.
dict: The updated losses of the network.
"""
if isinstance(outputs[-1],
BaseDataElement) and outputs[-1].keys() == ['loss']:
loss = outputs[-1].loss # type: ignore
outputs = outputs[:-1]
else:
loss = dict()
for loss_name, loss_value in loss.items():
if loss_name not in losses:
losses[loss_name] = HistoryBuffer()
if isinstance(loss_value, torch.Tensor):
losses[loss_name].update(loss_value.item())
elif is_list_of(loss_value, torch.Tensor):
for loss_value_i in loss_value:
losses[loss_name].update(loss_value_i.item())
return outputs, losses

View File

@ -579,7 +579,7 @@ class Runner:
@property @property
def hooks(self): def hooks(self):
"""list[:obj:`Hook`]: A list of registered hooks.""" """List[:obj:`Hook`]: A list of registered hooks."""
return self._hooks return self._hooks
@property @property
@ -720,7 +720,7 @@ class Runner:
def build_logger(self, def build_logger(self,
log_level: Union[int, str] = 'INFO', log_level: Union[int, str] = 'INFO',
log_file: str = None, log_file: Optional[str] = None,
**kwargs) -> MMLogger: **kwargs) -> MMLogger:
"""Build a global asscessable MMLogger. """Build a global asscessable MMLogger.
@ -1677,7 +1677,7 @@ class Runner:
return '\n'.join(stage_hook_infos) return '\n'.join(stage_hook_infos)
def load_or_resume(self) -> None: def load_or_resume(self) -> None:
"""load or resume checkpoint.""" """Load or resume checkpoint."""
if self._has_loaded: if self._has_loaded:
return None return None

View File

@ -387,7 +387,7 @@ class BaseDataElement:
return dict(self.metainfo_items()) return dict(self.metainfo_items())
def __setattr__(self, name: str, value: Any): def __setattr__(self, name: str, value: Any):
"""setattr is only used to set data.""" """Setattr is only used to set data."""
if name in ('_metainfo_fields', '_data_fields'): if name in ('_metainfo_fields', '_data_fields'):
if not hasattr(self, name): if not hasattr(self, name):
super().__setattr__(name, value) super().__setattr__(name, value)

View File

@ -135,7 +135,7 @@ class InstanceData(BaseDataElement):
""" """
def __setattr__(self, name: str, value: Sized): def __setattr__(self, name: str, value: Sized):
"""setattr is only used to set data. """Setattr is only used to set data.
The value must have the attribute of `__len__` and have the same length The value must have the attribute of `__len__` and have the same length
of `InstanceData`. of `InstanceData`.

View File

@ -92,10 +92,25 @@ class MultiProcessTestCase(TestCase):
# Constructor patches current instance test method to # Constructor patches current instance test method to
# assume the role of the main process and join its subprocesses, # assume the role of the main process and join its subprocesses,
# or run the underlying test function. # or run the underlying test function.
def __init__(self, method_name: str = 'runTest') -> None: def __init__(self,
method_name: str = 'runTest',
methodName: str = 'runTest') -> None:
# methodName is the correct naming in unittest
# and testslide uses keyword arguments.
# So we need to use both to 1) not break BC and, 2) support testslide.
if methodName != 'runTest':
method_name = methodName
super().__init__(method_name) super().__init__(method_name)
try:
fn = getattr(self, method_name) fn = getattr(self, method_name)
setattr(self, method_name, self.join_or_run(fn)) setattr(self, method_name, self.join_or_run(fn))
except AttributeError as e:
if methodName != 'runTest':
# we allow instantiation with no explicit method name
# but not an *incorrect* or missing method name
raise ValueError(
f'no such test method in {self.__class__}: {methodName}'
) from e
def setUp(self) -> None: def setUp(self) -> None:
super().setUp() super().setUp()

View File

@ -57,6 +57,7 @@ if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version(
check_hash=False, check_hash=False,
file_name=None): file_name=None):
r"""Loads the Torch serialized object at the given URL. r"""Loads the Torch serialized object at the given URL.
If downloaded file is a zip file, it will be automatically decompressed If downloaded file is a zip file, it will be automatically decompressed
If the object is already present in `model_dir`, it's deserialized and If the object is already present in `model_dir`, it's deserialized and
returned. returned.

View File

@ -67,7 +67,7 @@ class TimeCounter:
instance.log_interval = log_interval instance.log_interval = log_interval
instance.warmup_interval = warmup_interval instance.warmup_interval = warmup_interval
instance.with_sync = with_sync instance.with_sync = with_sync # type: ignore
instance.tag = tag instance.tag = tag
instance.logger = logger instance.logger = logger
@ -127,7 +127,7 @@ class TimeCounter:
self.print_time(elapsed) self.print_time(elapsed)
def print_time(self, elapsed: Union[int, float]) -> None: def print_time(self, elapsed: Union[int, float]) -> None:
"""print times per count.""" """Print times per count."""
if self.__count >= self.warmup_interval: if self.__count >= self.warmup_interval:
self.__pure_inf_time += elapsed self.__pure_inf_time += elapsed

View File

@ -131,7 +131,7 @@ def tuple_cast(inputs, dst_type):
def is_seq_of(seq: Any, def is_seq_of(seq: Any,
expected_type: Union[Type, tuple], expected_type: Union[Type, tuple],
seq_type: Type = None) -> bool: seq_type: Optional[Type] = None) -> bool:
"""Check whether it is a sequence of some type. """Check whether it is a sequence of some type.
Args: Args:

View File

@ -69,11 +69,11 @@ def get_installed_path(package: str) -> str:
else: else:
raise e raise e
possible_path = osp.join(pkg.location, package) possible_path = osp.join(pkg.location, package) # type: ignore
if osp.exists(possible_path): if osp.exists(possible_path):
return possible_path return possible_path
else: else:
return osp.join(pkg.location, package2module(package)) return osp.join(pkg.location, package2module(package)) # type: ignore
def package2module(package: str): def package2module(package: str):

View File

@ -3,7 +3,7 @@ import sys
from collections.abc import Iterable from collections.abc import Iterable
from multiprocessing import Pool from multiprocessing import Pool
from shutil import get_terminal_size from shutil import get_terminal_size
from typing import Callable, Sequence from typing import Callable, Optional, Sequence
from .timer import Timer from .timer import Timer
@ -54,7 +54,7 @@ class ProgressBar:
self.timer = Timer() self.timer = Timer()
def update(self, num_tasks: int = 1): def update(self, num_tasks: int = 1):
"""update progressbar. """Update progressbar.
Args: Args:
num_tasks (int): Update step size. num_tasks (int): Update step size.
@ -142,8 +142,8 @@ def init_pool(process_num, initializer=None, initargs=None):
def track_parallel_progress(func: Callable, def track_parallel_progress(func: Callable,
tasks: Sequence, tasks: Sequence,
nproc: int, nproc: int,
initializer: Callable = None, initializer: Optional[Callable] = None,
initargs: tuple = None, initargs: Optional[tuple] = None,
bar_width: int = 50, bar_width: int = 50,
chunksize: int = 1, chunksize: int = 1,
skip_first: bool = False, skip_first: bool = False,

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from multiprocessing import Pool from multiprocessing import Pool
from typing import Callable, Iterable, Sized from typing import Callable, Iterable, Optional, Sized
from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task, from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task,
TaskProgressColumn, TextColumn, TimeRemainingColumn) TaskProgressColumn, TextColumn, TimeRemainingColumn)
@ -47,7 +47,7 @@ def _tasks_with_index(tasks):
def track_progress_rich(func: Callable, def track_progress_rich(func: Callable,
tasks: Iterable = tuple(), tasks: Iterable = tuple(),
task_num: int = None, task_num: Optional[int] = None,
nproc: int = 1, nproc: int = 1,
chunksize: int = 1, chunksize: int = 1,
description: str = 'Processing', description: str = 'Processing',

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
__version__ = '0.10.4' __version__ = '0.10.7'
def parse_version_info(version_str): def parse_version_info(version_str):

View File

@ -161,7 +161,7 @@ class BaseVisBackend(metaclass=ABCMeta):
pass pass
def close(self) -> None: def close(self) -> None:
"""close an opened object.""" """Close an opened object."""
pass pass
@ -314,7 +314,7 @@ class LocalVisBackend(BaseVisBackend):
def _dump(self, value_dict: dict, file_path: str, def _dump(self, value_dict: dict, file_path: str,
file_format: str) -> None: file_format: str) -> None:
"""dump dict to file. """Dump dict to file.
Args: Args:
value_dict (dict) : The dict data to saved. value_dict (dict) : The dict data to saved.
@ -505,7 +505,7 @@ class WandbVisBackend(BaseVisBackend):
self._wandb.log(scalar_dict, commit=self._commit) self._wandb.log(scalar_dict, commit=self._commit)
def close(self) -> None: def close(self) -> None:
"""close an opened wandb object.""" """Close an opened wandb object."""
if hasattr(self, '_wandb'): if hasattr(self, '_wandb'):
self._wandb.join() self._wandb.join()
@ -629,7 +629,7 @@ class TensorboardVisBackend(BaseVisBackend):
self.add_scalar(key, value, step) self.add_scalar(key, value, step)
def close(self): def close(self):
"""close an opened tensorboard object.""" """Close an opened tensorboard object."""
if hasattr(self, '_tensorboard'): if hasattr(self, '_tensorboard'):
self._tensorboard.close() self._tensorboard.close()
@ -1135,7 +1135,7 @@ class NeptuneVisBackend(BaseVisBackend):
self._neptune[k].append(v, step=step) self._neptune[k].append(v, step=step)
def close(self) -> None: def close(self) -> None:
"""close an opened object.""" """Close an opened object."""
if hasattr(self, '_neptune'): if hasattr(self, '_neptune'):
self._neptune.stop() self._neptune.stop()
@ -1282,7 +1282,7 @@ class DVCLiveVisBackend(BaseVisBackend):
self.add_scalar(key, value, step, **kwargs) self.add_scalar(key, value, step, **kwargs)
def close(self) -> None: def close(self) -> None:
"""close an opened dvclive object.""" """Close an opened dvclive object."""
if not hasattr(self, '_dvclive'): if not hasattr(self, '_dvclive'):
return return

View File

@ -356,7 +356,7 @@ class Visualizer(ManagerMixin):
@master_only @master_only
def get_backend(self, name) -> 'BaseVisBackend': def get_backend(self, name) -> 'BaseVisBackend':
"""get vis backend by name. """Get vis backend by name.
Args: Args:
name (str): The name of vis backend name (str): The name of vis backend
@ -1145,7 +1145,7 @@ class Visualizer(ManagerMixin):
pass pass
def close(self) -> None: def close(self) -> None:
"""close an opened object.""" """Close an opened object."""
for vis_backend in self._vis_backends.values(): for vis_backend in self._vis_backends.values():
vis_backend.close() vis_backend.close()

View File

@ -843,8 +843,8 @@ class TestConfig:
assert cfg_dict['item4'] == 'test' assert cfg_dict['item4'] == 'test'
assert '_delete_' not in cfg_dict['item1'] assert '_delete_' not in cfg_dict['item1']
assert type(cfg_dict['item1']) == ConfigDict assert type(cfg_dict['item1']) is ConfigDict
assert type(cfg_dict['item2']) == ConfigDict assert type(cfg_dict['item2']) is ConfigDict
def _merge_intermediate_variable(self): def _merge_intermediate_variable(self):

View File

@ -13,7 +13,8 @@ from mmengine.testing import assert_allclose
class TestAveragedModel(TestCase): class TestAveragedModel(TestCase):
"""Test the AveragedModel class. """Test the AveragedModel class.
Some test cases are referenced from https://github.com/pytorch/pytorch/blob/master/test/test_optim.py Some test cases are referenced from
https://github.com/pytorch/pytorch/blob/master/test/test_optim.py
""" # noqa: E501 """ # noqa: E501
def _test_swa_model(self, net_device, avg_device): def _test_swa_model(self, net_device, avg_device):

View File

@ -102,7 +102,7 @@ class TestLRScheduler(TestCase):
rtol=0) rtol=0)
def test_scheduler_before_optim_warning(self): def test_scheduler_before_optim_warning(self):
"""warns if scheduler is used before optimizer.""" """Warns if scheduler is used before optimizer."""
def call_sch_before_optim(): def call_sch_before_optim():
scheduler = StepLR(self.optimizer, gamma=0.1, step_size=3) scheduler = StepLR(self.optimizer, gamma=0.1, step_size=3)

View File

@ -120,7 +120,7 @@ class TestMomentumScheduler(TestCase):
rtol=0) rtol=0)
def test_scheduler_before_optim_warning(self): def test_scheduler_before_optim_warning(self):
"""warns if scheduler is used before optimizer.""" """Warns if scheduler is used before optimizer."""
def call_sch_before_optim(): def call_sch_before_optim():
scheduler = StepMomentum(self.optimizer, gamma=0.1, step_size=3) scheduler = StepMomentum(self.optimizer, gamma=0.1, step_size=3)

View File

@ -127,7 +127,7 @@ class TestParameterScheduler(TestCase):
rtol=0) rtol=0)
def test_scheduler_before_optim_warning(self): def test_scheduler_before_optim_warning(self):
"""warns if scheduler is used before optimizer.""" """Warns if scheduler is used before optimizer."""
def call_sch_before_optim(): def call_sch_before_optim():
scheduler = StepParamScheduler( scheduler = StepParamScheduler(

View File

@ -251,7 +251,7 @@ def test_load_checkpoint_metadata():
def _load_from_state_dict(self, state_dict, prefix, local_metadata, def _load_from_state_dict(self, state_dict, prefix, local_metadata,
*args, **kwargs): *args, **kwargs):
"""load checkpoints.""" """Load checkpoints."""
# Names of some parameters in has been changed. # Names of some parameters in has been changed.
version = local_metadata.get('version', None) version = local_metadata.get('version', None)

View File

@ -2226,7 +2226,7 @@ class TestRunner(TestCase):
@HOOKS.register_module(force=True) @HOOKS.register_module(force=True)
class TestWarmupHook(Hook): class TestWarmupHook(Hook):
"""test custom train loop.""" """Test custom train loop."""
def before_warmup_iter(self, runner, data_batch=None): def before_warmup_iter(self, runner, data_batch=None):
before_warmup_iter_results.append('before') before_warmup_iter_results.append('before')

View File

@ -64,7 +64,7 @@ class TestBaseDataElement(TestCase):
return metainfo, data return metainfo, data
def is_equal(self, x, y): def is_equal(self, x, y):
assert type(x) == type(y) assert type(x) is type(y)
if isinstance( if isinstance(
x, (int, float, str, list, tuple, dict, set, BaseDataElement)): x, (int, float, str, list, tuple, dict, set, BaseDataElement)):
return x == y return x == y
@ -141,7 +141,7 @@ class TestBaseDataElement(TestCase):
# test new() with no arguments # test new() with no arguments
new_instances = instances.new() new_instances = instances.new()
assert type(new_instances) == type(instances) assert type(new_instances) is type(instances)
# After deepcopy, the address of new data'element will be same as # After deepcopy, the address of new data'element will be same as
# origin, but when change new data' element will not effect the origin # origin, but when change new data' element will not effect the origin
# element and will have new address # element and will have new address
@ -154,7 +154,7 @@ class TestBaseDataElement(TestCase):
# test new() with arguments # test new() with arguments
metainfo, data = self.setup_data() metainfo, data = self.setup_data()
new_instances = instances.new(metainfo=metainfo, **data) new_instances = instances.new(metainfo=metainfo, **data)
assert type(new_instances) == type(instances) assert type(new_instances) is type(instances)
assert id(new_instances.gt_instances) != id(instances.gt_instances) assert id(new_instances.gt_instances) != id(instances.gt_instances)
_, new_data = self.setup_data() _, new_data = self.setup_data()
new_instances.set_data(new_data) new_instances.set_data(new_data)
@ -168,7 +168,7 @@ class TestBaseDataElement(TestCase):
metainfo, data = self.setup_data() metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, **data) instances = BaseDataElement(metainfo=metainfo, **data)
new_instances = instances.clone() new_instances = instances.clone()
assert type(new_instances) == type(instances) assert type(new_instances) is type(instances)
def test_set_metainfo(self): def test_set_metainfo(self):
metainfo, _ = self.setup_data() metainfo, _ = self.setup_data()

View File

@ -45,7 +45,7 @@ class MockVisBackend:
self._add_scalars = True self._add_scalars = True
def close(self) -> None: def close(self) -> None:
"""close an opened object.""" """Close an opened object."""
self._close = True self._close = True