Compare commits
17 Commits
Author | SHA1 | Date |
---|---|---|
|
390ba2fbb2 | |
|
d620552c2c | |
|
41fa84a9a9 | |
|
698782f920 | |
|
e60ab1dde3 | |
|
8ec837814e | |
|
a4475f5eea | |
|
a8c74c346d | |
|
9124ebf7a2 | |
|
2e0ab7a922 | |
|
fc59364d64 | |
|
4183cf0829 | |
|
cc3b74b5e8 | |
|
c9b59962d6 | |
|
5e736b143b | |
|
85c83ba616 | |
|
d1f1aabf81 |
.github/workflows
docs/en/notes
mmengine
_strategy
device
dist
evaluator
hooks
logging
model
structures
testing/_internal
visualization
tests
test_config
test_fileio/test_backends
test_model
test_optim/test_scheduler
test_runner
test_structures
test_visualizer
|
@ -1,6 +1,8 @@
|
|||
name: deploy
|
||||
|
||||
on: push
|
||||
on:
|
||||
- push
|
||||
- workflow_dispatch
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
@ -9,13 +11,14 @@ concurrency:
|
|||
jobs:
|
||||
build-n-publish:
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.event.ref, 'refs/tags')
|
||||
if: |
|
||||
startsWith(github.event.ref, 'refs/tags') || github.event_name == 'workflow_dispatch'
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.7
|
||||
uses: actions/setup-python@v2
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python 3.10.13
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.7
|
||||
python-version: 3.10.13
|
||||
- name: Install wheel
|
||||
run: pip install wheel
|
||||
- name: Build MMEngine
|
||||
|
@ -27,13 +30,14 @@ jobs:
|
|||
|
||||
build-n-publish-lite:
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.event.ref, 'refs/tags')
|
||||
if: |
|
||||
startsWith(github.event.ref, 'refs/tags') || github.event_name == 'workflow_dispatch'
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.7
|
||||
uses: actions/setup-python@v2
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python 3.10.13
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.7
|
||||
python-version: 3.10.13
|
||||
- name: Install wheel
|
||||
run: pip install wheel
|
||||
- name: Build MMEngine-lite
|
||||
|
|
|
@ -11,10 +11,10 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.7
|
||||
- name: Set up Python 3.10.15
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.7
|
||||
python-version: '3.10.15'
|
||||
- name: Install pre-commit hook
|
||||
run: |
|
||||
pip install pre-commit
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
name: pr_stage_test
|
||||
|
||||
env:
|
||||
ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true
|
||||
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths-ignore:
|
||||
|
@ -21,152 +25,114 @@ concurrency:
|
|||
jobs:
|
||||
build_cpu:
|
||||
runs-on: ubuntu-22.04
|
||||
defaults:
|
||||
run:
|
||||
shell: bash -l {0}
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.7]
|
||||
include:
|
||||
- torch: 1.8.1
|
||||
torchvision: 0.9.1
|
||||
python-version: ['3.9']
|
||||
torch: ['2.0.0']
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
- name: Check out repo
|
||||
uses: actions/checkout@v3
|
||||
- name: Setup conda env
|
||||
uses: conda-incubator/setup-miniconda@v2
|
||||
with:
|
||||
auto-update-conda: true
|
||||
miniconda-version: "latest"
|
||||
use-only-tar-bz2: true
|
||||
activate-environment: test
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Upgrade pip
|
||||
run: python -m pip install pip --upgrade
|
||||
- name: Upgrade wheel
|
||||
run: python -m pip install wheel --upgrade
|
||||
- name: Install PyTorch
|
||||
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
- name: Build MMEngine from source
|
||||
run: pip install -e . -v
|
||||
- name: Install unit tests dependencies
|
||||
- name: Update pip
|
||||
run: |
|
||||
pip install -r requirements/tests.txt
|
||||
pip install openmim
|
||||
mim install mmcv
|
||||
- name: Run unittests and generate coverage report
|
||||
python -m pip install --upgrade pip wheel
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist
|
||||
coverage xml
|
||||
coverage report -m
|
||||
# Upload coverage report for python3.7 && pytorch1.8.1 cpu
|
||||
- name: Upload coverage to Codecov
|
||||
apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6
|
||||
python -m pip install torch==${{matrix.torch}}
|
||||
python -m pip install -e . -v
|
||||
python -m pip install -r requirements/tests.txt
|
||||
python -m pip install openmim
|
||||
mim install mmcv coverage
|
||||
- name: Run unit tests with coverage
|
||||
run: coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist
|
||||
- name: Upload Coverage to Codecov
|
||||
uses: codecov/codecov-action@v3
|
||||
with:
|
||||
file: ./coverage.xml
|
||||
flags: unittests
|
||||
env_vars: OS,PYTHON
|
||||
name: codecov-umbrella
|
||||
fail_ci_if_error: false
|
||||
|
||||
build_cu102:
|
||||
build_gpu:
|
||||
runs-on: ubuntu-22.04
|
||||
container:
|
||||
image: pytorch/pytorch:1.8.1-cuda10.2-cudnn7-devel
|
||||
defaults:
|
||||
run:
|
||||
shell: bash -l {0}
|
||||
env:
|
||||
MKL_THREADING_LAYER: GNU
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.7]
|
||||
python-version: ['3.9','3.10']
|
||||
torch: ['2.0.0','2.3.1','2.5.1']
|
||||
cuda: ['cu118']
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
- name: Check out repo
|
||||
uses: actions/checkout@v3
|
||||
- name: Setup conda env
|
||||
uses: conda-incubator/setup-miniconda@v2
|
||||
with:
|
||||
auto-update-conda: true
|
||||
miniconda-version: "latest"
|
||||
use-only-tar-bz2: true
|
||||
activate-environment: test
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Upgrade pip
|
||||
run: pip install pip --upgrade
|
||||
- name: Fetch GPG keys
|
||||
- name: Update pip
|
||||
run: |
|
||||
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
|
||||
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
|
||||
- name: Install system dependencies
|
||||
run: apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6
|
||||
- name: Build MMEngine from source
|
||||
run: pip install -e . -v
|
||||
- name: Install unit tests dependencies
|
||||
python -m pip install --upgrade pip wheel
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -r requirements/tests.txt
|
||||
pip install openmim
|
||||
mim install mmcv
|
||||
- name: Run unittests and generate coverage report
|
||||
run: |
|
||||
coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist
|
||||
coverage xml
|
||||
coverage report -m
|
||||
apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6
|
||||
python -m pip install torch==${{matrix.torch}} --index-url https://download.pytorch.org/whl/${{matrix.cuda}}
|
||||
python -m pip install -e . -v
|
||||
python -m pip install -r requirements/tests.txt
|
||||
python -m pip install openmim
|
||||
mim install mmcv coverage
|
||||
- name: Run unit tests with coverage
|
||||
run: coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist
|
||||
- name: Upload Coverage to Codecov
|
||||
uses: codecov/codecov-action@v3
|
||||
|
||||
build_cu117:
|
||||
runs-on: ubuntu-22.04
|
||||
container:
|
||||
image: pytorch/pytorch:2.0.0-cuda11.7-cudnn8-devel
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.9]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Upgrade pip
|
||||
run: pip install pip --upgrade
|
||||
- name: Fetch GPG keys
|
||||
run: |
|
||||
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
|
||||
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
|
||||
- name: Install system dependencies
|
||||
run: apt-get update && apt-get install -y git ffmpeg libturbojpeg
|
||||
- name: Build MMEngine from source
|
||||
run: pip install -e . -v
|
||||
- name: Install unit tests dependencies
|
||||
run: |
|
||||
pip install -r requirements/tests.txt
|
||||
pip install openmim
|
||||
mim install mmcv
|
||||
# Distributed related unit test may randomly error in PyTorch 1.13.0
|
||||
- name: Run unittests and generate coverage report
|
||||
run: |
|
||||
coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist/
|
||||
coverage xml
|
||||
coverage report -m
|
||||
|
||||
build_windows:
|
||||
runs-on: windows-2022
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.7]
|
||||
platform: [cpu, cu111]
|
||||
torch: [1.8.1]
|
||||
torchvision: [0.9.1]
|
||||
include:
|
||||
- python-version: 3.8
|
||||
platform: cu118
|
||||
torch: 2.1.0
|
||||
torchvision: 0.16.0
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Upgrade pip
|
||||
# Windows CI could fail If we call `pip install pip --upgrade` directly.
|
||||
run: python -m pip install pip --upgrade
|
||||
- name: Install PyTorch
|
||||
run: pip install torch==${{matrix.torch}}+${{matrix.platform}} torchvision==${{matrix.torchvision}}+${{matrix.platform}} -f https://download.pytorch.org/whl/${{matrix.platform}}/torch_stable.html
|
||||
- name: Build MMEngine from source
|
||||
run: pip install -e . -v
|
||||
- name: Install unit tests dependencies
|
||||
run: |
|
||||
pip install -r requirements/tests.txt
|
||||
pip install openmim
|
||||
mim install mmcv
|
||||
- name: Run CPU unittests
|
||||
run: pytest tests/ --ignore tests/test_dist
|
||||
if: ${{ matrix.platform == 'cpu' }}
|
||||
- name: Run GPU unittests
|
||||
# Skip testing distributed related unit tests since the memory of windows CI is limited
|
||||
run: pytest tests/ --ignore tests/test_dist --ignore tests/test_optim/test_optimizer/test_optimizer_wrapper.py --ignore tests/test_model/test_wrappers/test_model_wrapper.py --ignore tests/test_hooks/test_sync_buffers_hook.py
|
||||
if: ${{ matrix.platform == 'cu111' }} || ${{ matrix.platform == 'cu118' }}
|
||||
# build_windows:
|
||||
# runs-on: windows-2022
|
||||
# strategy:
|
||||
# matrix:
|
||||
# python-version: [3.9]
|
||||
# platform: [cpu, cu111]
|
||||
# torch: [1.8.1]
|
||||
# torchvision: [0.9.1]
|
||||
# include:
|
||||
# - python-version: 3.8
|
||||
# platform: cu118
|
||||
# torch: 2.1.0
|
||||
# torchvision: 0.16.0
|
||||
# steps:
|
||||
# - uses: actions/checkout@v3
|
||||
# - name: Set up Python ${{ matrix.python-version }}
|
||||
# uses: actions/setup-python@v4
|
||||
# with:
|
||||
# python-version: ${{ matrix.python-version }}
|
||||
# - name: Upgrade pip
|
||||
# # Windows CI could fail If we call `pip install pip --upgrade` directly.
|
||||
# run: python -m pip install pip wheel --upgrade
|
||||
# - name: Install PyTorch
|
||||
# run: pip install torch==${{matrix.torch}}+${{matrix.platform}} torchvision==${{matrix.torchvision}}+${{matrix.platform}} -f https://download.pytorch.org/whl/${{matrix.platform}}/torch_stable.html
|
||||
# - name: Build MMEngine from source
|
||||
# run: pip install -e . -v
|
||||
# - name: Install unit tests dependencies
|
||||
# run: |
|
||||
# pip install -r requirements/tests.txt
|
||||
# pip install openmim
|
||||
# mim install mmcv
|
||||
# - name: Run CPU unittests
|
||||
# run: pytest tests/ --ignore tests/test_dist
|
||||
# if: ${{ matrix.platform == 'cpu' }}
|
||||
# - name: Run GPU unittests
|
||||
# # Skip testing distributed related unit tests since the memory of windows CI is limited
|
||||
# run: pytest tests/ --ignore tests/test_dist --ignore tests/test_optim/test_optimizer/test_optimizer_wrapper.py --ignore tests/test_model/test_wrappers/test_model_wrapper.py --ignore tests/test_hooks/test_sync_buffers_hook.py
|
||||
# if: ${{ matrix.platform == 'cu111' }} || ${{ matrix.platform == 'cu118' }}
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
exclude: ^tests/data/
|
||||
repos:
|
||||
- repo: https://gitee.com/openmmlab/mirrors-flake8
|
||||
rev: 5.0.4
|
||||
- repo: https://github.com/pre-commit/pre-commit
|
||||
rev: v4.0.0
|
||||
hooks:
|
||||
- id: validate_manifest
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 7.1.1
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://gitee.com/openmmlab/mirrors-isort
|
||||
|
@ -13,7 +17,7 @@ repos:
|
|||
hooks:
|
||||
- id: yapf
|
||||
- repo: https://gitee.com/openmmlab/mirrors-pre-commit-hooks
|
||||
rev: v4.3.0
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: check-yaml
|
||||
|
@ -55,7 +59,7 @@ repos:
|
|||
args: ["mmengine", "tests"]
|
||||
- id: remove-improper-eol-in-cn-docs
|
||||
- repo: https://gitee.com/openmmlab/mirrors-mypy
|
||||
rev: v0.812
|
||||
rev: v1.2.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
exclude: |-
|
||||
|
@ -63,3 +67,4 @@ repos:
|
|||
^examples
|
||||
| ^docs
|
||||
)
|
||||
additional_dependencies: ["types-setuptools", "types-requests", "types-PyYAML"]
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
exclude: ^tests/data/
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit
|
||||
rev: v4.0.0
|
||||
hooks:
|
||||
- id: validate_manifest
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 5.0.4
|
||||
rev: 7.1.1
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
|
@ -13,7 +17,7 @@ repos:
|
|||
hooks:
|
||||
- id: yapf
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.3.0
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: check-yaml
|
||||
|
@ -34,12 +38,8 @@ repos:
|
|||
- mdformat-openmmlab
|
||||
- mdformat_frontmatter
|
||||
- linkify-it-py
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.2.1
|
||||
hooks:
|
||||
- id: codespell
|
||||
- repo: https://github.com/myint/docformatter
|
||||
rev: v1.3.1
|
||||
rev: 06907d0
|
||||
hooks:
|
||||
- id: docformatter
|
||||
args: ["--in-place", "--wrap-descriptions", "79"]
|
||||
|
@ -55,7 +55,7 @@ repos:
|
|||
args: ["mmengine", "tests"]
|
||||
- id: remove-improper-eol-in-cn-docs
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v0.812
|
||||
rev: v1.2.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
exclude: |-
|
||||
|
@ -63,3 +63,4 @@ repos:
|
|||
^examples
|
||||
| ^docs
|
||||
)
|
||||
additional_dependencies: ["types-setuptools", "types-requests", "types-PyYAML"]
|
||||
|
|
|
@ -59,7 +59,7 @@ English | [简体中文](README_zh-CN.md)
|
|||
|
||||
## What's New
|
||||
|
||||
v0.10.4 was released on 2024-4-23.
|
||||
v0.10.6 was released on 2025-01-13.
|
||||
|
||||
Highlights:
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@
|
|||
|
||||
## 最近进展
|
||||
|
||||
最新版本 v0.10.4 在 2024.4.23 发布。
|
||||
最新版本 v0.10.5 在 2024.9.11 发布。
|
||||
|
||||
版本亮点:
|
||||
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
# Changelog of v0.x
|
||||
|
||||
## v0.10.5 (11/9/2024)
|
||||
|
||||
- Fix `_is_builtin_module`. by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/1571
|
||||
|
||||
## v0.10.4 (23/4/2024)
|
||||
|
||||
### New Features & Enhancements
|
||||
|
|
|
@ -499,7 +499,7 @@ class BaseStrategy(metaclass=ABCMeta):
|
|||
'"type" and "constructor" are not in '
|
||||
f'optimizer, but got {name}={optim}')
|
||||
optim_wrappers[name] = optim
|
||||
return OptimWrapperDict(**optim_wrappers)
|
||||
return OptimWrapperDict(**optim_wrappers) # type: ignore
|
||||
else:
|
||||
raise TypeError('optimizer wrapper should be an OptimWrapper '
|
||||
f'object or dict, but got {optim_wrapper}')
|
||||
|
|
|
@ -361,7 +361,7 @@ class ColossalAIStrategy(BaseStrategy):
|
|||
map_location: Union[str, Callable] = 'default',
|
||||
callback: Optional[Callable] = None,
|
||||
) -> dict:
|
||||
"""override this method since colossalai resume optimizer from filename
|
||||
"""Override this method since colossalai resume optimizer from filename
|
||||
directly."""
|
||||
self.logger.info(f'Resume checkpoint from {filename}')
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ class DDPStrategy(SingleDeviceStrategy):
|
|||
init_dist(launcher, backend, **kwargs)
|
||||
|
||||
def convert_model(self, model: nn.Module) -> nn.Module:
|
||||
"""convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm``
|
||||
"""Convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm``
|
||||
(SyncBN) or ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -393,7 +393,7 @@ class Config:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
cfg_dict: dict = None,
|
||||
cfg_dict: Optional[dict] = None,
|
||||
cfg_text: Optional[str] = None,
|
||||
filename: Optional[Union[str, Path]] = None,
|
||||
env_variables: Optional[dict] = None,
|
||||
|
@ -1227,7 +1227,8 @@ class Config:
|
|||
if base_code is not None:
|
||||
base_code = ast.Expression( # type: ignore
|
||||
body=base_code.value) # type: ignore
|
||||
base_files = eval(compile(base_code, '', mode='eval'))
|
||||
base_files = eval(compile(base_code, '',
|
||||
mode='eval')) # type: ignore
|
||||
else:
|
||||
base_files = []
|
||||
elif file_format in ('.yml', '.yaml', '.json'):
|
||||
|
@ -1288,7 +1289,7 @@ class Config:
|
|||
def _merge_a_into_b(a: dict,
|
||||
b: dict,
|
||||
allow_list_keys: bool = False) -> dict:
|
||||
"""merge dict ``a`` into dict ``b`` (non-inplace).
|
||||
"""Merge dict ``a`` into dict ``b`` (non-inplace).
|
||||
|
||||
Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
|
||||
in-place modifications.
|
||||
|
@ -1358,22 +1359,22 @@ class Config:
|
|||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
"""get file name of config."""
|
||||
"""Get file name of config."""
|
||||
return self._filename
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
"""get config text."""
|
||||
"""Get config text."""
|
||||
return self._text
|
||||
|
||||
@property
|
||||
def env_variables(self) -> dict:
|
||||
"""get used environment variables."""
|
||||
"""Get used environment variables."""
|
||||
return self._env_variables
|
||||
|
||||
@property
|
||||
def pretty_text(self) -> str:
|
||||
"""get formatted python config text."""
|
||||
"""Get formatted python config text."""
|
||||
|
||||
indent = 4
|
||||
|
||||
|
@ -1727,17 +1728,17 @@ class Config:
|
|||
|
||||
|
||||
class DictAction(Action):
|
||||
"""
|
||||
argparse action to split an argument into KEY=VALUE form
|
||||
on the first = and append to a dictionary. List options can
|
||||
be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
|
||||
brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
|
||||
list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
|
||||
"""Argparse action to split an argument into KEY=VALUE form on the first =
|
||||
and append to a dictionary.
|
||||
|
||||
List options can be passed as comma separated values, i.e 'KEY=V1,V2,V3',
|
||||
or with explicit brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested
|
||||
brackets to build list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _parse_int_float_bool(val: str) -> Union[int, float, bool, Any]:
|
||||
"""parse int/float/bool value in the string."""
|
||||
"""Parse int/float/bool value in the string."""
|
||||
try:
|
||||
return int(val)
|
||||
except ValueError:
|
||||
|
@ -1822,7 +1823,7 @@ class DictAction(Action):
|
|||
parser: ArgumentParser,
|
||||
namespace: Namespace,
|
||||
values: Union[str, Sequence[Any], None],
|
||||
option_string: str = None):
|
||||
option_string: str = None): # type: ignore
|
||||
"""Parse Variables in string and add them into argparser.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -12,6 +12,7 @@ from mmengine.fileio import load
|
|||
from mmengine.utils import check_file_exist
|
||||
|
||||
PYTHON_ROOT_DIR = osp.dirname(osp.dirname(sys.executable))
|
||||
SYSTEM_PYTHON_PREFIX = '/usr/lib/python'
|
||||
|
||||
MODULE2PACKAGE = {
|
||||
'mmcls': 'mmcls',
|
||||
|
@ -176,7 +177,8 @@ def _is_builtin_module(module_name: str) -> bool:
|
|||
return False
|
||||
origin_path = osp.abspath(origin_path)
|
||||
if ('site-package' in origin_path or 'dist-package' in origin_path
|
||||
or not origin_path.startswith(PYTHON_ROOT_DIR)):
|
||||
or not origin_path.startswith(
|
||||
(PYTHON_ROOT_DIR, SYSTEM_PYTHON_PREFIX))):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
|
|
@ -16,6 +16,12 @@ try:
|
|||
except Exception:
|
||||
IS_NPU_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import torch_mlu # noqa: F401
|
||||
IS_MLU_AVAILABLE = hasattr(torch, 'mlu') and torch.mlu.is_available()
|
||||
except Exception:
|
||||
IS_MLU_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import torch_dipu # noqa: F401
|
||||
IS_DIPU_AVAILABLE = True
|
||||
|
@ -64,7 +70,7 @@ def is_npu_available() -> bool:
|
|||
|
||||
def is_mlu_available() -> bool:
|
||||
"""Returns True if Cambricon PyTorch and mlu devices exist."""
|
||||
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
|
||||
return IS_MLU_AVAILABLE
|
||||
|
||||
|
||||
def is_mps_available() -> bool:
|
||||
|
|
|
@ -563,7 +563,7 @@ def cast_data_device(
|
|||
Tensor or list or dict: ``data`` was casted to ``device``.
|
||||
"""
|
||||
if out is not None:
|
||||
if type(data) != type(out):
|
||||
if type(data) is not type(out):
|
||||
raise TypeError(
|
||||
'out should be the same type with data, but got data is '
|
||||
f'{type(data)} and out is {type(data)}')
|
||||
|
|
|
@ -175,11 +175,11 @@ class DumpResults(BaseMetric):
|
|||
self.out_file_path = out_file_path
|
||||
|
||||
def process(self, data_batch: Any, predictions: Sequence[dict]) -> None:
|
||||
"""transfer tensors in predictions to CPU."""
|
||||
"""Transfer tensors in predictions to CPU."""
|
||||
self.results.extend(_to_cpu(predictions))
|
||||
|
||||
def compute_metrics(self, results: list) -> dict:
|
||||
"""dump the prediction results to a pickle file."""
|
||||
"""Dump the prediction results to a pickle file."""
|
||||
dump(results, self.out_file_path)
|
||||
print_log(
|
||||
f'Results has been saved to {self.out_file_path}.',
|
||||
|
@ -188,7 +188,7 @@ class DumpResults(BaseMetric):
|
|||
|
||||
|
||||
def _to_cpu(data: Any) -> Any:
|
||||
"""transfer all tensors and BaseDataElement to cpu."""
|
||||
"""Transfer all tensors and BaseDataElement to cpu."""
|
||||
if isinstance(data, (Tensor, BaseDataElement)):
|
||||
return data.to('cpu')
|
||||
elif isinstance(data, list):
|
||||
|
|
|
@ -233,7 +233,7 @@ class ProfilerHook(Hook):
|
|||
self._export_chrome_trace(runner)
|
||||
|
||||
def after_train_iter(self, runner, batch_idx, data_batch, outputs):
|
||||
"""profiler will call `step` method if it is not closed."""
|
||||
"""Profiler will call `step` method if it is not closed."""
|
||||
if not self._closed:
|
||||
self.profiler.step()
|
||||
if runner.iter == self.profile_times - 1 and not self.by_epoch:
|
||||
|
|
|
@ -58,7 +58,7 @@ class HistoryBuffer:
|
|||
self._statistics_methods.setdefault('mean', HistoryBuffer.mean)
|
||||
|
||||
def update(self, log_val: Union[int, float], count: int = 1) -> None:
|
||||
"""update the log history.
|
||||
"""Update the log history.
|
||||
|
||||
If the length of the buffer exceeds ``self._max_length``, the oldest
|
||||
element will be removed from the buffer.
|
||||
|
|
|
@ -443,8 +443,7 @@ def _get_host_info() -> str:
|
|||
host = f'{getuser()}@{gethostname()}'
|
||||
except Exception as e:
|
||||
warnings.warn(f'Host or user not found: {str(e)}')
|
||||
finally:
|
||||
return host
|
||||
return host
|
||||
|
||||
|
||||
def _get_logging_file_handlers() -> Dict:
|
||||
|
|
|
@ -253,17 +253,17 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
|
|||
dict or list: Data in the same format as the model input.
|
||||
"""
|
||||
data = self.cast_data(data) # type: ignore
|
||||
_batch_inputs = data['inputs']
|
||||
_batch_inputs = data['inputs'] # type: ignore
|
||||
# Process data with `pseudo_collate`.
|
||||
if is_seq_of(_batch_inputs, torch.Tensor):
|
||||
batch_inputs = []
|
||||
for _batch_input in _batch_inputs:
|
||||
# channel transform
|
||||
if self._channel_conversion:
|
||||
_batch_input = _batch_input[[2, 1, 0], ...]
|
||||
_batch_input = _batch_input[[2, 1, 0], ...] # type: ignore
|
||||
# Convert to float after channel conversion to ensure
|
||||
# efficiency
|
||||
_batch_input = _batch_input.float()
|
||||
_batch_input = _batch_input.float() # type: ignore
|
||||
# Normalization.
|
||||
if self._enable_normalize:
|
||||
if self.mean.shape[0] == 3:
|
||||
|
@ -302,7 +302,7 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
|
|||
else:
|
||||
raise TypeError('Output of `cast_data` should be a dict of '
|
||||
'list/tuple with inputs and data_samples, '
|
||||
f'but got {type(data)}: {data}')
|
||||
data['inputs'] = batch_inputs
|
||||
data.setdefault('data_samples', None)
|
||||
return data
|
||||
f'but got {type(data)}: {data}') # type: ignore
|
||||
data['inputs'] = batch_inputs # type: ignore
|
||||
data.setdefault('data_samples', None) # type: ignore
|
||||
return data # type: ignore
|
||||
|
|
|
@ -119,7 +119,7 @@ def caffe2_xavier_init(module, bias=0):
|
|||
|
||||
|
||||
def bias_init_with_prob(prior_prob):
|
||||
"""initialize conv/fc bias value according to a given probability value."""
|
||||
"""Initialize conv/fc bias value according to a given probability value."""
|
||||
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
|
||||
return bias_init
|
||||
|
||||
|
@ -662,12 +662,12 @@ def trunc_normal_(tensor: Tensor,
|
|||
std: float = 1.,
|
||||
a: float = -2.,
|
||||
b: float = 2.) -> Tensor:
|
||||
r"""Fills the input Tensor with values drawn from a truncated
|
||||
normal distribution. The values are effectively drawn from the
|
||||
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
||||
with values outside :math:`[a, b]` redrawn until they are within
|
||||
the bounds. The method used for generating the random values works
|
||||
best when :math:`a \leq \text{mean} \leq b`.
|
||||
r"""Fills the input Tensor with values drawn from a truncated normal
|
||||
distribution. The values are effectively drawn from the normal distribution
|
||||
:math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside
|
||||
:math:`[a, b]` redrawn until they are within the bounds. The method used
|
||||
for generating the random values works best when :math:`a \leq \text{mean}
|
||||
\leq b`.
|
||||
|
||||
Modified from
|
||||
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
|
||||
|
|
|
@ -127,7 +127,8 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
|||
auto_wrap_policy: Union[str, Callable, None] = None,
|
||||
backward_prefetch: Union[str, BackwardPrefetch, None] = None,
|
||||
mixed_precision: Union[dict, MixedPrecision, None] = None,
|
||||
param_init_fn: Union[str, Callable[[nn.Module], None]] = None,
|
||||
param_init_fn: Union[str, Callable[
|
||||
[nn.Module], None]] = None, # type: ignore # noqa: E501
|
||||
use_orig_params: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -362,7 +363,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
|||
optim: torch.optim.Optimizer,
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""copied from pytorch 2.0.1 which has fixed some bugs."""
|
||||
"""Copied from pytorch 2.0.1 which has fixed some bugs."""
|
||||
state_dict_settings = FullyShardedDataParallel.get_state_dict_type(
|
||||
model)
|
||||
return FullyShardedDataParallel._optim_state_dict_impl(
|
||||
|
@ -384,7 +385,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
|||
state_dict_config: Optional[StateDictConfig] = None,
|
||||
optim_state_dict_config: Optional[OptimStateDictConfig] = None,
|
||||
) -> StateDictSettings:
|
||||
"""copied from pytorch 2.0.1 which has fixed some bugs."""
|
||||
"""Copied from pytorch 2.0.1 which has fixed some bugs."""
|
||||
import torch.distributed.fsdp._traversal_utils as traversal_utils
|
||||
_state_dict_type_to_config = {
|
||||
StateDictType.FULL_STATE_DICT: FullStateDictConfig,
|
||||
|
|
|
@ -123,8 +123,7 @@ class ApexOptimWrapper(OptimWrapper):
|
|||
self._inner_count += 1
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
"""Get the state dictionary of :attr:`optimizer` and
|
||||
:attr:`apex_amp`.
|
||||
"""Get the state dictionary of :attr:`optimizer` and :attr:`apex_amp`.
|
||||
|
||||
Based on the state dictionary of the optimizer, the returned state
|
||||
dictionary will add a key named "apex_amp".
|
||||
|
|
|
@ -25,7 +25,11 @@ def register_torch_optimizers() -> List[str]:
|
|||
_optim = getattr(torch.optim, module_name)
|
||||
if inspect.isclass(_optim) and issubclass(_optim,
|
||||
torch.optim.Optimizer):
|
||||
OPTIMIZERS.register_module(module=_optim)
|
||||
if module_name == 'Adafactor':
|
||||
OPTIMIZERS.register_module(
|
||||
name='TorchAdafactor', module=_optim)
|
||||
else:
|
||||
OPTIMIZERS.register_module(module=_optim)
|
||||
torch_optimizers.append(module_name)
|
||||
return torch_optimizers
|
||||
|
||||
|
|
|
@ -131,7 +131,7 @@ class DefaultOptimWrapperConstructor:
|
|||
self._validate_cfg()
|
||||
|
||||
def _validate_cfg(self) -> None:
|
||||
"""verify the correctness of the config."""
|
||||
"""Verify the correctness of the config."""
|
||||
if not isinstance(self.paramwise_cfg, dict):
|
||||
raise TypeError('paramwise_cfg should be None or a dict, '
|
||||
f'but got {type(self.paramwise_cfg)}')
|
||||
|
@ -155,7 +155,7 @@ class DefaultOptimWrapperConstructor:
|
|||
raise ValueError('base_wd should not be None')
|
||||
|
||||
def _is_in(self, param_group: dict, param_group_list: list) -> bool:
|
||||
"""check whether the `param_group` is in the`param_group_list`"""
|
||||
"""Check whether the `param_group` is in the`param_group_list`"""
|
||||
assert is_list_of(param_group_list, dict)
|
||||
param = set(param_group['params'])
|
||||
param_set = set()
|
||||
|
|
|
@ -161,8 +161,7 @@ class OptimWrapperDict(OptimWrapper):
|
|||
self.optim_wrappers[name].load_state_dict(_state_dict)
|
||||
|
||||
def items(self) -> Iterator[Tuple[str, OptimWrapper]]:
|
||||
"""A generator to get the name and corresponding
|
||||
:obj:`OptimWrapper`"""
|
||||
"""A generator to get the name and corresponding :obj:`OptimWrapper`"""
|
||||
yield from self.optim_wrappers.items()
|
||||
|
||||
def values(self) -> Iterator[OptimWrapper]:
|
||||
|
|
|
@ -223,13 +223,13 @@ class PolyLR(LRSchedulerMixin, PolyParamScheduler):
|
|||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class OneCycleLR(LRSchedulerMixin, OneCycleParamScheduler):
|
||||
r"""Sets the learning rate of each parameter group according to the
|
||||
1cycle learning rate policy. The 1cycle policy anneals the learning
|
||||
rate from an initial learning rate to some maximum learning rate and then
|
||||
from that maximum learning rate to some minimum learning rate much lower
|
||||
than the initial learning rate.
|
||||
This policy was initially described in the paper `Super-Convergence:
|
||||
Very Fast Training of Neural Networks Using Large Learning Rates`_.
|
||||
r"""Sets the learning rate of each parameter group according to the 1cycle
|
||||
learning rate policy. The 1cycle policy anneals the learning rate from an
|
||||
initial learning rate to some maximum learning rate and then from that
|
||||
maximum learning rate to some minimum learning rate much lower than the
|
||||
initial learning rate. This policy was initially described in the paper
|
||||
`Super-Convergence: Very Fast Training of Neural Networks Using Large
|
||||
Learning Rates`_.
|
||||
|
||||
The 1cycle learning rate policy changes the learning rate after every
|
||||
batch. `step` should be called after a batch has been used for training.
|
||||
|
|
|
@ -565,9 +565,9 @@ class ExponentialParamScheduler(_ParamScheduler):
|
|||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class CosineAnnealingParamScheduler(_ParamScheduler):
|
||||
r"""Set the parameter value of each parameter group using a cosine
|
||||
annealing schedule, where :math:`\eta_{max}` is set to the initial value
|
||||
and :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
|
||||
r"""Set the parameter value of each parameter group using a cosine annealing
|
||||
schedule, where :math:`\eta_{max}` is set to the initial value and
|
||||
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
|
@ -617,7 +617,7 @@ class CosineAnnealingParamScheduler(_ParamScheduler):
|
|||
|
||||
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
||||
https://arxiv.org/abs/1608.03983
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self,
|
||||
optimizer: Union[Optimizer, BaseOptimWrapper],
|
||||
|
@ -890,13 +890,13 @@ class PolyParamScheduler(_ParamScheduler):
|
|||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class OneCycleParamScheduler(_ParamScheduler):
|
||||
r"""Sets the parameters of each parameter group according to the
|
||||
1cycle learning rate policy. The 1cycle policy anneals the learning
|
||||
rate from an initial learning rate to some maximum learning rate and then
|
||||
from that maximum learning rate to some minimum learning rate much lower
|
||||
than the initial learning rate.
|
||||
This policy was initially described in the paper `Super-Convergence:
|
||||
Very Fast Training of Neural Networks Using Large Learning Rates`_.
|
||||
r"""Sets the parameters of each parameter group according to the 1cycle
|
||||
learning rate policy. The 1cycle policy anneals the learning rate from an
|
||||
initial learning rate to some maximum learning rate and then from that
|
||||
maximum learning rate to some minimum learning rate much lower than the
|
||||
initial learning rate. This policy was initially described in the paper
|
||||
`Super-Convergence: Very Fast Training of Neural Networks Using Large
|
||||
Learning Rates`_.
|
||||
|
||||
The 1cycle learning rate policy changes the learning rate after every
|
||||
batch. `step` should be called after a batch has been used for training.
|
||||
|
|
|
@ -4,7 +4,7 @@ import logging
|
|||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from mmengine.config import Config, ConfigDict
|
||||
from mmengine.utils import ManagerMixin
|
||||
from mmengine.utils import ManagerMixin, digit_version
|
||||
from .registry import Registry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -232,6 +232,21 @@ def build_model_from_cfg(
|
|||
return build_from_cfg(cfg, registry, default_args)
|
||||
|
||||
|
||||
def build_optimizer_from_cfg(
|
||||
cfg: Union[dict, ConfigDict, Config],
|
||||
registry: Registry,
|
||||
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any:
|
||||
import torch
|
||||
|
||||
from ..logging import print_log
|
||||
if 'type' in cfg \
|
||||
and 'Adafactor' == cfg['type'] \
|
||||
and digit_version(torch.__version__) >= digit_version('2.5.0'):
|
||||
print_log(
|
||||
'the torch version of Adafactor is registered as TorchAdafactor')
|
||||
return build_from_cfg(cfg, registry, default_args)
|
||||
|
||||
|
||||
def build_scheduler_from_cfg(
|
||||
cfg: Union[dict, ConfigDict, Config],
|
||||
registry: Registry,
|
||||
|
|
|
@ -81,7 +81,7 @@ class DefaultScope(ManagerMixin):
|
|||
@classmethod
|
||||
@contextmanager
|
||||
def overwrite_default_scope(cls, scope_name: Optional[str]) -> Generator:
|
||||
"""overwrite the current default scope with `scope_name`"""
|
||||
"""Overwrite the current default scope with `scope_name`"""
|
||||
if scope_name is None:
|
||||
yield
|
||||
else:
|
||||
|
|
|
@ -332,7 +332,7 @@ class Registry:
|
|||
return root
|
||||
|
||||
def import_from_location(self) -> None:
|
||||
"""import modules from the pre-defined locations in self._location."""
|
||||
"""Import modules from the pre-defined locations in self._location."""
|
||||
if not self._imported:
|
||||
# Avoid circular import
|
||||
from ..logging import print_log
|
||||
|
|
|
@ -6,8 +6,8 @@ More datails can be found at
|
|||
https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html.
|
||||
"""
|
||||
|
||||
from .build_functions import (build_model_from_cfg, build_runner_from_cfg,
|
||||
build_scheduler_from_cfg)
|
||||
from .build_functions import (build_model_from_cfg, build_optimizer_from_cfg,
|
||||
build_runner_from_cfg, build_scheduler_from_cfg)
|
||||
from .registry import Registry
|
||||
|
||||
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
|
||||
|
@ -35,7 +35,7 @@ MODEL_WRAPPERS = Registry('model_wrapper')
|
|||
WEIGHT_INITIALIZERS = Registry('weight initializer')
|
||||
|
||||
# mangage all kinds of optimizers like `SGD` and `Adam`
|
||||
OPTIMIZERS = Registry('optimizer')
|
||||
OPTIMIZERS = Registry('optimizer', build_func=build_optimizer_from_cfg)
|
||||
# manage optimizer wrapper
|
||||
OPTIM_WRAPPERS = Registry('optim_wrapper')
|
||||
# manage constructors that customize the optimization hyperparameters.
|
||||
|
|
|
@ -109,7 +109,7 @@ def init_default_scope(scope: str) -> None:
|
|||
if current_scope.scope_name != scope: # type: ignore
|
||||
print_log(
|
||||
'The current default scope ' # type: ignore
|
||||
f'"{current_scope.scope_name}" is not "{scope}", '
|
||||
f'"{current_scope.scope_name}" is not "{scope}", ' # type: ignore
|
||||
'`init_default_scope` will force set the current'
|
||||
f'default scope to "{scope}".',
|
||||
logger='current',
|
||||
|
|
|
@ -540,7 +540,7 @@ class FlexibleRunner:
|
|||
|
||||
@property
|
||||
def hooks(self):
|
||||
"""list[:obj:`Hook`]: A list of registered hooks."""
|
||||
"""List[:obj:`Hook`]: A list of registered hooks."""
|
||||
return self._hooks
|
||||
|
||||
@property
|
||||
|
@ -1117,7 +1117,7 @@ class FlexibleRunner:
|
|||
return '\n'.join(stage_hook_infos)
|
||||
|
||||
def load_or_resume(self):
|
||||
"""load or resume checkpoint."""
|
||||
"""Load or resume checkpoint."""
|
||||
if self._has_loaded:
|
||||
return None
|
||||
|
||||
|
@ -1539,7 +1539,7 @@ class FlexibleRunner:
|
|||
file_client_args: Optional[dict] = None,
|
||||
save_optimizer: bool = True,
|
||||
save_param_scheduler: bool = True,
|
||||
meta: dict = None,
|
||||
meta: Optional[dict] = None,
|
||||
by_epoch: bool = True,
|
||||
backend_args: Optional[dict] = None,
|
||||
):
|
||||
|
|
|
@ -309,7 +309,7 @@ class CheckpointLoader:
|
|||
|
||||
@classmethod
|
||||
def load_checkpoint(cls, filename, map_location=None, logger='current'):
|
||||
"""load checkpoint through URL scheme path.
|
||||
"""Load checkpoint through URL scheme path.
|
||||
|
||||
Args:
|
||||
filename (str): checkpoint file name with given prefix
|
||||
|
@ -332,7 +332,7 @@ class CheckpointLoader:
|
|||
|
||||
@CheckpointLoader.register_scheme(prefixes='')
|
||||
def load_from_local(filename, map_location):
|
||||
"""load checkpoint by local file path.
|
||||
"""Load checkpoint by local file path.
|
||||
|
||||
Args:
|
||||
filename (str): local checkpoint file path
|
||||
|
@ -353,7 +353,7 @@ def load_from_http(filename,
|
|||
map_location=None,
|
||||
model_dir=None,
|
||||
progress=os.isatty(0)):
|
||||
"""load checkpoint through HTTP or HTTPS scheme path. In distributed
|
||||
"""Load checkpoint through HTTP or HTTPS scheme path. In distributed
|
||||
setting, this function only download checkpoint at local rank 0.
|
||||
|
||||
Args:
|
||||
|
@ -386,7 +386,7 @@ def load_from_http(filename,
|
|||
|
||||
@CheckpointLoader.register_scheme(prefixes='pavi://')
|
||||
def load_from_pavi(filename, map_location=None):
|
||||
"""load checkpoint through the file path prefixed with pavi. In distributed
|
||||
"""Load checkpoint through the file path prefixed with pavi. In distributed
|
||||
setting, this function download ckpt at all ranks to different temporary
|
||||
directories.
|
||||
|
||||
|
@ -419,7 +419,7 @@ def load_from_pavi(filename, map_location=None):
|
|||
@CheckpointLoader.register_scheme(
|
||||
prefixes=[r'(\S+\:)?s3://', r'(\S+\:)?petrel://'])
|
||||
def load_from_ceph(filename, map_location=None, backend='petrel'):
|
||||
"""load checkpoint through the file path prefixed with s3. In distributed
|
||||
"""Load checkpoint through the file path prefixed with s3. In distributed
|
||||
setting, this function download ckpt at all ranks to different temporary
|
||||
directories.
|
||||
|
||||
|
@ -441,7 +441,7 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
|
|||
|
||||
@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
|
||||
def load_from_torchvision(filename, map_location=None):
|
||||
"""load checkpoint through the file path prefixed with modelzoo or
|
||||
"""Load checkpoint through the file path prefixed with modelzoo or
|
||||
torchvision.
|
||||
|
||||
Args:
|
||||
|
@ -467,7 +467,7 @@ def load_from_torchvision(filename, map_location=None):
|
|||
|
||||
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
|
||||
def load_from_openmmlab(filename, map_location=None):
|
||||
"""load checkpoint through the file path prefixed with open-mmlab or
|
||||
"""Load checkpoint through the file path prefixed with open-mmlab or
|
||||
openmmlab.
|
||||
|
||||
Args:
|
||||
|
@ -510,7 +510,7 @@ def load_from_openmmlab(filename, map_location=None):
|
|||
|
||||
@CheckpointLoader.register_scheme(prefixes='mmcls://')
|
||||
def load_from_mmcls(filename, map_location=None):
|
||||
"""load checkpoint through the file path prefixed with mmcls.
|
||||
"""Load checkpoint through the file path prefixed with mmcls.
|
||||
|
||||
Args:
|
||||
filename (str): checkpoint file path with mmcls prefix
|
||||
|
|
|
@ -8,8 +8,10 @@ import torch
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.logging import HistoryBuffer, print_log
|
||||
from mmengine.registry import LOOPS
|
||||
from mmengine.structures import BaseDataElement
|
||||
from mmengine.utils import is_list_of
|
||||
from .amp import autocast
|
||||
from .base_loop import BaseLoop
|
||||
from .utils import calc_dynamic_intervals
|
||||
|
@ -363,17 +365,26 @@ class ValLoop(BaseLoop):
|
|||
logger='current',
|
||||
level=logging.WARNING)
|
||||
self.fp16 = fp16
|
||||
self.val_loss: Dict[str, HistoryBuffer] = dict()
|
||||
|
||||
def run(self) -> dict:
|
||||
"""Launch validation."""
|
||||
self.runner.call_hook('before_val')
|
||||
self.runner.call_hook('before_val_epoch')
|
||||
self.runner.model.eval()
|
||||
|
||||
# clear val loss
|
||||
self.val_loss.clear()
|
||||
for idx, data_batch in enumerate(self.dataloader):
|
||||
self.run_iter(idx, data_batch)
|
||||
|
||||
# compute metrics
|
||||
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||
|
||||
if self.val_loss:
|
||||
loss_dict = _parse_losses(self.val_loss, 'val')
|
||||
metrics.update(loss_dict)
|
||||
|
||||
self.runner.call_hook('after_val_epoch', metrics=metrics)
|
||||
self.runner.call_hook('after_val')
|
||||
return metrics
|
||||
|
@ -391,6 +402,9 @@ class ValLoop(BaseLoop):
|
|||
# outputs should be sequence of BaseDataElement
|
||||
with autocast(enabled=self.fp16):
|
||||
outputs = self.runner.model.val_step(data_batch)
|
||||
|
||||
outputs, self.val_loss = _update_losses(outputs, self.val_loss)
|
||||
|
||||
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
|
||||
self.runner.call_hook(
|
||||
'after_val_iter',
|
||||
|
@ -435,17 +449,26 @@ class TestLoop(BaseLoop):
|
|||
logger='current',
|
||||
level=logging.WARNING)
|
||||
self.fp16 = fp16
|
||||
self.test_loss: Dict[str, HistoryBuffer] = dict()
|
||||
|
||||
def run(self) -> dict:
|
||||
"""Launch test."""
|
||||
self.runner.call_hook('before_test')
|
||||
self.runner.call_hook('before_test_epoch')
|
||||
self.runner.model.eval()
|
||||
|
||||
# clear test loss
|
||||
self.test_loss.clear()
|
||||
for idx, data_batch in enumerate(self.dataloader):
|
||||
self.run_iter(idx, data_batch)
|
||||
|
||||
# compute metrics
|
||||
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||
|
||||
if self.test_loss:
|
||||
loss_dict = _parse_losses(self.test_loss, 'test')
|
||||
metrics.update(loss_dict)
|
||||
|
||||
self.runner.call_hook('after_test_epoch', metrics=metrics)
|
||||
self.runner.call_hook('after_test')
|
||||
return metrics
|
||||
|
@ -462,9 +485,66 @@ class TestLoop(BaseLoop):
|
|||
# predictions should be sequence of BaseDataElement
|
||||
with autocast(enabled=self.fp16):
|
||||
outputs = self.runner.model.test_step(data_batch)
|
||||
|
||||
outputs, self.test_loss = _update_losses(outputs, self.test_loss)
|
||||
|
||||
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
|
||||
self.runner.call_hook(
|
||||
'after_test_iter',
|
||||
batch_idx=idx,
|
||||
data_batch=data_batch,
|
||||
outputs=outputs)
|
||||
|
||||
|
||||
def _parse_losses(losses: Dict[str, HistoryBuffer],
|
||||
stage: str) -> Dict[str, float]:
|
||||
"""Parses the raw losses of the network.
|
||||
|
||||
Args:
|
||||
losses (dict): raw losses of the network.
|
||||
stage (str): The stage of loss, e.g., 'val' or 'test'.
|
||||
|
||||
Returns:
|
||||
dict[str, float]: The key is the loss name, and the value is the
|
||||
average loss.
|
||||
"""
|
||||
all_loss = 0
|
||||
loss_dict: Dict[str, float] = dict()
|
||||
|
||||
for loss_name, loss_value in losses.items():
|
||||
avg_loss = loss_value.mean()
|
||||
loss_dict[loss_name] = avg_loss
|
||||
if 'loss' in loss_name:
|
||||
all_loss += avg_loss
|
||||
|
||||
loss_dict[f'{stage}_loss'] = all_loss
|
||||
return loss_dict
|
||||
|
||||
|
||||
def _update_losses(outputs: list, losses: dict) -> Tuple[list, dict]:
|
||||
"""Update and record the losses of the network.
|
||||
|
||||
Args:
|
||||
outputs (list): The outputs of the network.
|
||||
losses (dict): The losses of the network.
|
||||
|
||||
Returns:
|
||||
list: The updated outputs of the network.
|
||||
dict: The updated losses of the network.
|
||||
"""
|
||||
if isinstance(outputs[-1],
|
||||
BaseDataElement) and outputs[-1].keys() == ['loss']:
|
||||
loss = outputs[-1].loss # type: ignore
|
||||
outputs = outputs[:-1]
|
||||
else:
|
||||
loss = dict()
|
||||
|
||||
for loss_name, loss_value in loss.items():
|
||||
if loss_name not in losses:
|
||||
losses[loss_name] = HistoryBuffer()
|
||||
if isinstance(loss_value, torch.Tensor):
|
||||
losses[loss_name].update(loss_value.item())
|
||||
elif is_list_of(loss_value, torch.Tensor):
|
||||
for loss_value_i in loss_value:
|
||||
losses[loss_name].update(loss_value_i.item())
|
||||
return outputs, losses
|
||||
|
|
|
@ -579,7 +579,7 @@ class Runner:
|
|||
|
||||
@property
|
||||
def hooks(self):
|
||||
"""list[:obj:`Hook`]: A list of registered hooks."""
|
||||
"""List[:obj:`Hook`]: A list of registered hooks."""
|
||||
return self._hooks
|
||||
|
||||
@property
|
||||
|
@ -720,7 +720,7 @@ class Runner:
|
|||
|
||||
def build_logger(self,
|
||||
log_level: Union[int, str] = 'INFO',
|
||||
log_file: str = None,
|
||||
log_file: Optional[str] = None,
|
||||
**kwargs) -> MMLogger:
|
||||
"""Build a global asscessable MMLogger.
|
||||
|
||||
|
@ -1677,7 +1677,7 @@ class Runner:
|
|||
return '\n'.join(stage_hook_infos)
|
||||
|
||||
def load_or_resume(self) -> None:
|
||||
"""load or resume checkpoint."""
|
||||
"""Load or resume checkpoint."""
|
||||
if self._has_loaded:
|
||||
return None
|
||||
|
||||
|
|
|
@ -387,7 +387,7 @@ class BaseDataElement:
|
|||
return dict(self.metainfo_items())
|
||||
|
||||
def __setattr__(self, name: str, value: Any):
|
||||
"""setattr is only used to set data."""
|
||||
"""Setattr is only used to set data."""
|
||||
if name in ('_metainfo_fields', '_data_fields'):
|
||||
if not hasattr(self, name):
|
||||
super().__setattr__(name, value)
|
||||
|
|
|
@ -135,7 +135,7 @@ class InstanceData(BaseDataElement):
|
|||
"""
|
||||
|
||||
def __setattr__(self, name: str, value: Sized):
|
||||
"""setattr is only used to set data.
|
||||
"""Setattr is only used to set data.
|
||||
|
||||
The value must have the attribute of `__len__` and have the same length
|
||||
of `InstanceData`.
|
||||
|
|
|
@ -92,10 +92,25 @@ class MultiProcessTestCase(TestCase):
|
|||
# Constructor patches current instance test method to
|
||||
# assume the role of the main process and join its subprocesses,
|
||||
# or run the underlying test function.
|
||||
def __init__(self, method_name: str = 'runTest') -> None:
|
||||
def __init__(self,
|
||||
method_name: str = 'runTest',
|
||||
methodName: str = 'runTest') -> None:
|
||||
# methodName is the correct naming in unittest
|
||||
# and testslide uses keyword arguments.
|
||||
# So we need to use both to 1) not break BC and, 2) support testslide.
|
||||
if methodName != 'runTest':
|
||||
method_name = methodName
|
||||
super().__init__(method_name)
|
||||
fn = getattr(self, method_name)
|
||||
setattr(self, method_name, self.join_or_run(fn))
|
||||
try:
|
||||
fn = getattr(self, method_name)
|
||||
setattr(self, method_name, self.join_or_run(fn))
|
||||
except AttributeError as e:
|
||||
if methodName != 'runTest':
|
||||
# we allow instantiation with no explicit method name
|
||||
# but not an *incorrect* or missing method name
|
||||
raise ValueError(
|
||||
f'no such test method in {self.__class__}: {methodName}'
|
||||
) from e
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
|
|
|
@ -57,6 +57,7 @@ if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version(
|
|||
check_hash=False,
|
||||
file_name=None):
|
||||
r"""Loads the Torch serialized object at the given URL.
|
||||
|
||||
If downloaded file is a zip file, it will be automatically decompressed
|
||||
If the object is already present in `model_dir`, it's deserialized and
|
||||
returned.
|
||||
|
|
|
@ -67,7 +67,7 @@ class TimeCounter:
|
|||
|
||||
instance.log_interval = log_interval
|
||||
instance.warmup_interval = warmup_interval
|
||||
instance.with_sync = with_sync
|
||||
instance.with_sync = with_sync # type: ignore
|
||||
instance.tag = tag
|
||||
instance.logger = logger
|
||||
|
||||
|
@ -127,7 +127,7 @@ class TimeCounter:
|
|||
self.print_time(elapsed)
|
||||
|
||||
def print_time(self, elapsed: Union[int, float]) -> None:
|
||||
"""print times per count."""
|
||||
"""Print times per count."""
|
||||
if self.__count >= self.warmup_interval:
|
||||
self.__pure_inf_time += elapsed
|
||||
|
||||
|
|
|
@ -131,7 +131,7 @@ def tuple_cast(inputs, dst_type):
|
|||
|
||||
def is_seq_of(seq: Any,
|
||||
expected_type: Union[Type, tuple],
|
||||
seq_type: Type = None) -> bool:
|
||||
seq_type: Optional[Type] = None) -> bool:
|
||||
"""Check whether it is a sequence of some type.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -69,11 +69,11 @@ def get_installed_path(package: str) -> str:
|
|||
else:
|
||||
raise e
|
||||
|
||||
possible_path = osp.join(pkg.location, package)
|
||||
possible_path = osp.join(pkg.location, package) # type: ignore
|
||||
if osp.exists(possible_path):
|
||||
return possible_path
|
||||
else:
|
||||
return osp.join(pkg.location, package2module(package))
|
||||
return osp.join(pkg.location, package2module(package)) # type: ignore
|
||||
|
||||
|
||||
def package2module(package: str):
|
||||
|
|
|
@ -3,7 +3,7 @@ import sys
|
|||
from collections.abc import Iterable
|
||||
from multiprocessing import Pool
|
||||
from shutil import get_terminal_size
|
||||
from typing import Callable, Sequence
|
||||
from typing import Callable, Optional, Sequence
|
||||
|
||||
from .timer import Timer
|
||||
|
||||
|
@ -54,7 +54,7 @@ class ProgressBar:
|
|||
self.timer = Timer()
|
||||
|
||||
def update(self, num_tasks: int = 1):
|
||||
"""update progressbar.
|
||||
"""Update progressbar.
|
||||
|
||||
Args:
|
||||
num_tasks (int): Update step size.
|
||||
|
@ -142,8 +142,8 @@ def init_pool(process_num, initializer=None, initargs=None):
|
|||
def track_parallel_progress(func: Callable,
|
||||
tasks: Sequence,
|
||||
nproc: int,
|
||||
initializer: Callable = None,
|
||||
initargs: tuple = None,
|
||||
initializer: Optional[Callable] = None,
|
||||
initargs: Optional[tuple] = None,
|
||||
bar_width: int = 50,
|
||||
chunksize: int = 1,
|
||||
skip_first: bool = False,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from multiprocessing import Pool
|
||||
from typing import Callable, Iterable, Sized
|
||||
from typing import Callable, Iterable, Optional, Sized
|
||||
|
||||
from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task,
|
||||
TaskProgressColumn, TextColumn, TimeRemainingColumn)
|
||||
|
@ -47,7 +47,7 @@ def _tasks_with_index(tasks):
|
|||
|
||||
def track_progress_rich(func: Callable,
|
||||
tasks: Iterable = tuple(),
|
||||
task_num: int = None,
|
||||
task_num: Optional[int] = None,
|
||||
nproc: int = 1,
|
||||
chunksize: int = 1,
|
||||
description: str = 'Processing',
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
__version__ = '0.10.4'
|
||||
__version__ = '0.10.7'
|
||||
|
||||
|
||||
def parse_version_info(version_str):
|
||||
|
|
|
@ -161,7 +161,7 @@ class BaseVisBackend(metaclass=ABCMeta):
|
|||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""close an opened object."""
|
||||
"""Close an opened object."""
|
||||
pass
|
||||
|
||||
|
||||
|
@ -314,7 +314,7 @@ class LocalVisBackend(BaseVisBackend):
|
|||
|
||||
def _dump(self, value_dict: dict, file_path: str,
|
||||
file_format: str) -> None:
|
||||
"""dump dict to file.
|
||||
"""Dump dict to file.
|
||||
|
||||
Args:
|
||||
value_dict (dict) : The dict data to saved.
|
||||
|
@ -505,7 +505,7 @@ class WandbVisBackend(BaseVisBackend):
|
|||
self._wandb.log(scalar_dict, commit=self._commit)
|
||||
|
||||
def close(self) -> None:
|
||||
"""close an opened wandb object."""
|
||||
"""Close an opened wandb object."""
|
||||
if hasattr(self, '_wandb'):
|
||||
self._wandb.join()
|
||||
|
||||
|
@ -629,7 +629,7 @@ class TensorboardVisBackend(BaseVisBackend):
|
|||
self.add_scalar(key, value, step)
|
||||
|
||||
def close(self):
|
||||
"""close an opened tensorboard object."""
|
||||
"""Close an opened tensorboard object."""
|
||||
if hasattr(self, '_tensorboard'):
|
||||
self._tensorboard.close()
|
||||
|
||||
|
@ -1135,7 +1135,7 @@ class NeptuneVisBackend(BaseVisBackend):
|
|||
self._neptune[k].append(v, step=step)
|
||||
|
||||
def close(self) -> None:
|
||||
"""close an opened object."""
|
||||
"""Close an opened object."""
|
||||
if hasattr(self, '_neptune'):
|
||||
self._neptune.stop()
|
||||
|
||||
|
@ -1282,7 +1282,7 @@ class DVCLiveVisBackend(BaseVisBackend):
|
|||
self.add_scalar(key, value, step, **kwargs)
|
||||
|
||||
def close(self) -> None:
|
||||
"""close an opened dvclive object."""
|
||||
"""Close an opened dvclive object."""
|
||||
if not hasattr(self, '_dvclive'):
|
||||
return
|
||||
|
||||
|
|
|
@ -356,7 +356,7 @@ class Visualizer(ManagerMixin):
|
|||
|
||||
@master_only
|
||||
def get_backend(self, name) -> 'BaseVisBackend':
|
||||
"""get vis backend by name.
|
||||
"""Get vis backend by name.
|
||||
|
||||
Args:
|
||||
name (str): The name of vis backend
|
||||
|
@ -1145,7 +1145,7 @@ class Visualizer(ManagerMixin):
|
|||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""close an opened object."""
|
||||
"""Close an opened object."""
|
||||
for vis_backend in self._vis_backends.values():
|
||||
vis_backend.close()
|
||||
|
||||
|
|
|
@ -843,8 +843,8 @@ class TestConfig:
|
|||
assert cfg_dict['item4'] == 'test'
|
||||
assert '_delete_' not in cfg_dict['item1']
|
||||
|
||||
assert type(cfg_dict['item1']) == ConfigDict
|
||||
assert type(cfg_dict['item2']) == ConfigDict
|
||||
assert type(cfg_dict['item1']) is ConfigDict
|
||||
assert type(cfg_dict['item2']) is ConfigDict
|
||||
|
||||
def _merge_intermediate_variable(self):
|
||||
|
||||
|
|
|
@ -300,8 +300,8 @@ except ImportError:
|
|||
get_inputs.append(filepath)
|
||||
|
||||
with build_temporary_directory() as tmp_dir, \
|
||||
patch.object(backend, 'put', side_effect=put),\
|
||||
patch.object(backend, 'get', side_effect=get),\
|
||||
patch.object(backend, 'put', side_effect=put), \
|
||||
patch.object(backend, 'get', side_effect=get), \
|
||||
patch.object(backend, 'exists', return_value=False):
|
||||
tmp_dir = tmp_dir.replace('\\', '/')
|
||||
dst = f'{tmp_dir}/dir'
|
||||
|
@ -351,7 +351,7 @@ except ImportError:
|
|||
|
||||
with build_temporary_directory() as tmp_dir, \
|
||||
patch.object(backend, 'copyfile_from_local',
|
||||
side_effect=copyfile_from_local),\
|
||||
side_effect=copyfile_from_local), \
|
||||
patch.object(backend, 'exists', return_value=False):
|
||||
backend.copytree_from_local(tmp_dir, self.petrel_dir)
|
||||
|
||||
|
@ -427,7 +427,7 @@ except ImportError:
|
|||
def remove(filepath):
|
||||
inputs.append(filepath)
|
||||
|
||||
with build_temporary_directory() as tmp_dir,\
|
||||
with build_temporary_directory() as tmp_dir, \
|
||||
patch.object(backend, 'remove', side_effect=remove):
|
||||
backend.rmtree(tmp_dir)
|
||||
|
||||
|
|
|
@ -13,7 +13,8 @@ from mmengine.testing import assert_allclose
|
|||
class TestAveragedModel(TestCase):
|
||||
"""Test the AveragedModel class.
|
||||
|
||||
Some test cases are referenced from https://github.com/pytorch/pytorch/blob/master/test/test_optim.py
|
||||
Some test cases are referenced from
|
||||
https://github.com/pytorch/pytorch/blob/master/test/test_optim.py
|
||||
""" # noqa: E501
|
||||
|
||||
def _test_swa_model(self, net_device, avg_device):
|
||||
|
|
|
@ -102,7 +102,7 @@ class TestLRScheduler(TestCase):
|
|||
rtol=0)
|
||||
|
||||
def test_scheduler_before_optim_warning(self):
|
||||
"""warns if scheduler is used before optimizer."""
|
||||
"""Warns if scheduler is used before optimizer."""
|
||||
|
||||
def call_sch_before_optim():
|
||||
scheduler = StepLR(self.optimizer, gamma=0.1, step_size=3)
|
||||
|
|
|
@ -120,7 +120,7 @@ class TestMomentumScheduler(TestCase):
|
|||
rtol=0)
|
||||
|
||||
def test_scheduler_before_optim_warning(self):
|
||||
"""warns if scheduler is used before optimizer."""
|
||||
"""Warns if scheduler is used before optimizer."""
|
||||
|
||||
def call_sch_before_optim():
|
||||
scheduler = StepMomentum(self.optimizer, gamma=0.1, step_size=3)
|
||||
|
|
|
@ -127,7 +127,7 @@ class TestParameterScheduler(TestCase):
|
|||
rtol=0)
|
||||
|
||||
def test_scheduler_before_optim_warning(self):
|
||||
"""warns if scheduler is used before optimizer."""
|
||||
"""Warns if scheduler is used before optimizer."""
|
||||
|
||||
def call_sch_before_optim():
|
||||
scheduler = StepParamScheduler(
|
||||
|
|
|
@ -251,7 +251,7 @@ def test_load_checkpoint_metadata():
|
|||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
*args, **kwargs):
|
||||
"""load checkpoints."""
|
||||
"""Load checkpoints."""
|
||||
|
||||
# Names of some parameters in has been changed.
|
||||
version = local_metadata.get('version', None)
|
||||
|
|
|
@ -2226,7 +2226,7 @@ class TestRunner(TestCase):
|
|||
|
||||
@HOOKS.register_module(force=True)
|
||||
class TestWarmupHook(Hook):
|
||||
"""test custom train loop."""
|
||||
"""Test custom train loop."""
|
||||
|
||||
def before_warmup_iter(self, runner, data_batch=None):
|
||||
before_warmup_iter_results.append('before')
|
||||
|
|
|
@ -64,7 +64,7 @@ class TestBaseDataElement(TestCase):
|
|||
return metainfo, data
|
||||
|
||||
def is_equal(self, x, y):
|
||||
assert type(x) == type(y)
|
||||
assert type(x) is type(y)
|
||||
if isinstance(
|
||||
x, (int, float, str, list, tuple, dict, set, BaseDataElement)):
|
||||
return x == y
|
||||
|
@ -141,7 +141,7 @@ class TestBaseDataElement(TestCase):
|
|||
|
||||
# test new() with no arguments
|
||||
new_instances = instances.new()
|
||||
assert type(new_instances) == type(instances)
|
||||
assert type(new_instances) is type(instances)
|
||||
# After deepcopy, the address of new data'element will be same as
|
||||
# origin, but when change new data' element will not effect the origin
|
||||
# element and will have new address
|
||||
|
@ -154,7 +154,7 @@ class TestBaseDataElement(TestCase):
|
|||
# test new() with arguments
|
||||
metainfo, data = self.setup_data()
|
||||
new_instances = instances.new(metainfo=metainfo, **data)
|
||||
assert type(new_instances) == type(instances)
|
||||
assert type(new_instances) is type(instances)
|
||||
assert id(new_instances.gt_instances) != id(instances.gt_instances)
|
||||
_, new_data = self.setup_data()
|
||||
new_instances.set_data(new_data)
|
||||
|
@ -168,7 +168,7 @@ class TestBaseDataElement(TestCase):
|
|||
metainfo, data = self.setup_data()
|
||||
instances = BaseDataElement(metainfo=metainfo, **data)
|
||||
new_instances = instances.clone()
|
||||
assert type(new_instances) == type(instances)
|
||||
assert type(new_instances) is type(instances)
|
||||
|
||||
def test_set_metainfo(self):
|
||||
metainfo, _ = self.setup_data()
|
||||
|
|
|
@ -45,7 +45,7 @@ class MockVisBackend:
|
|||
self._add_scalars = True
|
||||
|
||||
def close(self) -> None:
|
||||
"""close an opened object."""
|
||||
"""Close an opened object."""
|
||||
self._close = True
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue