mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
Compare commits
17 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
390ba2fbb2 | ||
|
d620552c2c | ||
|
41fa84a9a9 | ||
|
698782f920 | ||
|
e60ab1dde3 | ||
|
8ec837814e | ||
|
a4475f5eea | ||
|
a8c74c346d | ||
|
9124ebf7a2 | ||
|
2e0ab7a922 | ||
|
fc59364d64 | ||
|
4183cf0829 | ||
|
cc3b74b5e8 | ||
|
c9b59962d6 | ||
|
5e736b143b | ||
|
85c83ba616 | ||
|
d1f1aabf81 |
26
.github/workflows/deploy.yml
vendored
26
.github/workflows/deploy.yml
vendored
@ -1,6 +1,8 @@
|
|||||||
name: deploy
|
name: deploy
|
||||||
|
|
||||||
on: push
|
on:
|
||||||
|
- push
|
||||||
|
- workflow_dispatch
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
@ -9,13 +11,14 @@ concurrency:
|
|||||||
jobs:
|
jobs:
|
||||||
build-n-publish:
|
build-n-publish:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: startsWith(github.event.ref, 'refs/tags')
|
if: |
|
||||||
|
startsWith(github.event.ref, 'refs/tags') || github.event_name == 'workflow_dispatch'
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v4
|
||||||
- name: Set up Python 3.7
|
- name: Set up Python 3.10.13
|
||||||
uses: actions/setup-python@v2
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: 3.7
|
python-version: 3.10.13
|
||||||
- name: Install wheel
|
- name: Install wheel
|
||||||
run: pip install wheel
|
run: pip install wheel
|
||||||
- name: Build MMEngine
|
- name: Build MMEngine
|
||||||
@ -27,13 +30,14 @@ jobs:
|
|||||||
|
|
||||||
build-n-publish-lite:
|
build-n-publish-lite:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: startsWith(github.event.ref, 'refs/tags')
|
if: |
|
||||||
|
startsWith(github.event.ref, 'refs/tags') || github.event_name == 'workflow_dispatch'
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v4
|
||||||
- name: Set up Python 3.7
|
- name: Set up Python 3.10.13
|
||||||
uses: actions/setup-python@v2
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: 3.7
|
python-version: 3.10.13
|
||||||
- name: Install wheel
|
- name: Install wheel
|
||||||
run: pip install wheel
|
run: pip install wheel
|
||||||
- name: Build MMEngine-lite
|
- name: Build MMEngine-lite
|
||||||
|
4
.github/workflows/lint.yml
vendored
4
.github/workflows/lint.yml
vendored
@ -11,10 +11,10 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
- name: Set up Python 3.7
|
- name: Set up Python 3.10.15
|
||||||
uses: actions/setup-python@v2
|
uses: actions/setup-python@v2
|
||||||
with:
|
with:
|
||||||
python-version: 3.7
|
python-version: '3.10.15'
|
||||||
- name: Install pre-commit hook
|
- name: Install pre-commit hook
|
||||||
run: |
|
run: |
|
||||||
pip install pre-commit
|
pip install pre-commit
|
||||||
|
224
.github/workflows/pr_stage_test.yml
vendored
224
.github/workflows/pr_stage_test.yml
vendored
@ -1,5 +1,9 @@
|
|||||||
name: pr_stage_test
|
name: pr_stage_test
|
||||||
|
|
||||||
|
env:
|
||||||
|
ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true
|
||||||
|
|
||||||
|
|
||||||
on:
|
on:
|
||||||
pull_request:
|
pull_request:
|
||||||
paths-ignore:
|
paths-ignore:
|
||||||
@ -21,152 +25,114 @@ concurrency:
|
|||||||
jobs:
|
jobs:
|
||||||
build_cpu:
|
build_cpu:
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
shell: bash -l {0}
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: [3.7]
|
python-version: ['3.9']
|
||||||
include:
|
torch: ['2.0.0']
|
||||||
- torch: 1.8.1
|
|
||||||
torchvision: 0.9.1
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- name: Check out repo
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
uses: actions/checkout@v3
|
||||||
uses: actions/setup-python@v4
|
- name: Setup conda env
|
||||||
|
uses: conda-incubator/setup-miniconda@v2
|
||||||
with:
|
with:
|
||||||
|
auto-update-conda: true
|
||||||
|
miniconda-version: "latest"
|
||||||
|
use-only-tar-bz2: true
|
||||||
|
activate-environment: test
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
- name: Upgrade pip
|
- name: Update pip
|
||||||
run: python -m pip install pip --upgrade
|
|
||||||
- name: Upgrade wheel
|
|
||||||
run: python -m pip install wheel --upgrade
|
|
||||||
- name: Install PyTorch
|
|
||||||
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
|
||||||
- name: Build MMEngine from source
|
|
||||||
run: pip install -e . -v
|
|
||||||
- name: Install unit tests dependencies
|
|
||||||
run: |
|
run: |
|
||||||
pip install -r requirements/tests.txt
|
python -m pip install --upgrade pip wheel
|
||||||
pip install openmim
|
- name: Install dependencies
|
||||||
mim install mmcv
|
|
||||||
- name: Run unittests and generate coverage report
|
|
||||||
run: |
|
run: |
|
||||||
coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist
|
apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6
|
||||||
coverage xml
|
python -m pip install torch==${{matrix.torch}}
|
||||||
coverage report -m
|
python -m pip install -e . -v
|
||||||
# Upload coverage report for python3.7 && pytorch1.8.1 cpu
|
python -m pip install -r requirements/tests.txt
|
||||||
- name: Upload coverage to Codecov
|
python -m pip install openmim
|
||||||
|
mim install mmcv coverage
|
||||||
|
- name: Run unit tests with coverage
|
||||||
|
run: coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist
|
||||||
|
- name: Upload Coverage to Codecov
|
||||||
uses: codecov/codecov-action@v3
|
uses: codecov/codecov-action@v3
|
||||||
with:
|
|
||||||
file: ./coverage.xml
|
|
||||||
flags: unittests
|
|
||||||
env_vars: OS,PYTHON
|
|
||||||
name: codecov-umbrella
|
|
||||||
fail_ci_if_error: false
|
|
||||||
|
|
||||||
build_cu102:
|
build_gpu:
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
container:
|
defaults:
|
||||||
image: pytorch/pytorch:1.8.1-cuda10.2-cudnn7-devel
|
run:
|
||||||
|
shell: bash -l {0}
|
||||||
env:
|
env:
|
||||||
MKL_THREADING_LAYER: GNU
|
MKL_THREADING_LAYER: GNU
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: [3.7]
|
python-version: ['3.9','3.10']
|
||||||
|
torch: ['2.0.0','2.3.1','2.5.1']
|
||||||
|
cuda: ['cu118']
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- name: Check out repo
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
uses: actions/checkout@v3
|
||||||
uses: actions/setup-python@v4
|
- name: Setup conda env
|
||||||
|
uses: conda-incubator/setup-miniconda@v2
|
||||||
with:
|
with:
|
||||||
|
auto-update-conda: true
|
||||||
|
miniconda-version: "latest"
|
||||||
|
use-only-tar-bz2: true
|
||||||
|
activate-environment: test
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
- name: Upgrade pip
|
- name: Update pip
|
||||||
run: pip install pip --upgrade
|
|
||||||
- name: Fetch GPG keys
|
|
||||||
run: |
|
run: |
|
||||||
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
|
python -m pip install --upgrade pip wheel
|
||||||
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
|
- name: Install dependencies
|
||||||
- name: Install system dependencies
|
|
||||||
run: apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6
|
|
||||||
- name: Build MMEngine from source
|
|
||||||
run: pip install -e . -v
|
|
||||||
- name: Install unit tests dependencies
|
|
||||||
run: |
|
run: |
|
||||||
pip install -r requirements/tests.txt
|
apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6
|
||||||
pip install openmim
|
python -m pip install torch==${{matrix.torch}} --index-url https://download.pytorch.org/whl/${{matrix.cuda}}
|
||||||
mim install mmcv
|
python -m pip install -e . -v
|
||||||
- name: Run unittests and generate coverage report
|
python -m pip install -r requirements/tests.txt
|
||||||
run: |
|
python -m pip install openmim
|
||||||
coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist
|
mim install mmcv coverage
|
||||||
coverage xml
|
- name: Run unit tests with coverage
|
||||||
coverage report -m
|
run: coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist
|
||||||
|
- name: Upload Coverage to Codecov
|
||||||
|
uses: codecov/codecov-action@v3
|
||||||
|
|
||||||
build_cu117:
|
# build_windows:
|
||||||
runs-on: ubuntu-22.04
|
# runs-on: windows-2022
|
||||||
container:
|
# strategy:
|
||||||
image: pytorch/pytorch:2.0.0-cuda11.7-cudnn8-devel
|
# matrix:
|
||||||
strategy:
|
# python-version: [3.9]
|
||||||
matrix:
|
# platform: [cpu, cu111]
|
||||||
python-version: [3.9]
|
# torch: [1.8.1]
|
||||||
steps:
|
# torchvision: [0.9.1]
|
||||||
- uses: actions/checkout@v3
|
# include:
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
# - python-version: 3.8
|
||||||
uses: actions/setup-python@v4
|
# platform: cu118
|
||||||
with:
|
# torch: 2.1.0
|
||||||
python-version: ${{ matrix.python-version }}
|
# torchvision: 0.16.0
|
||||||
- name: Upgrade pip
|
# steps:
|
||||||
run: pip install pip --upgrade
|
# - uses: actions/checkout@v3
|
||||||
- name: Fetch GPG keys
|
# - name: Set up Python ${{ matrix.python-version }}
|
||||||
run: |
|
# uses: actions/setup-python@v4
|
||||||
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
|
# with:
|
||||||
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
|
# python-version: ${{ matrix.python-version }}
|
||||||
- name: Install system dependencies
|
# - name: Upgrade pip
|
||||||
run: apt-get update && apt-get install -y git ffmpeg libturbojpeg
|
# # Windows CI could fail If we call `pip install pip --upgrade` directly.
|
||||||
- name: Build MMEngine from source
|
# run: python -m pip install pip wheel --upgrade
|
||||||
run: pip install -e . -v
|
# - name: Install PyTorch
|
||||||
- name: Install unit tests dependencies
|
# run: pip install torch==${{matrix.torch}}+${{matrix.platform}} torchvision==${{matrix.torchvision}}+${{matrix.platform}} -f https://download.pytorch.org/whl/${{matrix.platform}}/torch_stable.html
|
||||||
run: |
|
# - name: Build MMEngine from source
|
||||||
pip install -r requirements/tests.txt
|
# run: pip install -e . -v
|
||||||
pip install openmim
|
# - name: Install unit tests dependencies
|
||||||
mim install mmcv
|
# run: |
|
||||||
# Distributed related unit test may randomly error in PyTorch 1.13.0
|
# pip install -r requirements/tests.txt
|
||||||
- name: Run unittests and generate coverage report
|
# pip install openmim
|
||||||
run: |
|
# mim install mmcv
|
||||||
coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist/
|
# - name: Run CPU unittests
|
||||||
coverage xml
|
# run: pytest tests/ --ignore tests/test_dist
|
||||||
coverage report -m
|
# if: ${{ matrix.platform == 'cpu' }}
|
||||||
|
# - name: Run GPU unittests
|
||||||
build_windows:
|
# # Skip testing distributed related unit tests since the memory of windows CI is limited
|
||||||
runs-on: windows-2022
|
# run: pytest tests/ --ignore tests/test_dist --ignore tests/test_optim/test_optimizer/test_optimizer_wrapper.py --ignore tests/test_model/test_wrappers/test_model_wrapper.py --ignore tests/test_hooks/test_sync_buffers_hook.py
|
||||||
strategy:
|
# if: ${{ matrix.platform == 'cu111' }} || ${{ matrix.platform == 'cu118' }}
|
||||||
matrix:
|
|
||||||
python-version: [3.7]
|
|
||||||
platform: [cpu, cu111]
|
|
||||||
torch: [1.8.1]
|
|
||||||
torchvision: [0.9.1]
|
|
||||||
include:
|
|
||||||
- python-version: 3.8
|
|
||||||
platform: cu118
|
|
||||||
torch: 2.1.0
|
|
||||||
torchvision: 0.16.0
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v3
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python-version }}
|
|
||||||
- name: Upgrade pip
|
|
||||||
# Windows CI could fail If we call `pip install pip --upgrade` directly.
|
|
||||||
run: python -m pip install pip --upgrade
|
|
||||||
- name: Install PyTorch
|
|
||||||
run: pip install torch==${{matrix.torch}}+${{matrix.platform}} torchvision==${{matrix.torchvision}}+${{matrix.platform}} -f https://download.pytorch.org/whl/${{matrix.platform}}/torch_stable.html
|
|
||||||
- name: Build MMEngine from source
|
|
||||||
run: pip install -e . -v
|
|
||||||
- name: Install unit tests dependencies
|
|
||||||
run: |
|
|
||||||
pip install -r requirements/tests.txt
|
|
||||||
pip install openmim
|
|
||||||
mim install mmcv
|
|
||||||
- name: Run CPU unittests
|
|
||||||
run: pytest tests/ --ignore tests/test_dist
|
|
||||||
if: ${{ matrix.platform == 'cpu' }}
|
|
||||||
- name: Run GPU unittests
|
|
||||||
# Skip testing distributed related unit tests since the memory of windows CI is limited
|
|
||||||
run: pytest tests/ --ignore tests/test_dist --ignore tests/test_optim/test_optimizer/test_optimizer_wrapper.py --ignore tests/test_model/test_wrappers/test_model_wrapper.py --ignore tests/test_hooks/test_sync_buffers_hook.py
|
|
||||||
if: ${{ matrix.platform == 'cu111' }} || ${{ matrix.platform == 'cu118' }}
|
|
||||||
|
@ -1,7 +1,11 @@
|
|||||||
exclude: ^tests/data/
|
exclude: ^tests/data/
|
||||||
repos:
|
repos:
|
||||||
- repo: https://gitee.com/openmmlab/mirrors-flake8
|
- repo: https://github.com/pre-commit/pre-commit
|
||||||
rev: 5.0.4
|
rev: v4.0.0
|
||||||
|
hooks:
|
||||||
|
- id: validate_manifest
|
||||||
|
- repo: https://github.com/PyCQA/flake8
|
||||||
|
rev: 7.1.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
- repo: https://gitee.com/openmmlab/mirrors-isort
|
- repo: https://gitee.com/openmmlab/mirrors-isort
|
||||||
@ -13,7 +17,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: yapf
|
- id: yapf
|
||||||
- repo: https://gitee.com/openmmlab/mirrors-pre-commit-hooks
|
- repo: https://gitee.com/openmmlab/mirrors-pre-commit-hooks
|
||||||
rev: v4.3.0
|
rev: v5.0.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
@ -55,7 +59,7 @@ repos:
|
|||||||
args: ["mmengine", "tests"]
|
args: ["mmengine", "tests"]
|
||||||
- id: remove-improper-eol-in-cn-docs
|
- id: remove-improper-eol-in-cn-docs
|
||||||
- repo: https://gitee.com/openmmlab/mirrors-mypy
|
- repo: https://gitee.com/openmmlab/mirrors-mypy
|
||||||
rev: v0.812
|
rev: v1.2.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
exclude: |-
|
exclude: |-
|
||||||
@ -63,3 +67,4 @@ repos:
|
|||||||
^examples
|
^examples
|
||||||
| ^docs
|
| ^docs
|
||||||
)
|
)
|
||||||
|
additional_dependencies: ["types-setuptools", "types-requests", "types-PyYAML"]
|
||||||
|
@ -1,7 +1,11 @@
|
|||||||
exclude: ^tests/data/
|
exclude: ^tests/data/
|
||||||
repos:
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit
|
||||||
|
rev: v4.0.0
|
||||||
|
hooks:
|
||||||
|
- id: validate_manifest
|
||||||
- repo: https://github.com/PyCQA/flake8
|
- repo: https://github.com/PyCQA/flake8
|
||||||
rev: 5.0.4
|
rev: 7.1.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
- repo: https://github.com/PyCQA/isort
|
- repo: https://github.com/PyCQA/isort
|
||||||
@ -13,7 +17,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: yapf
|
- id: yapf
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.3.0
|
rev: v5.0.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
@ -34,12 +38,8 @@ repos:
|
|||||||
- mdformat-openmmlab
|
- mdformat-openmmlab
|
||||||
- mdformat_frontmatter
|
- mdformat_frontmatter
|
||||||
- linkify-it-py
|
- linkify-it-py
|
||||||
- repo: https://github.com/codespell-project/codespell
|
|
||||||
rev: v2.2.1
|
|
||||||
hooks:
|
|
||||||
- id: codespell
|
|
||||||
- repo: https://github.com/myint/docformatter
|
- repo: https://github.com/myint/docformatter
|
||||||
rev: v1.3.1
|
rev: 06907d0
|
||||||
hooks:
|
hooks:
|
||||||
- id: docformatter
|
- id: docformatter
|
||||||
args: ["--in-place", "--wrap-descriptions", "79"]
|
args: ["--in-place", "--wrap-descriptions", "79"]
|
||||||
@ -55,7 +55,7 @@ repos:
|
|||||||
args: ["mmengine", "tests"]
|
args: ["mmengine", "tests"]
|
||||||
- id: remove-improper-eol-in-cn-docs
|
- id: remove-improper-eol-in-cn-docs
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v0.812
|
rev: v1.2.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
exclude: |-
|
exclude: |-
|
||||||
@ -63,3 +63,4 @@ repos:
|
|||||||
^examples
|
^examples
|
||||||
| ^docs
|
| ^docs
|
||||||
)
|
)
|
||||||
|
additional_dependencies: ["types-setuptools", "types-requests", "types-PyYAML"]
|
||||||
|
@ -59,7 +59,7 @@ English | [简体中文](README_zh-CN.md)
|
|||||||
|
|
||||||
## What's New
|
## What's New
|
||||||
|
|
||||||
v0.10.4 was released on 2024-4-23.
|
v0.10.6 was released on 2025-01-13.
|
||||||
|
|
||||||
Highlights:
|
Highlights:
|
||||||
|
|
||||||
|
@ -59,7 +59,7 @@
|
|||||||
|
|
||||||
## 最近进展
|
## 最近进展
|
||||||
|
|
||||||
最新版本 v0.10.4 在 2024.4.23 发布。
|
最新版本 v0.10.5 在 2024.9.11 发布。
|
||||||
|
|
||||||
版本亮点:
|
版本亮点:
|
||||||
|
|
||||||
|
@ -1,5 +1,9 @@
|
|||||||
# Changelog of v0.x
|
# Changelog of v0.x
|
||||||
|
|
||||||
|
## v0.10.5 (11/9/2024)
|
||||||
|
|
||||||
|
- Fix `_is_builtin_module`. by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/1571
|
||||||
|
|
||||||
## v0.10.4 (23/4/2024)
|
## v0.10.4 (23/4/2024)
|
||||||
|
|
||||||
### New Features & Enhancements
|
### New Features & Enhancements
|
||||||
|
@ -499,7 +499,7 @@ class BaseStrategy(metaclass=ABCMeta):
|
|||||||
'"type" and "constructor" are not in '
|
'"type" and "constructor" are not in '
|
||||||
f'optimizer, but got {name}={optim}')
|
f'optimizer, but got {name}={optim}')
|
||||||
optim_wrappers[name] = optim
|
optim_wrappers[name] = optim
|
||||||
return OptimWrapperDict(**optim_wrappers)
|
return OptimWrapperDict(**optim_wrappers) # type: ignore
|
||||||
else:
|
else:
|
||||||
raise TypeError('optimizer wrapper should be an OptimWrapper '
|
raise TypeError('optimizer wrapper should be an OptimWrapper '
|
||||||
f'object or dict, but got {optim_wrapper}')
|
f'object or dict, but got {optim_wrapper}')
|
||||||
|
@ -361,7 +361,7 @@ class ColossalAIStrategy(BaseStrategy):
|
|||||||
map_location: Union[str, Callable] = 'default',
|
map_location: Union[str, Callable] = 'default',
|
||||||
callback: Optional[Callable] = None,
|
callback: Optional[Callable] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""override this method since colossalai resume optimizer from filename
|
"""Override this method since colossalai resume optimizer from filename
|
||||||
directly."""
|
directly."""
|
||||||
self.logger.info(f'Resume checkpoint from {filename}')
|
self.logger.info(f'Resume checkpoint from {filename}')
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ class DDPStrategy(SingleDeviceStrategy):
|
|||||||
init_dist(launcher, backend, **kwargs)
|
init_dist(launcher, backend, **kwargs)
|
||||||
|
|
||||||
def convert_model(self, model: nn.Module) -> nn.Module:
|
def convert_model(self, model: nn.Module) -> nn.Module:
|
||||||
"""convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm``
|
"""Convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm``
|
||||||
(SyncBN) or ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers.
|
(SyncBN) or ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -393,7 +393,7 @@ class Config:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
cfg_dict: dict = None,
|
cfg_dict: Optional[dict] = None,
|
||||||
cfg_text: Optional[str] = None,
|
cfg_text: Optional[str] = None,
|
||||||
filename: Optional[Union[str, Path]] = None,
|
filename: Optional[Union[str, Path]] = None,
|
||||||
env_variables: Optional[dict] = None,
|
env_variables: Optional[dict] = None,
|
||||||
@ -1227,7 +1227,8 @@ class Config:
|
|||||||
if base_code is not None:
|
if base_code is not None:
|
||||||
base_code = ast.Expression( # type: ignore
|
base_code = ast.Expression( # type: ignore
|
||||||
body=base_code.value) # type: ignore
|
body=base_code.value) # type: ignore
|
||||||
base_files = eval(compile(base_code, '', mode='eval'))
|
base_files = eval(compile(base_code, '',
|
||||||
|
mode='eval')) # type: ignore
|
||||||
else:
|
else:
|
||||||
base_files = []
|
base_files = []
|
||||||
elif file_format in ('.yml', '.yaml', '.json'):
|
elif file_format in ('.yml', '.yaml', '.json'):
|
||||||
@ -1288,7 +1289,7 @@ class Config:
|
|||||||
def _merge_a_into_b(a: dict,
|
def _merge_a_into_b(a: dict,
|
||||||
b: dict,
|
b: dict,
|
||||||
allow_list_keys: bool = False) -> dict:
|
allow_list_keys: bool = False) -> dict:
|
||||||
"""merge dict ``a`` into dict ``b`` (non-inplace).
|
"""Merge dict ``a`` into dict ``b`` (non-inplace).
|
||||||
|
|
||||||
Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
|
Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
|
||||||
in-place modifications.
|
in-place modifications.
|
||||||
@ -1358,22 +1359,22 @@ class Config:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def filename(self) -> str:
|
def filename(self) -> str:
|
||||||
"""get file name of config."""
|
"""Get file name of config."""
|
||||||
return self._filename
|
return self._filename
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self) -> str:
|
def text(self) -> str:
|
||||||
"""get config text."""
|
"""Get config text."""
|
||||||
return self._text
|
return self._text
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def env_variables(self) -> dict:
|
def env_variables(self) -> dict:
|
||||||
"""get used environment variables."""
|
"""Get used environment variables."""
|
||||||
return self._env_variables
|
return self._env_variables
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pretty_text(self) -> str:
|
def pretty_text(self) -> str:
|
||||||
"""get formatted python config text."""
|
"""Get formatted python config text."""
|
||||||
|
|
||||||
indent = 4
|
indent = 4
|
||||||
|
|
||||||
@ -1727,17 +1728,17 @@ class Config:
|
|||||||
|
|
||||||
|
|
||||||
class DictAction(Action):
|
class DictAction(Action):
|
||||||
"""
|
"""Argparse action to split an argument into KEY=VALUE form on the first =
|
||||||
argparse action to split an argument into KEY=VALUE form
|
and append to a dictionary.
|
||||||
on the first = and append to a dictionary. List options can
|
|
||||||
be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
|
List options can be passed as comma separated values, i.e 'KEY=V1,V2,V3',
|
||||||
brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
|
or with explicit brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested
|
||||||
list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
|
brackets to build list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _parse_int_float_bool(val: str) -> Union[int, float, bool, Any]:
|
def _parse_int_float_bool(val: str) -> Union[int, float, bool, Any]:
|
||||||
"""parse int/float/bool value in the string."""
|
"""Parse int/float/bool value in the string."""
|
||||||
try:
|
try:
|
||||||
return int(val)
|
return int(val)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@ -1822,7 +1823,7 @@ class DictAction(Action):
|
|||||||
parser: ArgumentParser,
|
parser: ArgumentParser,
|
||||||
namespace: Namespace,
|
namespace: Namespace,
|
||||||
values: Union[str, Sequence[Any], None],
|
values: Union[str, Sequence[Any], None],
|
||||||
option_string: str = None):
|
option_string: str = None): # type: ignore
|
||||||
"""Parse Variables in string and add them into argparser.
|
"""Parse Variables in string and add them into argparser.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -12,6 +12,7 @@ from mmengine.fileio import load
|
|||||||
from mmengine.utils import check_file_exist
|
from mmengine.utils import check_file_exist
|
||||||
|
|
||||||
PYTHON_ROOT_DIR = osp.dirname(osp.dirname(sys.executable))
|
PYTHON_ROOT_DIR = osp.dirname(osp.dirname(sys.executable))
|
||||||
|
SYSTEM_PYTHON_PREFIX = '/usr/lib/python'
|
||||||
|
|
||||||
MODULE2PACKAGE = {
|
MODULE2PACKAGE = {
|
||||||
'mmcls': 'mmcls',
|
'mmcls': 'mmcls',
|
||||||
@ -176,7 +177,8 @@ def _is_builtin_module(module_name: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
origin_path = osp.abspath(origin_path)
|
origin_path = osp.abspath(origin_path)
|
||||||
if ('site-package' in origin_path or 'dist-package' in origin_path
|
if ('site-package' in origin_path or 'dist-package' in origin_path
|
||||||
or not origin_path.startswith(PYTHON_ROOT_DIR)):
|
or not origin_path.startswith(
|
||||||
|
(PYTHON_ROOT_DIR, SYSTEM_PYTHON_PREFIX))):
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
@ -16,6 +16,12 @@ try:
|
|||||||
except Exception:
|
except Exception:
|
||||||
IS_NPU_AVAILABLE = False
|
IS_NPU_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch_mlu # noqa: F401
|
||||||
|
IS_MLU_AVAILABLE = hasattr(torch, 'mlu') and torch.mlu.is_available()
|
||||||
|
except Exception:
|
||||||
|
IS_MLU_AVAILABLE = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch_dipu # noqa: F401
|
import torch_dipu # noqa: F401
|
||||||
IS_DIPU_AVAILABLE = True
|
IS_DIPU_AVAILABLE = True
|
||||||
@ -64,7 +70,7 @@ def is_npu_available() -> bool:
|
|||||||
|
|
||||||
def is_mlu_available() -> bool:
|
def is_mlu_available() -> bool:
|
||||||
"""Returns True if Cambricon PyTorch and mlu devices exist."""
|
"""Returns True if Cambricon PyTorch and mlu devices exist."""
|
||||||
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
|
return IS_MLU_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
def is_mps_available() -> bool:
|
def is_mps_available() -> bool:
|
||||||
|
2
mmengine/dist/utils.py
vendored
2
mmengine/dist/utils.py
vendored
@ -563,7 +563,7 @@ def cast_data_device(
|
|||||||
Tensor or list or dict: ``data`` was casted to ``device``.
|
Tensor or list or dict: ``data`` was casted to ``device``.
|
||||||
"""
|
"""
|
||||||
if out is not None:
|
if out is not None:
|
||||||
if type(data) != type(out):
|
if type(data) is not type(out):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
'out should be the same type with data, but got data is '
|
'out should be the same type with data, but got data is '
|
||||||
f'{type(data)} and out is {type(data)}')
|
f'{type(data)} and out is {type(data)}')
|
||||||
|
@ -175,11 +175,11 @@ class DumpResults(BaseMetric):
|
|||||||
self.out_file_path = out_file_path
|
self.out_file_path = out_file_path
|
||||||
|
|
||||||
def process(self, data_batch: Any, predictions: Sequence[dict]) -> None:
|
def process(self, data_batch: Any, predictions: Sequence[dict]) -> None:
|
||||||
"""transfer tensors in predictions to CPU."""
|
"""Transfer tensors in predictions to CPU."""
|
||||||
self.results.extend(_to_cpu(predictions))
|
self.results.extend(_to_cpu(predictions))
|
||||||
|
|
||||||
def compute_metrics(self, results: list) -> dict:
|
def compute_metrics(self, results: list) -> dict:
|
||||||
"""dump the prediction results to a pickle file."""
|
"""Dump the prediction results to a pickle file."""
|
||||||
dump(results, self.out_file_path)
|
dump(results, self.out_file_path)
|
||||||
print_log(
|
print_log(
|
||||||
f'Results has been saved to {self.out_file_path}.',
|
f'Results has been saved to {self.out_file_path}.',
|
||||||
@ -188,7 +188,7 @@ class DumpResults(BaseMetric):
|
|||||||
|
|
||||||
|
|
||||||
def _to_cpu(data: Any) -> Any:
|
def _to_cpu(data: Any) -> Any:
|
||||||
"""transfer all tensors and BaseDataElement to cpu."""
|
"""Transfer all tensors and BaseDataElement to cpu."""
|
||||||
if isinstance(data, (Tensor, BaseDataElement)):
|
if isinstance(data, (Tensor, BaseDataElement)):
|
||||||
return data.to('cpu')
|
return data.to('cpu')
|
||||||
elif isinstance(data, list):
|
elif isinstance(data, list):
|
||||||
|
@ -233,7 +233,7 @@ class ProfilerHook(Hook):
|
|||||||
self._export_chrome_trace(runner)
|
self._export_chrome_trace(runner)
|
||||||
|
|
||||||
def after_train_iter(self, runner, batch_idx, data_batch, outputs):
|
def after_train_iter(self, runner, batch_idx, data_batch, outputs):
|
||||||
"""profiler will call `step` method if it is not closed."""
|
"""Profiler will call `step` method if it is not closed."""
|
||||||
if not self._closed:
|
if not self._closed:
|
||||||
self.profiler.step()
|
self.profiler.step()
|
||||||
if runner.iter == self.profile_times - 1 and not self.by_epoch:
|
if runner.iter == self.profile_times - 1 and not self.by_epoch:
|
||||||
|
@ -58,7 +58,7 @@ class HistoryBuffer:
|
|||||||
self._statistics_methods.setdefault('mean', HistoryBuffer.mean)
|
self._statistics_methods.setdefault('mean', HistoryBuffer.mean)
|
||||||
|
|
||||||
def update(self, log_val: Union[int, float], count: int = 1) -> None:
|
def update(self, log_val: Union[int, float], count: int = 1) -> None:
|
||||||
"""update the log history.
|
"""Update the log history.
|
||||||
|
|
||||||
If the length of the buffer exceeds ``self._max_length``, the oldest
|
If the length of the buffer exceeds ``self._max_length``, the oldest
|
||||||
element will be removed from the buffer.
|
element will be removed from the buffer.
|
||||||
|
@ -443,7 +443,6 @@ def _get_host_info() -> str:
|
|||||||
host = f'{getuser()}@{gethostname()}'
|
host = f'{getuser()}@{gethostname()}'
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
warnings.warn(f'Host or user not found: {str(e)}')
|
warnings.warn(f'Host or user not found: {str(e)}')
|
||||||
finally:
|
|
||||||
return host
|
return host
|
||||||
|
|
||||||
|
|
||||||
|
@ -253,17 +253,17 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
|
|||||||
dict or list: Data in the same format as the model input.
|
dict or list: Data in the same format as the model input.
|
||||||
"""
|
"""
|
||||||
data = self.cast_data(data) # type: ignore
|
data = self.cast_data(data) # type: ignore
|
||||||
_batch_inputs = data['inputs']
|
_batch_inputs = data['inputs'] # type: ignore
|
||||||
# Process data with `pseudo_collate`.
|
# Process data with `pseudo_collate`.
|
||||||
if is_seq_of(_batch_inputs, torch.Tensor):
|
if is_seq_of(_batch_inputs, torch.Tensor):
|
||||||
batch_inputs = []
|
batch_inputs = []
|
||||||
for _batch_input in _batch_inputs:
|
for _batch_input in _batch_inputs:
|
||||||
# channel transform
|
# channel transform
|
||||||
if self._channel_conversion:
|
if self._channel_conversion:
|
||||||
_batch_input = _batch_input[[2, 1, 0], ...]
|
_batch_input = _batch_input[[2, 1, 0], ...] # type: ignore
|
||||||
# Convert to float after channel conversion to ensure
|
# Convert to float after channel conversion to ensure
|
||||||
# efficiency
|
# efficiency
|
||||||
_batch_input = _batch_input.float()
|
_batch_input = _batch_input.float() # type: ignore
|
||||||
# Normalization.
|
# Normalization.
|
||||||
if self._enable_normalize:
|
if self._enable_normalize:
|
||||||
if self.mean.shape[0] == 3:
|
if self.mean.shape[0] == 3:
|
||||||
@ -302,7 +302,7 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
|
|||||||
else:
|
else:
|
||||||
raise TypeError('Output of `cast_data` should be a dict of '
|
raise TypeError('Output of `cast_data` should be a dict of '
|
||||||
'list/tuple with inputs and data_samples, '
|
'list/tuple with inputs and data_samples, '
|
||||||
f'but got {type(data)}: {data}')
|
f'but got {type(data)}: {data}') # type: ignore
|
||||||
data['inputs'] = batch_inputs
|
data['inputs'] = batch_inputs # type: ignore
|
||||||
data.setdefault('data_samples', None)
|
data.setdefault('data_samples', None) # type: ignore
|
||||||
return data
|
return data # type: ignore
|
||||||
|
@ -119,7 +119,7 @@ def caffe2_xavier_init(module, bias=0):
|
|||||||
|
|
||||||
|
|
||||||
def bias_init_with_prob(prior_prob):
|
def bias_init_with_prob(prior_prob):
|
||||||
"""initialize conv/fc bias value according to a given probability value."""
|
"""Initialize conv/fc bias value according to a given probability value."""
|
||||||
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
|
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
|
||||||
return bias_init
|
return bias_init
|
||||||
|
|
||||||
@ -662,12 +662,12 @@ def trunc_normal_(tensor: Tensor,
|
|||||||
std: float = 1.,
|
std: float = 1.,
|
||||||
a: float = -2.,
|
a: float = -2.,
|
||||||
b: float = 2.) -> Tensor:
|
b: float = 2.) -> Tensor:
|
||||||
r"""Fills the input Tensor with values drawn from a truncated
|
r"""Fills the input Tensor with values drawn from a truncated normal
|
||||||
normal distribution. The values are effectively drawn from the
|
distribution. The values are effectively drawn from the normal distribution
|
||||||
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
:math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside
|
||||||
with values outside :math:`[a, b]` redrawn until they are within
|
:math:`[a, b]` redrawn until they are within the bounds. The method used
|
||||||
the bounds. The method used for generating the random values works
|
for generating the random values works best when :math:`a \leq \text{mean}
|
||||||
best when :math:`a \leq \text{mean} \leq b`.
|
\leq b`.
|
||||||
|
|
||||||
Modified from
|
Modified from
|
||||||
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
|
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
|
||||||
|
@ -127,7 +127,8 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
|||||||
auto_wrap_policy: Union[str, Callable, None] = None,
|
auto_wrap_policy: Union[str, Callable, None] = None,
|
||||||
backward_prefetch: Union[str, BackwardPrefetch, None] = None,
|
backward_prefetch: Union[str, BackwardPrefetch, None] = None,
|
||||||
mixed_precision: Union[dict, MixedPrecision, None] = None,
|
mixed_precision: Union[dict, MixedPrecision, None] = None,
|
||||||
param_init_fn: Union[str, Callable[[nn.Module], None]] = None,
|
param_init_fn: Union[str, Callable[
|
||||||
|
[nn.Module], None]] = None, # type: ignore # noqa: E501
|
||||||
use_orig_params: bool = True,
|
use_orig_params: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -362,7 +363,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
|||||||
optim: torch.optim.Optimizer,
|
optim: torch.optim.Optimizer,
|
||||||
group: Optional[dist.ProcessGroup] = None,
|
group: Optional[dist.ProcessGroup] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""copied from pytorch 2.0.1 which has fixed some bugs."""
|
"""Copied from pytorch 2.0.1 which has fixed some bugs."""
|
||||||
state_dict_settings = FullyShardedDataParallel.get_state_dict_type(
|
state_dict_settings = FullyShardedDataParallel.get_state_dict_type(
|
||||||
model)
|
model)
|
||||||
return FullyShardedDataParallel._optim_state_dict_impl(
|
return FullyShardedDataParallel._optim_state_dict_impl(
|
||||||
@ -384,7 +385,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
|||||||
state_dict_config: Optional[StateDictConfig] = None,
|
state_dict_config: Optional[StateDictConfig] = None,
|
||||||
optim_state_dict_config: Optional[OptimStateDictConfig] = None,
|
optim_state_dict_config: Optional[OptimStateDictConfig] = None,
|
||||||
) -> StateDictSettings:
|
) -> StateDictSettings:
|
||||||
"""copied from pytorch 2.0.1 which has fixed some bugs."""
|
"""Copied from pytorch 2.0.1 which has fixed some bugs."""
|
||||||
import torch.distributed.fsdp._traversal_utils as traversal_utils
|
import torch.distributed.fsdp._traversal_utils as traversal_utils
|
||||||
_state_dict_type_to_config = {
|
_state_dict_type_to_config = {
|
||||||
StateDictType.FULL_STATE_DICT: FullStateDictConfig,
|
StateDictType.FULL_STATE_DICT: FullStateDictConfig,
|
||||||
|
@ -123,8 +123,7 @@ class ApexOptimWrapper(OptimWrapper):
|
|||||||
self._inner_count += 1
|
self._inner_count += 1
|
||||||
|
|
||||||
def state_dict(self) -> dict:
|
def state_dict(self) -> dict:
|
||||||
"""Get the state dictionary of :attr:`optimizer` and
|
"""Get the state dictionary of :attr:`optimizer` and :attr:`apex_amp`.
|
||||||
:attr:`apex_amp`.
|
|
||||||
|
|
||||||
Based on the state dictionary of the optimizer, the returned state
|
Based on the state dictionary of the optimizer, the returned state
|
||||||
dictionary will add a key named "apex_amp".
|
dictionary will add a key named "apex_amp".
|
||||||
|
@ -25,6 +25,10 @@ def register_torch_optimizers() -> List[str]:
|
|||||||
_optim = getattr(torch.optim, module_name)
|
_optim = getattr(torch.optim, module_name)
|
||||||
if inspect.isclass(_optim) and issubclass(_optim,
|
if inspect.isclass(_optim) and issubclass(_optim,
|
||||||
torch.optim.Optimizer):
|
torch.optim.Optimizer):
|
||||||
|
if module_name == 'Adafactor':
|
||||||
|
OPTIMIZERS.register_module(
|
||||||
|
name='TorchAdafactor', module=_optim)
|
||||||
|
else:
|
||||||
OPTIMIZERS.register_module(module=_optim)
|
OPTIMIZERS.register_module(module=_optim)
|
||||||
torch_optimizers.append(module_name)
|
torch_optimizers.append(module_name)
|
||||||
return torch_optimizers
|
return torch_optimizers
|
||||||
|
@ -131,7 +131,7 @@ class DefaultOptimWrapperConstructor:
|
|||||||
self._validate_cfg()
|
self._validate_cfg()
|
||||||
|
|
||||||
def _validate_cfg(self) -> None:
|
def _validate_cfg(self) -> None:
|
||||||
"""verify the correctness of the config."""
|
"""Verify the correctness of the config."""
|
||||||
if not isinstance(self.paramwise_cfg, dict):
|
if not isinstance(self.paramwise_cfg, dict):
|
||||||
raise TypeError('paramwise_cfg should be None or a dict, '
|
raise TypeError('paramwise_cfg should be None or a dict, '
|
||||||
f'but got {type(self.paramwise_cfg)}')
|
f'but got {type(self.paramwise_cfg)}')
|
||||||
@ -155,7 +155,7 @@ class DefaultOptimWrapperConstructor:
|
|||||||
raise ValueError('base_wd should not be None')
|
raise ValueError('base_wd should not be None')
|
||||||
|
|
||||||
def _is_in(self, param_group: dict, param_group_list: list) -> bool:
|
def _is_in(self, param_group: dict, param_group_list: list) -> bool:
|
||||||
"""check whether the `param_group` is in the`param_group_list`"""
|
"""Check whether the `param_group` is in the`param_group_list`"""
|
||||||
assert is_list_of(param_group_list, dict)
|
assert is_list_of(param_group_list, dict)
|
||||||
param = set(param_group['params'])
|
param = set(param_group['params'])
|
||||||
param_set = set()
|
param_set = set()
|
||||||
|
@ -161,8 +161,7 @@ class OptimWrapperDict(OptimWrapper):
|
|||||||
self.optim_wrappers[name].load_state_dict(_state_dict)
|
self.optim_wrappers[name].load_state_dict(_state_dict)
|
||||||
|
|
||||||
def items(self) -> Iterator[Tuple[str, OptimWrapper]]:
|
def items(self) -> Iterator[Tuple[str, OptimWrapper]]:
|
||||||
"""A generator to get the name and corresponding
|
"""A generator to get the name and corresponding :obj:`OptimWrapper`"""
|
||||||
:obj:`OptimWrapper`"""
|
|
||||||
yield from self.optim_wrappers.items()
|
yield from self.optim_wrappers.items()
|
||||||
|
|
||||||
def values(self) -> Iterator[OptimWrapper]:
|
def values(self) -> Iterator[OptimWrapper]:
|
||||||
|
@ -223,13 +223,13 @@ class PolyLR(LRSchedulerMixin, PolyParamScheduler):
|
|||||||
|
|
||||||
@PARAM_SCHEDULERS.register_module()
|
@PARAM_SCHEDULERS.register_module()
|
||||||
class OneCycleLR(LRSchedulerMixin, OneCycleParamScheduler):
|
class OneCycleLR(LRSchedulerMixin, OneCycleParamScheduler):
|
||||||
r"""Sets the learning rate of each parameter group according to the
|
r"""Sets the learning rate of each parameter group according to the 1cycle
|
||||||
1cycle learning rate policy. The 1cycle policy anneals the learning
|
learning rate policy. The 1cycle policy anneals the learning rate from an
|
||||||
rate from an initial learning rate to some maximum learning rate and then
|
initial learning rate to some maximum learning rate and then from that
|
||||||
from that maximum learning rate to some minimum learning rate much lower
|
maximum learning rate to some minimum learning rate much lower than the
|
||||||
than the initial learning rate.
|
initial learning rate. This policy was initially described in the paper
|
||||||
This policy was initially described in the paper `Super-Convergence:
|
`Super-Convergence: Very Fast Training of Neural Networks Using Large
|
||||||
Very Fast Training of Neural Networks Using Large Learning Rates`_.
|
Learning Rates`_.
|
||||||
|
|
||||||
The 1cycle learning rate policy changes the learning rate after every
|
The 1cycle learning rate policy changes the learning rate after every
|
||||||
batch. `step` should be called after a batch has been used for training.
|
batch. `step` should be called after a batch has been used for training.
|
||||||
|
@ -565,9 +565,9 @@ class ExponentialParamScheduler(_ParamScheduler):
|
|||||||
|
|
||||||
@PARAM_SCHEDULERS.register_module()
|
@PARAM_SCHEDULERS.register_module()
|
||||||
class CosineAnnealingParamScheduler(_ParamScheduler):
|
class CosineAnnealingParamScheduler(_ParamScheduler):
|
||||||
r"""Set the parameter value of each parameter group using a cosine
|
r"""Set the parameter value of each parameter group using a cosine annealing
|
||||||
annealing schedule, where :math:`\eta_{max}` is set to the initial value
|
schedule, where :math:`\eta_{max}` is set to the initial value and
|
||||||
and :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
|
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\begin{aligned}
|
\begin{aligned}
|
||||||
@ -617,7 +617,7 @@ class CosineAnnealingParamScheduler(_ParamScheduler):
|
|||||||
|
|
||||||
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
||||||
https://arxiv.org/abs/1608.03983
|
https://arxiv.org/abs/1608.03983
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
optimizer: Union[Optimizer, BaseOptimWrapper],
|
optimizer: Union[Optimizer, BaseOptimWrapper],
|
||||||
@ -890,13 +890,13 @@ class PolyParamScheduler(_ParamScheduler):
|
|||||||
|
|
||||||
@PARAM_SCHEDULERS.register_module()
|
@PARAM_SCHEDULERS.register_module()
|
||||||
class OneCycleParamScheduler(_ParamScheduler):
|
class OneCycleParamScheduler(_ParamScheduler):
|
||||||
r"""Sets the parameters of each parameter group according to the
|
r"""Sets the parameters of each parameter group according to the 1cycle
|
||||||
1cycle learning rate policy. The 1cycle policy anneals the learning
|
learning rate policy. The 1cycle policy anneals the learning rate from an
|
||||||
rate from an initial learning rate to some maximum learning rate and then
|
initial learning rate to some maximum learning rate and then from that
|
||||||
from that maximum learning rate to some minimum learning rate much lower
|
maximum learning rate to some minimum learning rate much lower than the
|
||||||
than the initial learning rate.
|
initial learning rate. This policy was initially described in the paper
|
||||||
This policy was initially described in the paper `Super-Convergence:
|
`Super-Convergence: Very Fast Training of Neural Networks Using Large
|
||||||
Very Fast Training of Neural Networks Using Large Learning Rates`_.
|
Learning Rates`_.
|
||||||
|
|
||||||
The 1cycle learning rate policy changes the learning rate after every
|
The 1cycle learning rate policy changes the learning rate after every
|
||||||
batch. `step` should be called after a batch has been used for training.
|
batch. `step` should be called after a batch has been used for training.
|
||||||
|
@ -4,7 +4,7 @@ import logging
|
|||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
from mmengine.config import Config, ConfigDict
|
from mmengine.config import Config, ConfigDict
|
||||||
from mmengine.utils import ManagerMixin
|
from mmengine.utils import ManagerMixin, digit_version
|
||||||
from .registry import Registry
|
from .registry import Registry
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -232,6 +232,21 @@ def build_model_from_cfg(
|
|||||||
return build_from_cfg(cfg, registry, default_args)
|
return build_from_cfg(cfg, registry, default_args)
|
||||||
|
|
||||||
|
|
||||||
|
def build_optimizer_from_cfg(
|
||||||
|
cfg: Union[dict, ConfigDict, Config],
|
||||||
|
registry: Registry,
|
||||||
|
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..logging import print_log
|
||||||
|
if 'type' in cfg \
|
||||||
|
and 'Adafactor' == cfg['type'] \
|
||||||
|
and digit_version(torch.__version__) >= digit_version('2.5.0'):
|
||||||
|
print_log(
|
||||||
|
'the torch version of Adafactor is registered as TorchAdafactor')
|
||||||
|
return build_from_cfg(cfg, registry, default_args)
|
||||||
|
|
||||||
|
|
||||||
def build_scheduler_from_cfg(
|
def build_scheduler_from_cfg(
|
||||||
cfg: Union[dict, ConfigDict, Config],
|
cfg: Union[dict, ConfigDict, Config],
|
||||||
registry: Registry,
|
registry: Registry,
|
||||||
|
@ -81,7 +81,7 @@ class DefaultScope(ManagerMixin):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def overwrite_default_scope(cls, scope_name: Optional[str]) -> Generator:
|
def overwrite_default_scope(cls, scope_name: Optional[str]) -> Generator:
|
||||||
"""overwrite the current default scope with `scope_name`"""
|
"""Overwrite the current default scope with `scope_name`"""
|
||||||
if scope_name is None:
|
if scope_name is None:
|
||||||
yield
|
yield
|
||||||
else:
|
else:
|
||||||
|
@ -332,7 +332,7 @@ class Registry:
|
|||||||
return root
|
return root
|
||||||
|
|
||||||
def import_from_location(self) -> None:
|
def import_from_location(self) -> None:
|
||||||
"""import modules from the pre-defined locations in self._location."""
|
"""Import modules from the pre-defined locations in self._location."""
|
||||||
if not self._imported:
|
if not self._imported:
|
||||||
# Avoid circular import
|
# Avoid circular import
|
||||||
from ..logging import print_log
|
from ..logging import print_log
|
||||||
|
@ -6,8 +6,8 @@ More datails can be found at
|
|||||||
https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html.
|
https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .build_functions import (build_model_from_cfg, build_runner_from_cfg,
|
from .build_functions import (build_model_from_cfg, build_optimizer_from_cfg,
|
||||||
build_scheduler_from_cfg)
|
build_runner_from_cfg, build_scheduler_from_cfg)
|
||||||
from .registry import Registry
|
from .registry import Registry
|
||||||
|
|
||||||
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
|
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
|
||||||
@ -35,7 +35,7 @@ MODEL_WRAPPERS = Registry('model_wrapper')
|
|||||||
WEIGHT_INITIALIZERS = Registry('weight initializer')
|
WEIGHT_INITIALIZERS = Registry('weight initializer')
|
||||||
|
|
||||||
# mangage all kinds of optimizers like `SGD` and `Adam`
|
# mangage all kinds of optimizers like `SGD` and `Adam`
|
||||||
OPTIMIZERS = Registry('optimizer')
|
OPTIMIZERS = Registry('optimizer', build_func=build_optimizer_from_cfg)
|
||||||
# manage optimizer wrapper
|
# manage optimizer wrapper
|
||||||
OPTIM_WRAPPERS = Registry('optim_wrapper')
|
OPTIM_WRAPPERS = Registry('optim_wrapper')
|
||||||
# manage constructors that customize the optimization hyperparameters.
|
# manage constructors that customize the optimization hyperparameters.
|
||||||
|
@ -109,7 +109,7 @@ def init_default_scope(scope: str) -> None:
|
|||||||
if current_scope.scope_name != scope: # type: ignore
|
if current_scope.scope_name != scope: # type: ignore
|
||||||
print_log(
|
print_log(
|
||||||
'The current default scope ' # type: ignore
|
'The current default scope ' # type: ignore
|
||||||
f'"{current_scope.scope_name}" is not "{scope}", '
|
f'"{current_scope.scope_name}" is not "{scope}", ' # type: ignore
|
||||||
'`init_default_scope` will force set the current'
|
'`init_default_scope` will force set the current'
|
||||||
f'default scope to "{scope}".',
|
f'default scope to "{scope}".',
|
||||||
logger='current',
|
logger='current',
|
||||||
|
@ -540,7 +540,7 @@ class FlexibleRunner:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def hooks(self):
|
def hooks(self):
|
||||||
"""list[:obj:`Hook`]: A list of registered hooks."""
|
"""List[:obj:`Hook`]: A list of registered hooks."""
|
||||||
return self._hooks
|
return self._hooks
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -1117,7 +1117,7 @@ class FlexibleRunner:
|
|||||||
return '\n'.join(stage_hook_infos)
|
return '\n'.join(stage_hook_infos)
|
||||||
|
|
||||||
def load_or_resume(self):
|
def load_or_resume(self):
|
||||||
"""load or resume checkpoint."""
|
"""Load or resume checkpoint."""
|
||||||
if self._has_loaded:
|
if self._has_loaded:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -1539,7 +1539,7 @@ class FlexibleRunner:
|
|||||||
file_client_args: Optional[dict] = None,
|
file_client_args: Optional[dict] = None,
|
||||||
save_optimizer: bool = True,
|
save_optimizer: bool = True,
|
||||||
save_param_scheduler: bool = True,
|
save_param_scheduler: bool = True,
|
||||||
meta: dict = None,
|
meta: Optional[dict] = None,
|
||||||
by_epoch: bool = True,
|
by_epoch: bool = True,
|
||||||
backend_args: Optional[dict] = None,
|
backend_args: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
|
@ -309,7 +309,7 @@ class CheckpointLoader:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_checkpoint(cls, filename, map_location=None, logger='current'):
|
def load_checkpoint(cls, filename, map_location=None, logger='current'):
|
||||||
"""load checkpoint through URL scheme path.
|
"""Load checkpoint through URL scheme path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename (str): checkpoint file name with given prefix
|
filename (str): checkpoint file name with given prefix
|
||||||
@ -332,7 +332,7 @@ class CheckpointLoader:
|
|||||||
|
|
||||||
@CheckpointLoader.register_scheme(prefixes='')
|
@CheckpointLoader.register_scheme(prefixes='')
|
||||||
def load_from_local(filename, map_location):
|
def load_from_local(filename, map_location):
|
||||||
"""load checkpoint by local file path.
|
"""Load checkpoint by local file path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename (str): local checkpoint file path
|
filename (str): local checkpoint file path
|
||||||
@ -353,7 +353,7 @@ def load_from_http(filename,
|
|||||||
map_location=None,
|
map_location=None,
|
||||||
model_dir=None,
|
model_dir=None,
|
||||||
progress=os.isatty(0)):
|
progress=os.isatty(0)):
|
||||||
"""load checkpoint through HTTP or HTTPS scheme path. In distributed
|
"""Load checkpoint through HTTP or HTTPS scheme path. In distributed
|
||||||
setting, this function only download checkpoint at local rank 0.
|
setting, this function only download checkpoint at local rank 0.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -386,7 +386,7 @@ def load_from_http(filename,
|
|||||||
|
|
||||||
@CheckpointLoader.register_scheme(prefixes='pavi://')
|
@CheckpointLoader.register_scheme(prefixes='pavi://')
|
||||||
def load_from_pavi(filename, map_location=None):
|
def load_from_pavi(filename, map_location=None):
|
||||||
"""load checkpoint through the file path prefixed with pavi. In distributed
|
"""Load checkpoint through the file path prefixed with pavi. In distributed
|
||||||
setting, this function download ckpt at all ranks to different temporary
|
setting, this function download ckpt at all ranks to different temporary
|
||||||
directories.
|
directories.
|
||||||
|
|
||||||
@ -419,7 +419,7 @@ def load_from_pavi(filename, map_location=None):
|
|||||||
@CheckpointLoader.register_scheme(
|
@CheckpointLoader.register_scheme(
|
||||||
prefixes=[r'(\S+\:)?s3://', r'(\S+\:)?petrel://'])
|
prefixes=[r'(\S+\:)?s3://', r'(\S+\:)?petrel://'])
|
||||||
def load_from_ceph(filename, map_location=None, backend='petrel'):
|
def load_from_ceph(filename, map_location=None, backend='petrel'):
|
||||||
"""load checkpoint through the file path prefixed with s3. In distributed
|
"""Load checkpoint through the file path prefixed with s3. In distributed
|
||||||
setting, this function download ckpt at all ranks to different temporary
|
setting, this function download ckpt at all ranks to different temporary
|
||||||
directories.
|
directories.
|
||||||
|
|
||||||
@ -441,7 +441,7 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
|
|||||||
|
|
||||||
@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
|
@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
|
||||||
def load_from_torchvision(filename, map_location=None):
|
def load_from_torchvision(filename, map_location=None):
|
||||||
"""load checkpoint through the file path prefixed with modelzoo or
|
"""Load checkpoint through the file path prefixed with modelzoo or
|
||||||
torchvision.
|
torchvision.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -467,7 +467,7 @@ def load_from_torchvision(filename, map_location=None):
|
|||||||
|
|
||||||
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
|
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
|
||||||
def load_from_openmmlab(filename, map_location=None):
|
def load_from_openmmlab(filename, map_location=None):
|
||||||
"""load checkpoint through the file path prefixed with open-mmlab or
|
"""Load checkpoint through the file path prefixed with open-mmlab or
|
||||||
openmmlab.
|
openmmlab.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -510,7 +510,7 @@ def load_from_openmmlab(filename, map_location=None):
|
|||||||
|
|
||||||
@CheckpointLoader.register_scheme(prefixes='mmcls://')
|
@CheckpointLoader.register_scheme(prefixes='mmcls://')
|
||||||
def load_from_mmcls(filename, map_location=None):
|
def load_from_mmcls(filename, map_location=None):
|
||||||
"""load checkpoint through the file path prefixed with mmcls.
|
"""Load checkpoint through the file path prefixed with mmcls.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename (str): checkpoint file path with mmcls prefix
|
filename (str): checkpoint file path with mmcls prefix
|
||||||
|
@ -8,8 +8,10 @@ import torch
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from mmengine.evaluator import Evaluator
|
from mmengine.evaluator import Evaluator
|
||||||
from mmengine.logging import print_log
|
from mmengine.logging import HistoryBuffer, print_log
|
||||||
from mmengine.registry import LOOPS
|
from mmengine.registry import LOOPS
|
||||||
|
from mmengine.structures import BaseDataElement
|
||||||
|
from mmengine.utils import is_list_of
|
||||||
from .amp import autocast
|
from .amp import autocast
|
||||||
from .base_loop import BaseLoop
|
from .base_loop import BaseLoop
|
||||||
from .utils import calc_dynamic_intervals
|
from .utils import calc_dynamic_intervals
|
||||||
@ -363,17 +365,26 @@ class ValLoop(BaseLoop):
|
|||||||
logger='current',
|
logger='current',
|
||||||
level=logging.WARNING)
|
level=logging.WARNING)
|
||||||
self.fp16 = fp16
|
self.fp16 = fp16
|
||||||
|
self.val_loss: Dict[str, HistoryBuffer] = dict()
|
||||||
|
|
||||||
def run(self) -> dict:
|
def run(self) -> dict:
|
||||||
"""Launch validation."""
|
"""Launch validation."""
|
||||||
self.runner.call_hook('before_val')
|
self.runner.call_hook('before_val')
|
||||||
self.runner.call_hook('before_val_epoch')
|
self.runner.call_hook('before_val_epoch')
|
||||||
self.runner.model.eval()
|
self.runner.model.eval()
|
||||||
|
|
||||||
|
# clear val loss
|
||||||
|
self.val_loss.clear()
|
||||||
for idx, data_batch in enumerate(self.dataloader):
|
for idx, data_batch in enumerate(self.dataloader):
|
||||||
self.run_iter(idx, data_batch)
|
self.run_iter(idx, data_batch)
|
||||||
|
|
||||||
# compute metrics
|
# compute metrics
|
||||||
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||||
|
|
||||||
|
if self.val_loss:
|
||||||
|
loss_dict = _parse_losses(self.val_loss, 'val')
|
||||||
|
metrics.update(loss_dict)
|
||||||
|
|
||||||
self.runner.call_hook('after_val_epoch', metrics=metrics)
|
self.runner.call_hook('after_val_epoch', metrics=metrics)
|
||||||
self.runner.call_hook('after_val')
|
self.runner.call_hook('after_val')
|
||||||
return metrics
|
return metrics
|
||||||
@ -391,6 +402,9 @@ class ValLoop(BaseLoop):
|
|||||||
# outputs should be sequence of BaseDataElement
|
# outputs should be sequence of BaseDataElement
|
||||||
with autocast(enabled=self.fp16):
|
with autocast(enabled=self.fp16):
|
||||||
outputs = self.runner.model.val_step(data_batch)
|
outputs = self.runner.model.val_step(data_batch)
|
||||||
|
|
||||||
|
outputs, self.val_loss = _update_losses(outputs, self.val_loss)
|
||||||
|
|
||||||
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
|
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
|
||||||
self.runner.call_hook(
|
self.runner.call_hook(
|
||||||
'after_val_iter',
|
'after_val_iter',
|
||||||
@ -435,17 +449,26 @@ class TestLoop(BaseLoop):
|
|||||||
logger='current',
|
logger='current',
|
||||||
level=logging.WARNING)
|
level=logging.WARNING)
|
||||||
self.fp16 = fp16
|
self.fp16 = fp16
|
||||||
|
self.test_loss: Dict[str, HistoryBuffer] = dict()
|
||||||
|
|
||||||
def run(self) -> dict:
|
def run(self) -> dict:
|
||||||
"""Launch test."""
|
"""Launch test."""
|
||||||
self.runner.call_hook('before_test')
|
self.runner.call_hook('before_test')
|
||||||
self.runner.call_hook('before_test_epoch')
|
self.runner.call_hook('before_test_epoch')
|
||||||
self.runner.model.eval()
|
self.runner.model.eval()
|
||||||
|
|
||||||
|
# clear test loss
|
||||||
|
self.test_loss.clear()
|
||||||
for idx, data_batch in enumerate(self.dataloader):
|
for idx, data_batch in enumerate(self.dataloader):
|
||||||
self.run_iter(idx, data_batch)
|
self.run_iter(idx, data_batch)
|
||||||
|
|
||||||
# compute metrics
|
# compute metrics
|
||||||
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||||
|
|
||||||
|
if self.test_loss:
|
||||||
|
loss_dict = _parse_losses(self.test_loss, 'test')
|
||||||
|
metrics.update(loss_dict)
|
||||||
|
|
||||||
self.runner.call_hook('after_test_epoch', metrics=metrics)
|
self.runner.call_hook('after_test_epoch', metrics=metrics)
|
||||||
self.runner.call_hook('after_test')
|
self.runner.call_hook('after_test')
|
||||||
return metrics
|
return metrics
|
||||||
@ -462,9 +485,66 @@ class TestLoop(BaseLoop):
|
|||||||
# predictions should be sequence of BaseDataElement
|
# predictions should be sequence of BaseDataElement
|
||||||
with autocast(enabled=self.fp16):
|
with autocast(enabled=self.fp16):
|
||||||
outputs = self.runner.model.test_step(data_batch)
|
outputs = self.runner.model.test_step(data_batch)
|
||||||
|
|
||||||
|
outputs, self.test_loss = _update_losses(outputs, self.test_loss)
|
||||||
|
|
||||||
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
|
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
|
||||||
self.runner.call_hook(
|
self.runner.call_hook(
|
||||||
'after_test_iter',
|
'after_test_iter',
|
||||||
batch_idx=idx,
|
batch_idx=idx,
|
||||||
data_batch=data_batch,
|
data_batch=data_batch,
|
||||||
outputs=outputs)
|
outputs=outputs)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_losses(losses: Dict[str, HistoryBuffer],
|
||||||
|
stage: str) -> Dict[str, float]:
|
||||||
|
"""Parses the raw losses of the network.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
losses (dict): raw losses of the network.
|
||||||
|
stage (str): The stage of loss, e.g., 'val' or 'test'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, float]: The key is the loss name, and the value is the
|
||||||
|
average loss.
|
||||||
|
"""
|
||||||
|
all_loss = 0
|
||||||
|
loss_dict: Dict[str, float] = dict()
|
||||||
|
|
||||||
|
for loss_name, loss_value in losses.items():
|
||||||
|
avg_loss = loss_value.mean()
|
||||||
|
loss_dict[loss_name] = avg_loss
|
||||||
|
if 'loss' in loss_name:
|
||||||
|
all_loss += avg_loss
|
||||||
|
|
||||||
|
loss_dict[f'{stage}_loss'] = all_loss
|
||||||
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _update_losses(outputs: list, losses: dict) -> Tuple[list, dict]:
|
||||||
|
"""Update and record the losses of the network.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs (list): The outputs of the network.
|
||||||
|
losses (dict): The losses of the network.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: The updated outputs of the network.
|
||||||
|
dict: The updated losses of the network.
|
||||||
|
"""
|
||||||
|
if isinstance(outputs[-1],
|
||||||
|
BaseDataElement) and outputs[-1].keys() == ['loss']:
|
||||||
|
loss = outputs[-1].loss # type: ignore
|
||||||
|
outputs = outputs[:-1]
|
||||||
|
else:
|
||||||
|
loss = dict()
|
||||||
|
|
||||||
|
for loss_name, loss_value in loss.items():
|
||||||
|
if loss_name not in losses:
|
||||||
|
losses[loss_name] = HistoryBuffer()
|
||||||
|
if isinstance(loss_value, torch.Tensor):
|
||||||
|
losses[loss_name].update(loss_value.item())
|
||||||
|
elif is_list_of(loss_value, torch.Tensor):
|
||||||
|
for loss_value_i in loss_value:
|
||||||
|
losses[loss_name].update(loss_value_i.item())
|
||||||
|
return outputs, losses
|
||||||
|
@ -579,7 +579,7 @@ class Runner:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def hooks(self):
|
def hooks(self):
|
||||||
"""list[:obj:`Hook`]: A list of registered hooks."""
|
"""List[:obj:`Hook`]: A list of registered hooks."""
|
||||||
return self._hooks
|
return self._hooks
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -720,7 +720,7 @@ class Runner:
|
|||||||
|
|
||||||
def build_logger(self,
|
def build_logger(self,
|
||||||
log_level: Union[int, str] = 'INFO',
|
log_level: Union[int, str] = 'INFO',
|
||||||
log_file: str = None,
|
log_file: Optional[str] = None,
|
||||||
**kwargs) -> MMLogger:
|
**kwargs) -> MMLogger:
|
||||||
"""Build a global asscessable MMLogger.
|
"""Build a global asscessable MMLogger.
|
||||||
|
|
||||||
@ -1677,7 +1677,7 @@ class Runner:
|
|||||||
return '\n'.join(stage_hook_infos)
|
return '\n'.join(stage_hook_infos)
|
||||||
|
|
||||||
def load_or_resume(self) -> None:
|
def load_or_resume(self) -> None:
|
||||||
"""load or resume checkpoint."""
|
"""Load or resume checkpoint."""
|
||||||
if self._has_loaded:
|
if self._has_loaded:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -387,7 +387,7 @@ class BaseDataElement:
|
|||||||
return dict(self.metainfo_items())
|
return dict(self.metainfo_items())
|
||||||
|
|
||||||
def __setattr__(self, name: str, value: Any):
|
def __setattr__(self, name: str, value: Any):
|
||||||
"""setattr is only used to set data."""
|
"""Setattr is only used to set data."""
|
||||||
if name in ('_metainfo_fields', '_data_fields'):
|
if name in ('_metainfo_fields', '_data_fields'):
|
||||||
if not hasattr(self, name):
|
if not hasattr(self, name):
|
||||||
super().__setattr__(name, value)
|
super().__setattr__(name, value)
|
||||||
|
@ -135,7 +135,7 @@ class InstanceData(BaseDataElement):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __setattr__(self, name: str, value: Sized):
|
def __setattr__(self, name: str, value: Sized):
|
||||||
"""setattr is only used to set data.
|
"""Setattr is only used to set data.
|
||||||
|
|
||||||
The value must have the attribute of `__len__` and have the same length
|
The value must have the attribute of `__len__` and have the same length
|
||||||
of `InstanceData`.
|
of `InstanceData`.
|
||||||
|
@ -92,10 +92,25 @@ class MultiProcessTestCase(TestCase):
|
|||||||
# Constructor patches current instance test method to
|
# Constructor patches current instance test method to
|
||||||
# assume the role of the main process and join its subprocesses,
|
# assume the role of the main process and join its subprocesses,
|
||||||
# or run the underlying test function.
|
# or run the underlying test function.
|
||||||
def __init__(self, method_name: str = 'runTest') -> None:
|
def __init__(self,
|
||||||
|
method_name: str = 'runTest',
|
||||||
|
methodName: str = 'runTest') -> None:
|
||||||
|
# methodName is the correct naming in unittest
|
||||||
|
# and testslide uses keyword arguments.
|
||||||
|
# So we need to use both to 1) not break BC and, 2) support testslide.
|
||||||
|
if methodName != 'runTest':
|
||||||
|
method_name = methodName
|
||||||
super().__init__(method_name)
|
super().__init__(method_name)
|
||||||
|
try:
|
||||||
fn = getattr(self, method_name)
|
fn = getattr(self, method_name)
|
||||||
setattr(self, method_name, self.join_or_run(fn))
|
setattr(self, method_name, self.join_or_run(fn))
|
||||||
|
except AttributeError as e:
|
||||||
|
if methodName != 'runTest':
|
||||||
|
# we allow instantiation with no explicit method name
|
||||||
|
# but not an *incorrect* or missing method name
|
||||||
|
raise ValueError(
|
||||||
|
f'no such test method in {self.__class__}: {methodName}'
|
||||||
|
) from e
|
||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
@ -57,6 +57,7 @@ if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version(
|
|||||||
check_hash=False,
|
check_hash=False,
|
||||||
file_name=None):
|
file_name=None):
|
||||||
r"""Loads the Torch serialized object at the given URL.
|
r"""Loads the Torch serialized object at the given URL.
|
||||||
|
|
||||||
If downloaded file is a zip file, it will be automatically decompressed
|
If downloaded file is a zip file, it will be automatically decompressed
|
||||||
If the object is already present in `model_dir`, it's deserialized and
|
If the object is already present in `model_dir`, it's deserialized and
|
||||||
returned.
|
returned.
|
||||||
|
@ -67,7 +67,7 @@ class TimeCounter:
|
|||||||
|
|
||||||
instance.log_interval = log_interval
|
instance.log_interval = log_interval
|
||||||
instance.warmup_interval = warmup_interval
|
instance.warmup_interval = warmup_interval
|
||||||
instance.with_sync = with_sync
|
instance.with_sync = with_sync # type: ignore
|
||||||
instance.tag = tag
|
instance.tag = tag
|
||||||
instance.logger = logger
|
instance.logger = logger
|
||||||
|
|
||||||
@ -127,7 +127,7 @@ class TimeCounter:
|
|||||||
self.print_time(elapsed)
|
self.print_time(elapsed)
|
||||||
|
|
||||||
def print_time(self, elapsed: Union[int, float]) -> None:
|
def print_time(self, elapsed: Union[int, float]) -> None:
|
||||||
"""print times per count."""
|
"""Print times per count."""
|
||||||
if self.__count >= self.warmup_interval:
|
if self.__count >= self.warmup_interval:
|
||||||
self.__pure_inf_time += elapsed
|
self.__pure_inf_time += elapsed
|
||||||
|
|
||||||
|
@ -131,7 +131,7 @@ def tuple_cast(inputs, dst_type):
|
|||||||
|
|
||||||
def is_seq_of(seq: Any,
|
def is_seq_of(seq: Any,
|
||||||
expected_type: Union[Type, tuple],
|
expected_type: Union[Type, tuple],
|
||||||
seq_type: Type = None) -> bool:
|
seq_type: Optional[Type] = None) -> bool:
|
||||||
"""Check whether it is a sequence of some type.
|
"""Check whether it is a sequence of some type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -69,11 +69,11 @@ def get_installed_path(package: str) -> str:
|
|||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
possible_path = osp.join(pkg.location, package)
|
possible_path = osp.join(pkg.location, package) # type: ignore
|
||||||
if osp.exists(possible_path):
|
if osp.exists(possible_path):
|
||||||
return possible_path
|
return possible_path
|
||||||
else:
|
else:
|
||||||
return osp.join(pkg.location, package2module(package))
|
return osp.join(pkg.location, package2module(package)) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def package2module(package: str):
|
def package2module(package: str):
|
||||||
|
@ -3,7 +3,7 @@ import sys
|
|||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
from shutil import get_terminal_size
|
from shutil import get_terminal_size
|
||||||
from typing import Callable, Sequence
|
from typing import Callable, Optional, Sequence
|
||||||
|
|
||||||
from .timer import Timer
|
from .timer import Timer
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ class ProgressBar:
|
|||||||
self.timer = Timer()
|
self.timer = Timer()
|
||||||
|
|
||||||
def update(self, num_tasks: int = 1):
|
def update(self, num_tasks: int = 1):
|
||||||
"""update progressbar.
|
"""Update progressbar.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_tasks (int): Update step size.
|
num_tasks (int): Update step size.
|
||||||
@ -142,8 +142,8 @@ def init_pool(process_num, initializer=None, initargs=None):
|
|||||||
def track_parallel_progress(func: Callable,
|
def track_parallel_progress(func: Callable,
|
||||||
tasks: Sequence,
|
tasks: Sequence,
|
||||||
nproc: int,
|
nproc: int,
|
||||||
initializer: Callable = None,
|
initializer: Optional[Callable] = None,
|
||||||
initargs: tuple = None,
|
initargs: Optional[tuple] = None,
|
||||||
bar_width: int = 50,
|
bar_width: int = 50,
|
||||||
chunksize: int = 1,
|
chunksize: int = 1,
|
||||||
skip_first: bool = False,
|
skip_first: bool = False,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
from typing import Callable, Iterable, Sized
|
from typing import Callable, Iterable, Optional, Sized
|
||||||
|
|
||||||
from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task,
|
from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task,
|
||||||
TaskProgressColumn, TextColumn, TimeRemainingColumn)
|
TaskProgressColumn, TextColumn, TimeRemainingColumn)
|
||||||
@ -47,7 +47,7 @@ def _tasks_with_index(tasks):
|
|||||||
|
|
||||||
def track_progress_rich(func: Callable,
|
def track_progress_rich(func: Callable,
|
||||||
tasks: Iterable = tuple(),
|
tasks: Iterable = tuple(),
|
||||||
task_num: int = None,
|
task_num: Optional[int] = None,
|
||||||
nproc: int = 1,
|
nproc: int = 1,
|
||||||
chunksize: int = 1,
|
chunksize: int = 1,
|
||||||
description: str = 'Processing',
|
description: str = 'Processing',
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
|
||||||
__version__ = '0.10.4'
|
__version__ = '0.10.7'
|
||||||
|
|
||||||
|
|
||||||
def parse_version_info(version_str):
|
def parse_version_info(version_str):
|
||||||
|
@ -161,7 +161,7 @@ class BaseVisBackend(metaclass=ABCMeta):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""close an opened object."""
|
"""Close an opened object."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -314,7 +314,7 @@ class LocalVisBackend(BaseVisBackend):
|
|||||||
|
|
||||||
def _dump(self, value_dict: dict, file_path: str,
|
def _dump(self, value_dict: dict, file_path: str,
|
||||||
file_format: str) -> None:
|
file_format: str) -> None:
|
||||||
"""dump dict to file.
|
"""Dump dict to file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
value_dict (dict) : The dict data to saved.
|
value_dict (dict) : The dict data to saved.
|
||||||
@ -505,7 +505,7 @@ class WandbVisBackend(BaseVisBackend):
|
|||||||
self._wandb.log(scalar_dict, commit=self._commit)
|
self._wandb.log(scalar_dict, commit=self._commit)
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""close an opened wandb object."""
|
"""Close an opened wandb object."""
|
||||||
if hasattr(self, '_wandb'):
|
if hasattr(self, '_wandb'):
|
||||||
self._wandb.join()
|
self._wandb.join()
|
||||||
|
|
||||||
@ -629,7 +629,7 @@ class TensorboardVisBackend(BaseVisBackend):
|
|||||||
self.add_scalar(key, value, step)
|
self.add_scalar(key, value, step)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""close an opened tensorboard object."""
|
"""Close an opened tensorboard object."""
|
||||||
if hasattr(self, '_tensorboard'):
|
if hasattr(self, '_tensorboard'):
|
||||||
self._tensorboard.close()
|
self._tensorboard.close()
|
||||||
|
|
||||||
@ -1135,7 +1135,7 @@ class NeptuneVisBackend(BaseVisBackend):
|
|||||||
self._neptune[k].append(v, step=step)
|
self._neptune[k].append(v, step=step)
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""close an opened object."""
|
"""Close an opened object."""
|
||||||
if hasattr(self, '_neptune'):
|
if hasattr(self, '_neptune'):
|
||||||
self._neptune.stop()
|
self._neptune.stop()
|
||||||
|
|
||||||
@ -1282,7 +1282,7 @@ class DVCLiveVisBackend(BaseVisBackend):
|
|||||||
self.add_scalar(key, value, step, **kwargs)
|
self.add_scalar(key, value, step, **kwargs)
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""close an opened dvclive object."""
|
"""Close an opened dvclive object."""
|
||||||
if not hasattr(self, '_dvclive'):
|
if not hasattr(self, '_dvclive'):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -356,7 +356,7 @@ class Visualizer(ManagerMixin):
|
|||||||
|
|
||||||
@master_only
|
@master_only
|
||||||
def get_backend(self, name) -> 'BaseVisBackend':
|
def get_backend(self, name) -> 'BaseVisBackend':
|
||||||
"""get vis backend by name.
|
"""Get vis backend by name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name (str): The name of vis backend
|
name (str): The name of vis backend
|
||||||
@ -1145,7 +1145,7 @@ class Visualizer(ManagerMixin):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""close an opened object."""
|
"""Close an opened object."""
|
||||||
for vis_backend in self._vis_backends.values():
|
for vis_backend in self._vis_backends.values():
|
||||||
vis_backend.close()
|
vis_backend.close()
|
||||||
|
|
||||||
|
@ -843,8 +843,8 @@ class TestConfig:
|
|||||||
assert cfg_dict['item4'] == 'test'
|
assert cfg_dict['item4'] == 'test'
|
||||||
assert '_delete_' not in cfg_dict['item1']
|
assert '_delete_' not in cfg_dict['item1']
|
||||||
|
|
||||||
assert type(cfg_dict['item1']) == ConfigDict
|
assert type(cfg_dict['item1']) is ConfigDict
|
||||||
assert type(cfg_dict['item2']) == ConfigDict
|
assert type(cfg_dict['item2']) is ConfigDict
|
||||||
|
|
||||||
def _merge_intermediate_variable(self):
|
def _merge_intermediate_variable(self):
|
||||||
|
|
||||||
|
@ -13,7 +13,8 @@ from mmengine.testing import assert_allclose
|
|||||||
class TestAveragedModel(TestCase):
|
class TestAveragedModel(TestCase):
|
||||||
"""Test the AveragedModel class.
|
"""Test the AveragedModel class.
|
||||||
|
|
||||||
Some test cases are referenced from https://github.com/pytorch/pytorch/blob/master/test/test_optim.py
|
Some test cases are referenced from
|
||||||
|
https://github.com/pytorch/pytorch/blob/master/test/test_optim.py
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
def _test_swa_model(self, net_device, avg_device):
|
def _test_swa_model(self, net_device, avg_device):
|
||||||
|
@ -102,7 +102,7 @@ class TestLRScheduler(TestCase):
|
|||||||
rtol=0)
|
rtol=0)
|
||||||
|
|
||||||
def test_scheduler_before_optim_warning(self):
|
def test_scheduler_before_optim_warning(self):
|
||||||
"""warns if scheduler is used before optimizer."""
|
"""Warns if scheduler is used before optimizer."""
|
||||||
|
|
||||||
def call_sch_before_optim():
|
def call_sch_before_optim():
|
||||||
scheduler = StepLR(self.optimizer, gamma=0.1, step_size=3)
|
scheduler = StepLR(self.optimizer, gamma=0.1, step_size=3)
|
||||||
|
@ -120,7 +120,7 @@ class TestMomentumScheduler(TestCase):
|
|||||||
rtol=0)
|
rtol=0)
|
||||||
|
|
||||||
def test_scheduler_before_optim_warning(self):
|
def test_scheduler_before_optim_warning(self):
|
||||||
"""warns if scheduler is used before optimizer."""
|
"""Warns if scheduler is used before optimizer."""
|
||||||
|
|
||||||
def call_sch_before_optim():
|
def call_sch_before_optim():
|
||||||
scheduler = StepMomentum(self.optimizer, gamma=0.1, step_size=3)
|
scheduler = StepMomentum(self.optimizer, gamma=0.1, step_size=3)
|
||||||
|
@ -127,7 +127,7 @@ class TestParameterScheduler(TestCase):
|
|||||||
rtol=0)
|
rtol=0)
|
||||||
|
|
||||||
def test_scheduler_before_optim_warning(self):
|
def test_scheduler_before_optim_warning(self):
|
||||||
"""warns if scheduler is used before optimizer."""
|
"""Warns if scheduler is used before optimizer."""
|
||||||
|
|
||||||
def call_sch_before_optim():
|
def call_sch_before_optim():
|
||||||
scheduler = StepParamScheduler(
|
scheduler = StepParamScheduler(
|
||||||
|
@ -251,7 +251,7 @@ def test_load_checkpoint_metadata():
|
|||||||
|
|
||||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||||
*args, **kwargs):
|
*args, **kwargs):
|
||||||
"""load checkpoints."""
|
"""Load checkpoints."""
|
||||||
|
|
||||||
# Names of some parameters in has been changed.
|
# Names of some parameters in has been changed.
|
||||||
version = local_metadata.get('version', None)
|
version = local_metadata.get('version', None)
|
||||||
|
@ -2226,7 +2226,7 @@ class TestRunner(TestCase):
|
|||||||
|
|
||||||
@HOOKS.register_module(force=True)
|
@HOOKS.register_module(force=True)
|
||||||
class TestWarmupHook(Hook):
|
class TestWarmupHook(Hook):
|
||||||
"""test custom train loop."""
|
"""Test custom train loop."""
|
||||||
|
|
||||||
def before_warmup_iter(self, runner, data_batch=None):
|
def before_warmup_iter(self, runner, data_batch=None):
|
||||||
before_warmup_iter_results.append('before')
|
before_warmup_iter_results.append('before')
|
||||||
|
@ -64,7 +64,7 @@ class TestBaseDataElement(TestCase):
|
|||||||
return metainfo, data
|
return metainfo, data
|
||||||
|
|
||||||
def is_equal(self, x, y):
|
def is_equal(self, x, y):
|
||||||
assert type(x) == type(y)
|
assert type(x) is type(y)
|
||||||
if isinstance(
|
if isinstance(
|
||||||
x, (int, float, str, list, tuple, dict, set, BaseDataElement)):
|
x, (int, float, str, list, tuple, dict, set, BaseDataElement)):
|
||||||
return x == y
|
return x == y
|
||||||
@ -141,7 +141,7 @@ class TestBaseDataElement(TestCase):
|
|||||||
|
|
||||||
# test new() with no arguments
|
# test new() with no arguments
|
||||||
new_instances = instances.new()
|
new_instances = instances.new()
|
||||||
assert type(new_instances) == type(instances)
|
assert type(new_instances) is type(instances)
|
||||||
# After deepcopy, the address of new data'element will be same as
|
# After deepcopy, the address of new data'element will be same as
|
||||||
# origin, but when change new data' element will not effect the origin
|
# origin, but when change new data' element will not effect the origin
|
||||||
# element and will have new address
|
# element and will have new address
|
||||||
@ -154,7 +154,7 @@ class TestBaseDataElement(TestCase):
|
|||||||
# test new() with arguments
|
# test new() with arguments
|
||||||
metainfo, data = self.setup_data()
|
metainfo, data = self.setup_data()
|
||||||
new_instances = instances.new(metainfo=metainfo, **data)
|
new_instances = instances.new(metainfo=metainfo, **data)
|
||||||
assert type(new_instances) == type(instances)
|
assert type(new_instances) is type(instances)
|
||||||
assert id(new_instances.gt_instances) != id(instances.gt_instances)
|
assert id(new_instances.gt_instances) != id(instances.gt_instances)
|
||||||
_, new_data = self.setup_data()
|
_, new_data = self.setup_data()
|
||||||
new_instances.set_data(new_data)
|
new_instances.set_data(new_data)
|
||||||
@ -168,7 +168,7 @@ class TestBaseDataElement(TestCase):
|
|||||||
metainfo, data = self.setup_data()
|
metainfo, data = self.setup_data()
|
||||||
instances = BaseDataElement(metainfo=metainfo, **data)
|
instances = BaseDataElement(metainfo=metainfo, **data)
|
||||||
new_instances = instances.clone()
|
new_instances = instances.clone()
|
||||||
assert type(new_instances) == type(instances)
|
assert type(new_instances) is type(instances)
|
||||||
|
|
||||||
def test_set_metainfo(self):
|
def test_set_metainfo(self):
|
||||||
metainfo, _ = self.setup_data()
|
metainfo, _ = self.setup_data()
|
||||||
|
@ -45,7 +45,7 @@ class MockVisBackend:
|
|||||||
self._add_scalars = True
|
self._add_scalars = True
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""close an opened object."""
|
"""Close an opened object."""
|
||||||
self._close = True
|
self._close = True
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user