Compare commits
45 Commits
Author | SHA1 | Date |
---|---|---|
|
390ba2fbb2 | |
|
d620552c2c | |
|
41fa84a9a9 | |
|
698782f920 | |
|
e60ab1dde3 | |
|
8ec837814e | |
|
a4475f5eea | |
|
a8c74c346d | |
|
9124ebf7a2 | |
|
2e0ab7a922 | |
|
fc59364d64 | |
|
4183cf0829 | |
|
cc3b74b5e8 | |
|
c9b59962d6 | |
|
5e736b143b | |
|
85c83ba616 | |
|
d1f1aabf81 | |
|
66fb81f7b3 | |
|
acbc5e46dc | |
|
9ecced821b | |
|
39ed23fae8 | |
|
e258c84824 | |
|
2c4516c622 | |
|
447d3bba2c | |
|
2fe0ecec3d | |
|
c423d0c1da | |
|
9b98405672 | |
|
4df682ba2d | |
|
ba5eed8409 | |
|
f79111ecc0 | |
|
b5f2d5860d | |
|
02f80e8bdd | |
|
cd298e3086 | |
|
396cac19cd | |
|
3d8a611eec | |
|
109cd44c7e | |
|
b51bf60964 | |
|
4a50213c69 | |
|
e4600a6993 | |
|
369f15e27a | |
|
1398e4200e | |
|
8e6fb12b1f | |
|
671f3bcdf4 | |
|
efcd364124 | |
|
504fa4f5cb |
|
@ -1,6 +1,8 @@
|
|||
name: deploy
|
||||
|
||||
on: push
|
||||
on:
|
||||
- push
|
||||
- workflow_dispatch
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
@ -9,13 +11,14 @@ concurrency:
|
|||
jobs:
|
||||
build-n-publish:
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.event.ref, 'refs/tags')
|
||||
if: |
|
||||
startsWith(github.event.ref, 'refs/tags') || github.event_name == 'workflow_dispatch'
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.7
|
||||
uses: actions/setup-python@v2
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python 3.10.13
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.7
|
||||
python-version: 3.10.13
|
||||
- name: Install wheel
|
||||
run: pip install wheel
|
||||
- name: Build MMEngine
|
||||
|
@ -27,13 +30,14 @@ jobs:
|
|||
|
||||
build-n-publish-lite:
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.event.ref, 'refs/tags')
|
||||
if: |
|
||||
startsWith(github.event.ref, 'refs/tags') || github.event_name == 'workflow_dispatch'
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.7
|
||||
uses: actions/setup-python@v2
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python 3.10.13
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.7
|
||||
python-version: 3.10.13
|
||||
- name: Install wheel
|
||||
run: pip install wheel
|
||||
- name: Build MMEngine-lite
|
||||
|
|
|
@ -11,10 +11,10 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.7
|
||||
- name: Set up Python 3.10.15
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.7
|
||||
python-version: '3.10.15'
|
||||
- name: Install pre-commit hook
|
||||
run: |
|
||||
pip install pre-commit
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
name: pr_stage_test
|
||||
|
||||
env:
|
||||
ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true
|
||||
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths-ignore:
|
||||
|
@ -21,152 +25,114 @@ concurrency:
|
|||
jobs:
|
||||
build_cpu:
|
||||
runs-on: ubuntu-22.04
|
||||
defaults:
|
||||
run:
|
||||
shell: bash -l {0}
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.7]
|
||||
include:
|
||||
- torch: 1.8.1
|
||||
torchvision: 0.9.1
|
||||
python-version: ['3.9']
|
||||
torch: ['2.0.0']
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
- name: Check out repo
|
||||
uses: actions/checkout@v3
|
||||
- name: Setup conda env
|
||||
uses: conda-incubator/setup-miniconda@v2
|
||||
with:
|
||||
auto-update-conda: true
|
||||
miniconda-version: "latest"
|
||||
use-only-tar-bz2: true
|
||||
activate-environment: test
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Upgrade pip
|
||||
run: python -m pip install pip --upgrade
|
||||
- name: Upgrade wheel
|
||||
run: python -m pip install wheel --upgrade
|
||||
- name: Install PyTorch
|
||||
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
- name: Build MMEngine from source
|
||||
run: pip install -e . -v
|
||||
- name: Install unit tests dependencies
|
||||
- name: Update pip
|
||||
run: |
|
||||
pip install -r requirements/tests.txt
|
||||
pip install openmim
|
||||
mim install mmcv
|
||||
- name: Run unittests and generate coverage report
|
||||
python -m pip install --upgrade pip wheel
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist
|
||||
coverage xml
|
||||
coverage report -m
|
||||
# Upload coverage report for python3.7 && pytorch1.8.1 cpu
|
||||
- name: Upload coverage to Codecov
|
||||
apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6
|
||||
python -m pip install torch==${{matrix.torch}}
|
||||
python -m pip install -e . -v
|
||||
python -m pip install -r requirements/tests.txt
|
||||
python -m pip install openmim
|
||||
mim install mmcv coverage
|
||||
- name: Run unit tests with coverage
|
||||
run: coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist
|
||||
- name: Upload Coverage to Codecov
|
||||
uses: codecov/codecov-action@v3
|
||||
with:
|
||||
file: ./coverage.xml
|
||||
flags: unittests
|
||||
env_vars: OS,PYTHON
|
||||
name: codecov-umbrella
|
||||
fail_ci_if_error: false
|
||||
|
||||
build_cu102:
|
||||
build_gpu:
|
||||
runs-on: ubuntu-22.04
|
||||
container:
|
||||
image: pytorch/pytorch:1.8.1-cuda10.2-cudnn7-devel
|
||||
defaults:
|
||||
run:
|
||||
shell: bash -l {0}
|
||||
env:
|
||||
MKL_THREADING_LAYER: GNU
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.7]
|
||||
python-version: ['3.9','3.10']
|
||||
torch: ['2.0.0','2.3.1','2.5.1']
|
||||
cuda: ['cu118']
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
- name: Check out repo
|
||||
uses: actions/checkout@v3
|
||||
- name: Setup conda env
|
||||
uses: conda-incubator/setup-miniconda@v2
|
||||
with:
|
||||
auto-update-conda: true
|
||||
miniconda-version: "latest"
|
||||
use-only-tar-bz2: true
|
||||
activate-environment: test
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Upgrade pip
|
||||
run: pip install pip --upgrade
|
||||
- name: Fetch GPG keys
|
||||
- name: Update pip
|
||||
run: |
|
||||
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
|
||||
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
|
||||
- name: Install system dependencies
|
||||
run: apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6
|
||||
- name: Build MMEngine from source
|
||||
run: pip install -e . -v
|
||||
- name: Install unit tests dependencies
|
||||
python -m pip install --upgrade pip wheel
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -r requirements/tests.txt
|
||||
pip install openmim
|
||||
mim install mmcv
|
||||
- name: Run unittests and generate coverage report
|
||||
run: |
|
||||
coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist
|
||||
coverage xml
|
||||
coverage report -m
|
||||
apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6
|
||||
python -m pip install torch==${{matrix.torch}} --index-url https://download.pytorch.org/whl/${{matrix.cuda}}
|
||||
python -m pip install -e . -v
|
||||
python -m pip install -r requirements/tests.txt
|
||||
python -m pip install openmim
|
||||
mim install mmcv coverage
|
||||
- name: Run unit tests with coverage
|
||||
run: coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist
|
||||
- name: Upload Coverage to Codecov
|
||||
uses: codecov/codecov-action@v3
|
||||
|
||||
build_cu117:
|
||||
runs-on: ubuntu-22.04
|
||||
container:
|
||||
image: pytorch/pytorch:2.0.0-cuda11.7-cudnn8-devel
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.9]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Upgrade pip
|
||||
run: pip install pip --upgrade
|
||||
- name: Fetch GPG keys
|
||||
run: |
|
||||
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
|
||||
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
|
||||
- name: Install system dependencies
|
||||
run: apt-get update && apt-get install -y git ffmpeg libturbojpeg
|
||||
- name: Build MMEngine from source
|
||||
run: pip install -e . -v
|
||||
- name: Install unit tests dependencies
|
||||
run: |
|
||||
pip install -r requirements/tests.txt
|
||||
pip install openmim
|
||||
mim install mmcv
|
||||
# Distributed related unit test may randomly error in PyTorch 1.13.0
|
||||
- name: Run unittests and generate coverage report
|
||||
run: |
|
||||
coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist/
|
||||
coverage xml
|
||||
coverage report -m
|
||||
|
||||
build_windows:
|
||||
runs-on: windows-2022
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.7]
|
||||
platform: [cpu, cu111]
|
||||
torch: [1.8.1]
|
||||
torchvision: [0.9.1]
|
||||
include:
|
||||
- python-version: 3.8
|
||||
platform: cu118
|
||||
torch: 2.1.0
|
||||
torchvision: 0.16.0
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Upgrade pip
|
||||
# Windows CI could fail If we call `pip install pip --upgrade` directly.
|
||||
run: python -m pip install pip --upgrade
|
||||
- name: Install PyTorch
|
||||
run: pip install torch==${{matrix.torch}}+${{matrix.platform}} torchvision==${{matrix.torchvision}}+${{matrix.platform}} -f https://download.pytorch.org/whl/${{matrix.platform}}/torch_stable.html
|
||||
- name: Build MMEngine from source
|
||||
run: pip install -e . -v
|
||||
- name: Install unit tests dependencies
|
||||
run: |
|
||||
pip install -r requirements/tests.txt
|
||||
pip install openmim
|
||||
mim install mmcv
|
||||
- name: Run CPU unittests
|
||||
run: pytest tests/ --ignore tests/test_dist
|
||||
if: ${{ matrix.platform == 'cpu' }}
|
||||
- name: Run GPU unittests
|
||||
# Skip testing distributed related unit tests since the memory of windows CI is limited
|
||||
run: pytest tests/ --ignore tests/test_dist --ignore tests/test_optim/test_optimizer/test_optimizer_wrapper.py --ignore tests/test_model/test_wrappers/test_model_wrapper.py --ignore tests/test_hooks/test_sync_buffers_hook.py
|
||||
if: ${{ matrix.platform == 'cu111' }} || ${{ matrix.platform == 'cu118' }}
|
||||
# build_windows:
|
||||
# runs-on: windows-2022
|
||||
# strategy:
|
||||
# matrix:
|
||||
# python-version: [3.9]
|
||||
# platform: [cpu, cu111]
|
||||
# torch: [1.8.1]
|
||||
# torchvision: [0.9.1]
|
||||
# include:
|
||||
# - python-version: 3.8
|
||||
# platform: cu118
|
||||
# torch: 2.1.0
|
||||
# torchvision: 0.16.0
|
||||
# steps:
|
||||
# - uses: actions/checkout@v3
|
||||
# - name: Set up Python ${{ matrix.python-version }}
|
||||
# uses: actions/setup-python@v4
|
||||
# with:
|
||||
# python-version: ${{ matrix.python-version }}
|
||||
# - name: Upgrade pip
|
||||
# # Windows CI could fail If we call `pip install pip --upgrade` directly.
|
||||
# run: python -m pip install pip wheel --upgrade
|
||||
# - name: Install PyTorch
|
||||
# run: pip install torch==${{matrix.torch}}+${{matrix.platform}} torchvision==${{matrix.torchvision}}+${{matrix.platform}} -f https://download.pytorch.org/whl/${{matrix.platform}}/torch_stable.html
|
||||
# - name: Build MMEngine from source
|
||||
# run: pip install -e . -v
|
||||
# - name: Install unit tests dependencies
|
||||
# run: |
|
||||
# pip install -r requirements/tests.txt
|
||||
# pip install openmim
|
||||
# mim install mmcv
|
||||
# - name: Run CPU unittests
|
||||
# run: pytest tests/ --ignore tests/test_dist
|
||||
# if: ${{ matrix.platform == 'cpu' }}
|
||||
# - name: Run GPU unittests
|
||||
# # Skip testing distributed related unit tests since the memory of windows CI is limited
|
||||
# run: pytest tests/ --ignore tests/test_dist --ignore tests/test_optim/test_optimizer/test_optimizer_wrapper.py --ignore tests/test_model/test_wrappers/test_model_wrapper.py --ignore tests/test_hooks/test_sync_buffers_hook.py
|
||||
# if: ${{ matrix.platform == 'cu111' }} || ${{ matrix.platform == 'cu118' }}
|
||||
|
|
10
.owners.yml
10
.owners.yml
|
@ -1,10 +0,0 @@
|
|||
assign:
|
||||
strategy:
|
||||
# random
|
||||
daily-shift-based
|
||||
scedule:
|
||||
'*/1 * * * *'
|
||||
assignees:
|
||||
- zhouzaida
|
||||
- HAOCHENYE
|
||||
- C1rN09
|
|
@ -1,7 +1,11 @@
|
|||
exclude: ^tests/data/
|
||||
repos:
|
||||
- repo: https://gitee.com/openmmlab/mirrors-flake8
|
||||
rev: 5.0.4
|
||||
- repo: https://github.com/pre-commit/pre-commit
|
||||
rev: v4.0.0
|
||||
hooks:
|
||||
- id: validate_manifest
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 7.1.1
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://gitee.com/openmmlab/mirrors-isort
|
||||
|
@ -13,7 +17,7 @@ repos:
|
|||
hooks:
|
||||
- id: yapf
|
||||
- repo: https://gitee.com/openmmlab/mirrors-pre-commit-hooks
|
||||
rev: v4.3.0
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: check-yaml
|
||||
|
@ -55,7 +59,7 @@ repos:
|
|||
args: ["mmengine", "tests"]
|
||||
- id: remove-improper-eol-in-cn-docs
|
||||
- repo: https://gitee.com/openmmlab/mirrors-mypy
|
||||
rev: v0.812
|
||||
rev: v1.2.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
exclude: |-
|
||||
|
@ -63,3 +67,4 @@ repos:
|
|||
^examples
|
||||
| ^docs
|
||||
)
|
||||
additional_dependencies: ["types-setuptools", "types-requests", "types-PyYAML"]
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
exclude: ^tests/data/
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit
|
||||
rev: v4.0.0
|
||||
hooks:
|
||||
- id: validate_manifest
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 5.0.4
|
||||
rev: 7.1.1
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
|
@ -13,7 +17,7 @@ repos:
|
|||
hooks:
|
||||
- id: yapf
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.3.0
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: check-yaml
|
||||
|
@ -34,12 +38,8 @@ repos:
|
|||
- mdformat-openmmlab
|
||||
- mdformat_frontmatter
|
||||
- linkify-it-py
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.2.1
|
||||
hooks:
|
||||
- id: codespell
|
||||
- repo: https://github.com/myint/docformatter
|
||||
rev: v1.3.1
|
||||
rev: 06907d0
|
||||
hooks:
|
||||
- id: docformatter
|
||||
args: ["--in-place", "--wrap-descriptions", "79"]
|
||||
|
@ -55,7 +55,7 @@ repos:
|
|||
args: ["mmengine", "tests"]
|
||||
- id: remove-improper-eol-in-cn-docs
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v0.812
|
||||
rev: v1.2.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
exclude: |-
|
||||
|
@ -63,3 +63,4 @@ repos:
|
|||
^examples
|
||||
| ^docs
|
||||
)
|
||||
additional_dependencies: ["types-setuptools", "types-requests", "types-PyYAML"]
|
||||
|
|
84
CODEOWNERS
84
CODEOWNERS
|
@ -1,84 +0,0 @@
|
|||
# IMPORTANT:
|
||||
# This file is ONLY used to subscribe for notifications for PRs
|
||||
# related to a specific file path, and each line is a file pattern followed by
|
||||
# one or more owners.
|
||||
|
||||
# Order is important; the last matching pattern takes the most
|
||||
# precedence.
|
||||
|
||||
# These owners will be the default owners for everything in
|
||||
# the repo. Unless a later match takes precedence,
|
||||
# @global-owner1 and @global-owner2 will be requested for
|
||||
# review when someone opens a pull request.
|
||||
* @zhouzaida @HAOCHENYE
|
||||
|
||||
# Docs
|
||||
/docs/ @C1rN09
|
||||
*.rst @zhouzaida @HAOCHENYE
|
||||
|
||||
# mmengine file
|
||||
# config
|
||||
/mmengine/config/ @HAOCHENYE
|
||||
|
||||
# dataset
|
||||
/mmengine/dataset/ @HAOCHENYE
|
||||
|
||||
# device
|
||||
/mmengine/device/ @zhouzaida
|
||||
|
||||
# dist
|
||||
/mmengine/dist/ @zhouzaida @C1rN09
|
||||
|
||||
# evaluator
|
||||
/mmengine/evaluator/ @RangiLyu @C1rN09
|
||||
|
||||
# fileio
|
||||
/mmengine/fileio/ @zhouzaida
|
||||
|
||||
# hooks
|
||||
/mmengine/hooks/ @zhouzaida @HAOCHENYE
|
||||
/mmengine/hooks/ema_hook.py @RangiLyu
|
||||
|
||||
# hub
|
||||
/mmengine/hub/ @HAOCHENYE @zhouzaida
|
||||
|
||||
# logging
|
||||
/mmengine/logging/ @HAOCHENYE
|
||||
|
||||
# model
|
||||
/mmengine/model/ @HAOCHENYE @C1rN09
|
||||
/mmengine/model/averaged_model.py @RangiLyu
|
||||
/mmengine/model/wrappers/fully_sharded_distributed.py @C1rN09
|
||||
|
||||
# optim
|
||||
/mmengine/optim/ @HAOCHENYE
|
||||
/mmengine/optim/scheduler/ @RangiLyu
|
||||
|
||||
# registry
|
||||
/mmengine/registry/ @C1rN09 @HAOCHENYE
|
||||
|
||||
# runner
|
||||
/mmengine/runner/ @zhouzaida @RangiLyu @HAOCHENYE
|
||||
/mmengine/runner/amp.py @HAOCHENYE
|
||||
/mmengine/runner/log_processor.py @HAOCHENYE
|
||||
/mmengine/runner/checkpoint.py @zhouzaida @C1rN09
|
||||
/mmengine/runner/priority.py @zhouzaida
|
||||
/mmengine/runner/utils.py @zhouzaida @HAOCHENYE
|
||||
|
||||
# structure
|
||||
/mmengine/structures/ @Harold-lkk @HAOCHENYE
|
||||
|
||||
# testing
|
||||
/mmengine/testing/ @zhouzaida
|
||||
|
||||
# utils
|
||||
/mmengine/utils/ @HAOCHENYE @zhouzaida
|
||||
|
||||
# visualization
|
||||
/mmengine/visualization/ @Harold-lkk @HAOCHENYE
|
||||
|
||||
# version
|
||||
/mmengine/__version__.py @zhouzaida
|
||||
|
||||
# unit test
|
||||
/tests/ @zhouzaida @HAOCHENYE
|
70
README.md
70
README.md
|
@ -19,13 +19,14 @@
|
|||
<div> </div>
|
||||
|
||||
[](https://pypi.org/project/mmengine/)
|
||||
[](#installation)
|
||||
[](https://pypi.org/project/mmengine)
|
||||
[](https://github.com/open-mmlab/mmengine/blob/main/LICENSE)
|
||||
[](https://github.com/open-mmlab/mmengine/issues)
|
||||
[](https://github.com/open-mmlab/mmengine/issues)
|
||||
|
||||
[Introduction](#introduction) |
|
||||
[Installation](#installation) |
|
||||
[Get Started](#get-started) |
|
||||
[📘Documentation](https://mmengine.readthedocs.io/en/latest/) |
|
||||
[🛠️Installation](https://mmengine.readthedocs.io/en/latest/get_started/installation.html) |
|
||||
[🤔Reporting Issues](https://github.com/open-mmlab/mmengine/issues/new/choose)
|
||||
|
||||
</div>
|
||||
|
@ -58,58 +59,53 @@ English | [简体中文](README_zh-CN.md)
|
|||
|
||||
## What's New
|
||||
|
||||
v0.10.1 was released on 2023-11-22.
|
||||
v0.10.6 was released on 2025-01-13.
|
||||
|
||||
Highlights:
|
||||
|
||||
- Support installing mmengine-lite with no dependency on opencv. Refer to the [Installation](https://mmengine.readthedocs.io/en/latest/get_started/installation.html#install-mmengine) for more details.
|
||||
- Support custom `artifact_location` in MLflowVisBackend [#1505](#1505)
|
||||
- Enable `exclude_frozen_parameters` for `DeepSpeedEngine._zero3_consolidated_16bit_state_dict` [#1517](#1517)
|
||||
|
||||
- Support training with [ColossalAI](https://colossalai.org/). Refer to the [Training Large Models](https://mmengine.readthedocs.io/en/latest/common_usage/large_model_training.html#colossalai) for more detailed usages.
|
||||
|
||||
- Support gradient checkpointing. Refer to the [Save Memory on GPU](https://mmengine.readthedocs.io/en/latest/common_usage/save_gpu_memory.html#gradient-checkpointing) for more details.
|
||||
|
||||
- Supports multiple visualization backends, including `NeptuneVisBackend`, `DVCLiveVisBackend` and `AimVisBackend`. Refer to [Visualization Backends](https://mmengine.readthedocs.io/en/latest/common_usage/visualize_training_log.html) for more details.
|
||||
|
||||
Read [Changelog](./docs/en/notes/changelog.md#v0101-22112023) for more details.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Introduction](#introduction)
|
||||
- [Installation](#installation)
|
||||
- [Get Started](#get-started)
|
||||
- [Learn More](#learn-more)
|
||||
- [Contributing](#contributing)
|
||||
- [Citation](#citation)
|
||||
- [License](#license)
|
||||
- [Ecosystem](#ecosystem)
|
||||
- [Projects in OpenMMLab](#projects-in-openmmlab)
|
||||
Read [Changelog](./docs/en/notes/changelog.md#v0104-2342024) for more details.
|
||||
|
||||
## Introduction
|
||||
|
||||
MMEngine is a foundational library for training deep learning models based on PyTorch. It provides a solid engineering foundation and frees developers from writing redundant codes on workflows. It serves as the training engine of all OpenMMLab codebases, which support hundreds of algorithms in various research areas. Moreover, MMEngine is also generic to be applied to non-OpenMMLab projects.
|
||||
MMEngine is a foundational library for training deep learning models based on PyTorch. It serves as the training engine of all OpenMMLab codebases, which support hundreds of algorithms in various research areas. Moreover, MMEngine is also generic to be applied to non-OpenMMLab projects. Its highlights are as follows:
|
||||
|
||||
Major features:
|
||||
**Integrate mainstream large-scale model training frameworks**
|
||||
|
||||
1. **A universal and powerful runner**:
|
||||
- [ColossalAI](https://mmengine.readthedocs.io/en/latest/common_usage/large_model_training.html#colossalai)
|
||||
- [DeepSpeed](https://mmengine.readthedocs.io/en/latest/common_usage/large_model_training.html#deepspeed)
|
||||
- [FSDP](https://mmengine.readthedocs.io/en/latest/common_usage/large_model_training.html#fullyshardeddataparallel-fsdp)
|
||||
|
||||
- Supports training different tasks with a small amount of code, e.g., ImageNet can be trained with only 80 lines of code (400 lines of the original PyTorch example).
|
||||
- Easily compatible with models from popular algorithm libraries such as TIMM, TorchVision, and Detectron2.
|
||||
**Supports a variety of training strategies**
|
||||
|
||||
2. **Open architecture with unified interfaces**:
|
||||
- [Mixed Precision Training](https://mmengine.readthedocs.io/en/latest/common_usage/speed_up_training.html#mixed-precision-training)
|
||||
- [Gradient Accumulation](https://mmengine.readthedocs.io/en/latest/common_usage/save_gpu_memory.html#gradient-accumulation)
|
||||
- [Gradient Checkpointing](https://mmengine.readthedocs.io/en/latest/common_usage/save_gpu_memory.html#gradient-checkpointing)
|
||||
|
||||
- Handles different algorithm tasks with unified APIs, e.g., implement a method and apply it to all compatible models.
|
||||
- Provides a unified abstraction for upper-level algorithm libraries, which supports various back-end devices such as Nvidia CUDA, Mac MPS, AMD, MLU, and more for model training.
|
||||
**Provides a user-friendly configuration system**
|
||||
|
||||
3. **Customizable training process**:
|
||||
- [Pure Python-style configuration files, easy to navigate](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta)
|
||||
- [Plain-text-style configuration files, supporting JSON and YAML](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html)
|
||||
|
||||
- Defines the training process just like playing with Legos.
|
||||
- Provides rich components and strategies.
|
||||
- Complete controls on the training process with different levels of APIs.
|
||||
**Covers mainstream training monitoring platforms**
|
||||
|
||||

|
||||
- [TensorBoard](https://mmengine.readthedocs.io/en/latest/common_usage/visualize_training_log.html#tensorboard) | [WandB](https://mmengine.readthedocs.io/en/latest/common_usage/visualize_training_log.html#wandb) | [MLflow](https://mmengine.readthedocs.io/en/latest/common_usage/visualize_training_log.html#mlflow-wip)
|
||||
- [ClearML](https://mmengine.readthedocs.io/en/latest/common_usage/visualize_training_log.html#clearml) | [Neptune](https://mmengine.readthedocs.io/en/latest/common_usage/visualize_training_log.html#neptune) | [DVCLive](https://mmengine.readthedocs.io/en/latest/common_usage/visualize_training_log.html#dvclive) | [Aim](https://mmengine.readthedocs.io/en/latest/common_usage/visualize_training_log.html#aim)
|
||||
|
||||
## Installation
|
||||
|
||||
<details>
|
||||
<summary>Supported PyTorch Versions</summary>
|
||||
|
||||
| MMEngine | PyTorch | Python |
|
||||
| ------------------ | ------------ | -------------- |
|
||||
| main | >=1.6 \<=2.1 | >=3.8, \<=3.11 |
|
||||
| >=0.9.0, \<=0.10.4 | >=1.6 \<=2.1 | >=3.8, \<=3.11 |
|
||||
|
||||
</details>
|
||||
|
||||
Before installing MMEngine, please ensure that PyTorch has been successfully installed following the [official guide](https://pytorch.org/get-started/locally/).
|
||||
|
||||
Install MMEngine
|
||||
|
|
|
@ -19,13 +19,14 @@
|
|||
<div> </div>
|
||||
|
||||
[](https://pypi.org/project/mmengine/)
|
||||
[](#安装)
|
||||
[](https://pypi.org/project/mmengine)
|
||||
[](https://github.com/open-mmlab/mmengine/blob/main/LICENSE)
|
||||
[](https://github.com/open-mmlab/mmengine/issues)
|
||||
[](https://github.com/open-mmlab/mmengine/issues)
|
||||
|
||||
[📘使用文档](https://mmengine.readthedocs.io/zh_CN/latest/) |
|
||||
[🛠️安装教程](https://mmengine.readthedocs.io/zh_CN/latest/get_started/installation.html) |
|
||||
[简介](#简介) |
|
||||
[安装](#安装) |
|
||||
[快速上手](#快速上手) |
|
||||
[📘用户文档](https://mmengine.readthedocs.io/zh_CN/latest/) |
|
||||
[🤔报告问题](https://github.com/open-mmlab/mmengine/issues/new/choose)
|
||||
|
||||
</div>
|
||||
|
@ -58,59 +59,58 @@
|
|||
|
||||
## 最近进展
|
||||
|
||||
最新版本 v0.10.1 在 2023.11.22 发布。
|
||||
最新版本 v0.10.5 在 2024.9.11 发布。
|
||||
|
||||
亮点:
|
||||
版本亮点:
|
||||
|
||||
- 支持安装不依赖于 opencv 的 mmengine-lite 版本。可阅读[安装文档](https://mmengine.readthedocs.io/zh-cn/latest/get_started/installation.html#mmengine)了解用法。
|
||||
- 支持在 MLFlowVisBackend 中自定义 `artifact_location` [#1505](#1505)
|
||||
- 支持在 `DeepSpeedEngine._zero3_consolidated_16bit_state_dict` 使用 `exclude_frozen_parameters` [#1517](#1517)
|
||||
|
||||
- 支持使用 [ColossalAI](https://colossalai.org/) 进行训练。可阅读[大模型训练](https://mmengine.readthedocs.io/zh_CN/latest/common_usage/large_model_training.html#colossalai)了解用法。
|
||||
|
||||
- 支持梯度检查点。详见[用法](https://mmengine.readthedocs.io/zh_CN/latest/common_usage/save_gpu_memory.html#id3)。
|
||||
|
||||
- 支持多种可视化后端,包括`NeptuneVisBackend`、`DVCLiveVisBackend` 和 `AimVisBackend`。可阅读[可视化后端](https://mmengine.readthedocs.io/zh_CN/latest/common_usage/visualize_training_log.html)了解用法。
|
||||
|
||||
如果想了解更多版本更新细节和历史信息,请阅读[更新日志](./docs/en/notes/changelog.md#v0101-22112023)
|
||||
|
||||
## 目录
|
||||
|
||||
- [简介](#简介)
|
||||
- [安装](#安装)
|
||||
- [快速上手](#快速上手)
|
||||
- [了解更多](#了解更多)
|
||||
- [贡献指南](#贡献指南)
|
||||
- [引用](#引用)
|
||||
- [开源许可证](#开源许可证)
|
||||
- [生态项目](#生态项目)
|
||||
- [OpenMMLab 的其他项目](#openmmlab-的其他项目)
|
||||
- [欢迎加入 OpenMMLab 社区](#欢迎加入-openmmlab-社区)
|
||||
如果想了解更多版本更新细节和历史信息,请阅读[更新日志](./docs/en/notes/changelog.md#v0104-2342024)。
|
||||
|
||||
## 简介
|
||||
|
||||
MMEngine 是一个基于 PyTorch 实现的,用于训练深度学习模型的基础库。它为开发人员提供了坚实的工程基础,以此避免在工作流上编写冗余代码。作为 OpenMMLab 所有代码库的训练引擎,其在不同研究领域支持了上百个算法。此外,MMEngine 也可以用于非 OpenMMLab 项目中。
|
||||
MMEngine 是一个基于 PyTorch 实现的,用于训练深度学习模型的基础库。它作为 OpenMMLab 所有代码库的训练引擎,其在不同研究领域支持了上百个算法。此外,MMEngine 也可以用于非 OpenMMLab 项目中。它的亮点如下:
|
||||
|
||||
主要特性:
|
||||
**集成主流的大模型训练框架**
|
||||
|
||||
1. **通用且强大的执行器**:
|
||||
- [ColossalAI](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/large_model_training.html#colossalai)
|
||||
- [DeepSpeed](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/large_model_training.html#deepspeed)
|
||||
- [FSDP](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/large_model_training.html#fullyshardeddataparallel-fsdp)
|
||||
|
||||
- 支持用少量代码训练不同的任务,例如仅使用 80 行代码就可以训练 ImageNet(原始 PyTorch 示例需要 400 行)。
|
||||
- 轻松兼容流行的算法库(如 TIMM、TorchVision 和 Detectron2)中的模型。
|
||||
**支持丰富的训练策略**
|
||||
|
||||
2. **接口统一的开放架构**:
|
||||
- [混合精度训练(Mixed Precision Training)](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/speed_up_training.html#id3)
|
||||
- [梯度累积(Gradient Accumulation)](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/save_gpu_memory.html#id2)
|
||||
- [梯度检查点(Gradient Checkpointing)](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/save_gpu_memory.html#id3)
|
||||
|
||||
- 使用统一的接口处理不同的算法任务,例如,实现一个方法并应用于所有的兼容性模型。
|
||||
- 上下游的对接更加统一便捷,在为上层算法库提供统一抽象的同时,支持多种后端设备。目前 MMEngine 支持 Nvidia CUDA、Mac MPS、AMD、MLU 等设备进行模型训练。
|
||||
**提供易用的配置系统**
|
||||
|
||||
3. **可定制的训练流程**:
|
||||
- [纯 Python 风格的配置文件,易于跳转](https://mmengine.readthedocs.io/zh-cn/latest/advanced_tutorials/config.html#python-beta)
|
||||
- [纯文本风格的配置文件,支持 JSON 和 YAML](https://mmengine.readthedocs.io/zh-cn/latest/advanced_tutorials/config.html#id1)
|
||||
|
||||
- 定义了“乐高”式的训练流程。
|
||||
- 提供了丰富的组件和策略。
|
||||
- 使用不同等级的 API 控制训练过程。
|
||||
**覆盖主流的训练监测平台**
|
||||
|
||||

|
||||
- [TensorBoard](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/visualize_training_log.html#tensorboard) | [WandB](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/visualize_training_log.html#wandb) | [MLflow](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/visualize_training_log.html#mlflow-wip)
|
||||
- [ClearML](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/visualize_training_log.html#clearml) | [Neptune](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/visualize_training_log.html#neptune) | [DVCLive](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/visualize_training_log.html#dvclive) | [Aim](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/visualize_training_log.html#aim)
|
||||
|
||||
**兼容主流的训练芯片**
|
||||
|
||||
- 英伟达 CUDA | 苹果 MPS
|
||||
- 华为 Ascend | 寒武纪 MLU | 摩尔线程 MUSA
|
||||
|
||||
## 安装
|
||||
|
||||
<details>
|
||||
<summary>支持的 PyTorch 版本</summary>
|
||||
|
||||
| MMEngine | PyTorch | Python |
|
||||
| ------------------ | ------------ | -------------- |
|
||||
| main | >=1.6 \<=2.1 | >=3.8, \<=3.11 |
|
||||
| >=0.9.0, \<=0.10.4 | >=1.6 \<=2.1 | >=3.8, \<=3.11 |
|
||||
|
||||
</details>
|
||||
|
||||
在安装 MMEngine 之前,请确保 PyTorch 已成功安装在环境中,可以参考 [PyTorch 官方安装文档](https://pytorch.org/get-started/locally/)。
|
||||
|
||||
安装 MMEngine
|
||||
|
|
|
@ -3,6 +3,9 @@ version: 2
|
|||
formats:
|
||||
- epub
|
||||
|
||||
sphinx:
|
||||
configuration: docs/en/conf.py
|
||||
|
||||
# Set the version of Python and other tools you might need
|
||||
build:
|
||||
os: ubuntu-22.04
|
|
@ -1008,7 +1008,7 @@ In this section, we use MMDetection to demonstrate how to migrate the abstract d
|
|||
|
||||
### 1. Simplify the module interface
|
||||
|
||||
Detector's external interfaces can be significantly simplified and unified. In the training process of a single-stage detection and segmentation algorithm in MMDet 2.X, `SingleStageDetector` requires `img`, `img_metas`, `gt_bboxes`, `gt_labels` and `gt_bboxes_ignore` as the inputs, but `SingleStageInstanceSegmentor` requires `gt_masks` as well. This causes inconsistency in the training interface and affects flexibility.
|
||||
Detector's external interfaces can be significantly simplified and unified. In the training process of a single-stage detection and segmentation algorithm in MMDet 2.X, `SingleStageDetector` requires `img`, `img_metas`, `gt_bboxes`, `gt_labels` and `gt_bboxes_ignore` as the inputs, but `SingleStageInstanceSegmentor` requires `gt_masks` as well. This causes inconsistency in the training interface and affects flexibility.
|
||||
|
||||
```python
|
||||
class SingleStageDetector(BaseDetector):
|
||||
|
|
|
@ -4,7 +4,7 @@ This document provides some third-party optimizers supported by MMEngine, which
|
|||
|
||||
## D-Adaptation
|
||||
|
||||
[D-Adaptation](https://github.com/facebookresearch/dadaptation) provides `DAdaptAdaGrad`, `DAdaptAdam` and `DAdaptSGD` optimziers。
|
||||
[D-Adaptation](https://github.com/facebookresearch/dadaptation) provides `DAdaptAdaGrad`, `DAdaptAdam` and `DAdaptSGD` optimizers.
|
||||
|
||||
```{note}
|
||||
If you use the optimizer provided by D-Adaptation, you need to upgrade mmengine to `0.6.0`.
|
||||
|
@ -35,7 +35,7 @@ runner.train()
|
|||
|
||||
## Lion-Pytorch
|
||||
|
||||
[lion-pytorch](https://github.com/lucidrains/lion-pytorch) provides the `Lion` optimizer。
|
||||
[lion-pytorch](https://github.com/lucidrains/lion-pytorch) provides the `Lion` optimizer.
|
||||
|
||||
```{note}
|
||||
If you use the optimizer provided by Lion-Pytorch, you need to upgrade mmengine to `0.6.0`.
|
||||
|
@ -93,7 +93,7 @@ runner.train()
|
|||
|
||||
## bitsandbytes
|
||||
|
||||
[bitsandbytes](https://github.com/TimDettmers/bitsandbytes) provides `AdamW8bit`, `Adam8bit`, `Adagrad8bit`, `PagedAdam8bit`, `PagedAdamW8bit`, `LAMB8bit`, `LARS8bit`, `RMSprop8bit`, `Lion8bit`, `PagedLion8bit` and `SGD8bit` optimziers。
|
||||
[bitsandbytes](https://github.com/TimDettmers/bitsandbytes) provides `AdamW8bit`, `Adam8bit`, `Adagrad8bit`, `PagedAdam8bit`, `PagedAdamW8bit`, `LAMB8bit`, `LARS8bit`, `RMSprop8bit`, `Lion8bit`, `PagedLion8bit` and `SGD8bit` optimizers.
|
||||
|
||||
```{note}
|
||||
If you use the optimizer provided by bitsandbytes, you need to upgrade mmengine to `0.9.0`.
|
||||
|
@ -124,7 +124,7 @@ runner.train()
|
|||
|
||||
## transformers
|
||||
|
||||
[transformers](https://github.com/huggingface/transformers) provides `Adafactor` optimzier。
|
||||
[transformers](https://github.com/huggingface/transformers) provides `Adafactor` optimzier.
|
||||
|
||||
```{note}
|
||||
If you use the optimizer provided by transformers, you need to upgrade mmengine to `0.9.0`.
|
||||
|
|
|
@ -30,7 +30,7 @@ train_dataloader = dict(
|
|||
type=dataset_type,
|
||||
data_prefix='data/cifar10',
|
||||
test_mode=False,
|
||||
indices=5000, # set indices=5000,represent every epoch only iterator 5000 samples
|
||||
indices=5000, # set indices=5000, represent every epoch only iterator 5000 samples
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
)
|
||||
|
|
|
@ -26,7 +26,7 @@ On the first machine:
|
|||
|
||||
```bash
|
||||
python -m torch.distributed.launch \
|
||||
--nnodes 8 \
|
||||
--nnodes 2 \
|
||||
--node_rank 0 \
|
||||
--master_addr 127.0.0.1 \
|
||||
--master_port 29500 \
|
||||
|
@ -38,9 +38,9 @@ On the second machine:
|
|||
|
||||
```bash
|
||||
python -m torch.distributed.launch \
|
||||
--nnodes 8 \
|
||||
--nnodes 2 \
|
||||
--node_rank 1 \
|
||||
--master_addr 127.0.0.1 \
|
||||
--master_addr "ip_of_the_first_machine" \
|
||||
--master_port 29500 \
|
||||
--nproc_per_node=8 \
|
||||
examples/distributed_training.py --launcher pytorch
|
||||
|
|
|
@ -87,10 +87,10 @@ OpenMMLab requires the `inferencer(img)` to output a `dict` containing two field
|
|||
|
||||
When performing inference, the following steps are typically executed:
|
||||
|
||||
1. preprocess:Input data preprocessing, including data reading, data preprocessing, data format conversion, etc.
|
||||
1. preprocess: Input data preprocessing, including data reading, data preprocessing, data format conversion, etc.
|
||||
2. forward: Execute `model.forwward`
|
||||
3. visualize:Visualization of predicted results.
|
||||
4. postprocess:Post-processing of predicted results, including result format conversion, exporting predicted results, etc.
|
||||
3. visualize: Visualization of predicted results.
|
||||
4. postprocess: Post-processing of predicted results, including result format conversion, exporting predicted results, etc.
|
||||
|
||||
To improve the user experience of the inferencer, we do not want users to have to configure parameters for each step when performing inference. In other words, we hope that users can simply configure parameters for the `__call__` interface without being aware of the above process and complete the inference.
|
||||
|
||||
|
@ -173,8 +173,8 @@ Initializes and returns the `visualizer` required by the inferencer, which is eq
|
|||
|
||||
Input arguments:
|
||||
|
||||
- inputs:Input data, passed into `__call__`, usually a list of image paths or image data.
|
||||
- batch_size:batch size, passed in by the user when calling `__call__`.
|
||||
- inputs: Input data, passed into `__call__`, usually a list of image paths or image data.
|
||||
- batch_size: batch size, passed in by the user when calling `__call__`.
|
||||
- Other parameters: Passed in by the user and specified in `preprocess_kwargs`.
|
||||
|
||||
Return:
|
||||
|
@ -187,7 +187,7 @@ The `preprocess` function is a generator function by default, which applies the
|
|||
|
||||
Input arguments:
|
||||
|
||||
- inputs:The batch data processed by `preprocess` function.
|
||||
- inputs: The batch data processed by `preprocess` function.
|
||||
- Other parameters: Passed in by the user and specified in `forward_kwargs`.
|
||||
|
||||
Return:
|
||||
|
@ -204,9 +204,9 @@ This is an abstract method that must be implemented by the subclass.
|
|||
|
||||
Input arguments:
|
||||
|
||||
- inputs:The input data, which is the raw data without preprocessing.
|
||||
- preds:Predicted results of the model.
|
||||
- show:Whether to visualize.
|
||||
- inputs: The input data, which is the raw data without preprocessing.
|
||||
- preds: Predicted results of the model.
|
||||
- show: Whether to visualize.
|
||||
- Other parameters: Passed in by the user and specified in `visualize_kwargs`.
|
||||
|
||||
Return:
|
||||
|
@ -221,12 +221,12 @@ This is an abstract method that must be implemented by the subclass.
|
|||
|
||||
Input arguments:
|
||||
|
||||
- preds:The predicted results of the model, which is a `list` type. Each element in the list represents the prediction result for a single data item. In the OpenMMLab series of algorithm libraries, the type of each element in the prediction result is `BaseDataElement`.
|
||||
- visualization:Visualization results
|
||||
- preds: The predicted results of the model, which is a `list` type. Each element in the list represents the prediction result for a single data item. In the OpenMMLab series of algorithm libraries, the type of each element in the prediction result is `BaseDataElement`.
|
||||
- visualization: Visualization results
|
||||
- return_datasample: Whether to maintain datasample for return. When set to `False`, the returned result is converted to a `dict`.
|
||||
- Other parameters: Passed in by the user and specified in `postprocess_kwargs`.
|
||||
|
||||
Return:
|
||||
Return:
|
||||
|
||||
- The type of the returned value is a dictionary containing both the visualization and prediction results. OpenMMLab requires the returned dictionary to have two keys: `predictions` and `visualization`.
|
||||
|
||||
|
@ -234,9 +234,9 @@ Return:
|
|||
|
||||
Input arguments:
|
||||
|
||||
- inputs:The input data, usually a list of image paths or image data. Each element in `inputs` can also be other types of data as long as it can be processed by the `pipeline` returned by [init_pipeline](#_init_pipeline). When there is only one inference data in `inputs`, it does not have to be a `list`, `__call__` will internally wrap it into a list for further processing.
|
||||
- inputs: The input data, usually a list of image paths or image data. Each element in `inputs` can also be other types of data as long as it can be processed by the `pipeline` returned by [init_pipeline](#_init_pipeline). When there is only one inference data in `inputs`, it does not have to be a `list`, `__call__` will internally wrap it into a list for further processing.
|
||||
- return_datasample: Whether to convert datasample to dict for return.
|
||||
- batch_size:Batch size for inference, which will be further passed to the `preprocess` function.
|
||||
- batch_size: Batch size for inference, which will be further passed to the `preprocess` function.
|
||||
- Other parameters: Additional parameters assigned to `preprocess`, `forward`, `visualize`, and `postprocess` methods.
|
||||
|
||||
Return:
|
||||
|
|
|
@ -74,11 +74,11 @@ history_buffer.min()
|
|||
# 1, the global minimum
|
||||
|
||||
history_buffer.max(2)
|
||||
# 3,the maximum in [2, 3]
|
||||
# 3, the maximum in [2, 3]
|
||||
history_buffer.min()
|
||||
# 3, the global maximum
|
||||
history_buffer.mean(2)
|
||||
# 2.5,the mean value in [2, 3], (2 + 3) / (1 + 1)
|
||||
# 2.5, the mean value in [2, 3], (2 + 3) / (1 + 1)
|
||||
history_buffer.mean()
|
||||
# 2, the global mean, (1 + 2 + 3) / (1 + 1 + 1)
|
||||
history_buffer = HistoryBuffer([1, 2, 3], [2, 2, 2]) # Cases when counts are not 1
|
||||
|
@ -431,7 +431,7 @@ In the case of multiple processes in multiple nodes without storage, logs are or
|
|||
|
||||
```text
|
||||
# without shared storage
|
||||
# node 0:
|
||||
# node 0:
|
||||
work_dir/20230228_141908
|
||||
├── 20230306_183634_${hostname}_device0_rank0.log
|
||||
├── 20230306_183634_${hostname}_device1_rank1.log
|
||||
|
@ -442,7 +442,7 @@ work_dir/20230228_141908
|
|||
├── 20230306_183634_${hostname}_device6_rank6.log
|
||||
├── 20230306_183634_${hostname}_device7_rank7.log
|
||||
|
||||
# node 7:
|
||||
# node 7:
|
||||
work_dir/20230228_141908
|
||||
├── 20230306_183634_${hostname}_device0_rank56.log
|
||||
├── 20230306_183634_${hostname}_device1_rank57.log
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# 15 minutes to get started with MMEngine
|
||||
|
||||
In this tutorial, we'll take training a ResNet-50 model on CIFAR-10 dataset as an example. We will build a complete and configurable pipeline for both training and validation in only 80 lines of code with `MMEgnine`.
|
||||
In this tutorial, we'll take training a ResNet-50 model on CIFAR-10 dataset as an example. We will build a complete and configurable pipeline for both training and validation in only 80 lines of code with `MMEngine`.
|
||||
The whole process includes the following steps:
|
||||
|
||||
- [15 minutes to get started with MMEngine](#15-minutes-to-get-started-with-mmengine)
|
||||
|
|
|
@ -1,30 +1,29 @@
|
|||
# Introduction
|
||||
|
||||
MMEngine is a foundational library for training deep learning models based on
|
||||
PyTorch. It supports running on Linux, Windows, and macOS. It has the
|
||||
following three features:
|
||||
PyTorch. It supports running on Linux, Windows, and macOS. Its highlights are as follows:
|
||||
|
||||
1. **Universal and powerful executor**:
|
||||
**Integrate mainstream large-scale model training frameworks**
|
||||
|
||||
- Supports training different tasks with minimal code, such as training
|
||||
ImageNet with just 80 lines of code (original PyTorch examples require
|
||||
400 lines).
|
||||
- Easily compatible with models from popular algorithm libraries like TIMM,
|
||||
TorchVision, and Detectron2.
|
||||
- [ColossalAI](../common_usage/large_model_training.md#colossalai)
|
||||
- [DeepSpeed](../common_usage/large_model_training.md#deepspeed)
|
||||
- [FSDP](../common_usage/large_model_training.md#fullyshardeddataparallel-fsdp)
|
||||
|
||||
2. **Open architecture with unified interfaces**:
|
||||
**Supports a variety of training strategies**
|
||||
|
||||
- Handles different tasks with a unified API: you can implement a method
|
||||
once and apply it to all compatible models.
|
||||
- Supports various backend devices through a simple, high-level
|
||||
abstraction. Currently, MMEngine supports model training on Nvidia CUDA,
|
||||
Mac MPS, AMD, MLU, and other devices.
|
||||
- [Mixed Precision Training](../common_usage/speed_up_training.md#mixed-precision-training)
|
||||
- [Gradient Accumulation](../common_usage/save_gpu_memory.md#gradient-accumulation)
|
||||
- [Gradient Checkpointing](../common_usage/save_gpu_memory.md#gradient-checkpointing)
|
||||
|
||||
3. **Customizable training process**:
|
||||
**Provides a user-friendly configuration system**
|
||||
|
||||
- Defines a highly modular training engine with "Lego"-like composability.
|
||||
- Offers a rich set of components and strategies.
|
||||
- Total control over the training process with different levels of APIs.
|
||||
- [Pure Python-style configuration files, easy to navigate](../advanced_tutorials/config.md#a-pure-python-style-configuration-file-beta)
|
||||
- [Plain-text-style configuration files, supporting JSON and YAML](../advanced_tutorials/config.html)
|
||||
|
||||
**Covers mainstream training monitoring platforms**
|
||||
|
||||
- [TensorBoard](../common_usage/visualize_training_log.md#tensorboard) | [WandB](../common_usage/visualize_training_log.md#wandb) | [MLflow](../common_usage/visualize_training_log.md#mlflow-wip)
|
||||
- [ClearML](../common_usage/visualize_training_log.md#clearml) | [Neptune](../common_usage/visualize_training_log.md#neptune) | [DVCLive](../common_usage/visualize_training_log.md#dvclive) | [Aim](../common_usage/visualize_training_log.md#aim)
|
||||
|
||||
## Architecture
|
||||
|
||||
|
|
|
@ -156,7 +156,7 @@ This tutorial compares the difference in function, mount point, usage and implem
|
|||
<tr>
|
||||
<td>after each iteration</td>
|
||||
<td>after_train_iter</td>
|
||||
<td>after_train_iter, with additional args: batch_idx、data_batch, and outputs</td>
|
||||
<td>after_train_iter, with additional args: batch_idx, data_batch, and outputs</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td rowspan="6">Validation related</td>
|
||||
|
@ -187,7 +187,7 @@ This tutorial compares the difference in function, mount point, usage and implem
|
|||
<tr>
|
||||
<td>after each iteration</td>
|
||||
<td>after_val_iter</td>
|
||||
<td>after_val_iter, with additional args: batch_idx、data_batch and outputs</td>
|
||||
<td>after_val_iter, with additional args: batch_idx, data_batch and outputs</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td rowspan="6">Test related</td>
|
||||
|
@ -218,7 +218,7 @@ This tutorial compares the difference in function, mount point, usage and implem
|
|||
<tr>
|
||||
<td>after each iteration</td>
|
||||
<td>None</td>
|
||||
<td>after_test_iter, with additional args: batch_idx、data_batch and outputs</td>
|
||||
<td>after_test_iter, with additional args: batch_idx, data_batch and outputs</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
|
|
@ -393,7 +393,7 @@ MMCV will wrap the model with distributed wrapper before building the runner, wh
|
|||
cfg = dict(model_wrapper_cfg='MMSeparateDistributedDataParallel')
|
||||
runner = Runner(
|
||||
model=model,
|
||||
..., # 其他配置
|
||||
...,
|
||||
launcher='pytorch',
|
||||
cfg=cfg)
|
||||
```
|
||||
|
|
|
@ -1,5 +1,68 @@
|
|||
# Changelog of v0.x
|
||||
|
||||
## v0.10.5 (11/9/2024)
|
||||
|
||||
- Fix `_is_builtin_module`. by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/1571
|
||||
|
||||
## v0.10.4 (23/4/2024)
|
||||
|
||||
### New Features & Enhancements
|
||||
|
||||
- Support custom `artifact_location` in MLflowVisBackend. by [@daavoo](https://github.com/daavoo) in https://github.com/open-mmlab/mmengine/pull/1505
|
||||
- Add the supported pytorch versions in README by [@zhouzaida](https://github.com/zhouzaida) in https://github.com/open-mmlab/mmengine/pull/1512
|
||||
- Perform evaluation upon training completion by [@LZHgrla](https://github.com/LZHgrla) in https://github.com/open-mmlab/mmengine/pull/1529
|
||||
- Enable `exclude_frozen_parameters` for `DeepSpeedEngine._zero3_consolidated_16bit_state_dict` by [@LZHgrla](https://github.com/LZHgrla) in https://github.com/open-mmlab/mmengine/pull/1517
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
- Fix warning capture by [@fanqiNO1](https://github.com/fanqiNO1) in https://github.com/open-mmlab/mmengine/pull/1494
|
||||
- Remove codeowners file by [@zhouzaida](https://github.com/zhouzaida) in https://github.com/open-mmlab/mmengine/pull/1496
|
||||
- Fix config of readthedocs by [@zhouzaida](https://github.com/zhouzaida) in https://github.com/open-mmlab/mmengine/pull/1511
|
||||
- Delete frozen parameters when using `paramwise_cfg` by [@LZHgrla](https://github.com/LZHgrla) in https://github.com/open-mmlab/mmengine/pull/1441
|
||||
|
||||
### Docs
|
||||
|
||||
- Refine mmengine intro by [@zhouzaida](https://github.com/zhouzaida) in https://github.com/open-mmlab/mmengine/pull/1479
|
||||
- Fix typo by [@zhouzaida](https://github.com/zhouzaida) in https://github.com/open-mmlab/mmengine/pull/1481
|
||||
- Fix typos and remove fullwidth unicode chars by [@evdcush](https://github.com/evdcush) in https://github.com/open-mmlab/mmengine/pull/1488
|
||||
- Fix docstring of Config by [@MambaWong](https://github.com/MambaWong) in https://github.com/open-mmlab/mmengine/pull/1506
|
||||
- Fix typo by [@hiramf](https://github.com/hiramf) in https://github.com/open-mmlab/mmengine/pull/1532
|
||||
|
||||
## v0.10.3 (24/1/2024)
|
||||
|
||||
### New Features & Enhancements
|
||||
|
||||
- Add the support for musa device support by [@hanhaowen-mt](https://github.com/hanhaowen-mt) in https://github.com/open-mmlab/mmengine/pull/1453
|
||||
- Support `save_optimizer=False` for DeepSpeed by [@LZHgrla](https://github.com/LZHgrla) in https://github.com/open-mmlab/mmengine/pull/1474
|
||||
- Update visualizer.py by [@Anm-pinellia](https://github.com/Anm-pinellia) in https://github.com/open-mmlab/mmengine/pull/1476
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
- Fix `Config.to_dict` by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/1465
|
||||
- Fix the resume of iteration by [@LZHgrla](https://github.com/LZHgrla) in https://github.com/open-mmlab/mmengine/pull/1471
|
||||
- Fix `dist.collect_results` to keep all ranks' elements by [@LZHgrla](https://github.com/LZHgrla) in https://github.com/open-mmlab/mmengine/pull/1469
|
||||
|
||||
### Docs
|
||||
|
||||
- Add the usage of ProfilerHook by [@zhouzaida](https://github.com/zhouzaida) in https://github.com/open-mmlab/mmengine/pull/1466
|
||||
- Fix the nnodes in the doc of ddp training by [@XiwuChen](https://github.com/XiwuChen) in https://github.com/open-mmlab/mmengine/pull/1462
|
||||
|
||||
## v0.10.2 (26/12/2023)
|
||||
|
||||
### New Features & Enhancements
|
||||
|
||||
- Support multi-node distributed training with NPU backend by [@shun001](https://github.com/shun001) in https://github.com/open-mmlab/mmengine/pull/1459
|
||||
- Use `ImportError` to cover `ModuleNotFoundError` by [@del-zhenwu](https://github.com/del-zhenwu) in https://github.com/open-mmlab/mmengine/pull/1438
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
- Fix bug in `load_model_state_dict` of `BaseStrategy` by [@SCZwangxiao](https://github.com/SCZwangxiao) in https://github.com/open-mmlab/mmengine/pull/1447
|
||||
- Fix placement policy in ColossalAIStrategy by [@fanqiNO1](https://github.com/fanqiNO1) in https://github.com/open-mmlab/mmengine/pull/1440
|
||||
|
||||
### Contributors
|
||||
|
||||
A total of 4 developers contributed to this release. Thanks [@shun001](https://github.com/shun001), [@del-zhenwu](https://github.com/del-zhenwu), [@SCZwangxiao](https://github.com/SCZwangxiao), [@fanqiNO1](https://github.com/fanqiNO1)
|
||||
|
||||
## v0.10.1 (22/11/2023)
|
||||
|
||||
### Bug Fixes
|
||||
|
@ -652,7 +715,7 @@ A total of 16 developers contributed to this release. Thanks [@BayMaxBHL](https:
|
|||
### Bug Fixes
|
||||
|
||||
- Fix error calculation of `eta_min` in `CosineRestartParamScheduler` by [@Z-Fran](https://github.com/Z-Fran) in https://github.com/open-mmlab/mmengine/pull/639
|
||||
- Fix `BaseDataPreprocessor.cast_data` could not handle string data by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/602
|
||||
- Fix `BaseDataPreprocessor.cast_data` could not handle string data by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/602
|
||||
- Make `autocast` compatible with mps by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/587
|
||||
- Fix error format of log message by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/508
|
||||
- Fix error implementation of `is_model_wrapper` by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/640
|
||||
|
|
|
@ -31,11 +31,12 @@ Each hook has a corresponding priority. At each mount point, hooks with higher p
|
|||
|
||||
**custom hooks**
|
||||
|
||||
| Name | Function | Priority |
|
||||
| :---------------------------------: | :----------------------------------------------------------------------: | :---------: |
|
||||
| [EMAHook](#emahook) | apply Exponential Moving Average (EMA) on the model during training | NORMAL (50) |
|
||||
| [EmptyCacheHook](#emptycachehook) | Releases all unoccupied cached GPU memory during the process of training | NORMAL (50) |
|
||||
| [SyncBuffersHook](#syncbuffershook) | Synchronize model buffers at the end of each epoch | NORMAL (50) |
|
||||
| Name | Function | Priority |
|
||||
| :---------------------------------: | :----------------------------------------------------------------------: | :-----------: |
|
||||
| [EMAHook](#emahook) | Apply Exponential Moving Average (EMA) on the model during training | NORMAL (50) |
|
||||
| [EmptyCacheHook](#emptycachehook) | Releases all unoccupied cached GPU memory during the process of training | NORMAL (50) |
|
||||
| [SyncBuffersHook](#syncbuffershook) | Synchronize model buffers at the end of each epoch | NORMAL (50) |
|
||||
| [ProfilerHook](#profilerhook) | Analyze the execution time and GPU memory usage of model operators | VERY_LOW (90) |
|
||||
|
||||
```{note}
|
||||
It is not recommended to modify the priority of the default hooks, as hooks with lower priority may depend on hooks with higher priority. For example, `CheckpointHook` needs to have a lower priority than ParamSchedulerHook so that the saved optimizer state is correct. Also, the priority of custom hooks defaults to `NORMAL (50)`.
|
||||
|
@ -211,6 +212,20 @@ runner = Runner(custom_hooks=custom_hooks, ...)
|
|||
runner.train()
|
||||
```
|
||||
|
||||
### ProfilerHook
|
||||
|
||||
The [ProfilerHook](mmengine.hooks.ProfilerHook) is used to analyze the execution time and GPU memory occupancy of model operators.
|
||||
|
||||
```python
|
||||
custom_hooks = [dict(type='ProfilerHook', on_trace_ready=dict(type='tb_trace'))]
|
||||
runner = Runner(custom_hooks=custom_hooks, ...)
|
||||
runner.train()
|
||||
```
|
||||
|
||||
The profiling results will be saved in the tf_tracing_logs directory under `work_dirs/{timestamp}`, and can be visualized using TensorBoard with the command `tensorboard --logdir work_dirs/{timestamp}/tf_tracing_logs`.
|
||||
|
||||
For more information on the usage of the ProfilerHook, please refer to the [ProfilerHook](mmengine.hooks.ProfilerHook) documentation.
|
||||
|
||||
## Customize Your Hooks
|
||||
|
||||
If the built-in hooks provided by MMEngine do not cover your demands, you are encouraged to customize your own hooks by simply inheriting the base [hook](mmengine.hooks.Hook) class and overriding the corresponding mount point methods.
|
||||
|
|
|
@ -43,7 +43,7 @@ Usually, we should define a model to implement the body of the algorithm. In MME
|
|||
Benefits from the `BaseModel`, we only need to make the model inherit from `BaseModel`, and implement the `forward` function to perform the training, testing, and validation process.
|
||||
|
||||
```{note}
|
||||
BaseModel inherits from [BaseModule](../advanced_tutorials/initialize.md),which can be used to initialize the model parameters dynamically.
|
||||
BaseModel inherits from [BaseModule](../advanced_tutorials/initialize.md), which can be used to initialize the model parameters dynamically.
|
||||
```
|
||||
|
||||
[**forward**](mmengine.model.BaseModel.forward): The arguments of `forward` need to match with the data given by [DataLoader](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). If the DataLoader samples a tuple `data`, `forward` needs to accept the value of unpacked `*data`. If DataLoader returns a dict `data`, `forward` needs to accept the key-value of unpacked `**data`. `forward` also accepts `mode` parameter, which is used to control the running branch:
|
||||
|
|
|
@ -40,7 +40,7 @@ for epoch in range(10):
|
|||
|
||||
`mmengine.optim.scheduler` supports most of PyTorch's learning rate schedulers such as `ExponentialLR`, `LinearLR`, `StepLR`, `MultiStepLR`, etc. Please refer to [parameter scheduler API documentation](https://mmengine.readthedocs.io/en/latest/api/optim.html#scheduler) for all of the supported schedulers.
|
||||
|
||||
MMEngine also supports adjusting momentum with parameter schedulers. To use momentum schedulers, replace `LR` in the class name to `Momentum`, such as `ExponentialMomentum`,`LinearMomentum`. Further, we implement the general parameter scheduler ParamScheduler, which is used to adjust the specified hyperparameters in the optimizer, such as weight_decay, etc. This feature makes it easier to apply some complex hyperparameter tuning strategies.
|
||||
MMEngine also supports adjusting momentum with parameter schedulers. To use momentum schedulers, replace `LR` in the class name to `Momentum`, such as `ExponentialMomentum`, `LinearMomentum`. Further, we implement the general parameter scheduler ParamScheduler, which is used to adjust the specified hyperparameters in the optimizer, such as weight_decay, etc. This feature makes it easier to apply some complex hyperparameter tuning strategies.
|
||||
|
||||
Different from the above example, MMEngine usually does not need to manually implement the training loop and call `optimizer.step()`. The runner will automatically manage the training progress and control the execution of the parameter scheduler through `ParamSchedulerHook`.
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ Pros and cons lie in both approaches. For the former one, beginners may be lost
|
|||
|
||||
We argue that the key to learning runner is using it as a memo. You should remember its most commonly used arguments and only focus on those less used when in need, since default values usually work fine. In the following, we will provide a beginner-friendly example to illustrate the most commonly used arguments of the runner, along with advanced guidelines for those less used.
|
||||
|
||||
### A beginer-friendly example
|
||||
### A beginner-friendly example
|
||||
|
||||
```{hint}
|
||||
In this tutorial, we hope you can focus more on overall architecture instead of implementation details. This "top-down" way of thinking is exactly what we advocate. Don't worry, you will definitely have plenty of opportunities and guidance afterward to focus on modules you want to improve.
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
version: 2
|
||||
|
||||
formats:
|
||||
- epub
|
||||
|
||||
sphinx:
|
||||
configuration: docs/zh_cn/conf.py
|
||||
|
||||
# Set the version of Python and other tools you might need
|
||||
build:
|
||||
os: ubuntu-22.04
|
||||
tools:
|
||||
python: "3.10"
|
||||
|
||||
python:
|
||||
install:
|
||||
- requirements: requirements/runtime.txt
|
||||
- requirements: requirements/docs.txt
|
||||
- requirements: requirements/docs_extra.txt
|
|
@ -26,7 +26,7 @@ CUDA_VISIBLE_DEVICES=0,3 python -m torch.distributed.launch --nproc_per_node=2 e
|
|||
|
||||
```bash
|
||||
python -m torch.distributed.launch \
|
||||
--nnodes 8 \
|
||||
--nnodes 2 \
|
||||
--node_rank 0 \
|
||||
--master_addr 127.0.0.1 \
|
||||
--master_port 29500 \
|
||||
|
@ -38,9 +38,9 @@ python -m torch.distributed.launch \
|
|||
|
||||
```bash
|
||||
python -m torch.distributed.launch \
|
||||
--nnodes 8 \
|
||||
--nnodes 2 \
|
||||
--node_rank 1 \
|
||||
--master_addr 127.0.0.1 \
|
||||
--master_addr "ip_of_the_first_machine" \
|
||||
--master_port 29500 \
|
||||
--nproc_per_node=8 \
|
||||
examples/distributed_training.py --launcher pytorch
|
||||
|
|
|
@ -1,22 +1,33 @@
|
|||
# 介绍
|
||||
|
||||
MMEngine 是一个基于 PyTorch 实现的,用于训练深度学习模型的基础库,支持在 Linux、Windows、macOS 上运行。它具有如下三个特性:
|
||||
MMEngine 是一个基于 PyTorch 实现的,用于训练深度学习模型的基础库,支持在 Linux、Windows、macOS 上运行。它的亮点如下:
|
||||
|
||||
1. **通用且强大的执行器**:
|
||||
**集成主流的大模型训练框架**
|
||||
|
||||
- 支持用少量代码训练不同的任务,例如仅使用 80 行代码就可以训练 ImageNet(原始 PyTorch 示例需要 400 行)。
|
||||
- 轻松兼容流行的算法库(如 TIMM、TorchVision 和 Detectron2)中的模型。
|
||||
- [ColossalAI](../common_usage/large_model_training.md#colossalai)
|
||||
- [DeepSpeed](../common_usage/large_model_training.md#deepspeed)
|
||||
- [FSDP](../common_usage/large_model_training.md#fullyshardeddataparallel-fsdp)
|
||||
|
||||
2. **接口统一的开放架构**:
|
||||
**支持丰富的训练策略**
|
||||
|
||||
- 使用统一的接口处理不同的算法任务,例如,实现一个方法并应用于所有的兼容性模型。
|
||||
- 上下游的对接更加统一便捷,在为上层算法库提供统一抽象的同时,支持多种后端设备。目前 MMEngine 支持 Nvidia CUDA、Mac MPS、AMD、MLU 等设备进行模型训练。
|
||||
- [混合精度训练(Mixed Precision Training)](../common_usage/speed_up_training.md#混合精度训练)
|
||||
- [梯度累积(Gradient Accumulation)](../common_usage/save_gpu_memory.md#梯度累加)
|
||||
- [梯度检查点(Gradient Checkpointing)](../common_usage/save_gpu_memory.md#梯度检查点)
|
||||
|
||||
3. **可定制的训练流程**:
|
||||
**提供易用的配置系统**
|
||||
|
||||
- 定义了“乐高”式的训练流程。
|
||||
- 提供了丰富的组件和策略。
|
||||
- 使用不同等级的 API 控制训练过程。
|
||||
- [纯 Python 风格的配置文件,易于跳转](../advanced_tutorials/config.md#纯-python-风格的配置文件beta)
|
||||
- [纯文本风格的配置文件,支持 JSON 和 YAML](../advanced_tutorials/config.md)
|
||||
|
||||
**覆盖主流的训练监测平台**
|
||||
|
||||
- [TensorBoard](../common_usage/visualize_training_log.md#tensorboard) | [WandB](../common_usage/visualize_training_log.md#wandb) | [MLflow](../common_usage/visualize_training_log.md#mlflow-wip)
|
||||
- [ClearML](../common_usage/visualize_training_log.md#clearml) | [Neptune](../common_usage/visualize_training_log.md#neptune) | [DVCLive](../common_usage/visualize_training_log.md#dvclive) | [Aim](../common_usage/visualize_training_log.md#aim)
|
||||
|
||||
**兼容主流的训练芯片**
|
||||
|
||||
- 英伟达 CUDA | 苹果 MPS
|
||||
- 华为 Ascend | 寒武纪 MLU | 摩尔线程 MUSA
|
||||
|
||||
## 架构
|
||||
|
||||
|
|
|
@ -31,11 +31,12 @@ MMEngine 提供了很多内置的钩子,将钩子分为两类,分别是默
|
|||
|
||||
**自定义钩子**
|
||||
|
||||
| 名称 | 用途 | 优先级 |
|
||||
| :---------------------------------: | :-------------------: | :---------: |
|
||||
| [EMAHook](#emahook) | 模型参数指数滑动平均 | NORMAL (50) |
|
||||
| [EmptyCacheHook](#emptycachehook) | PyTorch CUDA 缓存清理 | NORMAL (50) |
|
||||
| [SyncBuffersHook](#syncbuffershook) | 同步模型的 buffer | NORMAL (50) |
|
||||
| 名称 | 用途 | 优先级 |
|
||||
| :---------------------------------: | :--------------------------------: | :-----------: |
|
||||
| [EMAHook](#emahook) | 模型参数指数滑动平均 | NORMAL (50) |
|
||||
| [EmptyCacheHook](#emptycachehook) | PyTorch CUDA 缓存清理 | NORMAL (50) |
|
||||
| [SyncBuffersHook](#syncbuffershook) | 同步模型的 buffer | NORMAL (50) |
|
||||
| [ProfilerHook](#profilerhook) | 分析算子的执行时间以及显存占用情况 | VERY_LOW (90) |
|
||||
|
||||
```{note}
|
||||
不建议修改默认钩子的优先级,因为优先级低的钩子可能会依赖优先级高的钩子。例如 CheckpointHook 的优先级需要比 ParamSchedulerHook 低,这样保存的优化器状态才是正确的状态。另外,自定义钩子的优先级默认为 `NORMAL (50)`。
|
||||
|
@ -206,6 +207,20 @@ runner = Runner(custom_hooks=custom_hooks, ...)
|
|||
runner.train()
|
||||
```
|
||||
|
||||
### ProfilerHook
|
||||
|
||||
[ProfilerHook](mmengine.hooks.ProfilerHook) 用于分析模型算子的执行时间以及显存占用情况。
|
||||
|
||||
```python
|
||||
custom_hooks = [dict(type='ProfilerHook', on_trace_ready=dict(type='tb_trace'))]
|
||||
runner = Runner(custom_hooks=custom_hooks, ...)
|
||||
runner.train()
|
||||
```
|
||||
|
||||
profile 的结果会保存在 `work_dirs/{timestamp}` 下的 `tf_tracing_logs` 目录,通过 `tensorboard --logdir work_dirs/{timestamp}tf_tracing_logs`。
|
||||
|
||||
更多关于 ProfilerHook 的用法请阅读 [ProfilerHook](mmengine.hooks.ProfilerHook) 文档。
|
||||
|
||||
## 自定义钩子
|
||||
|
||||
如果 MMEngine 提供的默认钩子不能满足需求,用户可以自定义钩子,只需继承钩子基类并重写相应的位点方法。
|
||||
|
|
|
@ -499,7 +499,7 @@ class BaseStrategy(metaclass=ABCMeta):
|
|||
'"type" and "constructor" are not in '
|
||||
f'optimizer, but got {name}={optim}')
|
||||
optim_wrappers[name] = optim
|
||||
return OptimWrapperDict(**optim_wrappers)
|
||||
return OptimWrapperDict(**optim_wrappers) # type: ignore
|
||||
else:
|
||||
raise TypeError('optimizer wrapper should be an OptimWrapper '
|
||||
f'object or dict, but got {optim_wrapper}')
|
||||
|
@ -799,7 +799,8 @@ class BaseStrategy(metaclass=ABCMeta):
|
|||
else:
|
||||
model = self.model
|
||||
|
||||
_load_checkpoint_to_model(model, state_dict, strict, revise_keys)
|
||||
_load_checkpoint_to_model(
|
||||
model, state_dict, strict=strict, revise_keys=revise_keys)
|
||||
|
||||
def load_optim_state_dict(self, state_dict: dict) -> None:
|
||||
"""Load optimizer state from dict."""
|
||||
|
|
|
@ -120,8 +120,9 @@ class ColossalAIOptimWrapper(OptimWrapper):
|
|||
self.optimizer.backward(loss, **kwargs)
|
||||
|
||||
|
||||
@MODEL_WRAPPERS.register_module()
|
||||
class CollosalAIModelWrapper:
|
||||
@MODEL_WRAPPERS.register_module(
|
||||
name=['ColossalAIModelWrapper', 'CollosalAIModelWrapper'])
|
||||
class ColossalAIModelWrapper:
|
||||
|
||||
def __init__(self, model_wrapper: ModelWrapper, model: nn.Module):
|
||||
self.model_wrapper = model_wrapper
|
||||
|
@ -238,7 +239,7 @@ class ColossalAIStrategy(BaseStrategy):
|
|||
OPTIMIZER_DIR = 'optimizer' # directory to save optimizer state.
|
||||
MODEL_DIR = 'model' # directory to save model
|
||||
SCHEDULER_DIR = 'scheduler' # directory to save scheduelrs
|
||||
model: CollosalAIModelWrapper # type: ignore
|
||||
model: ColossalAIModelWrapper # type: ignore
|
||||
optim_wrapper: ColossalAIOptimWrapper # type: ignore
|
||||
|
||||
def __init__(
|
||||
|
@ -360,7 +361,7 @@ class ColossalAIStrategy(BaseStrategy):
|
|||
map_location: Union[str, Callable] = 'default',
|
||||
callback: Optional[Callable] = None,
|
||||
) -> dict:
|
||||
"""override this method since colossalai resume optimizer from filename
|
||||
"""Override this method since colossalai resume optimizer from filename
|
||||
directly."""
|
||||
self.logger.info(f'Resume checkpoint from {filename}')
|
||||
|
||||
|
@ -468,8 +469,14 @@ class ColossalAIStrategy(BaseStrategy):
|
|||
def _build_plugin(self, plugin: Union[str, dict]):
|
||||
if isinstance(plugin, str):
|
||||
if plugin == 'gemini':
|
||||
plugin = colo_plugin.GeminiPlugin(
|
||||
precision='bf16', placement_policy='cuda')
|
||||
try:
|
||||
plugin = colo_plugin.GeminiPlugin(
|
||||
precision='bf16', placement_policy='auto')
|
||||
except AssertionError:
|
||||
from colossalai.zero.gemini.placement_policy import \
|
||||
PlacementPolicyFactory as colo_placement
|
||||
raise ValueError('placement policy must be one of ' +
|
||||
f'{list(colo_placement.policies.keys())}')
|
||||
elif plugin == 'lowlevel-zero':
|
||||
plugin = colo_plugin.LowLevelZeroPlugin()
|
||||
else:
|
||||
|
@ -508,11 +515,11 @@ class ColossalAIStrategy(BaseStrategy):
|
|||
self,
|
||||
model: nn.Module,
|
||||
optim_wrapper: Optional[OptimWrapper] = None,
|
||||
) -> Union[Tuple[CollosalAIModelWrapper, ColossalAIOptimWrapper],
|
||||
CollosalAIModelWrapper]: # type: ignore
|
||||
) -> Union[Tuple[ColossalAIModelWrapper, ColossalAIOptimWrapper],
|
||||
ColossalAIModelWrapper]: # type: ignore
|
||||
"""Wrap model with :class:`ModelWrapper`."""
|
||||
if self.model_wrapper is None:
|
||||
self.model_wrapper = {'type': 'CollosalAIModelWrapper'}
|
||||
self.model_wrapper = {'type': 'ColossalAIModelWrapper'}
|
||||
|
||||
# For zero series parallel, move `data_preprocessor` to current device
|
||||
# is reasonable. We need to `BaseDataPreprocessor.to` manually since
|
||||
|
|
|
@ -6,18 +6,23 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
|||
|
||||
import torch
|
||||
|
||||
from mmengine.logging import print_log
|
||||
|
||||
try:
|
||||
import deepspeed
|
||||
except ImportError:
|
||||
deepspeed = None
|
||||
|
||||
import logging
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
import mmengine
|
||||
from mmengine.dist import init_dist
|
||||
from mmengine.dist import init_dist, is_main_process
|
||||
from mmengine.optim import BaseOptimWrapper, _ParamScheduler
|
||||
from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS, OPTIMIZERS,
|
||||
STRATEGIES)
|
||||
from mmengine.runner.checkpoint import save_checkpoint, weights_to_cpu
|
||||
from mmengine.utils import apply_to, digit_version, get_git_hash
|
||||
from .base import BaseStrategy
|
||||
|
||||
|
@ -306,8 +311,8 @@ class DeepSpeedStrategy(BaseStrategy):
|
|||
self.config['steps_per_print'] = steps_per_print
|
||||
self._inputs_to_half = inputs_to_half
|
||||
assert (exclude_frozen_parameters is None or
|
||||
digit_version(deepspeed.__version__) >= digit_version('0.10.1')
|
||||
), ('DeepSpeed >= 0.10.1 is required to enable '
|
||||
digit_version(deepspeed.__version__) >= digit_version('0.13.2')
|
||||
), ('DeepSpeed >= 0.13.2 is required to enable '
|
||||
'exclude_frozen_parameters')
|
||||
self.exclude_frozen_parameters = exclude_frozen_parameters
|
||||
|
||||
|
@ -425,7 +430,7 @@ class DeepSpeedStrategy(BaseStrategy):
|
|||
self.logger.info(f'Load checkpoint from {filename}')
|
||||
|
||||
dirname, basename = osp.split(filename)
|
||||
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
|
||||
if digit_version(deepspeed.__version__) >= digit_version('0.13.2'):
|
||||
_, extra_ckpt = self.model.load_checkpoint(
|
||||
dirname,
|
||||
tag=basename,
|
||||
|
@ -463,7 +468,7 @@ class DeepSpeedStrategy(BaseStrategy):
|
|||
self.logger.info(f'Resume checkpoint from {filename}')
|
||||
|
||||
dirname, basename = osp.split(filename)
|
||||
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
|
||||
if digit_version(deepspeed.__version__) >= digit_version('0.13.2'):
|
||||
_, extra_ckpt = self.model.load_checkpoint(
|
||||
dirname,
|
||||
tag=basename,
|
||||
|
@ -506,7 +511,7 @@ class DeepSpeedStrategy(BaseStrategy):
|
|||
"""Save checkpoint to given ``filename``.
|
||||
|
||||
Warning:
|
||||
`save_optimizer` and `callback` parameters are not supported yet.
|
||||
`callback` parameter is not supported yet.
|
||||
|
||||
Args:
|
||||
filename (str): Filename to save checkpoint.
|
||||
|
@ -527,25 +532,50 @@ class DeepSpeedStrategy(BaseStrategy):
|
|||
mmengine=mmengine.__version__ + get_git_hash(),
|
||||
)
|
||||
|
||||
if save_optimizer and hasattr(self, 'optim_wrapper'):
|
||||
# The key can not be 'optimizer', otherwise error will be thrown
|
||||
# when loading or resuming checkpoint.
|
||||
extra_ckpt['optim_wrapper'] = self.optim_state_dict()
|
||||
|
||||
if save_param_scheduler and hasattr(self, 'param_schedulers'):
|
||||
extra_ckpt['param_schedulers'] = self.scheduler_state_dict()
|
||||
|
||||
dirname, basename = osp.split(filename)
|
||||
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
|
||||
if (not save_optimizer
|
||||
and self.model.zero_optimization_partition_weights()
|
||||
and not self.model.zero_gather_16bit_weights_on_model_save()):
|
||||
print_log(
|
||||
'Configured to `save_optimizer=False`, but currently using '
|
||||
"DeepSpeed's ZeRO stage 3 with "
|
||||
'`gather_16bit_weights_on_model_save=False`. In '
|
||||
'this configuration, the model cannot be saved properly '
|
||||
'and will be saved with the optimizer state. '
|
||||
'To support `save_optimizer=False`, please set '
|
||||
'`gather_16bit_weights_on_model_save=True` in your '
|
||||
'DeepSpeed config.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
save_optimizer = True
|
||||
|
||||
state_dict_kwargs = {}
|
||||
if digit_version(deepspeed.__version__) >= digit_version('0.13.2'):
|
||||
state_dict_kwargs[
|
||||
'exclude_frozen_parameters'] = self.exclude_frozen_parameters
|
||||
|
||||
if save_optimizer:
|
||||
if hasattr(self, 'optim_wrapper'):
|
||||
# The key can not be 'optimizer', otherwise error will be
|
||||
# thrown when loading or resuming checkpoint.
|
||||
extra_ckpt['optim_wrapper'] = self.optim_state_dict()
|
||||
|
||||
dirname, basename = osp.split(filename)
|
||||
self.model.save_checkpoint(
|
||||
dirname,
|
||||
tag=basename,
|
||||
client_state=extra_ckpt,
|
||||
save_latest=False,
|
||||
exclude_frozen_parameters=self.exclude_frozen_parameters)
|
||||
**state_dict_kwargs)
|
||||
else:
|
||||
self.model.save_checkpoint(
|
||||
dirname,
|
||||
tag=basename,
|
||||
client_state=extra_ckpt,
|
||||
save_latest=False)
|
||||
if self.model.zero_optimization_partition_weights():
|
||||
state_dict = self.model._zero3_consolidated_16bit_state_dict(
|
||||
**state_dict_kwargs)
|
||||
else:
|
||||
state_dict = self.model.module_state_dict(**state_dict_kwargs)
|
||||
|
||||
if is_main_process():
|
||||
ckpt = {'state_dict': weights_to_cpu(state_dict), **extra_ckpt}
|
||||
save_checkpoint(ckpt, filename)
|
||||
|
|
|
@ -53,7 +53,7 @@ class DDPStrategy(SingleDeviceStrategy):
|
|||
init_dist(launcher, backend, **kwargs)
|
||||
|
||||
def convert_model(self, model: nn.Module) -> nn.Module:
|
||||
"""convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm``
|
||||
"""Convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm``
|
||||
(SyncBN) or ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -48,9 +48,11 @@ else:
|
|||
def _lazy2string(cfg_dict, dict_type=None):
|
||||
if isinstance(cfg_dict, dict):
|
||||
dict_type = dict_type or type(cfg_dict)
|
||||
return dict_type({k: _lazy2string(v) for k, v in dict.items(cfg_dict)})
|
||||
return dict_type(
|
||||
{k: _lazy2string(v, dict_type)
|
||||
for k, v in dict.items(cfg_dict)})
|
||||
elif isinstance(cfg_dict, (tuple, list)):
|
||||
return type(cfg_dict)(_lazy2string(v) for v in cfg_dict)
|
||||
return type(cfg_dict)(_lazy2string(v, dict_type) for v in cfg_dict)
|
||||
elif isinstance(cfg_dict, (LazyAttr, LazyObject)):
|
||||
return f'{cfg_dict.module}.{str(cfg_dict)}'
|
||||
else:
|
||||
|
@ -391,7 +393,7 @@ class Config:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
cfg_dict: dict = None,
|
||||
cfg_dict: Optional[dict] = None,
|
||||
cfg_text: Optional[str] = None,
|
||||
filename: Optional[Union[str, Path]] = None,
|
||||
env_variables: Optional[dict] = None,
|
||||
|
@ -442,6 +444,8 @@ class Config:
|
|||
predefined variables. Defaults to True.
|
||||
import_custom_modules (bool, optional): Whether to support
|
||||
importing custom modules in config. Defaults to None.
|
||||
use_environment_variables (bool, optional): Whether to use
|
||||
environment variables. Defaults to True.
|
||||
lazy_import (bool): Whether to load config in `lazy_import` mode.
|
||||
If it is `None`, it will be deduced by the content of the
|
||||
config file. Defaults to None.
|
||||
|
@ -829,6 +833,8 @@ class Config:
|
|||
filename (str): Name of config file.
|
||||
use_predefined_variables (bool, optional): Whether to use
|
||||
predefined variables. Defaults to True.
|
||||
use_environment_variables (bool, optional): Whether to use
|
||||
environment variables. Defaults to True.
|
||||
lazy_import (bool): Whether to load config in `lazy_import` mode.
|
||||
If it is `None`, it will be deduced by the content of the
|
||||
config file. Defaults to None.
|
||||
|
@ -899,7 +905,7 @@ class Config:
|
|||
# 2. Set `_scope_` for the outer dict variable for the base
|
||||
# config.
|
||||
# 3. Set `scope` attribute for each base variable.
|
||||
# Different from `_scope_`, `scope` is not a key of base
|
||||
# Different from `_scope_`, `scope` is not a key of base
|
||||
# dict, `scope` attribute will be parsed to key `_scope_`
|
||||
# by function `_parse_scope` only if the base variable is
|
||||
# accessed by the current config.
|
||||
|
@ -1221,7 +1227,8 @@ class Config:
|
|||
if base_code is not None:
|
||||
base_code = ast.Expression( # type: ignore
|
||||
body=base_code.value) # type: ignore
|
||||
base_files = eval(compile(base_code, '', mode='eval'))
|
||||
base_files = eval(compile(base_code, '',
|
||||
mode='eval')) # type: ignore
|
||||
else:
|
||||
base_files = []
|
||||
elif file_format in ('.yml', '.yaml', '.json'):
|
||||
|
@ -1282,7 +1289,7 @@ class Config:
|
|||
def _merge_a_into_b(a: dict,
|
||||
b: dict,
|
||||
allow_list_keys: bool = False) -> dict:
|
||||
"""merge dict ``a`` into dict ``b`` (non-inplace).
|
||||
"""Merge dict ``a`` into dict ``b`` (non-inplace).
|
||||
|
||||
Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
|
||||
in-place modifications.
|
||||
|
@ -1352,22 +1359,22 @@ class Config:
|
|||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
"""get file name of config."""
|
||||
"""Get file name of config."""
|
||||
return self._filename
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
"""get config text."""
|
||||
"""Get config text."""
|
||||
return self._text
|
||||
|
||||
@property
|
||||
def env_variables(self) -> dict:
|
||||
"""get used environment variables."""
|
||||
"""Get used environment variables."""
|
||||
return self._env_variables
|
||||
|
||||
@property
|
||||
def pretty_text(self) -> str:
|
||||
"""get formatted python config text."""
|
||||
"""Get formatted python config text."""
|
||||
|
||||
indent = 4
|
||||
|
||||
|
@ -1721,17 +1728,17 @@ class Config:
|
|||
|
||||
|
||||
class DictAction(Action):
|
||||
"""
|
||||
argparse action to split an argument into KEY=VALUE form
|
||||
on the first = and append to a dictionary. List options can
|
||||
be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
|
||||
brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
|
||||
list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
|
||||
"""Argparse action to split an argument into KEY=VALUE form on the first =
|
||||
and append to a dictionary.
|
||||
|
||||
List options can be passed as comma separated values, i.e 'KEY=V1,V2,V3',
|
||||
or with explicit brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested
|
||||
brackets to build list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _parse_int_float_bool(val: str) -> Union[int, float, bool, Any]:
|
||||
"""parse int/float/bool value in the string."""
|
||||
"""Parse int/float/bool value in the string."""
|
||||
try:
|
||||
return int(val)
|
||||
except ValueError:
|
||||
|
@ -1816,7 +1823,7 @@ class DictAction(Action):
|
|||
parser: ArgumentParser,
|
||||
namespace: Namespace,
|
||||
values: Union[str, Sequence[Any], None],
|
||||
option_string: str = None):
|
||||
option_string: str = None): # type: ignore
|
||||
"""Parse Variables in string and add them into argparser.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -12,6 +12,7 @@ from mmengine.fileio import load
|
|||
from mmengine.utils import check_file_exist
|
||||
|
||||
PYTHON_ROOT_DIR = osp.dirname(osp.dirname(sys.executable))
|
||||
SYSTEM_PYTHON_PREFIX = '/usr/lib/python'
|
||||
|
||||
MODULE2PACKAGE = {
|
||||
'mmcls': 'mmcls',
|
||||
|
@ -176,7 +177,8 @@ def _is_builtin_module(module_name: str) -> bool:
|
|||
return False
|
||||
origin_path = osp.abspath(origin_path)
|
||||
if ('site-package' in origin_path or 'dist-package' in origin_path
|
||||
or not origin_path.startswith(PYTHON_ROOT_DIR)):
|
||||
or not origin_path.startswith(
|
||||
(PYTHON_ROOT_DIR, SYSTEM_PYTHON_PREFIX))):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
|
||||
is_dipu_available, is_mlu_available, is_mps_available,
|
||||
is_npu_available, is_npu_support_full_precision)
|
||||
from .utils import (get_device, get_max_cuda_memory, get_max_musa_memory,
|
||||
is_cuda_available, is_dipu_available, is_mlu_available,
|
||||
is_mps_available, is_musa_available, is_npu_available,
|
||||
is_npu_support_full_precision)
|
||||
|
||||
__all__ = [
|
||||
'get_max_cuda_memory', 'get_device', 'is_cuda_available',
|
||||
'is_mlu_available', 'is_mps_available', 'is_npu_available',
|
||||
'is_dipu_available', 'is_npu_support_full_precision'
|
||||
'is_dipu_available', 'get_max_musa_memory', 'is_musa_available',
|
||||
'is_npu_support_full_precision'
|
||||
]
|
||||
|
|
|
@ -16,12 +16,24 @@ try:
|
|||
except Exception:
|
||||
IS_NPU_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import torch_mlu # noqa: F401
|
||||
IS_MLU_AVAILABLE = hasattr(torch, 'mlu') and torch.mlu.is_available()
|
||||
except Exception:
|
||||
IS_MLU_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import torch_dipu # noqa: F401
|
||||
IS_DIPU_AVAILABLE = True
|
||||
except Exception:
|
||||
IS_DIPU_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import torch_musa # noqa: F401
|
||||
IS_MUSA_AVAILABLE = True
|
||||
except Exception:
|
||||
IS_MUSA_AVAILABLE = False
|
||||
|
||||
|
||||
def get_max_cuda_memory(device: Optional[torch.device] = None) -> int:
|
||||
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB) for
|
||||
|
@ -58,7 +70,7 @@ def is_npu_available() -> bool:
|
|||
|
||||
def is_mlu_available() -> bool:
|
||||
"""Returns True if Cambricon PyTorch and mlu devices exist."""
|
||||
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
|
||||
return IS_MLU_AVAILABLE
|
||||
|
||||
|
||||
def is_mps_available() -> bool:
|
||||
|
@ -73,6 +85,34 @@ def is_dipu_available() -> bool:
|
|||
return IS_DIPU_AVAILABLE
|
||||
|
||||
|
||||
def get_max_musa_memory(device: Optional[torch.device] = None) -> int:
|
||||
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB) for
|
||||
a given device. By default, this returns the peak allocated memory since
|
||||
the beginning of this program.
|
||||
|
||||
Args:
|
||||
device (torch.device, optional): selected device. Returns
|
||||
statistic for the current device, given by
|
||||
:func:`~torch.musa.current_device`, if ``device`` is None.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
int: The maximum GPU memory occupied by tensors in megabytes
|
||||
for a given device.
|
||||
"""
|
||||
mem = torch.musa.max_memory_allocated(device=device)
|
||||
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
|
||||
dtype=torch.int,
|
||||
device=device)
|
||||
# TODO:haowen.han@mthreads.com: This function is not supported by musa yet.
|
||||
# torch.musa.reset_peak_memory_stats()
|
||||
return int(mem_mb.item())
|
||||
|
||||
|
||||
def is_musa_available() -> bool:
|
||||
return IS_MUSA_AVAILABLE
|
||||
|
||||
|
||||
def is_npu_support_full_precision() -> bool:
|
||||
"""Returns True if npu devices support full precision training."""
|
||||
version_of_support_full_precision = 220
|
||||
|
@ -91,12 +131,14 @@ elif is_mps_available():
|
|||
DEVICE = 'mps'
|
||||
elif is_dipu_available():
|
||||
DEVICE = 'dipu'
|
||||
elif is_musa_available():
|
||||
DEVICE = 'musa'
|
||||
|
||||
|
||||
def get_device() -> str:
|
||||
"""Returns the currently existing device type.
|
||||
|
||||
Returns:
|
||||
str: cuda | npu | mlu | mps | cpu.
|
||||
str: cuda | npu | mlu | mps | musa | cpu.
|
||||
"""
|
||||
return DEVICE
|
||||
|
|
|
@ -13,7 +13,7 @@ from torch import distributed as torch_dist
|
|||
from torch._utils import (_flatten_dense_tensors, _take_tensors,
|
||||
_unflatten_dense_tensors)
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from itertools import zip_longest, chain
|
||||
import mmengine
|
||||
from .utils import (get_world_size, get_rank, get_backend, get_dist_info,
|
||||
get_default_group, barrier, get_data_device,
|
||||
|
@ -415,12 +415,16 @@ def _broadcast_object_list(object_list: List[Any],
|
|||
current_device = torch.device('cpu')
|
||||
is_hccl_backend = group_backend == 'hccl'
|
||||
is_cncl_backend = group_backend == 'cncl'
|
||||
is_mccl_backend = group_backend == 'mccl'
|
||||
if is_hccl_backend:
|
||||
current_device = torch.device('npu', torch.npu.current_device())
|
||||
object_sizes_tensor = object_sizes_tensor.to(current_device)
|
||||
elif is_cncl_backend:
|
||||
current_device = torch.device('mlu', torch.mlu.current_device())
|
||||
object_sizes_tensor = object_sizes_tensor.to(current_device)
|
||||
elif is_mccl_backend:
|
||||
current_device = torch.device('musa', torch.musa.current_device())
|
||||
object_sizes_tensor = object_sizes_tensor.to(current_device)
|
||||
elif is_nccl_backend:
|
||||
# See note about using torch.cuda.current_device() here in
|
||||
# docstring. We cannot simply use my_rank since rank == device is
|
||||
|
@ -624,6 +628,7 @@ def _all_gather_object(object_list: List[Any],
|
|||
group_backend = get_backend(group)
|
||||
current_device = torch.device('cpu')
|
||||
is_nccl_backend = group_backend == torch_dist.Backend.NCCL
|
||||
is_mccl_backend = group_backend == 'mccl'
|
||||
if is_nccl_backend:
|
||||
# See note about using torch.cuda.current_device() here in docstring.
|
||||
# We cannot simply use my_rank since rank == device is not necessarily
|
||||
|
@ -631,6 +636,13 @@ def _all_gather_object(object_list: List[Any],
|
|||
current_device = torch.device('cuda', torch.cuda.current_device())
|
||||
input_tensor = input_tensor.to(current_device)
|
||||
local_size = local_size.to(current_device)
|
||||
elif is_mccl_backend:
|
||||
# See note about using torch.musa.current_device() here in docstring.
|
||||
# We cannot simply use my_rank since rank == device is not necessarily
|
||||
# true.
|
||||
current_device = torch.device('musa', torch.musa.current_device())
|
||||
input_tensor = input_tensor.to(current_device)
|
||||
local_size = local_size.to(current_device)
|
||||
# Gather all local sizes. This is so that we can find the max size, and
|
||||
# index until the correct size when deserializing the tensors.
|
||||
group_size = get_world_size(group=group)
|
||||
|
@ -776,10 +788,15 @@ def _gather_object(obj: Any,
|
|||
group_backend = get_backend(group)
|
||||
current_device = torch.device('cpu')
|
||||
is_nccl_backend = group_backend == torch_dist.Backend.NCCL
|
||||
is_mccl_backend = group_backend == 'mccl'
|
||||
if is_nccl_backend:
|
||||
current_device = torch.device('cuda', torch.cuda.current_device())
|
||||
input_tensor = input_tensor.to(current_device)
|
||||
local_size = local_size.to(current_device)
|
||||
elif is_mccl_backend:
|
||||
current_device = torch.device('musa', torch.musa.current_device())
|
||||
input_tensor = input_tensor.to(current_device)
|
||||
local_size = local_size.to(current_device)
|
||||
# Gather all local sizes. This is so that we can find the max size, and
|
||||
# index until the correct size when deserializing the tensors.
|
||||
group_size = get_world_size(group=group)
|
||||
|
@ -1010,8 +1027,10 @@ def collect_results_cpu(result_part: list,
|
|||
part_list.append(pickle.load(f))
|
||||
# sort the results
|
||||
ordered_results = []
|
||||
for res in zip(*part_list):
|
||||
ordered_results.extend(list(res))
|
||||
zipped_results = zip_longest(*part_list)
|
||||
ordered_results = [
|
||||
i for i in chain.from_iterable(zipped_results) if i is not None
|
||||
]
|
||||
# the dataloader may pad some samples
|
||||
ordered_results = ordered_results[:size]
|
||||
# remove tmp dir
|
||||
|
@ -1032,8 +1051,10 @@ def _collect_results_device(result_part: list, size: int) -> Optional[list]:
|
|||
if rank == 0:
|
||||
# sort the results
|
||||
ordered_results = []
|
||||
for res in zip(*part_list):
|
||||
ordered_results.extend(list(res))
|
||||
zipped_results = zip_longest(*part_list)
|
||||
ordered_results = [
|
||||
i for i in chain.from_iterable(zipped_results) if i is not None
|
||||
]
|
||||
# the dataloader may pad some samples
|
||||
ordered_results = ordered_results[:size]
|
||||
return ordered_results
|
||||
|
|
|
@ -11,7 +11,8 @@ import torch.multiprocessing as mp
|
|||
from torch import Tensor
|
||||
from torch import distributed as torch_dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from mmengine.device import is_mlu_available, is_npu_available
|
||||
from mmengine.device import (is_mlu_available, is_npu_available,
|
||||
is_musa_available)
|
||||
|
||||
from collections.abc import Iterable, Mapping
|
||||
|
||||
|
@ -99,9 +100,10 @@ def _init_dist_pytorch(backend, init_backend='torch', **kwargs) -> None:
|
|||
**kwargs: keyword arguments are passed to ``init_process_group``.
|
||||
"""
|
||||
rank = int(os.environ['RANK'])
|
||||
# LOCAL_RANK is set by `torch.distributed.launch` since PyTorch 1.1
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
if is_mlu_available():
|
||||
import torch_mlu # noqa: F401
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
torch.mlu.set_device(local_rank)
|
||||
torch_dist.init_process_group(
|
||||
backend='cncl',
|
||||
|
@ -110,15 +112,21 @@ def _init_dist_pytorch(backend, init_backend='torch', **kwargs) -> None:
|
|||
**kwargs)
|
||||
elif is_npu_available():
|
||||
import torch_npu # noqa: F401
|
||||
torch.npu.set_device(rank)
|
||||
torch.npu.set_device(local_rank)
|
||||
torch_dist.init_process_group(
|
||||
backend='hccl',
|
||||
rank=rank,
|
||||
world_size=int(os.environ['WORLD_SIZE']),
|
||||
**kwargs)
|
||||
elif is_musa_available():
|
||||
import torch_musa # noqa: F401
|
||||
torch.musa.set_device(rank)
|
||||
torch_dist.init_process_group(
|
||||
backend='mccl',
|
||||
rank=rank,
|
||||
world_size=int(os.environ['WORLD_SIZE']),
|
||||
**kwargs)
|
||||
else:
|
||||
# LOCAL_RANK is set by `torch.distributed.launch` since PyTorch 1.1
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
if init_backend == 'torch':
|
||||
|
@ -528,6 +536,9 @@ def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device:
|
|||
return torch.device('mlu', torch.mlu.current_device())
|
||||
elif backend == 'smddp':
|
||||
return torch.device('cuda', torch.cuda.current_device())
|
||||
elif backend == 'mccl':
|
||||
import torch_musa
|
||||
return torch.device('musa', torch_musa.current_device())
|
||||
else:
|
||||
# GLOO and MPI backends use cpu device by default
|
||||
return torch.device('cpu')
|
||||
|
@ -552,7 +563,7 @@ def cast_data_device(
|
|||
Tensor or list or dict: ``data`` was casted to ``device``.
|
||||
"""
|
||||
if out is not None:
|
||||
if type(data) != type(out):
|
||||
if type(data) is not type(out):
|
||||
raise TypeError(
|
||||
'out should be the same type with data, but got data is '
|
||||
f'{type(data)} and out is {type(data)}')
|
||||
|
|
|
@ -175,11 +175,11 @@ class DumpResults(BaseMetric):
|
|||
self.out_file_path = out_file_path
|
||||
|
||||
def process(self, data_batch: Any, predictions: Sequence[dict]) -> None:
|
||||
"""transfer tensors in predictions to CPU."""
|
||||
"""Transfer tensors in predictions to CPU."""
|
||||
self.results.extend(_to_cpu(predictions))
|
||||
|
||||
def compute_metrics(self, results: list) -> dict:
|
||||
"""dump the prediction results to a pickle file."""
|
||||
"""Dump the prediction results to a pickle file."""
|
||||
dump(results, self.out_file_path)
|
||||
print_log(
|
||||
f'Results has been saved to {self.out_file_path}.',
|
||||
|
@ -188,7 +188,7 @@ class DumpResults(BaseMetric):
|
|||
|
||||
|
||||
def _to_cpu(data: Any) -> Any:
|
||||
"""transfer all tensors and BaseDataElement to cpu."""
|
||||
"""Transfer all tensors and BaseDataElement to cpu."""
|
||||
if isinstance(data, (Tensor, BaseDataElement)):
|
||||
return data.to('cpu')
|
||||
elif isinstance(data, list):
|
||||
|
|
|
@ -4,6 +4,7 @@ from typing import Optional, Sequence, Union
|
|||
import torch
|
||||
|
||||
from mmengine.registry import HOOKS
|
||||
from ..device import is_cuda_available, is_musa_available
|
||||
from .hook import Hook
|
||||
|
||||
DATA_BATCH = Optional[Union[dict, tuple, list]]
|
||||
|
@ -49,7 +50,10 @@ class EmptyCacheHook(Hook):
|
|||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
if self._do_after_iter:
|
||||
torch.cuda.empty_cache()
|
||||
if is_cuda_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif is_musa_available():
|
||||
torch.musa.empty_cache()
|
||||
|
||||
def _before_epoch(self, runner, mode: str = 'train') -> None:
|
||||
"""Empty cache before an epoch.
|
||||
|
@ -59,7 +63,10 @@ class EmptyCacheHook(Hook):
|
|||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
if self._do_before_epoch:
|
||||
torch.cuda.empty_cache()
|
||||
if is_cuda_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif is_musa_available():
|
||||
torch.musa.empty_cache()
|
||||
|
||||
def _after_epoch(self, runner, mode: str = 'train') -> None:
|
||||
"""Empty cache after an epoch.
|
||||
|
@ -69,4 +76,7 @@ class EmptyCacheHook(Hook):
|
|||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
if self._do_after_epoch:
|
||||
torch.cuda.empty_cache()
|
||||
if is_cuda_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif is_musa_available():
|
||||
torch.musa.empty_cache()
|
||||
|
|
|
@ -233,7 +233,7 @@ class ProfilerHook(Hook):
|
|||
self._export_chrome_trace(runner)
|
||||
|
||||
def after_train_iter(self, runner, batch_idx, data_batch, outputs):
|
||||
"""profiler will call `step` method if it is not closed."""
|
||||
"""Profiler will call `step` method if it is not closed."""
|
||||
if not self._closed:
|
||||
self.profiler.step()
|
||||
if runner.iter == self.profile_times - 1 and not self.by_epoch:
|
||||
|
|
|
@ -58,7 +58,7 @@ class HistoryBuffer:
|
|||
self._statistics_methods.setdefault('mean', HistoryBuffer.mean)
|
||||
|
||||
def update(self, log_val: Union[int, float], count: int = 1) -> None:
|
||||
"""update the log history.
|
||||
"""Update the log history.
|
||||
|
||||
If the length of the buffer exceeds ``self._max_length``, the oldest
|
||||
element will be removed from the buffer.
|
||||
|
|
|
@ -398,22 +398,38 @@ def _get_device_id():
|
|||
except ImportError:
|
||||
return 0
|
||||
else:
|
||||
local_rank = int(os.getenv('LOCAL_RANK', '0'))
|
||||
# TODO: return device id of npu and mlu.
|
||||
if not torch.cuda.is_available():
|
||||
return local_rank
|
||||
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
|
||||
if cuda_visible_devices is None:
|
||||
num_device = torch.cuda.device_count()
|
||||
cuda_visible_devices = list(range(num_device))
|
||||
else:
|
||||
cuda_visible_devices = cuda_visible_devices.split(',')
|
||||
MUSA_AVAILABLE = False
|
||||
try:
|
||||
return int(cuda_visible_devices[local_rank])
|
||||
except ValueError:
|
||||
# handle case for Multi-Instance GPUs
|
||||
# see #1148 for details
|
||||
return cuda_visible_devices[local_rank]
|
||||
import torch_musa
|
||||
MUSA_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
if MUSA_AVAILABLE:
|
||||
local_rank = int(os.getenv('LOCAL_RANK', '0'))
|
||||
musa_visible_devices = os.getenv('MUSA_VISIBLE_DEVICES', None)
|
||||
if musa_visible_devices is None:
|
||||
num_device = torch_musa.device_count()
|
||||
musa_visible_devices = list(range(num_device))
|
||||
else:
|
||||
musa_visible_devices = musa_visible_devices.split(',')
|
||||
return int(musa_visible_devices[local_rank])
|
||||
else:
|
||||
local_rank = int(os.getenv('LOCAL_RANK', '0'))
|
||||
# TODO: return device id of npu and mlu.
|
||||
if not torch.cuda.is_available():
|
||||
return local_rank
|
||||
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
|
||||
if cuda_visible_devices is None:
|
||||
num_device = torch.cuda.device_count()
|
||||
cuda_visible_devices = list(range(num_device))
|
||||
else:
|
||||
cuda_visible_devices = cuda_visible_devices.split(',')
|
||||
try:
|
||||
return int(cuda_visible_devices[local_rank])
|
||||
except ValueError:
|
||||
# handle case for Multi-Instance GPUs
|
||||
# see #1148 for details
|
||||
return cuda_visible_devices[local_rank]
|
||||
|
||||
|
||||
def _get_host_info() -> str:
|
||||
|
@ -427,8 +443,7 @@ def _get_host_info() -> str:
|
|||
host = f'{getuser()}@{gethostname()}'
|
||||
except Exception as e:
|
||||
warnings.warn(f'Host or user not found: {str(e)}')
|
||||
finally:
|
||||
return host
|
||||
return host
|
||||
|
||||
|
||||
def _get_logging_file_handlers() -> Dict:
|
||||
|
|
|
@ -317,7 +317,7 @@ class MessageHub(ManagerMixin):
|
|||
if key not in self.runtime_info:
|
||||
return default
|
||||
else:
|
||||
# TODO: There are restrictions on objects that can be saved
|
||||
# TODO: There are restrictions on objects that can be saved
|
||||
# return copy.deepcopy(self._runtime_info[key])
|
||||
return self._runtime_info[key]
|
||||
|
||||
|
|
|
@ -222,6 +222,21 @@ class BaseModel(BaseModule):
|
|||
self._set_device(torch.device(device))
|
||||
return super().cuda(device)
|
||||
|
||||
def musa(
|
||||
self,
|
||||
device: Optional[Union[int, str, torch.device]] = None,
|
||||
) -> nn.Module:
|
||||
"""Overrides this method to call :meth:`BaseDataPreprocessor.musa`
|
||||
additionally.
|
||||
|
||||
Returns:
|
||||
nn.Module: The model itself.
|
||||
"""
|
||||
if device is None or isinstance(device, int):
|
||||
device = torch.device('musa', index=device)
|
||||
self._set_device(torch.device(device))
|
||||
return super().musa(device)
|
||||
|
||||
def mlu(
|
||||
self,
|
||||
device: Union[int, str, torch.device, None] = None,
|
||||
|
|
|
@ -113,6 +113,15 @@ class BaseDataPreprocessor(nn.Module):
|
|||
self._device = torch.device(torch.cuda.current_device())
|
||||
return super().cuda()
|
||||
|
||||
def musa(self, *args, **kwargs) -> nn.Module:
|
||||
"""Overrides this method to set the :attr:`device`
|
||||
|
||||
Returns:
|
||||
nn.Module: The model itself.
|
||||
"""
|
||||
self._device = torch.device(torch.musa.current_device())
|
||||
return super().musa()
|
||||
|
||||
def npu(self, *args, **kwargs) -> nn.Module:
|
||||
"""Overrides this method to set the :attr:`device`
|
||||
|
||||
|
@ -226,7 +235,7 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
|
|||
self.pad_value = pad_value
|
||||
|
||||
def forward(self, data: dict, training: bool = False) -> Union[dict, list]:
|
||||
"""Performs normalization、padding and bgr2rgb conversion based on
|
||||
"""Performs normalization, padding and bgr2rgb conversion based on
|
||||
``BaseDataPreprocessor``.
|
||||
|
||||
Args:
|
||||
|
@ -244,17 +253,17 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
|
|||
dict or list: Data in the same format as the model input.
|
||||
"""
|
||||
data = self.cast_data(data) # type: ignore
|
||||
_batch_inputs = data['inputs']
|
||||
_batch_inputs = data['inputs'] # type: ignore
|
||||
# Process data with `pseudo_collate`.
|
||||
if is_seq_of(_batch_inputs, torch.Tensor):
|
||||
batch_inputs = []
|
||||
for _batch_input in _batch_inputs:
|
||||
# channel transform
|
||||
if self._channel_conversion:
|
||||
_batch_input = _batch_input[[2, 1, 0], ...]
|
||||
_batch_input = _batch_input[[2, 1, 0], ...] # type: ignore
|
||||
# Convert to float after channel conversion to ensure
|
||||
# efficiency
|
||||
_batch_input = _batch_input.float()
|
||||
_batch_input = _batch_input.float() # type: ignore
|
||||
# Normalization.
|
||||
if self._enable_normalize:
|
||||
if self.mean.shape[0] == 3:
|
||||
|
@ -293,7 +302,7 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
|
|||
else:
|
||||
raise TypeError('Output of `cast_data` should be a dict of '
|
||||
'list/tuple with inputs and data_samples, '
|
||||
f'but got {type(data)}: {data}')
|
||||
data['inputs'] = batch_inputs
|
||||
data.setdefault('data_samples', None)
|
||||
return data
|
||||
f'but got {type(data)}: {data}') # type: ignore
|
||||
data['inputs'] = batch_inputs # type: ignore
|
||||
data.setdefault('data_samples', None) # type: ignore
|
||||
return data # type: ignore
|
||||
|
|
|
@ -119,7 +119,7 @@ def caffe2_xavier_init(module, bias=0):
|
|||
|
||||
|
||||
def bias_init_with_prob(prior_prob):
|
||||
"""initialize conv/fc bias value according to a given probability value."""
|
||||
"""Initialize conv/fc bias value according to a given probability value."""
|
||||
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
|
||||
return bias_init
|
||||
|
||||
|
@ -662,12 +662,12 @@ def trunc_normal_(tensor: Tensor,
|
|||
std: float = 1.,
|
||||
a: float = -2.,
|
||||
b: float = 2.) -> Tensor:
|
||||
r"""Fills the input Tensor with values drawn from a truncated
|
||||
normal distribution. The values are effectively drawn from the
|
||||
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
||||
with values outside :math:`[a, b]` redrawn until they are within
|
||||
the bounds. The method used for generating the random values works
|
||||
best when :math:`a \leq \text{mean} \leq b`.
|
||||
r"""Fills the input Tensor with values drawn from a truncated normal
|
||||
distribution. The values are effectively drawn from the normal distribution
|
||||
:math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside
|
||||
:math:`[a, b]` redrawn until they are within the bounds. The method used
|
||||
for generating the random values works best when :math:`a \leq \text{mean}
|
||||
\leq b`.
|
||||
|
||||
Modified from
|
||||
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
|
||||
|
|
|
@ -127,7 +127,8 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
|||
auto_wrap_policy: Union[str, Callable, None] = None,
|
||||
backward_prefetch: Union[str, BackwardPrefetch, None] = None,
|
||||
mixed_precision: Union[dict, MixedPrecision, None] = None,
|
||||
param_init_fn: Union[str, Callable[[nn.Module], None]] = None,
|
||||
param_init_fn: Union[str, Callable[
|
||||
[nn.Module], None]] = None, # type: ignore # noqa: E501
|
||||
use_orig_params: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -362,7 +363,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
|||
optim: torch.optim.Optimizer,
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""copied from pytorch 2.0.1 which has fixed some bugs."""
|
||||
"""Copied from pytorch 2.0.1 which has fixed some bugs."""
|
||||
state_dict_settings = FullyShardedDataParallel.get_state_dict_type(
|
||||
model)
|
||||
return FullyShardedDataParallel._optim_state_dict_impl(
|
||||
|
@ -384,7 +385,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
|||
state_dict_config: Optional[StateDictConfig] = None,
|
||||
optim_state_dict_config: Optional[OptimStateDictConfig] = None,
|
||||
) -> StateDictSettings:
|
||||
"""copied from pytorch 2.0.1 which has fixed some bugs."""
|
||||
"""Copied from pytorch 2.0.1 which has fixed some bugs."""
|
||||
import torch.distributed.fsdp._traversal_utils as traversal_utils
|
||||
_state_dict_type_to_config = {
|
||||
StateDictType.FULL_STATE_DICT: FullStateDictConfig,
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from mmengine.device import (is_cuda_available, is_mlu_available,
|
||||
is_npu_available)
|
||||
is_musa_available, is_npu_available)
|
||||
from mmengine.registry import OPTIM_WRAPPERS
|
||||
from mmengine.utils import digit_version
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
|
@ -74,8 +74,9 @@ class AmpOptimWrapper(OptimWrapper):
|
|||
assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), (
|
||||
'`torch.cuda.amp` is only available when pytorch version >= 1.6')
|
||||
assert is_cuda_available() or is_npu_available() or is_mlu_available(
|
||||
), ('``AmpOptimizerWrapper`` is only available training '
|
||||
'on gpu, npu or mlu')
|
||||
) or is_musa_available(), (
|
||||
'``AmpOptimizerWrapper`` is only available training '
|
||||
'on gpu, npu, mlu or musa')
|
||||
super().__init__(**kwargs)
|
||||
self._scale_update_param = None
|
||||
|
||||
|
|
|
@ -123,8 +123,7 @@ class ApexOptimWrapper(OptimWrapper):
|
|||
self._inner_count += 1
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
"""Get the state dictionary of :attr:`optimizer` and
|
||||
:attr:`apex_amp`.
|
||||
"""Get the state dictionary of :attr:`optimizer` and :attr:`apex_amp`.
|
||||
|
||||
Based on the state dictionary of the optimizer, the returned state
|
||||
dictionary will add a key named "apex_amp".
|
||||
|
|
|
@ -25,7 +25,11 @@ def register_torch_optimizers() -> List[str]:
|
|||
_optim = getattr(torch.optim, module_name)
|
||||
if inspect.isclass(_optim) and issubclass(_optim,
|
||||
torch.optim.Optimizer):
|
||||
OPTIMIZERS.register_module(module=_optim)
|
||||
if module_name == 'Adafactor':
|
||||
OPTIMIZERS.register_module(
|
||||
name='TorchAdafactor', module=_optim)
|
||||
else:
|
||||
OPTIMIZERS.register_module(module=_optim)
|
||||
torch_optimizers.append(module_name)
|
||||
return torch_optimizers
|
||||
|
||||
|
|
|
@ -131,7 +131,7 @@ class DefaultOptimWrapperConstructor:
|
|||
self._validate_cfg()
|
||||
|
||||
def _validate_cfg(self) -> None:
|
||||
"""verify the correctness of the config."""
|
||||
"""Verify the correctness of the config."""
|
||||
if not isinstance(self.paramwise_cfg, dict):
|
||||
raise TypeError('paramwise_cfg should be None or a dict, '
|
||||
f'but got {type(self.paramwise_cfg)}')
|
||||
|
@ -155,7 +155,7 @@ class DefaultOptimWrapperConstructor:
|
|||
raise ValueError('base_wd should not be None')
|
||||
|
||||
def _is_in(self, param_group: dict, param_group_list: list) -> bool:
|
||||
"""check whether the `param_group` is in the`param_group_list`"""
|
||||
"""Check whether the `param_group` is in the`param_group_list`"""
|
||||
assert is_list_of(param_group_list, dict)
|
||||
param = set(param_group['params'])
|
||||
param_set = set()
|
||||
|
@ -213,7 +213,10 @@ class DefaultOptimWrapperConstructor:
|
|||
level=logging.WARNING)
|
||||
continue
|
||||
if not param.requires_grad:
|
||||
params.append(param_group)
|
||||
print_log((f'{prefix}.{name} is skipped since its '
|
||||
f'requires_grad={param.requires_grad}'),
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
continue
|
||||
|
||||
# if the parameter match one of the custom keys, ignore other rules
|
||||
|
|
|
@ -161,8 +161,7 @@ class OptimWrapperDict(OptimWrapper):
|
|||
self.optim_wrappers[name].load_state_dict(_state_dict)
|
||||
|
||||
def items(self) -> Iterator[Tuple[str, OptimWrapper]]:
|
||||
"""A generator to get the name and corresponding
|
||||
:obj:`OptimWrapper`"""
|
||||
"""A generator to get the name and corresponding :obj:`OptimWrapper`"""
|
||||
yield from self.optim_wrappers.items()
|
||||
|
||||
def values(self) -> Iterator[OptimWrapper]:
|
||||
|
|
|
@ -223,13 +223,13 @@ class PolyLR(LRSchedulerMixin, PolyParamScheduler):
|
|||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class OneCycleLR(LRSchedulerMixin, OneCycleParamScheduler):
|
||||
r"""Sets the learning rate of each parameter group according to the
|
||||
1cycle learning rate policy. The 1cycle policy anneals the learning
|
||||
rate from an initial learning rate to some maximum learning rate and then
|
||||
from that maximum learning rate to some minimum learning rate much lower
|
||||
than the initial learning rate.
|
||||
This policy was initially described in the paper `Super-Convergence:
|
||||
Very Fast Training of Neural Networks Using Large Learning Rates`_.
|
||||
r"""Sets the learning rate of each parameter group according to the 1cycle
|
||||
learning rate policy. The 1cycle policy anneals the learning rate from an
|
||||
initial learning rate to some maximum learning rate and then from that
|
||||
maximum learning rate to some minimum learning rate much lower than the
|
||||
initial learning rate. This policy was initially described in the paper
|
||||
`Super-Convergence: Very Fast Training of Neural Networks Using Large
|
||||
Learning Rates`_.
|
||||
|
||||
The 1cycle learning rate policy changes the learning rate after every
|
||||
batch. `step` should be called after a batch has been used for training.
|
||||
|
|
|
@ -565,9 +565,9 @@ class ExponentialParamScheduler(_ParamScheduler):
|
|||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class CosineAnnealingParamScheduler(_ParamScheduler):
|
||||
r"""Set the parameter value of each parameter group using a cosine
|
||||
annealing schedule, where :math:`\eta_{max}` is set to the initial value
|
||||
and :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
|
||||
r"""Set the parameter value of each parameter group using a cosine annealing
|
||||
schedule, where :math:`\eta_{max}` is set to the initial value and
|
||||
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
|
@ -617,7 +617,7 @@ class CosineAnnealingParamScheduler(_ParamScheduler):
|
|||
|
||||
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
||||
https://arxiv.org/abs/1608.03983
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self,
|
||||
optimizer: Union[Optimizer, BaseOptimWrapper],
|
||||
|
@ -890,13 +890,13 @@ class PolyParamScheduler(_ParamScheduler):
|
|||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class OneCycleParamScheduler(_ParamScheduler):
|
||||
r"""Sets the parameters of each parameter group according to the
|
||||
1cycle learning rate policy. The 1cycle policy anneals the learning
|
||||
rate from an initial learning rate to some maximum learning rate and then
|
||||
from that maximum learning rate to some minimum learning rate much lower
|
||||
than the initial learning rate.
|
||||
This policy was initially described in the paper `Super-Convergence:
|
||||
Very Fast Training of Neural Networks Using Large Learning Rates`_.
|
||||
r"""Sets the parameters of each parameter group according to the 1cycle
|
||||
learning rate policy. The 1cycle policy anneals the learning rate from an
|
||||
initial learning rate to some maximum learning rate and then from that
|
||||
maximum learning rate to some minimum learning rate much lower than the
|
||||
initial learning rate. This policy was initially described in the paper
|
||||
`Super-Convergence: Very Fast Training of Neural Networks Using Large
|
||||
Learning Rates`_.
|
||||
|
||||
The 1cycle learning rate policy changes the learning rate after every
|
||||
batch. `step` should be called after a batch has been used for training.
|
||||
|
|
|
@ -4,7 +4,7 @@ import logging
|
|||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from mmengine.config import Config, ConfigDict
|
||||
from mmengine.utils import ManagerMixin
|
||||
from mmengine.utils import ManagerMixin, digit_version
|
||||
from .registry import Registry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -232,6 +232,21 @@ def build_model_from_cfg(
|
|||
return build_from_cfg(cfg, registry, default_args)
|
||||
|
||||
|
||||
def build_optimizer_from_cfg(
|
||||
cfg: Union[dict, ConfigDict, Config],
|
||||
registry: Registry,
|
||||
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any:
|
||||
import torch
|
||||
|
||||
from ..logging import print_log
|
||||
if 'type' in cfg \
|
||||
and 'Adafactor' == cfg['type'] \
|
||||
and digit_version(torch.__version__) >= digit_version('2.5.0'):
|
||||
print_log(
|
||||
'the torch version of Adafactor is registered as TorchAdafactor')
|
||||
return build_from_cfg(cfg, registry, default_args)
|
||||
|
||||
|
||||
def build_scheduler_from_cfg(
|
||||
cfg: Union[dict, ConfigDict, Config],
|
||||
registry: Registry,
|
||||
|
|
|
@ -81,7 +81,7 @@ class DefaultScope(ManagerMixin):
|
|||
@classmethod
|
||||
@contextmanager
|
||||
def overwrite_default_scope(cls, scope_name: Optional[str]) -> Generator:
|
||||
"""overwrite the current default scope with `scope_name`"""
|
||||
"""Overwrite the current default scope with `scope_name`"""
|
||||
if scope_name is None:
|
||||
yield
|
||||
else:
|
||||
|
|
|
@ -332,7 +332,7 @@ class Registry:
|
|||
return root
|
||||
|
||||
def import_from_location(self) -> None:
|
||||
"""import modules from the pre-defined locations in self._location."""
|
||||
"""Import modules from the pre-defined locations in self._location."""
|
||||
if not self._imported:
|
||||
# Avoid circular import
|
||||
from ..logging import print_log
|
||||
|
|
|
@ -6,8 +6,8 @@ More datails can be found at
|
|||
https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html.
|
||||
"""
|
||||
|
||||
from .build_functions import (build_model_from_cfg, build_runner_from_cfg,
|
||||
build_scheduler_from_cfg)
|
||||
from .build_functions import (build_model_from_cfg, build_optimizer_from_cfg,
|
||||
build_runner_from_cfg, build_scheduler_from_cfg)
|
||||
from .registry import Registry
|
||||
|
||||
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
|
||||
|
@ -35,7 +35,7 @@ MODEL_WRAPPERS = Registry('model_wrapper')
|
|||
WEIGHT_INITIALIZERS = Registry('weight initializer')
|
||||
|
||||
# mangage all kinds of optimizers like `SGD` and `Adam`
|
||||
OPTIMIZERS = Registry('optimizer')
|
||||
OPTIMIZERS = Registry('optimizer', build_func=build_optimizer_from_cfg)
|
||||
# manage optimizer wrapper
|
||||
OPTIM_WRAPPERS = Registry('optim_wrapper')
|
||||
# manage constructors that customize the optimization hyperparameters.
|
||||
|
|
|
@ -109,7 +109,7 @@ def init_default_scope(scope: str) -> None:
|
|||
if current_scope.scope_name != scope: # type: ignore
|
||||
print_log(
|
||||
'The current default scope ' # type: ignore
|
||||
f'"{current_scope.scope_name}" is not "{scope}", '
|
||||
f'"{current_scope.scope_name}" is not "{scope}", ' # type: ignore
|
||||
'`init_default_scope` will force set the current'
|
||||
f'default scope to "{scope}".',
|
||||
logger='current',
|
||||
|
|
|
@ -540,7 +540,7 @@ class FlexibleRunner:
|
|||
|
||||
@property
|
||||
def hooks(self):
|
||||
"""list[:obj:`Hook`]: A list of registered hooks."""
|
||||
"""List[:obj:`Hook`]: A list of registered hooks."""
|
||||
return self._hooks
|
||||
|
||||
@property
|
||||
|
@ -1117,7 +1117,7 @@ class FlexibleRunner:
|
|||
return '\n'.join(stage_hook_infos)
|
||||
|
||||
def load_or_resume(self):
|
||||
"""load or resume checkpoint."""
|
||||
"""Load or resume checkpoint."""
|
||||
if self._has_loaded:
|
||||
return None
|
||||
|
||||
|
@ -1539,7 +1539,7 @@ class FlexibleRunner:
|
|||
file_client_args: Optional[dict] = None,
|
||||
save_optimizer: bool = True,
|
||||
save_param_scheduler: bool = True,
|
||||
meta: dict = None,
|
||||
meta: Optional[dict] = None,
|
||||
by_epoch: bool = True,
|
||||
backend_args: Optional[dict] = None,
|
||||
):
|
||||
|
|
|
@ -135,7 +135,13 @@ def autocast(device_type: Optional[str] = None,
|
|||
|
||||
elif device_type == 'npu':
|
||||
pass
|
||||
|
||||
elif device_type == 'musa':
|
||||
if dtype is None:
|
||||
dtype = torch.get_autocast_gpu_dtype()
|
||||
with torch.musa.amp.autocast(
|
||||
enabled=enabled, dtype=dtype, cache_enabled=cache_enabled):
|
||||
yield
|
||||
return
|
||||
else:
|
||||
# Device like MPS does not support fp16 training or testing.
|
||||
# If an inappropriate device is set and fp16 is enabled, an error
|
||||
|
|
|
@ -309,7 +309,7 @@ class CheckpointLoader:
|
|||
|
||||
@classmethod
|
||||
def load_checkpoint(cls, filename, map_location=None, logger='current'):
|
||||
"""load checkpoint through URL scheme path.
|
||||
"""Load checkpoint through URL scheme path.
|
||||
|
||||
Args:
|
||||
filename (str): checkpoint file name with given prefix
|
||||
|
@ -332,7 +332,7 @@ class CheckpointLoader:
|
|||
|
||||
@CheckpointLoader.register_scheme(prefixes='')
|
||||
def load_from_local(filename, map_location):
|
||||
"""load checkpoint by local file path.
|
||||
"""Load checkpoint by local file path.
|
||||
|
||||
Args:
|
||||
filename (str): local checkpoint file path
|
||||
|
@ -353,7 +353,7 @@ def load_from_http(filename,
|
|||
map_location=None,
|
||||
model_dir=None,
|
||||
progress=os.isatty(0)):
|
||||
"""load checkpoint through HTTP or HTTPS scheme path. In distributed
|
||||
"""Load checkpoint through HTTP or HTTPS scheme path. In distributed
|
||||
setting, this function only download checkpoint at local rank 0.
|
||||
|
||||
Args:
|
||||
|
@ -386,7 +386,7 @@ def load_from_http(filename,
|
|||
|
||||
@CheckpointLoader.register_scheme(prefixes='pavi://')
|
||||
def load_from_pavi(filename, map_location=None):
|
||||
"""load checkpoint through the file path prefixed with pavi. In distributed
|
||||
"""Load checkpoint through the file path prefixed with pavi. In distributed
|
||||
setting, this function download ckpt at all ranks to different temporary
|
||||
directories.
|
||||
|
||||
|
@ -419,7 +419,7 @@ def load_from_pavi(filename, map_location=None):
|
|||
@CheckpointLoader.register_scheme(
|
||||
prefixes=[r'(\S+\:)?s3://', r'(\S+\:)?petrel://'])
|
||||
def load_from_ceph(filename, map_location=None, backend='petrel'):
|
||||
"""load checkpoint through the file path prefixed with s3. In distributed
|
||||
"""Load checkpoint through the file path prefixed with s3. In distributed
|
||||
setting, this function download ckpt at all ranks to different temporary
|
||||
directories.
|
||||
|
||||
|
@ -441,7 +441,7 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
|
|||
|
||||
@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
|
||||
def load_from_torchvision(filename, map_location=None):
|
||||
"""load checkpoint through the file path prefixed with modelzoo or
|
||||
"""Load checkpoint through the file path prefixed with modelzoo or
|
||||
torchvision.
|
||||
|
||||
Args:
|
||||
|
@ -467,7 +467,7 @@ def load_from_torchvision(filename, map_location=None):
|
|||
|
||||
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
|
||||
def load_from_openmmlab(filename, map_location=None):
|
||||
"""load checkpoint through the file path prefixed with open-mmlab or
|
||||
"""Load checkpoint through the file path prefixed with open-mmlab or
|
||||
openmmlab.
|
||||
|
||||
Args:
|
||||
|
@ -510,7 +510,7 @@ def load_from_openmmlab(filename, map_location=None):
|
|||
|
||||
@CheckpointLoader.register_scheme(prefixes='mmcls://')
|
||||
def load_from_mmcls(filename, map_location=None):
|
||||
"""load checkpoint through the file path prefixed with mmcls.
|
||||
"""Load checkpoint through the file path prefixed with mmcls.
|
||||
|
||||
Args:
|
||||
filename (str): checkpoint file path with mmcls prefix
|
||||
|
|
|
@ -9,7 +9,8 @@ from typing import List, Optional, Tuple
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmengine.device import get_max_cuda_memory, is_cuda_available
|
||||
from mmengine.device import (get_max_cuda_memory, get_max_musa_memory,
|
||||
is_cuda_available, is_musa_available)
|
||||
from mmengine.registry import LOG_PROCESSORS
|
||||
|
||||
|
||||
|
@ -226,11 +227,13 @@ class LogProcessor:
|
|||
log_tag.pop('time')
|
||||
log_tag.pop('data_time')
|
||||
|
||||
# If cuda is available, the max memory occupied should be calculated.
|
||||
if is_cuda_available():
|
||||
# If cuda/musa is available,
|
||||
# the max memory occupied should be calculated.
|
||||
if is_cuda_available() or is_musa_available():
|
||||
max_memory = self._get_max_memory(runner)
|
||||
log_str += f'memory: {max_memory} '
|
||||
tag['memory'] = max_memory
|
||||
|
||||
# Loop left keys to fill `log_str`.
|
||||
if mode in ('train', 'val'):
|
||||
log_items = []
|
||||
|
@ -498,6 +501,9 @@ class LogProcessor:
|
|||
"""
|
||||
|
||||
device = getattr(runner.model, 'output_device', None)
|
||||
|
||||
if is_musa_available():
|
||||
return get_max_musa_memory(device)
|
||||
return get_max_cuda_memory(device)
|
||||
|
||||
def _get_iter(self, runner, batch_idx: int) -> int:
|
||||
|
|
|
@ -8,8 +8,10 @@ import torch
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.logging import HistoryBuffer, print_log
|
||||
from mmengine.registry import LOOPS
|
||||
from mmengine.structures import BaseDataElement
|
||||
from mmengine.utils import is_list_of
|
||||
from .amp import autocast
|
||||
from .base_loop import BaseLoop
|
||||
from .utils import calc_dynamic_intervals
|
||||
|
@ -98,7 +100,8 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||
self._decide_current_val_interval()
|
||||
if (self.runner.val_loop is not None
|
||||
and self._epoch >= self.val_begin
|
||||
and self._epoch % self.val_interval == 0):
|
||||
and (self._epoch % self.val_interval == 0
|
||||
or self._epoch == self._max_epochs)):
|
||||
self.runner.val_loop.run()
|
||||
|
||||
self.runner.call_hook('after_train')
|
||||
|
@ -271,6 +274,14 @@ class IterBasedTrainLoop(BaseLoop):
|
|||
# In iteration-based training loop, we treat the whole training process
|
||||
# as a big epoch and execute the corresponding hook.
|
||||
self.runner.call_hook('before_train_epoch')
|
||||
if self._iter > 0:
|
||||
print_log(
|
||||
f'Advance dataloader {self._iter} steps to skip data '
|
||||
'that has already been trained',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
for _ in range(self._iter):
|
||||
next(self.dataloader_iterator)
|
||||
while self._iter < self._max_iters and not self.stop_training:
|
||||
self.runner.model.train()
|
||||
|
||||
|
@ -280,7 +291,8 @@ class IterBasedTrainLoop(BaseLoop):
|
|||
self._decide_current_val_interval()
|
||||
if (self.runner.val_loop is not None
|
||||
and self._iter >= self.val_begin
|
||||
and self._iter % self.val_interval == 0):
|
||||
and (self._iter % self.val_interval == 0
|
||||
or self._iter == self._max_iters)):
|
||||
self.runner.val_loop.run()
|
||||
|
||||
self.runner.call_hook('after_train_epoch')
|
||||
|
@ -353,17 +365,26 @@ class ValLoop(BaseLoop):
|
|||
logger='current',
|
||||
level=logging.WARNING)
|
||||
self.fp16 = fp16
|
||||
self.val_loss: Dict[str, HistoryBuffer] = dict()
|
||||
|
||||
def run(self) -> dict:
|
||||
"""Launch validation."""
|
||||
self.runner.call_hook('before_val')
|
||||
self.runner.call_hook('before_val_epoch')
|
||||
self.runner.model.eval()
|
||||
|
||||
# clear val loss
|
||||
self.val_loss.clear()
|
||||
for idx, data_batch in enumerate(self.dataloader):
|
||||
self.run_iter(idx, data_batch)
|
||||
|
||||
# compute metrics
|
||||
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||
|
||||
if self.val_loss:
|
||||
loss_dict = _parse_losses(self.val_loss, 'val')
|
||||
metrics.update(loss_dict)
|
||||
|
||||
self.runner.call_hook('after_val_epoch', metrics=metrics)
|
||||
self.runner.call_hook('after_val')
|
||||
return metrics
|
||||
|
@ -381,6 +402,9 @@ class ValLoop(BaseLoop):
|
|||
# outputs should be sequence of BaseDataElement
|
||||
with autocast(enabled=self.fp16):
|
||||
outputs = self.runner.model.val_step(data_batch)
|
||||
|
||||
outputs, self.val_loss = _update_losses(outputs, self.val_loss)
|
||||
|
||||
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
|
||||
self.runner.call_hook(
|
||||
'after_val_iter',
|
||||
|
@ -425,17 +449,26 @@ class TestLoop(BaseLoop):
|
|||
logger='current',
|
||||
level=logging.WARNING)
|
||||
self.fp16 = fp16
|
||||
self.test_loss: Dict[str, HistoryBuffer] = dict()
|
||||
|
||||
def run(self) -> dict:
|
||||
"""Launch test."""
|
||||
self.runner.call_hook('before_test')
|
||||
self.runner.call_hook('before_test_epoch')
|
||||
self.runner.model.eval()
|
||||
|
||||
# clear test loss
|
||||
self.test_loss.clear()
|
||||
for idx, data_batch in enumerate(self.dataloader):
|
||||
self.run_iter(idx, data_batch)
|
||||
|
||||
# compute metrics
|
||||
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||
|
||||
if self.test_loss:
|
||||
loss_dict = _parse_losses(self.test_loss, 'test')
|
||||
metrics.update(loss_dict)
|
||||
|
||||
self.runner.call_hook('after_test_epoch', metrics=metrics)
|
||||
self.runner.call_hook('after_test')
|
||||
return metrics
|
||||
|
@ -452,9 +485,66 @@ class TestLoop(BaseLoop):
|
|||
# predictions should be sequence of BaseDataElement
|
||||
with autocast(enabled=self.fp16):
|
||||
outputs = self.runner.model.test_step(data_batch)
|
||||
|
||||
outputs, self.test_loss = _update_losses(outputs, self.test_loss)
|
||||
|
||||
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
|
||||
self.runner.call_hook(
|
||||
'after_test_iter',
|
||||
batch_idx=idx,
|
||||
data_batch=data_batch,
|
||||
outputs=outputs)
|
||||
|
||||
|
||||
def _parse_losses(losses: Dict[str, HistoryBuffer],
|
||||
stage: str) -> Dict[str, float]:
|
||||
"""Parses the raw losses of the network.
|
||||
|
||||
Args:
|
||||
losses (dict): raw losses of the network.
|
||||
stage (str): The stage of loss, e.g., 'val' or 'test'.
|
||||
|
||||
Returns:
|
||||
dict[str, float]: The key is the loss name, and the value is the
|
||||
average loss.
|
||||
"""
|
||||
all_loss = 0
|
||||
loss_dict: Dict[str, float] = dict()
|
||||
|
||||
for loss_name, loss_value in losses.items():
|
||||
avg_loss = loss_value.mean()
|
||||
loss_dict[loss_name] = avg_loss
|
||||
if 'loss' in loss_name:
|
||||
all_loss += avg_loss
|
||||
|
||||
loss_dict[f'{stage}_loss'] = all_loss
|
||||
return loss_dict
|
||||
|
||||
|
||||
def _update_losses(outputs: list, losses: dict) -> Tuple[list, dict]:
|
||||
"""Update and record the losses of the network.
|
||||
|
||||
Args:
|
||||
outputs (list): The outputs of the network.
|
||||
losses (dict): The losses of the network.
|
||||
|
||||
Returns:
|
||||
list: The updated outputs of the network.
|
||||
dict: The updated losses of the network.
|
||||
"""
|
||||
if isinstance(outputs[-1],
|
||||
BaseDataElement) and outputs[-1].keys() == ['loss']:
|
||||
loss = outputs[-1].loss # type: ignore
|
||||
outputs = outputs[:-1]
|
||||
else:
|
||||
loss = dict()
|
||||
|
||||
for loss_name, loss_value in loss.items():
|
||||
if loss_name not in losses:
|
||||
losses[loss_name] = HistoryBuffer()
|
||||
if isinstance(loss_value, torch.Tensor):
|
||||
losses[loss_name].update(loss_value.item())
|
||||
elif is_list_of(loss_value, torch.Tensor):
|
||||
for loss_value_i in loss_value:
|
||||
losses[loss_name].update(loss_value_i.item())
|
||||
return outputs, losses
|
||||
|
|
|
@ -579,7 +579,7 @@ class Runner:
|
|||
|
||||
@property
|
||||
def hooks(self):
|
||||
"""list[:obj:`Hook`]: A list of registered hooks."""
|
||||
"""List[:obj:`Hook`]: A list of registered hooks."""
|
||||
return self._hooks
|
||||
|
||||
@property
|
||||
|
@ -720,7 +720,7 @@ class Runner:
|
|||
|
||||
def build_logger(self,
|
||||
log_level: Union[int, str] = 'INFO',
|
||||
log_file: str = None,
|
||||
log_file: Optional[str] = None,
|
||||
**kwargs) -> MMLogger:
|
||||
"""Build a global asscessable MMLogger.
|
||||
|
||||
|
@ -1677,7 +1677,7 @@ class Runner:
|
|||
return '\n'.join(stage_hook_infos)
|
||||
|
||||
def load_or_resume(self) -> None:
|
||||
"""load or resume checkpoint."""
|
||||
"""Load or resume checkpoint."""
|
||||
if self._has_loaded:
|
||||
return None
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import numpy as np
|
|||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmengine.device import is_cuda_available, is_musa_available
|
||||
from mmengine.dist import get_rank, sync_random_seed
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.utils import digit_version, is_list_of
|
||||
|
@ -69,7 +70,10 @@ def set_random_seed(seed: Optional[int] = None,
|
|||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
# torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
if is_cuda_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
elif is_musa_available():
|
||||
torch.musa.manual_seed_all(seed)
|
||||
# os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
if deterministic:
|
||||
if torch.backends.cudnn.benchmark:
|
||||
|
|
|
@ -387,7 +387,7 @@ class BaseDataElement:
|
|||
return dict(self.metainfo_items())
|
||||
|
||||
def __setattr__(self, name: str, value: Any):
|
||||
"""setattr is only used to set data."""
|
||||
"""Setattr is only used to set data."""
|
||||
if name in ('_metainfo_fields', '_data_fields'):
|
||||
if not hasattr(self, name):
|
||||
super().__setattr__(name, value)
|
||||
|
@ -510,6 +510,17 @@ class BaseDataElement:
|
|||
new_data.set_data(data)
|
||||
return new_data
|
||||
|
||||
# Tensor-like methods
|
||||
def musa(self) -> 'BaseDataElement':
|
||||
"""Convert all tensors to musa in data."""
|
||||
new_data = self.new()
|
||||
for k, v in self.items():
|
||||
if isinstance(v, (torch.Tensor, BaseDataElement)):
|
||||
v = v.musa()
|
||||
data = {k: v}
|
||||
new_data.set_data(data)
|
||||
return new_data
|
||||
|
||||
# Tensor-like methods
|
||||
def npu(self) -> 'BaseDataElement':
|
||||
"""Convert all tensors to NPU in data."""
|
||||
|
|
|
@ -18,6 +18,9 @@ if get_device() == 'npu':
|
|||
elif get_device() == 'mlu':
|
||||
BoolTypeTensor = Union[torch.BoolTensor, torch.mlu.BoolTensor]
|
||||
LongTypeTensor = Union[torch.LongTensor, torch.mlu.LongTensor]
|
||||
elif get_device() == 'musa':
|
||||
BoolTypeTensor = Union[torch.BoolTensor, torch.musa.BoolTensor]
|
||||
LongTypeTensor = Union[torch.LongTensor, torch.musa.LongTensor]
|
||||
else:
|
||||
BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor]
|
||||
LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor]
|
||||
|
@ -132,7 +135,7 @@ class InstanceData(BaseDataElement):
|
|||
"""
|
||||
|
||||
def __setattr__(self, name: str, value: Sized):
|
||||
"""setattr is only used to set data.
|
||||
"""Setattr is only used to set data.
|
||||
|
||||
The value must have the attribute of `__len__` and have the same length
|
||||
of `InstanceData`.
|
||||
|
|
|
@ -92,10 +92,25 @@ class MultiProcessTestCase(TestCase):
|
|||
# Constructor patches current instance test method to
|
||||
# assume the role of the main process and join its subprocesses,
|
||||
# or run the underlying test function.
|
||||
def __init__(self, method_name: str = 'runTest') -> None:
|
||||
def __init__(self,
|
||||
method_name: str = 'runTest',
|
||||
methodName: str = 'runTest') -> None:
|
||||
# methodName is the correct naming in unittest
|
||||
# and testslide uses keyword arguments.
|
||||
# So we need to use both to 1) not break BC and, 2) support testslide.
|
||||
if methodName != 'runTest':
|
||||
method_name = methodName
|
||||
super().__init__(method_name)
|
||||
fn = getattr(self, method_name)
|
||||
setattr(self, method_name, self.join_or_run(fn))
|
||||
try:
|
||||
fn = getattr(self, method_name)
|
||||
setattr(self, method_name, self.join_or_run(fn))
|
||||
except AttributeError as e:
|
||||
if methodName != 'runTest':
|
||||
# we allow instantiation with no explicit method name
|
||||
# but not an *incorrect* or missing method name
|
||||
raise ValueError(
|
||||
f'no such test method in {self.__class__}: {methodName}'
|
||||
) from e
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""This file holding some environment constant for sharing by other files."""
|
||||
import os
|
||||
import os.path as osp
|
||||
import subprocess
|
||||
import sys
|
||||
|
@ -9,6 +10,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
import mmengine
|
||||
from mmengine.device import is_cuda_available, is_musa_available
|
||||
from .parrots_wrapper import TORCH_VERSION, get_build_config, is_rocm_pytorch
|
||||
|
||||
|
||||
|
@ -24,6 +26,10 @@ def _get_cuda_home():
|
|||
return CUDA_HOME
|
||||
|
||||
|
||||
def _get_musa_home():
|
||||
return os.environ.get('MUSA_HOME')
|
||||
|
||||
|
||||
def collect_env():
|
||||
"""Collect the information of the running environments.
|
||||
|
||||
|
@ -51,9 +57,10 @@ def collect_env():
|
|||
env_info['sys.platform'] = sys.platform
|
||||
env_info['Python'] = sys.version.replace('\n', '')
|
||||
|
||||
cuda_available = torch.cuda.is_available()
|
||||
cuda_available = is_cuda_available()
|
||||
musa_available = is_musa_available()
|
||||
env_info['CUDA available'] = cuda_available
|
||||
|
||||
env_info['MUSA available'] = musa_available
|
||||
env_info['numpy_random_seed'] = np.random.get_state()[1][0]
|
||||
|
||||
if cuda_available:
|
||||
|
@ -89,7 +96,23 @@ def collect_env():
|
|||
except subprocess.SubprocessError:
|
||||
nvcc = 'Not Available'
|
||||
env_info['NVCC'] = nvcc
|
||||
elif musa_available:
|
||||
devices = defaultdict(list)
|
||||
for k in range(torch.musa.device_count()):
|
||||
devices[torch.musa.get_device_name(k)].append(str(k))
|
||||
for name, device_ids in devices.items():
|
||||
env_info['GPU ' + ','.join(device_ids)] = name
|
||||
|
||||
MUSA_HOME = _get_musa_home()
|
||||
env_info['MUSA_HOME'] = MUSA_HOME
|
||||
|
||||
if MUSA_HOME is not None and osp.isdir(MUSA_HOME):
|
||||
try:
|
||||
mcc = osp.join(MUSA_HOME, 'bin/mcc')
|
||||
subprocess.check_output(f'"{mcc}" -v', shell=True)
|
||||
except subprocess.SubprocessError:
|
||||
mcc = 'Not Available'
|
||||
env_info['mcc'] = mcc
|
||||
try:
|
||||
# Check C++ Compiler.
|
||||
# For Unix-like, sysconfig has 'CC' variable like 'gcc -pthread ...',
|
||||
|
@ -138,7 +161,7 @@ def collect_env():
|
|||
try:
|
||||
import cv2
|
||||
env_info['OpenCV'] = cv2.__version__
|
||||
except ModuleNotFoundError:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
env_info['MMEngine'] = mmengine.__version__
|
||||
|
|
|
@ -57,6 +57,7 @@ if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version(
|
|||
check_hash=False,
|
||||
file_name=None):
|
||||
r"""Loads the Torch serialized object at the given URL.
|
||||
|
||||
If downloaded file is a zip file, it will be automatically decompressed
|
||||
If the object is already present in `model_dir`, it's deserialized and
|
||||
returned.
|
||||
|
|
|
@ -4,6 +4,7 @@ from typing import Optional, Union
|
|||
|
||||
import torch
|
||||
|
||||
from mmengine.device import is_cuda_available, is_musa_available
|
||||
from mmengine.dist.utils import master_only
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
|
||||
|
@ -66,7 +67,7 @@ class TimeCounter:
|
|||
|
||||
instance.log_interval = log_interval
|
||||
instance.warmup_interval = warmup_interval
|
||||
instance.with_sync = with_sync
|
||||
instance.with_sync = with_sync # type: ignore
|
||||
instance.tag = tag
|
||||
instance.logger = logger
|
||||
|
||||
|
@ -84,15 +85,20 @@ class TimeCounter:
|
|||
def wrapper(*args, **kwargs):
|
||||
self.__count += 1
|
||||
|
||||
if self.with_sync and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
if self.with_sync:
|
||||
if is_cuda_available():
|
||||
torch.cuda.synchronize()
|
||||
elif is_musa_available():
|
||||
torch.musa.synchronize()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
result = fn(*args, **kwargs)
|
||||
|
||||
if self.with_sync and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if self.with_sync:
|
||||
if is_cuda_available():
|
||||
torch.cuda.synchronize()
|
||||
elif is_musa_available():
|
||||
torch.musa.synchronize()
|
||||
elapsed = time.perf_counter() - start_time
|
||||
self.print_time(elapsed)
|
||||
|
||||
|
@ -121,7 +127,7 @@ class TimeCounter:
|
|||
self.print_time(elapsed)
|
||||
|
||||
def print_time(self, elapsed: Union[int, float]) -> None:
|
||||
"""print times per count."""
|
||||
"""Print times per count."""
|
||||
if self.__count >= self.warmup_interval:
|
||||
self.__pure_inf_time += elapsed
|
||||
|
||||
|
|
|
@ -131,7 +131,7 @@ def tuple_cast(inputs, dst_type):
|
|||
|
||||
def is_seq_of(seq: Any,
|
||||
expected_type: Union[Type, tuple],
|
||||
seq_type: Type = None) -> bool:
|
||||
seq_type: Optional[Type] = None) -> bool:
|
||||
"""Check whether it is a sequence of some type.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -69,11 +69,11 @@ def get_installed_path(package: str) -> str:
|
|||
else:
|
||||
raise e
|
||||
|
||||
possible_path = osp.join(pkg.location, package)
|
||||
possible_path = osp.join(pkg.location, package) # type: ignore
|
||||
if osp.exists(possible_path):
|
||||
return possible_path
|
||||
else:
|
||||
return osp.join(pkg.location, package2module(package))
|
||||
return osp.join(pkg.location, package2module(package)) # type: ignore
|
||||
|
||||
|
||||
def package2module(package: str):
|
||||
|
|
|
@ -3,7 +3,7 @@ import sys
|
|||
from collections.abc import Iterable
|
||||
from multiprocessing import Pool
|
||||
from shutil import get_terminal_size
|
||||
from typing import Callable, Sequence
|
||||
from typing import Callable, Optional, Sequence
|
||||
|
||||
from .timer import Timer
|
||||
|
||||
|
@ -54,7 +54,7 @@ class ProgressBar:
|
|||
self.timer = Timer()
|
||||
|
||||
def update(self, num_tasks: int = 1):
|
||||
"""update progressbar.
|
||||
"""Update progressbar.
|
||||
|
||||
Args:
|
||||
num_tasks (int): Update step size.
|
||||
|
@ -142,8 +142,8 @@ def init_pool(process_num, initializer=None, initargs=None):
|
|||
def track_parallel_progress(func: Callable,
|
||||
tasks: Sequence,
|
||||
nproc: int,
|
||||
initializer: Callable = None,
|
||||
initargs: tuple = None,
|
||||
initializer: Optional[Callable] = None,
|
||||
initargs: Optional[tuple] = None,
|
||||
bar_width: int = 50,
|
||||
chunksize: int = 1,
|
||||
skip_first: bool = False,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from multiprocessing import Pool
|
||||
from typing import Callable, Iterable, Sized
|
||||
from typing import Callable, Iterable, Optional, Sized
|
||||
|
||||
from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task,
|
||||
TaskProgressColumn, TextColumn, TimeRemainingColumn)
|
||||
|
@ -47,7 +47,7 @@ def _tasks_with_index(tasks):
|
|||
|
||||
def track_progress_rich(func: Callable,
|
||||
tasks: Iterable = tuple(),
|
||||
task_num: int = None,
|
||||
task_num: Optional[int] = None,
|
||||
nproc: int = 1,
|
||||
chunksize: int = 1,
|
||||
description: str = 'Processing',
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
__version__ = '0.10.1'
|
||||
__version__ = '0.10.7'
|
||||
|
||||
|
||||
def parse_version_info(version_str):
|
||||
|
|
|
@ -161,7 +161,7 @@ class BaseVisBackend(metaclass=ABCMeta):
|
|||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""close an opened object."""
|
||||
"""Close an opened object."""
|
||||
pass
|
||||
|
||||
|
||||
|
@ -314,7 +314,7 @@ class LocalVisBackend(BaseVisBackend):
|
|||
|
||||
def _dump(self, value_dict: dict, file_path: str,
|
||||
file_format: str) -> None:
|
||||
"""dump dict to file.
|
||||
"""Dump dict to file.
|
||||
|
||||
Args:
|
||||
value_dict (dict) : The dict data to saved.
|
||||
|
@ -505,7 +505,7 @@ class WandbVisBackend(BaseVisBackend):
|
|||
self._wandb.log(scalar_dict, commit=self._commit)
|
||||
|
||||
def close(self) -> None:
|
||||
"""close an opened wandb object."""
|
||||
"""Close an opened wandb object."""
|
||||
if hasattr(self, '_wandb'):
|
||||
self._wandb.join()
|
||||
|
||||
|
@ -629,7 +629,7 @@ class TensorboardVisBackend(BaseVisBackend):
|
|||
self.add_scalar(key, value, step)
|
||||
|
||||
def close(self):
|
||||
"""close an opened tensorboard object."""
|
||||
"""Close an opened tensorboard object."""
|
||||
if hasattr(self, '_tensorboard'):
|
||||
self._tensorboard.close()
|
||||
|
||||
|
@ -669,6 +669,10 @@ class MLflowVisBackend(BaseVisBackend):
|
|||
will be added to the experiment. If it is None, which means all
|
||||
the config will be added. Defaults to None.
|
||||
`New in version 0.7.4.`
|
||||
artifact_location (str, optional): The location to store run artifacts.
|
||||
If None, the server picks an appropriate default.
|
||||
Defaults to None.
|
||||
`New in version 0.10.4.`
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -680,7 +684,8 @@ class MLflowVisBackend(BaseVisBackend):
|
|||
tracking_uri: Optional[str] = None,
|
||||
artifact_suffix: SUFFIX_TYPE = ('.json', '.log', '.py',
|
||||
'yaml'),
|
||||
tracked_config_keys: Optional[dict] = None):
|
||||
tracked_config_keys: Optional[dict] = None,
|
||||
artifact_location: Optional[str] = None):
|
||||
super().__init__(save_dir)
|
||||
self._exp_name = exp_name
|
||||
self._run_name = run_name
|
||||
|
@ -689,6 +694,7 @@ class MLflowVisBackend(BaseVisBackend):
|
|||
self._tracking_uri = tracking_uri
|
||||
self._artifact_suffix = artifact_suffix
|
||||
self._tracked_config_keys = tracked_config_keys
|
||||
self._artifact_location = artifact_location
|
||||
|
||||
def _init_env(self):
|
||||
"""Setup env for MLflow."""
|
||||
|
@ -726,7 +732,8 @@ class MLflowVisBackend(BaseVisBackend):
|
|||
self._exp_name = self._exp_name or 'Default'
|
||||
|
||||
if self._mlflow.get_experiment_by_name(self._exp_name) is None:
|
||||
self._mlflow.create_experiment(self._exp_name)
|
||||
self._mlflow.create_experiment(
|
||||
self._exp_name, artifact_location=self._artifact_location)
|
||||
|
||||
self._mlflow.set_experiment(self._exp_name)
|
||||
|
||||
|
@ -1128,7 +1135,7 @@ class NeptuneVisBackend(BaseVisBackend):
|
|||
self._neptune[k].append(v, step=step)
|
||||
|
||||
def close(self) -> None:
|
||||
"""close an opened object."""
|
||||
"""Close an opened object."""
|
||||
if hasattr(self, '_neptune'):
|
||||
self._neptune.stop()
|
||||
|
||||
|
@ -1275,7 +1282,7 @@ class DVCLiveVisBackend(BaseVisBackend):
|
|||
self.add_scalar(key, value, step, **kwargs)
|
||||
|
||||
def close(self) -> None:
|
||||
"""close an opened dvclive object."""
|
||||
"""Close an opened dvclive object."""
|
||||
if not hasattr(self, '_dvclive'):
|
||||
return
|
||||
|
||||
|
|
|
@ -356,7 +356,7 @@ class Visualizer(ManagerMixin):
|
|||
|
||||
@master_only
|
||||
def get_backend(self, name) -> 'BaseVisBackend':
|
||||
"""get vis backend by name.
|
||||
"""Get vis backend by name.
|
||||
|
||||
Args:
|
||||
name (str): The name of vis backend
|
||||
|
@ -879,7 +879,7 @@ class Visualizer(ManagerMixin):
|
|||
if binary_masks.ndim == 2:
|
||||
binary_masks = binary_masks[None]
|
||||
assert img.shape[:2] == binary_masks.shape[
|
||||
1:], '`binary_marks` must have ' \
|
||||
1:], '`binary_masks` must have ' \
|
||||
'the same shape with image'
|
||||
binary_mask_len = binary_masks.shape[0]
|
||||
|
||||
|
@ -961,7 +961,7 @@ class Visualizer(ManagerMixin):
|
|||
if topk <= 0, tensor_chw is assert to be one or three.
|
||||
Defaults to 20.
|
||||
arrangement (Tuple[int, int]): The arrangement of featmap when
|
||||
channel_reduction is not None and topk > 0. Defaults to (4, 5).
|
||||
channel_reduction is None and topk > 0. Defaults to (4, 5).
|
||||
resize_shape (tuple, optional): The shape to scale the feature map.
|
||||
Defaults to None.
|
||||
alpha (Union[int, List[int]]): The transparency of featmap.
|
||||
|
@ -989,7 +989,7 @@ class Visualizer(ManagerMixin):
|
|||
f'overlaid_image: {overlaid_image.shape[:2]} and '
|
||||
f'featmap: {featmap.shape[1:]} are not same, '
|
||||
f'the feature map will be interpolated. '
|
||||
f'This may cause mismatch problems !')
|
||||
f'This may cause mismatch problems !')
|
||||
if resize_shape is None:
|
||||
featmap = F.interpolate(
|
||||
featmap[None],
|
||||
|
@ -1145,7 +1145,7 @@ class Visualizer(ManagerMixin):
|
|||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""close an opened object."""
|
||||
"""Close an opened object."""
|
||||
for vis_backend in self._vis_backends.values():
|
||||
vis_backend.close()
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
docutils==0.17.1
|
||||
docutils==0.18.1
|
||||
myst-parser
|
||||
opencv-python
|
||||
-e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
|
||||
sphinx==4.5.0
|
||||
sphinx==6.2.1
|
||||
sphinx-copybutton
|
||||
sphinx-tabs
|
||||
sphinx_markdown_tables
|
||||
|
|
|
@ -843,8 +843,8 @@ class TestConfig:
|
|||
assert cfg_dict['item4'] == 'test'
|
||||
assert '_delete_' not in cfg_dict['item1']
|
||||
|
||||
assert type(cfg_dict['item1']) == ConfigDict
|
||||
assert type(cfg_dict['item2']) == ConfigDict
|
||||
assert type(cfg_dict['item1']) is ConfigDict
|
||||
assert type(cfg_dict['item2']) is ConfigDict
|
||||
|
||||
def _merge_intermediate_variable(self):
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.device import (get_device, is_cuda_available, is_mlu_available,
|
||||
is_mps_available, is_npu_available)
|
||||
is_mps_available, is_musa_available,
|
||||
is_npu_available)
|
||||
|
||||
|
||||
def test_get_device():
|
||||
|
@ -13,5 +14,7 @@ def test_get_device():
|
|||
assert device == 'mlu'
|
||||
elif is_mps_available():
|
||||
assert device == 'mps'
|
||||
elif is_musa_available():
|
||||
assert device == 'musa'
|
||||
else:
|
||||
assert device == 'cpu'
|
||||
|
|
|
@ -11,6 +11,7 @@ import torch
|
|||
import torch.distributed as torch_dist
|
||||
|
||||
import mmengine.dist as dist
|
||||
from mmengine.device import is_musa_available
|
||||
from mmengine.dist.dist import sync_random_seed
|
||||
from mmengine.testing._internal import MultiProcessTestCase
|
||||
from mmengine.utils import digit_version
|
||||
|
@ -117,6 +118,7 @@ class TestDist(TestCase):
|
|||
self.assertTrue(torch.allclose(item1, item2))
|
||||
|
||||
|
||||
@unittest.skipIf(is_musa_available(), reason='musa do not support gloo yet')
|
||||
class TestDistWithGLOOBackend(MultiProcessTestCase):
|
||||
|
||||
def _init_dist_env(self, rank, world_size):
|
||||
|
|
|
@ -300,8 +300,8 @@ except ImportError:
|
|||
get_inputs.append(filepath)
|
||||
|
||||
with build_temporary_directory() as tmp_dir, \
|
||||
patch.object(backend, 'put', side_effect=put),\
|
||||
patch.object(backend, 'get', side_effect=get),\
|
||||
patch.object(backend, 'put', side_effect=put), \
|
||||
patch.object(backend, 'get', side_effect=get), \
|
||||
patch.object(backend, 'exists', return_value=False):
|
||||
tmp_dir = tmp_dir.replace('\\', '/')
|
||||
dst = f'{tmp_dir}/dir'
|
||||
|
@ -351,7 +351,7 @@ except ImportError:
|
|||
|
||||
with build_temporary_directory() as tmp_dir, \
|
||||
patch.object(backend, 'copyfile_from_local',
|
||||
side_effect=copyfile_from_local),\
|
||||
side_effect=copyfile_from_local), \
|
||||
patch.object(backend, 'exists', return_value=False):
|
||||
backend.copytree_from_local(tmp_dir, self.petrel_dir)
|
||||
|
||||
|
@ -427,7 +427,7 @@ except ImportError:
|
|||
def remove(filepath):
|
||||
inputs.append(filepath)
|
||||
|
||||
with build_temporary_directory() as tmp_dir,\
|
||||
with build_temporary_directory() as tmp_dir, \
|
||||
patch.object(backend, 'remove', side_effect=remove):
|
||||
backend.rmtree(tmp_dir)
|
||||
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os.path as osp
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmengine.config import ConfigDict
|
||||
from mmengine.device import is_musa_available
|
||||
from mmengine.hooks import EMAHook
|
||||
from mmengine.model import BaseModel, ExponentialMovingAverage
|
||||
from mmengine.registry import MODELS
|
||||
|
@ -45,6 +47,9 @@ class ToyModel3(ToyModel):
|
|||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
# TODO:haowen.han@mtheads.com
|
||||
@unittest.skipIf(is_musa_available(),
|
||||
"musa backend do not support 'aten::lerp.Scalar_out'")
|
||||
class TestEMAHook(RunnerTestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
|
|
|
@ -1,11 +1,16 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from mmengine.device import is_cuda_available
|
||||
from mmengine.testing import RunnerTestCase
|
||||
|
||||
|
||||
class TestEmptyCacheHook(RunnerTestCase):
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_cuda_available(), reason='cuda should be available')
|
||||
def test_with_runner(self):
|
||||
with patch('torch.cuda.empty_cache') as mock_empty_cache:
|
||||
cfg = self.epoch_based_cfg
|
||||
|
|
|
@ -13,7 +13,8 @@ from mmengine.testing import assert_allclose
|
|||
class TestAveragedModel(TestCase):
|
||||
"""Test the AveragedModel class.
|
||||
|
||||
Some test cases are referenced from https://github.com/pytorch/pytorch/blob/master/test/test_optim.py
|
||||
Some test cases are referenced from
|
||||
https://github.com/pytorch/pytorch/blob/master/test/test_optim.py
|
||||
""" # noqa: E501
|
||||
|
||||
def _test_swa_model(self, net_device, avg_device):
|
||||
|
|
|
@ -549,7 +549,8 @@ class TestBuilder(TestCase):
|
|||
weight_decay=self.base_wd,
|
||||
momentum=self.momentum))
|
||||
paramwise_cfg = dict()
|
||||
optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg)
|
||||
optim_constructor = DefaultOptimWrapperConstructor(
|
||||
optim_wrapper_cfg, paramwise_cfg)
|
||||
optim_wrapper = optim_constructor(model)
|
||||
self._check_default_optimizer(optim_wrapper.optimizer, model)
|
||||
|
||||
|
@ -591,23 +592,16 @@ class TestBuilder(TestCase):
|
|||
dwconv_decay_mult=0.1,
|
||||
dcn_offset_lr_mult=0.1)
|
||||
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
self.model.conv1.requires_grad_(False)
|
||||
optim_constructor = DefaultOptimWrapperConstructor(
|
||||
optim_wrapper_cfg, paramwise_cfg)
|
||||
optim_wrapper = optim_constructor(self.model)
|
||||
optimizer = optim_wrapper.optimizer
|
||||
param_groups = optimizer.param_groups
|
||||
assert isinstance(optim_wrapper.optimizer, torch.optim.SGD)
|
||||
assert optimizer.defaults['lr'] == self.base_lr
|
||||
assert optimizer.defaults['momentum'] == self.momentum
|
||||
assert optimizer.defaults['weight_decay'] == self.base_wd
|
||||
for i, (name, param) in enumerate(self.model.named_parameters()):
|
||||
param_group = param_groups[i]
|
||||
assert torch.equal(param_group['params'][0], param)
|
||||
assert param_group['momentum'] == self.momentum
|
||||
assert param_group['lr'] == self.base_lr
|
||||
assert param_group['weight_decay'] == self.base_wd
|
||||
|
||||
all_params = []
|
||||
for pg in optim_wrapper.param_groups:
|
||||
all_params.extend(map(id, pg['params']))
|
||||
self.assertNotIn(id(self.model.conv1.weight), all_params)
|
||||
self.assertIn(id(self.model.conv2.weight), all_params)
|
||||
|
||||
def test_default_optimizer_constructor_bypass_duplicate(self):
|
||||
# paramwise_cfg with bypass_duplicate option
|
||||
|
@ -663,10 +657,8 @@ class TestBuilder(TestCase):
|
|||
optim_wrapper = optim_constructor(model)
|
||||
model_parameters = list(model.parameters())
|
||||
num_params = 14 if MMCV_FULL_AVAILABLE else 11
|
||||
assert len(optim_wrapper.optimizer.param_groups) == len(
|
||||
model_parameters) == num_params
|
||||
self._check_sgd_optimizer(optim_wrapper.optimizer, model,
|
||||
**paramwise_cfg)
|
||||
assert len(optim_wrapper.optimizer.param_groups
|
||||
) == len(model_parameters) - 1 == num_params - 1
|
||||
|
||||
def test_default_optimizer_constructor_custom_key(self):
|
||||
# test DefaultOptimWrapperConstructor with custom_keys and
|
||||
|
|
|
@ -102,7 +102,7 @@ class TestLRScheduler(TestCase):
|
|||
rtol=0)
|
||||
|
||||
def test_scheduler_before_optim_warning(self):
|
||||
"""warns if scheduler is used before optimizer."""
|
||||
"""Warns if scheduler is used before optimizer."""
|
||||
|
||||
def call_sch_before_optim():
|
||||
scheduler = StepLR(self.optimizer, gamma=0.1, step_size=3)
|
||||
|
|
|
@ -120,7 +120,7 @@ class TestMomentumScheduler(TestCase):
|
|||
rtol=0)
|
||||
|
||||
def test_scheduler_before_optim_warning(self):
|
||||
"""warns if scheduler is used before optimizer."""
|
||||
"""Warns if scheduler is used before optimizer."""
|
||||
|
||||
def call_sch_before_optim():
|
||||
scheduler = StepMomentum(self.optimizer, gamma=0.1, step_size=3)
|
||||
|
|
|
@ -127,7 +127,7 @@ class TestParameterScheduler(TestCase):
|
|||
rtol=0)
|
||||
|
||||
def test_scheduler_before_optim_warning(self):
|
||||
"""warns if scheduler is used before optimizer."""
|
||||
"""Warns if scheduler is used before optimizer."""
|
||||
|
||||
def call_sch_before_optim():
|
||||
scheduler = StepParamScheduler(
|
||||
|
|
|
@ -5,7 +5,8 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
import mmengine
|
||||
from mmengine.device import get_device, is_mlu_available, is_npu_available
|
||||
from mmengine.device import (get_device, is_mlu_available, is_musa_available,
|
||||
is_npu_available)
|
||||
from mmengine.runner import autocast
|
||||
from mmengine.utils import digit_version
|
||||
from mmengine.utils.dl_utils import TORCH_VERSION
|
||||
|
@ -44,6 +45,21 @@ class TestAmp(unittest.TestCase):
|
|||
layer = nn.Conv2d(1, 1, 1).to(device)
|
||||
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
||||
self.assertEqual(res.dtype, torch.float32)
|
||||
elif is_musa_available():
|
||||
device = 'musa'
|
||||
with autocast(device_type=device):
|
||||
# torch.autocast support mlu mode.
|
||||
layer = nn.Conv2d(1, 1, 1).to(device)
|
||||
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
||||
self.assertIn(res.dtype, (torch.bfloat16, torch.float16))
|
||||
with autocast(enabled=False, device_type=device):
|
||||
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
||||
self.assertEqual(res.dtype, torch.float32)
|
||||
# Test with fp32_enabled
|
||||
with autocast(enabled=False, device_type=device):
|
||||
layer = nn.Conv2d(1, 1, 1).to(device)
|
||||
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
||||
self.assertEqual(res.dtype, torch.float32)
|
||||
elif not torch.cuda.is_available():
|
||||
if digit_version(TORCH_VERSION) < digit_version('1.10.0'):
|
||||
# `torch.cuda.amp.autocast` is only support in gpu mode, if
|
||||
|
|
|
@ -251,7 +251,7 @@ def test_load_checkpoint_metadata():
|
|||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
*args, **kwargs):
|
||||
"""load checkpoints."""
|
||||
"""Load checkpoints."""
|
||||
|
||||
# Names of some parameters in has been changed.
|
||||
version = local_metadata.get('version', None)
|
||||
|
|
|
@ -7,6 +7,7 @@ import pytest
|
|||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from mmengine.device import is_cuda_available, is_musa_available
|
||||
from mmengine.logging import HistoryBuffer, MessageHub, MMLogger
|
||||
from mmengine.runner import LogProcessor
|
||||
from mmengine.testing import RunnerTestCase
|
||||
|
@ -113,7 +114,7 @@ class TestLogProcessor(RunnerTestCase):
|
|||
f"time: {train_logs['time']:.4f} "
|
||||
f"data_time: {train_logs['data_time']:.4f} ")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if is_cuda_available() or is_musa_available():
|
||||
log_str += 'memory: 100 '
|
||||
if mode == 'train':
|
||||
log_str += f"loss_cls: {train_logs['loss_cls']:.4f}"
|
||||
|
@ -141,7 +142,7 @@ class TestLogProcessor(RunnerTestCase):
|
|||
f"time: {train_logs['time']:.4f} "
|
||||
f"data_time: {train_logs['data_time']:.4f} ")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if is_cuda_available() or is_musa_available():
|
||||
log_str += 'memory: 100 '
|
||||
|
||||
if mode == 'train':
|
||||
|
@ -249,6 +250,7 @@ class TestLogProcessor(RunnerTestCase):
|
|||
assert tag['metric1'] is metric1
|
||||
assert tag['metric2'] is metric2
|
||||
|
||||
# TODO:haowen.han@mtheads.com MUSA does not support it yet!
|
||||
@patch('torch.cuda.max_memory_allocated', MagicMock())
|
||||
@patch('torch.cuda.reset_peak_memory_stats', MagicMock())
|
||||
def test_get_max_memory(self):
|
||||
|
|
|
@ -2226,7 +2226,7 @@ class TestRunner(TestCase):
|
|||
|
||||
@HOOKS.register_module(force=True)
|
||||
class TestWarmupHook(Hook):
|
||||
"""test custom train loop."""
|
||||
"""Test custom train loop."""
|
||||
|
||||
def before_warmup_iter(self, runner, data_batch=None):
|
||||
before_warmup_iter_results.append('before')
|
||||
|
|
|
@ -64,7 +64,7 @@ class TestBaseDataElement(TestCase):
|
|||
return metainfo, data
|
||||
|
||||
def is_equal(self, x, y):
|
||||
assert type(x) == type(y)
|
||||
assert type(x) is type(y)
|
||||
if isinstance(
|
||||
x, (int, float, str, list, tuple, dict, set, BaseDataElement)):
|
||||
return x == y
|
||||
|
@ -141,7 +141,7 @@ class TestBaseDataElement(TestCase):
|
|||
|
||||
# test new() with no arguments
|
||||
new_instances = instances.new()
|
||||
assert type(new_instances) == type(instances)
|
||||
assert type(new_instances) is type(instances)
|
||||
# After deepcopy, the address of new data'element will be same as
|
||||
# origin, but when change new data' element will not effect the origin
|
||||
# element and will have new address
|
||||
|
@ -154,7 +154,7 @@ class TestBaseDataElement(TestCase):
|
|||
# test new() with arguments
|
||||
metainfo, data = self.setup_data()
|
||||
new_instances = instances.new(metainfo=metainfo, **data)
|
||||
assert type(new_instances) == type(instances)
|
||||
assert type(new_instances) is type(instances)
|
||||
assert id(new_instances.gt_instances) != id(instances.gt_instances)
|
||||
_, new_data = self.setup_data()
|
||||
new_instances.set_data(new_data)
|
||||
|
@ -168,7 +168,7 @@ class TestBaseDataElement(TestCase):
|
|||
metainfo, data = self.setup_data()
|
||||
instances = BaseDataElement(metainfo=metainfo, **data)
|
||||
new_instances = instances.clone()
|
||||
assert type(new_instances) == type(instances)
|
||||
assert type(new_instances) is type(instances)
|
||||
|
||||
def test_set_metainfo(self):
|
||||
metainfo, _ = self.setup_data()
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
from mmengine.utils.dl_utils import torch_meshgrid
|
||||
|
@ -7,9 +8,8 @@ from mmengine.utils.dl_utils import torch_meshgrid
|
|||
|
||||
def test_torch_meshgrid():
|
||||
# torch_meshgrid should not throw warning
|
||||
with pytest.warns(None) as record:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('error')
|
||||
x = torch.tensor([1, 2, 3])
|
||||
y = torch.tensor([4, 5, 6])
|
||||
grid_x, grid_y = torch_meshgrid(x, y)
|
||||
|
||||
assert len(record) == 0
|
||||
|
|
|
@ -282,6 +282,14 @@ class TestMLflowVisBackend:
|
|||
mlflow_vis_backend = MLflowVisBackend('temp_dir')
|
||||
assert mlflow_vis_backend.experiment == mlflow_vis_backend._mlflow
|
||||
|
||||
def test_create_experiment(self):
|
||||
with patch('mlflow.create_experiment') as mock_create_experiment:
|
||||
MLflowVisBackend(
|
||||
'temp_dir', exp_name='test',
|
||||
artifact_location='foo')._init_env()
|
||||
mock_create_experiment.assert_any_call(
|
||||
'test', artifact_location='foo')
|
||||
|
||||
def test_add_config(self):
|
||||
cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
|
||||
mlflow_vis_backend = MLflowVisBackend('temp_dir')
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue