Compare commits

...

45 Commits

Author SHA1 Message Date
Mashiro 390ba2fbb2
Bump version to 0.10.7 2025-03-04 20:20:27 +08:00
Epiphany d620552c2c
[fix] fix ci (#1636) 2025-02-27 15:13:38 +08:00
Mashiro 41fa84a9a9
[Fix] remove torch dependencies in `build_function.py` (#1632) 2025-02-18 16:38:51 +08:00
Mashiro 698782f920
[Fix] Fix deploy ci (#1628)
* [Enhance] Support trigger ci manually

* [Fix] Fix deploy CI
2025-01-15 18:11:05 +08:00
Mashiro e60ab1dde3
[Enhance] Support trigger ci manually (#1627) 2025-01-15 18:07:34 +08:00
Mashiro 8ec837814e
[Enhance] Support trigger ci manually (#1626) 2025-01-15 17:58:23 +08:00
Qian Zhao a4475f5eea
Update deploy.yml (#1625) 2025-01-15 17:34:07 +08:00
Mashiro a8c74c346d
Bump version to v0.10.6 (#1623) 2025-01-13 19:20:26 +08:00
Epiphany 9124ebf7a2
[Enhance] ensure type in cfg (#1602)
* ensure type in cfg

* change import level
2024-11-06 20:26:47 +08:00
Epiphany 2e0ab7a922
[Fix] fix Adafactor optim on torch2.5 and fix compatibility (#1600)
* fix Adafactor opptim on torch2.5 and fix compatibility

* fix runtest error
2024-11-05 21:22:46 +08:00
Epiphany fc59364d64
[Fix] fix error when pytest>=8.2 (#1601) 2024-11-05 20:43:17 +08:00
Tibor Reiss 4183cf0829
Fix return in finally (#1596) 2024-11-04 14:39:25 +08:00
Mashiro cc3b74b5e8
[Fix] Fix lint (#1598)
* [Fix] Fix lint

* [Fix] Fix lint

* Update mmengine/dist/utils.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

---------

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
2024-11-02 22:23:51 +08:00
Mashiro c9b59962d6
Yehc/bump version to v0.10.5 (#1572)
* fix check builtin module

* bump version to v0.10.5

* bump version to v0.10.5
2024-09-20 15:42:06 +08:00
Mashiro 5e736b143b
fix check builtin module (#1571) 2024-09-11 18:45:24 +08:00
Chris Jiang 85c83ba616
Update is_mlu_available (#1537)
* Update is_mlu_available

to adapt torch_mlu main, the torch.is_mlu_available method is removed

* Update utils.py

* Update utils.py
2024-05-30 10:07:45 +08:00
fanqiNO1 d1f1aabf81
[Feature] Support calculating loss during validation (#1503) 2024-05-17 15:27:53 +08:00
fanqiNO1 66fb81f7b3
Bump version to 0.10.4 (#1534) 2024-04-23 11:23:12 +08:00
Zhihao Lin acbc5e46dc
[Fix] Delete frozen parameters when using `paramwise_cfg` (#1441) 2024-04-22 19:54:48 +08:00
Hiram Foster 9ecced821b
Fix a typo (#1532) 2024-04-22 19:51:59 +08:00
Zhihao Lin 39ed23fae8
[Enhance] Enable `exclude_frozen_parameters` for `DeepSpeedEngine._zero3_consolidated_16bit_state_dict` (#1517) 2024-04-12 14:25:54 +08:00
Zhihao Lin e258c84824
Perform evaluation upon training completion (#1529) 2024-04-08 13:05:36 +08:00
Zaida Zhou 2c4516c622
Add the supported pytorch versions in README (#1512) 2024-03-06 11:00:42 +08:00
Zaida Zhou 447d3bba2c
Fix config of readthedocs (#1511) 2024-03-06 10:30:00 +08:00
David de la Iglesia Castro 2fe0ecec3d
[Feature] Support custom `artifact_location` in MLflowVisBackend (#1505) 2024-02-26 18:37:25 +08:00
jason_w c423d0c1da
Fix docstring of Config (#1506) 2024-02-24 09:46:01 +08:00
Zaida Zhou 9b98405672
Remove codeowners file (#1496) 2024-02-18 17:19:28 +08:00
Evan 4df682ba2d
Fix typos and remove fullwidth unicode chars (#1488) 2024-02-18 15:33:52 +08:00
fanqiNO1 ba5eed8409
[Fix] Fix warning capture (#1494) 2024-02-18 14:17:35 +08:00
Zaida Zhou f79111ecc0
fix typo (#1481) 2024-01-24 19:31:08 +08:00
Zaida Zhou b5f2d5860d
Refine mmengine introduction (#1479) 2024-01-24 19:27:02 +08:00
Zaida Zhou 02f80e8bdd
Bump version to 0.10.3 (#1478) 2024-01-24 12:45:00 +08:00
Zhihao Lin cd298e3086
[Feature] Support save_optimizer=False for DeepSpeed (#1474) 2024-01-24 11:12:54 +08:00
Anm半夏 396cac19cd
Fix a typo in visualizer.py (#1476) 2024-01-23 11:09:05 +08:00
hanhaowen-mt 3d8a611eec
[Feature] Add the support for musa device support (#1453) 2024-01-11 16:25:01 +08:00
Zhihao Lin 109cd44c7e
[Fix] Fix dist.collect_results to keep all ranks' elements (#1469) 2024-01-11 10:50:36 +08:00
Zhihao Lin b51bf60964
[Fix] Fix the resume of iteration (#1471) 2024-01-11 10:47:05 +08:00
Mashiro 4a50213c69
[Fix] Fix Config.to_dict (#1465) 2024-01-02 16:07:54 +08:00
Zaida Zhou e4600a6993
[Docs] Add the usage of ProfilerHook (#1466) 2024-01-02 15:59:37 +08:00
XiwuChen 369f15e27a
[Docs] Fix nnodes in the doc of ddp training (#1462) 2024-01-02 10:42:58 +08:00
fanqiNO1 1398e4200e
bump version to v0.10.2 (#1460) 2023-12-26 16:30:01 +08:00
lanzeshun 8e6fb12b1f
[Fix] Support multi-node distributed training with NPU backend (#1459) 2023-12-26 16:14:45 +08:00
fanqiNO1 671f3bcdf4
[Fix] Fix placement policy in ColossalAIStrategy (#1440) 2023-12-23 16:24:39 +08:00
SCZwangxiao efcd364124
[Fix] Fix load_model_state_dict in BaseStrategy (#1447) 2023-12-23 11:17:46 +08:00
del-zhenwu 504fa4f5cb
[Fix] Use ImportError to cover ModuleNotFoundError raised by opencv-python (#1438) 2023-12-23 11:15:20 +08:00
101 changed files with 1059 additions and 656 deletions

View File

@ -1,6 +1,8 @@
name: deploy
on: push
on:
- push
- workflow_dispatch
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
@ -9,13 +11,14 @@ concurrency:
jobs:
build-n-publish:
runs-on: ubuntu-latest
if: startsWith(github.event.ref, 'refs/tags')
if: |
startsWith(github.event.ref, 'refs/tags') || github.event_name == 'workflow_dispatch'
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
uses: actions/setup-python@v2
- uses: actions/checkout@v4
- name: Set up Python 3.10.13
uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.10.13
- name: Install wheel
run: pip install wheel
- name: Build MMEngine
@ -27,13 +30,14 @@ jobs:
build-n-publish-lite:
runs-on: ubuntu-latest
if: startsWith(github.event.ref, 'refs/tags')
if: |
startsWith(github.event.ref, 'refs/tags') || github.event_name == 'workflow_dispatch'
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
uses: actions/setup-python@v2
- uses: actions/checkout@v4
- name: Set up Python 3.10.13
uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.10.13
- name: Install wheel
run: pip install wheel
- name: Build MMEngine-lite

View File

@ -11,10 +11,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
- name: Set up Python 3.10.15
uses: actions/setup-python@v2
with:
python-version: 3.7
python-version: '3.10.15'
- name: Install pre-commit hook
run: |
pip install pre-commit

View File

@ -1,5 +1,9 @@
name: pr_stage_test
env:
ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true
on:
pull_request:
paths-ignore:
@ -21,152 +25,114 @@ concurrency:
jobs:
build_cpu:
runs-on: ubuntu-22.04
defaults:
run:
shell: bash -l {0}
strategy:
matrix:
python-version: [3.7]
include:
- torch: 1.8.1
torchvision: 0.9.1
python-version: ['3.9']
torch: ['2.0.0']
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
- name: Check out repo
uses: actions/checkout@v3
- name: Setup conda env
uses: conda-incubator/setup-miniconda@v2
with:
auto-update-conda: true
miniconda-version: "latest"
use-only-tar-bz2: true
activate-environment: test
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
run: python -m pip install pip --upgrade
- name: Upgrade wheel
run: python -m pip install wheel --upgrade
- name: Install PyTorch
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
- name: Build MMEngine from source
run: pip install -e . -v
- name: Install unit tests dependencies
- name: Update pip
run: |
pip install -r requirements/tests.txt
pip install openmim
mim install mmcv
- name: Run unittests and generate coverage report
python -m pip install --upgrade pip wheel
- name: Install dependencies
run: |
coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist
coverage xml
coverage report -m
# Upload coverage report for python3.7 && pytorch1.8.1 cpu
- name: Upload coverage to Codecov
apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6
python -m pip install torch==${{matrix.torch}}
python -m pip install -e . -v
python -m pip install -r requirements/tests.txt
python -m pip install openmim
mim install mmcv coverage
- name: Run unit tests with coverage
run: coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist
- name: Upload Coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
env_vars: OS,PYTHON
name: codecov-umbrella
fail_ci_if_error: false
build_cu102:
build_gpu:
runs-on: ubuntu-22.04
container:
image: pytorch/pytorch:1.8.1-cuda10.2-cudnn7-devel
defaults:
run:
shell: bash -l {0}
env:
MKL_THREADING_LAYER: GNU
strategy:
matrix:
python-version: [3.7]
python-version: ['3.9','3.10']
torch: ['2.0.0','2.3.1','2.5.1']
cuda: ['cu118']
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
- name: Check out repo
uses: actions/checkout@v3
- name: Setup conda env
uses: conda-incubator/setup-miniconda@v2
with:
auto-update-conda: true
miniconda-version: "latest"
use-only-tar-bz2: true
activate-environment: test
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
run: pip install pip --upgrade
- name: Fetch GPG keys
- name: Update pip
run: |
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
- name: Install system dependencies
run: apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6
- name: Build MMEngine from source
run: pip install -e . -v
- name: Install unit tests dependencies
python -m pip install --upgrade pip wheel
- name: Install dependencies
run: |
pip install -r requirements/tests.txt
pip install openmim
mim install mmcv
- name: Run unittests and generate coverage report
run: |
coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist
coverage xml
coverage report -m
apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6
python -m pip install torch==${{matrix.torch}} --index-url https://download.pytorch.org/whl/${{matrix.cuda}}
python -m pip install -e . -v
python -m pip install -r requirements/tests.txt
python -m pip install openmim
mim install mmcv coverage
- name: Run unit tests with coverage
run: coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist
- name: Upload Coverage to Codecov
uses: codecov/codecov-action@v3
build_cu117:
runs-on: ubuntu-22.04
container:
image: pytorch/pytorch:2.0.0-cuda11.7-cudnn8-devel
strategy:
matrix:
python-version: [3.9]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
run: pip install pip --upgrade
- name: Fetch GPG keys
run: |
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
- name: Install system dependencies
run: apt-get update && apt-get install -y git ffmpeg libturbojpeg
- name: Build MMEngine from source
run: pip install -e . -v
- name: Install unit tests dependencies
run: |
pip install -r requirements/tests.txt
pip install openmim
mim install mmcv
# Distributed related unit test may randomly error in PyTorch 1.13.0
- name: Run unittests and generate coverage report
run: |
coverage run --branch --source mmengine -m pytest tests/ --ignore tests/test_dist/
coverage xml
coverage report -m
build_windows:
runs-on: windows-2022
strategy:
matrix:
python-version: [3.7]
platform: [cpu, cu111]
torch: [1.8.1]
torchvision: [0.9.1]
include:
- python-version: 3.8
platform: cu118
torch: 2.1.0
torchvision: 0.16.0
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
# Windows CI could fail If we call `pip install pip --upgrade` directly.
run: python -m pip install pip --upgrade
- name: Install PyTorch
run: pip install torch==${{matrix.torch}}+${{matrix.platform}} torchvision==${{matrix.torchvision}}+${{matrix.platform}} -f https://download.pytorch.org/whl/${{matrix.platform}}/torch_stable.html
- name: Build MMEngine from source
run: pip install -e . -v
- name: Install unit tests dependencies
run: |
pip install -r requirements/tests.txt
pip install openmim
mim install mmcv
- name: Run CPU unittests
run: pytest tests/ --ignore tests/test_dist
if: ${{ matrix.platform == 'cpu' }}
- name: Run GPU unittests
# Skip testing distributed related unit tests since the memory of windows CI is limited
run: pytest tests/ --ignore tests/test_dist --ignore tests/test_optim/test_optimizer/test_optimizer_wrapper.py --ignore tests/test_model/test_wrappers/test_model_wrapper.py --ignore tests/test_hooks/test_sync_buffers_hook.py
if: ${{ matrix.platform == 'cu111' }} || ${{ matrix.platform == 'cu118' }}
# build_windows:
# runs-on: windows-2022
# strategy:
# matrix:
# python-version: [3.9]
# platform: [cpu, cu111]
# torch: [1.8.1]
# torchvision: [0.9.1]
# include:
# - python-version: 3.8
# platform: cu118
# torch: 2.1.0
# torchvision: 0.16.0
# steps:
# - uses: actions/checkout@v3
# - name: Set up Python ${{ matrix.python-version }}
# uses: actions/setup-python@v4
# with:
# python-version: ${{ matrix.python-version }}
# - name: Upgrade pip
# # Windows CI could fail If we call `pip install pip --upgrade` directly.
# run: python -m pip install pip wheel --upgrade
# - name: Install PyTorch
# run: pip install torch==${{matrix.torch}}+${{matrix.platform}} torchvision==${{matrix.torchvision}}+${{matrix.platform}} -f https://download.pytorch.org/whl/${{matrix.platform}}/torch_stable.html
# - name: Build MMEngine from source
# run: pip install -e . -v
# - name: Install unit tests dependencies
# run: |
# pip install -r requirements/tests.txt
# pip install openmim
# mim install mmcv
# - name: Run CPU unittests
# run: pytest tests/ --ignore tests/test_dist
# if: ${{ matrix.platform == 'cpu' }}
# - name: Run GPU unittests
# # Skip testing distributed related unit tests since the memory of windows CI is limited
# run: pytest tests/ --ignore tests/test_dist --ignore tests/test_optim/test_optimizer/test_optimizer_wrapper.py --ignore tests/test_model/test_wrappers/test_model_wrapper.py --ignore tests/test_hooks/test_sync_buffers_hook.py
# if: ${{ matrix.platform == 'cu111' }} || ${{ matrix.platform == 'cu118' }}

View File

@ -1,10 +0,0 @@
assign:
strategy:
# random
daily-shift-based
scedule:
'*/1 * * * *'
assignees:
- zhouzaida
- HAOCHENYE
- C1rN09

View File

@ -1,7 +1,11 @@
exclude: ^tests/data/
repos:
- repo: https://gitee.com/openmmlab/mirrors-flake8
rev: 5.0.4
- repo: https://github.com/pre-commit/pre-commit
rev: v4.0.0
hooks:
- id: validate_manifest
- repo: https://github.com/PyCQA/flake8
rev: 7.1.1
hooks:
- id: flake8
- repo: https://gitee.com/openmmlab/mirrors-isort
@ -13,7 +17,7 @@ repos:
hooks:
- id: yapf
- repo: https://gitee.com/openmmlab/mirrors-pre-commit-hooks
rev: v4.3.0
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: check-yaml
@ -55,7 +59,7 @@ repos:
args: ["mmengine", "tests"]
- id: remove-improper-eol-in-cn-docs
- repo: https://gitee.com/openmmlab/mirrors-mypy
rev: v0.812
rev: v1.2.0
hooks:
- id: mypy
exclude: |-
@ -63,3 +67,4 @@ repos:
^examples
| ^docs
)
additional_dependencies: ["types-setuptools", "types-requests", "types-PyYAML"]

View File

@ -1,7 +1,11 @@
exclude: ^tests/data/
repos:
- repo: https://github.com/pre-commit/pre-commit
rev: v4.0.0
hooks:
- id: validate_manifest
- repo: https://github.com/PyCQA/flake8
rev: 5.0.4
rev: 7.1.1
hooks:
- id: flake8
- repo: https://github.com/PyCQA/isort
@ -13,7 +17,7 @@ repos:
hooks:
- id: yapf
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: check-yaml
@ -34,12 +38,8 @@ repos:
- mdformat-openmmlab
- mdformat_frontmatter
- linkify-it-py
- repo: https://github.com/codespell-project/codespell
rev: v2.2.1
hooks:
- id: codespell
- repo: https://github.com/myint/docformatter
rev: v1.3.1
rev: 06907d0
hooks:
- id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"]
@ -55,7 +55,7 @@ repos:
args: ["mmengine", "tests"]
- id: remove-improper-eol-in-cn-docs
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.812
rev: v1.2.0
hooks:
- id: mypy
exclude: |-
@ -63,3 +63,4 @@ repos:
^examples
| ^docs
)
additional_dependencies: ["types-setuptools", "types-requests", "types-PyYAML"]

View File

@ -1,84 +0,0 @@
# IMPORTANT:
# This file is ONLY used to subscribe for notifications for PRs
# related to a specific file path, and each line is a file pattern followed by
# one or more owners.
# Order is important; the last matching pattern takes the most
# precedence.
# These owners will be the default owners for everything in
# the repo. Unless a later match takes precedence,
# @global-owner1 and @global-owner2 will be requested for
# review when someone opens a pull request.
* @zhouzaida @HAOCHENYE
# Docs
/docs/ @C1rN09
*.rst @zhouzaida @HAOCHENYE
# mmengine file
# config
/mmengine/config/ @HAOCHENYE
# dataset
/mmengine/dataset/ @HAOCHENYE
# device
/mmengine/device/ @zhouzaida
# dist
/mmengine/dist/ @zhouzaida @C1rN09
# evaluator
/mmengine/evaluator/ @RangiLyu @C1rN09
# fileio
/mmengine/fileio/ @zhouzaida
# hooks
/mmengine/hooks/ @zhouzaida @HAOCHENYE
/mmengine/hooks/ema_hook.py @RangiLyu
# hub
/mmengine/hub/ @HAOCHENYE @zhouzaida
# logging
/mmengine/logging/ @HAOCHENYE
# model
/mmengine/model/ @HAOCHENYE @C1rN09
/mmengine/model/averaged_model.py @RangiLyu
/mmengine/model/wrappers/fully_sharded_distributed.py @C1rN09
# optim
/mmengine/optim/ @HAOCHENYE
/mmengine/optim/scheduler/ @RangiLyu
# registry
/mmengine/registry/ @C1rN09 @HAOCHENYE
# runner
/mmengine/runner/ @zhouzaida @RangiLyu @HAOCHENYE
/mmengine/runner/amp.py @HAOCHENYE
/mmengine/runner/log_processor.py @HAOCHENYE
/mmengine/runner/checkpoint.py @zhouzaida @C1rN09
/mmengine/runner/priority.py @zhouzaida
/mmengine/runner/utils.py @zhouzaida @HAOCHENYE
# structure
/mmengine/structures/ @Harold-lkk @HAOCHENYE
# testing
/mmengine/testing/ @zhouzaida
# utils
/mmengine/utils/ @HAOCHENYE @zhouzaida
# visualization
/mmengine/visualization/ @Harold-lkk @HAOCHENYE
# version
/mmengine/__version__.py @zhouzaida
# unit test
/tests/ @zhouzaida @HAOCHENYE

View File

@ -19,13 +19,14 @@
<div>&nbsp;</div>
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/mmengine)](https://pypi.org/project/mmengine/)
[![pytorch](https://img.shields.io/badge/pytorch-1.6~2.1-yellow)](#installation)
[![PyPI](https://img.shields.io/pypi/v/mmengine)](https://pypi.org/project/mmengine)
[![license](https://img.shields.io/github/license/open-mmlab/mmengine.svg)](https://github.com/open-mmlab/mmengine/blob/main/LICENSE)
[![open issues](https://isitmaintained.com/badge/open/open-mmlab/mmengine.svg)](https://github.com/open-mmlab/mmengine/issues)
[![issue resolution](https://isitmaintained.com/badge/resolution/open-mmlab/mmengine.svg)](https://github.com/open-mmlab/mmengine/issues)
[Introduction](#introduction) |
[Installation](#installation) |
[Get Started](#get-started) |
[📘Documentation](https://mmengine.readthedocs.io/en/latest/) |
[🛠Installation](https://mmengine.readthedocs.io/en/latest/get_started/installation.html) |
[🤔Reporting Issues](https://github.com/open-mmlab/mmengine/issues/new/choose)
</div>
@ -58,58 +59,53 @@ English | [简体中文](README_zh-CN.md)
## What's New
v0.10.1 was released on 2023-11-22.
v0.10.6 was released on 2025-01-13.
Highlights:
- Support installing mmengine-lite with no dependency on opencv. Refer to the [Installation](https://mmengine.readthedocs.io/en/latest/get_started/installation.html#install-mmengine) for more details.
- Support custom `artifact_location` in MLflowVisBackend [#1505](#1505)
- Enable `exclude_frozen_parameters` for `DeepSpeedEngine._zero3_consolidated_16bit_state_dict` [#1517](#1517)
- Support training with [ColossalAI](https://colossalai.org/). Refer to the [Training Large Models](https://mmengine.readthedocs.io/en/latest/common_usage/large_model_training.html#colossalai) for more detailed usages.
- Support gradient checkpointing. Refer to the [Save Memory on GPU](https://mmengine.readthedocs.io/en/latest/common_usage/save_gpu_memory.html#gradient-checkpointing) for more details.
- Supports multiple visualization backends, including `NeptuneVisBackend`, `DVCLiveVisBackend` and `AimVisBackend`. Refer to [Visualization Backends](https://mmengine.readthedocs.io/en/latest/common_usage/visualize_training_log.html) for more details.
Read [Changelog](./docs/en/notes/changelog.md#v0101-22112023) for more details.
## Table of Contents
- [Introduction](#introduction)
- [Installation](#installation)
- [Get Started](#get-started)
- [Learn More](#learn-more)
- [Contributing](#contributing)
- [Citation](#citation)
- [License](#license)
- [Ecosystem](#ecosystem)
- [Projects in OpenMMLab](#projects-in-openmmlab)
Read [Changelog](./docs/en/notes/changelog.md#v0104-2342024) for more details.
## Introduction
MMEngine is a foundational library for training deep learning models based on PyTorch. It provides a solid engineering foundation and frees developers from writing redundant codes on workflows. It serves as the training engine of all OpenMMLab codebases, which support hundreds of algorithms in various research areas. Moreover, MMEngine is also generic to be applied to non-OpenMMLab projects.
MMEngine is a foundational library for training deep learning models based on PyTorch. It serves as the training engine of all OpenMMLab codebases, which support hundreds of algorithms in various research areas. Moreover, MMEngine is also generic to be applied to non-OpenMMLab projects. Its highlights are as follows:
Major features:
**Integrate mainstream large-scale model training frameworks**
1. **A universal and powerful runner**:
- [ColossalAI](https://mmengine.readthedocs.io/en/latest/common_usage/large_model_training.html#colossalai)
- [DeepSpeed](https://mmengine.readthedocs.io/en/latest/common_usage/large_model_training.html#deepspeed)
- [FSDP](https://mmengine.readthedocs.io/en/latest/common_usage/large_model_training.html#fullyshardeddataparallel-fsdp)
- Supports training different tasks with a small amount of code, e.g., ImageNet can be trained with only 80 lines of code (400 lines of the original PyTorch example).
- Easily compatible with models from popular algorithm libraries such as TIMM, TorchVision, and Detectron2.
**Supports a variety of training strategies**
2. **Open architecture with unified interfaces**:
- [Mixed Precision Training](https://mmengine.readthedocs.io/en/latest/common_usage/speed_up_training.html#mixed-precision-training)
- [Gradient Accumulation](https://mmengine.readthedocs.io/en/latest/common_usage/save_gpu_memory.html#gradient-accumulation)
- [Gradient Checkpointing](https://mmengine.readthedocs.io/en/latest/common_usage/save_gpu_memory.html#gradient-checkpointing)
- Handles different algorithm tasks with unified APIs, e.g., implement a method and apply it to all compatible models.
- Provides a unified abstraction for upper-level algorithm libraries, which supports various back-end devices such as Nvidia CUDA, Mac MPS, AMD, MLU, and more for model training.
**Provides a user-friendly configuration system**
3. **Customizable training process**:
- [Pure Python-style configuration files, easy to navigate](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta)
- [Plain-text-style configuration files, supporting JSON and YAML](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html)
- Defines the training process just like playing with Legos.
- Provides rich components and strategies.
- Complete controls on the training process with different levels of APIs.
**Covers mainstream training monitoring platforms**
![mmengine_dataflow](https://github.com/open-mmlab/mmengine/assets/58739961/267db9cb-72e4-4af2-a58b-877b30091acc)
- [TensorBoard](https://mmengine.readthedocs.io/en/latest/common_usage/visualize_training_log.html#tensorboard) | [WandB](https://mmengine.readthedocs.io/en/latest/common_usage/visualize_training_log.html#wandb) | [MLflow](https://mmengine.readthedocs.io/en/latest/common_usage/visualize_training_log.html#mlflow-wip)
- [ClearML](https://mmengine.readthedocs.io/en/latest/common_usage/visualize_training_log.html#clearml) | [Neptune](https://mmengine.readthedocs.io/en/latest/common_usage/visualize_training_log.html#neptune) | [DVCLive](https://mmengine.readthedocs.io/en/latest/common_usage/visualize_training_log.html#dvclive) | [Aim](https://mmengine.readthedocs.io/en/latest/common_usage/visualize_training_log.html#aim)
## Installation
<details>
<summary>Supported PyTorch Versions</summary>
| MMEngine | PyTorch | Python |
| ------------------ | ------------ | -------------- |
| main | >=1.6 \<=2.1 | >=3.8, \<=3.11 |
| >=0.9.0, \<=0.10.4 | >=1.6 \<=2.1 | >=3.8, \<=3.11 |
</details>
Before installing MMEngine, please ensure that PyTorch has been successfully installed following the [official guide](https://pytorch.org/get-started/locally/).
Install MMEngine

View File

@ -19,13 +19,14 @@
<div>&nbsp;</div>
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/mmengine)](https://pypi.org/project/mmengine/)
[![pytorch](https://img.shields.io/badge/pytorch-1.6~2.1-yellow)](#安装)
[![PyPI](https://img.shields.io/pypi/v/mmengine)](https://pypi.org/project/mmengine)
[![license](https://img.shields.io/github/license/open-mmlab/mmengine.svg)](https://github.com/open-mmlab/mmengine/blob/main/LICENSE)
[![open issues](https://isitmaintained.com/badge/open/open-mmlab/mmengine.svg)](https://github.com/open-mmlab/mmengine/issues)
[![issue resolution](https://isitmaintained.com/badge/resolution/open-mmlab/mmengine.svg)](https://github.com/open-mmlab/mmengine/issues)
[📘使用文档](https://mmengine.readthedocs.io/zh_CN/latest/) |
[🛠️安装教程](https://mmengine.readthedocs.io/zh_CN/latest/get_started/installation.html) |
[简介](#简介) |
[安装](#安装) |
[快速上手](#快速上手) |
[📘用户文档](https://mmengine.readthedocs.io/zh_CN/latest/) |
[🤔报告问题](https://github.com/open-mmlab/mmengine/issues/new/choose)
</div>
@ -58,59 +59,58 @@
## 最近进展
最新版本 v0.10.1 在 2023.11.22 发布。
最新版本 v0.10.5 在 2024.9.11 发布。
亮点:
版本亮点:
- 支持安装不依赖于 opencv 的 mmengine-lite 版本。可阅读[安装文档](https://mmengine.readthedocs.io/zh-cn/latest/get_started/installation.html#mmengine)了解用法。
- 支持在 MLFlowVisBackend 中自定义 `artifact_location` [#1505](#1505)
- 支持在 `DeepSpeedEngine._zero3_consolidated_16bit_state_dict` 使用 `exclude_frozen_parameters` [#1517](#1517)
- 支持使用 [ColossalAI](https://colossalai.org/) 进行训练。可阅读[大模型训练](https://mmengine.readthedocs.io/zh_CN/latest/common_usage/large_model_training.html#colossalai)了解用法。
- 支持梯度检查点。详见[用法](https://mmengine.readthedocs.io/zh_CN/latest/common_usage/save_gpu_memory.html#id3)。
- 支持多种可视化后端,包括`NeptuneVisBackend`、`DVCLiveVisBackend` 和 `AimVisBackend`。可阅读[可视化后端](https://mmengine.readthedocs.io/zh_CN/latest/common_usage/visualize_training_log.html)了解用法。
如果想了解更多版本更新细节和历史信息,请阅读[更新日志](./docs/en/notes/changelog.md#v0101-22112023)
## 目录
- [简介](#简介)
- [安装](#安装)
- [快速上手](#快速上手)
- [了解更多](#了解更多)
- [贡献指南](#贡献指南)
- [引用](#引用)
- [开源许可证](#开源许可证)
- [生态项目](#生态项目)
- [OpenMMLab 的其他项目](#openmmlab-的其他项目)
- [欢迎加入 OpenMMLab 社区](#欢迎加入-openmmlab-社区)
如果想了解更多版本更新细节和历史信息,请阅读[更新日志](./docs/en/notes/changelog.md#v0104-2342024)。
## 简介
MMEngine 是一个基于 PyTorch 实现的,用于训练深度学习模型的基础库。它为开发人员提供了坚实的工程基础,以此避免在工作流上编写冗余代码。作为 OpenMMLab 所有代码库的训练引擎其在不同研究领域支持了上百个算法。此外MMEngine 也可以用于非 OpenMMLab 项目中。
MMEngine 是一个基于 PyTorch 实现的,用于训练深度学习模型的基础库。它作为 OpenMMLab 所有代码库的训练引擎其在不同研究领域支持了上百个算法。此外MMEngine 也可以用于非 OpenMMLab 项目中。它的亮点如下:
主要特性:
**集成主流的大模型训练框架**
1. **通用且强大的执行器**
- [ColossalAI](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/large_model_training.html#colossalai)
- [DeepSpeed](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/large_model_training.html#deepspeed)
- [FSDP](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/large_model_training.html#fullyshardeddataparallel-fsdp)
- 支持用少量代码训练不同的任务,例如仅使用 80 行代码就可以训练 ImageNet原始 PyTorch 示例需要 400 行)。
- 轻松兼容流行的算法库(如 TIMM、TorchVision 和 Detectron2中的模型。
**支持丰富的训练策略**
2. **接口统一的开放架构**
- [混合精度训练Mixed Precision Training](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/speed_up_training.html#id3)
- [梯度累积Gradient Accumulation](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/save_gpu_memory.html#id2)
- [梯度检查点Gradient Checkpointing](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/save_gpu_memory.html#id3)
- 使用统一的接口处理不同的算法任务,例如,实现一个方法并应用于所有的兼容性模型。
- 上下游的对接更加统一便捷,在为上层算法库提供统一抽象的同时,支持多种后端设备。目前 MMEngine 支持 Nvidia CUDA、Mac MPS、AMD、MLU 等设备进行模型训练。
**提供易用的配置系统**
3. **可定制的训练流程**
- [纯 Python 风格的配置文件,易于跳转](https://mmengine.readthedocs.io/zh-cn/latest/advanced_tutorials/config.html#python-beta)
- [纯文本风格的配置文件,支持 JSON 和 YAML](https://mmengine.readthedocs.io/zh-cn/latest/advanced_tutorials/config.html#id1)
- 定义了“乐高”式的训练流程。
- 提供了丰富的组件和策略。
- 使用不同等级的 API 控制训练过程。
**覆盖主流的训练监测平台**
![mmengine_dataflow](https://github.com/open-mmlab/mmengine/assets/58739961/267db9cb-72e4-4af2-a58b-877b30091acc)
- [TensorBoard](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/visualize_training_log.html#tensorboard) | [WandB](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/visualize_training_log.html#wandb) | [MLflow](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/visualize_training_log.html#mlflow-wip)
- [ClearML](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/visualize_training_log.html#clearml) | [Neptune](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/visualize_training_log.html#neptune) | [DVCLive](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/visualize_training_log.html#dvclive) | [Aim](https://mmengine.readthedocs.io/zh-cn/latest/common_usage/visualize_training_log.html#aim)
**兼容主流的训练芯片**
- 英伟达 CUDA | 苹果 MPS
- 华为 Ascend | 寒武纪 MLU | 摩尔线程 MUSA
## 安装
<details>
<summary>支持的 PyTorch 版本</summary>
| MMEngine | PyTorch | Python |
| ------------------ | ------------ | -------------- |
| main | >=1.6 \<=2.1 | >=3.8, \<=3.11 |
| >=0.9.0, \<=0.10.4 | >=1.6 \<=2.1 | >=3.8, \<=3.11 |
</details>
在安装 MMEngine 之前,请确保 PyTorch 已成功安装在环境中,可以参考 [PyTorch 官方安装文档](https://pytorch.org/get-started/locally/)。
安装 MMEngine

View File

@ -3,6 +3,9 @@ version: 2
formats:
- epub
sphinx:
configuration: docs/en/conf.py
# Set the version of Python and other tools you might need
build:
os: ubuntu-22.04

View File

@ -1008,7 +1008,7 @@ In this section, we use MMDetection to demonstrate how to migrate the abstract d
### 1. Simplify the module interface
Detector's external interfaces can be significantly simplified and unified. In the training process of a single-stage detection and segmentation algorithm in MMDet 2.X, `SingleStageDetector` requires `img`, `img_metas`, `gt_bboxes` `gt_labels` and `gt_bboxes_ignore` as the inputs, but `SingleStageInstanceSegmentor` requires `gt_masks` as well. This causes inconsistency in the training interface and affects flexibility.
Detector's external interfaces can be significantly simplified and unified. In the training process of a single-stage detection and segmentation algorithm in MMDet 2.X, `SingleStageDetector` requires `img`, `img_metas`, `gt_bboxes`, `gt_labels` and `gt_bboxes_ignore` as the inputs, but `SingleStageInstanceSegmentor` requires `gt_masks` as well. This causes inconsistency in the training interface and affects flexibility.
```python
class SingleStageDetector(BaseDetector):

View File

@ -4,7 +4,7 @@ This document provides some third-party optimizers supported by MMEngine, which
## D-Adaptation
[D-Adaptation](https://github.com/facebookresearch/dadaptation) provides `DAdaptAdaGrad`, `DAdaptAdam` and `DAdaptSGD` optimziers。
[D-Adaptation](https://github.com/facebookresearch/dadaptation) provides `DAdaptAdaGrad`, `DAdaptAdam` and `DAdaptSGD` optimizers.
```{note}
If you use the optimizer provided by D-Adaptation, you need to upgrade mmengine to `0.6.0`.
@ -35,7 +35,7 @@ runner.train()
## Lion-Pytorch
[lion-pytorch](https://github.com/lucidrains/lion-pytorch) provides the `Lion` optimizer
[lion-pytorch](https://github.com/lucidrains/lion-pytorch) provides the `Lion` optimizer.
```{note}
If you use the optimizer provided by Lion-Pytorch, you need to upgrade mmengine to `0.6.0`.
@ -93,7 +93,7 @@ runner.train()
## bitsandbytes
[bitsandbytes](https://github.com/TimDettmers/bitsandbytes) provides `AdamW8bit`, `Adam8bit`, `Adagrad8bit`, `PagedAdam8bit`, `PagedAdamW8bit`, `LAMB8bit`, `LARS8bit`, `RMSprop8bit`, `Lion8bit`, `PagedLion8bit` and `SGD8bit` optimziers。
[bitsandbytes](https://github.com/TimDettmers/bitsandbytes) provides `AdamW8bit`, `Adam8bit`, `Adagrad8bit`, `PagedAdam8bit`, `PagedAdamW8bit`, `LAMB8bit`, `LARS8bit`, `RMSprop8bit`, `Lion8bit`, `PagedLion8bit` and `SGD8bit` optimizers.
```{note}
If you use the optimizer provided by bitsandbytes, you need to upgrade mmengine to `0.9.0`.
@ -124,7 +124,7 @@ runner.train()
## transformers
[transformers](https://github.com/huggingface/transformers) provides `Adafactor` optimzier
[transformers](https://github.com/huggingface/transformers) provides `Adafactor` optimzier.
```{note}
If you use the optimizer provided by transformers, you need to upgrade mmengine to `0.9.0`.

View File

@ -30,7 +30,7 @@ train_dataloader = dict(
type=dataset_type,
data_prefix='data/cifar10',
test_mode=False,
indices=5000, # set indices=5000represent every epoch only iterator 5000 samples
indices=5000, # set indices=5000, represent every epoch only iterator 5000 samples
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)

View File

@ -26,7 +26,7 @@ On the first machine:
```bash
python -m torch.distributed.launch \
--nnodes 8 \
--nnodes 2 \
--node_rank 0 \
--master_addr 127.0.0.1 \
--master_port 29500 \
@ -38,9 +38,9 @@ On the second machine:
```bash
python -m torch.distributed.launch \
--nnodes 8 \
--nnodes 2 \
--node_rank 1 \
--master_addr 127.0.0.1 \
--master_addr "ip_of_the_first_machine" \
--master_port 29500 \
--nproc_per_node=8 \
examples/distributed_training.py --launcher pytorch

View File

@ -87,10 +87,10 @@ OpenMMLab requires the `inferencer(img)` to output a `dict` containing two field
When performing inference, the following steps are typically executed:
1. preprocessInput data preprocessing, including data reading, data preprocessing, data format conversion, etc.
1. preprocess: Input data preprocessing, including data reading, data preprocessing, data format conversion, etc.
2. forward: Execute `model.forwward`
3. visualizeVisualization of predicted results.
4. postprocessPost-processing of predicted results, including result format conversion, exporting predicted results, etc.
3. visualize: Visualization of predicted results.
4. postprocess: Post-processing of predicted results, including result format conversion, exporting predicted results, etc.
To improve the user experience of the inferencer, we do not want users to have to configure parameters for each step when performing inference. In other words, we hope that users can simply configure parameters for the `__call__` interface without being aware of the above process and complete the inference.
@ -173,8 +173,8 @@ Initializes and returns the `visualizer` required by the inferencer, which is eq
Input arguments:
- inputsInput data, passed into `__call__`, usually a list of image paths or image data.
- batch_sizebatch size, passed in by the user when calling `__call__`.
- inputs: Input data, passed into `__call__`, usually a list of image paths or image data.
- batch_size: batch size, passed in by the user when calling `__call__`.
- Other parameters: Passed in by the user and specified in `preprocess_kwargs`.
Return:
@ -187,7 +187,7 @@ The `preprocess` function is a generator function by default, which applies the
Input arguments:
- inputsThe batch data processed by `preprocess` function.
- inputs: The batch data processed by `preprocess` function.
- Other parameters: Passed in by the user and specified in `forward_kwargs`.
Return:
@ -204,9 +204,9 @@ This is an abstract method that must be implemented by the subclass.
Input arguments:
- inputsThe input data, which is the raw data without preprocessing.
- predsPredicted results of the model.
- showWhether to visualize.
- inputs: The input data, which is the raw data without preprocessing.
- preds: Predicted results of the model.
- show: Whether to visualize.
- Other parameters: Passed in by the user and specified in `visualize_kwargs`.
Return:
@ -221,12 +221,12 @@ This is an abstract method that must be implemented by the subclass.
Input arguments:
- predsThe predicted results of the model, which is a `list` type. Each element in the list represents the prediction result for a single data item. In the OpenMMLab series of algorithm libraries, the type of each element in the prediction result is `BaseDataElement`.
- visualizationVisualization results
- preds: The predicted results of the model, which is a `list` type. Each element in the list represents the prediction result for a single data item. In the OpenMMLab series of algorithm libraries, the type of each element in the prediction result is `BaseDataElement`.
- visualization: Visualization results
- return_datasample: Whether to maintain datasample for return. When set to `False`, the returned result is converted to a `dict`.
- Other parameters: Passed in by the user and specified in `postprocess_kwargs`.
Return
Return:
- The type of the returned value is a dictionary containing both the visualization and prediction results. OpenMMLab requires the returned dictionary to have two keys: `predictions` and `visualization`.
@ -234,9 +234,9 @@ Return
Input arguments:
- inputsThe input data, usually a list of image paths or image data. Each element in `inputs` can also be other types of data as long as it can be processed by the `pipeline` returned by [init_pipeline](#_init_pipeline). When there is only one inference data in `inputs`, it does not have to be a `list`, `__call__` will internally wrap it into a list for further processing.
- inputs: The input data, usually a list of image paths or image data. Each element in `inputs` can also be other types of data as long as it can be processed by the `pipeline` returned by [init_pipeline](#_init_pipeline). When there is only one inference data in `inputs`, it does not have to be a `list`, `__call__` will internally wrap it into a list for further processing.
- return_datasample: Whether to convert datasample to dict for return.
- batch_sizeBatch size for inference, which will be further passed to the `preprocess` function.
- batch_size: Batch size for inference, which will be further passed to the `preprocess` function.
- Other parameters: Additional parameters assigned to `preprocess`, `forward`, `visualize`, and `postprocess` methods.
Return:

View File

@ -74,11 +74,11 @@ history_buffer.min()
# 1, the global minimum
history_buffer.max(2)
# 3the maximum in [2, 3]
# 3, the maximum in [2, 3]
history_buffer.min()
# 3, the global maximum
history_buffer.mean(2)
# 2.5the mean value in [2, 3], (2 + 3) / (1 + 1)
# 2.5, the mean value in [2, 3], (2 + 3) / (1 + 1)
history_buffer.mean()
# 2, the global mean, (1 + 2 + 3) / (1 + 1 + 1)
history_buffer = HistoryBuffer([1, 2, 3], [2, 2, 2]) # Cases when counts are not 1
@ -431,7 +431,7 @@ In the case of multiple processes in multiple nodes without storage, logs are or
```text
# without shared storage
# node 0
# node 0:
work_dir/20230228_141908
├── 20230306_183634_${hostname}_device0_rank0.log
├── 20230306_183634_${hostname}_device1_rank1.log
@ -442,7 +442,7 @@ work_dir/20230228_141908
├── 20230306_183634_${hostname}_device6_rank6.log
├── 20230306_183634_${hostname}_device7_rank7.log
# node 7
# node 7:
work_dir/20230228_141908
├── 20230306_183634_${hostname}_device0_rank56.log
├── 20230306_183634_${hostname}_device1_rank57.log

View File

@ -1,6 +1,6 @@
# 15 minutes to get started with MMEngine
In this tutorial, we'll take training a ResNet-50 model on CIFAR-10 dataset as an example. We will build a complete and configurable pipeline for both training and validation in only 80 lines of code with `MMEgnine`.
In this tutorial, we'll take training a ResNet-50 model on CIFAR-10 dataset as an example. We will build a complete and configurable pipeline for both training and validation in only 80 lines of code with `MMEngine`.
The whole process includes the following steps:
- [15 minutes to get started with MMEngine](#15-minutes-to-get-started-with-mmengine)

View File

@ -1,30 +1,29 @@
# Introduction
MMEngine is a foundational library for training deep learning models based on
PyTorch. It supports running on Linux, Windows, and macOS. It has the
following three features:
PyTorch. It supports running on Linux, Windows, and macOS. Its highlights are as follows:
1. **Universal and powerful executor**:
**Integrate mainstream large-scale model training frameworks**
- Supports training different tasks with minimal code, such as training
ImageNet with just 80 lines of code (original PyTorch examples require
400 lines).
- Easily compatible with models from popular algorithm libraries like TIMM,
TorchVision, and Detectron2.
- [ColossalAI](../common_usage/large_model_training.md#colossalai)
- [DeepSpeed](../common_usage/large_model_training.md#deepspeed)
- [FSDP](../common_usage/large_model_training.md#fullyshardeddataparallel-fsdp)
2. **Open architecture with unified interfaces**:
**Supports a variety of training strategies**
- Handles different tasks with a unified API: you can implement a method
once and apply it to all compatible models.
- Supports various backend devices through a simple, high-level
abstraction. Currently, MMEngine supports model training on Nvidia CUDA,
Mac MPS, AMD, MLU, and other devices.
- [Mixed Precision Training](../common_usage/speed_up_training.md#mixed-precision-training)
- [Gradient Accumulation](../common_usage/save_gpu_memory.md#gradient-accumulation)
- [Gradient Checkpointing](../common_usage/save_gpu_memory.md#gradient-checkpointing)
3. **Customizable training process**:
**Provides a user-friendly configuration system**
- Defines a highly modular training engine with "Lego"-like composability.
- Offers a rich set of components and strategies.
- Total control over the training process with different levels of APIs.
- [Pure Python-style configuration files, easy to navigate](../advanced_tutorials/config.md#a-pure-python-style-configuration-file-beta)
- [Plain-text-style configuration files, supporting JSON and YAML](../advanced_tutorials/config.html)
**Covers mainstream training monitoring platforms**
- [TensorBoard](../common_usage/visualize_training_log.md#tensorboard) | [WandB](../common_usage/visualize_training_log.md#wandb) | [MLflow](../common_usage/visualize_training_log.md#mlflow-wip)
- [ClearML](../common_usage/visualize_training_log.md#clearml) | [Neptune](../common_usage/visualize_training_log.md#neptune) | [DVCLive](../common_usage/visualize_training_log.md#dvclive) | [Aim](../common_usage/visualize_training_log.md#aim)
## Architecture

View File

@ -156,7 +156,7 @@ This tutorial compares the difference in function, mount point, usage and implem
<tr>
<td>after each iteration</td>
<td>after_train_iter</td>
<td>after_train_iter, with additional args: batch_idxdata_batch, and outputs</td>
<td>after_train_iter, with additional args: batch_idx, data_batch, and outputs</td>
</tr>
<tr>
<td rowspan="6">Validation related</td>
@ -187,7 +187,7 @@ This tutorial compares the difference in function, mount point, usage and implem
<tr>
<td>after each iteration</td>
<td>after_val_iter</td>
<td>after_val_iter, with additional args: batch_idxdata_batch and outputs</td>
<td>after_val_iter, with additional args: batch_idx, data_batch and outputs</td>
</tr>
<tr>
<td rowspan="6">Test related</td>
@ -218,7 +218,7 @@ This tutorial compares the difference in function, mount point, usage and implem
<tr>
<td>after each iteration</td>
<td>None</td>
<td>after_test_iter, with additional args: batch_idxdata_batch and outputs</td>
<td>after_test_iter, with additional args: batch_idx, data_batch and outputs</td>
</tr>
</tbody>
</table>

View File

@ -393,7 +393,7 @@ MMCV will wrap the model with distributed wrapper before building the runner, wh
cfg = dict(model_wrapper_cfg='MMSeparateDistributedDataParallel')
runner = Runner(
model=model,
..., # 其他配置
...,
launcher='pytorch',
cfg=cfg)
```

View File

@ -1,5 +1,68 @@
# Changelog of v0.x
## v0.10.5 (11/9/2024)
- Fix `_is_builtin_module`. by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/1571
## v0.10.4 (23/4/2024)
### New Features & Enhancements
- Support custom `artifact_location` in MLflowVisBackend. by [@daavoo](https://github.com/daavoo) in https://github.com/open-mmlab/mmengine/pull/1505
- Add the supported pytorch versions in README by [@zhouzaida](https://github.com/zhouzaida) in https://github.com/open-mmlab/mmengine/pull/1512
- Perform evaluation upon training completion by [@LZHgrla](https://github.com/LZHgrla) in https://github.com/open-mmlab/mmengine/pull/1529
- Enable `exclude_frozen_parameters` for `DeepSpeedEngine._zero3_consolidated_16bit_state_dict` by [@LZHgrla](https://github.com/LZHgrla) in https://github.com/open-mmlab/mmengine/pull/1517
### Bug Fixes
- Fix warning capture by [@fanqiNO1](https://github.com/fanqiNO1) in https://github.com/open-mmlab/mmengine/pull/1494
- Remove codeowners file by [@zhouzaida](https://github.com/zhouzaida) in https://github.com/open-mmlab/mmengine/pull/1496
- Fix config of readthedocs by [@zhouzaida](https://github.com/zhouzaida) in https://github.com/open-mmlab/mmengine/pull/1511
- Delete frozen parameters when using `paramwise_cfg` by [@LZHgrla](https://github.com/LZHgrla) in https://github.com/open-mmlab/mmengine/pull/1441
### Docs
- Refine mmengine intro by [@zhouzaida](https://github.com/zhouzaida) in https://github.com/open-mmlab/mmengine/pull/1479
- Fix typo by [@zhouzaida](https://github.com/zhouzaida) in https://github.com/open-mmlab/mmengine/pull/1481
- Fix typos and remove fullwidth unicode chars by [@evdcush](https://github.com/evdcush) in https://github.com/open-mmlab/mmengine/pull/1488
- Fix docstring of Config by [@MambaWong](https://github.com/MambaWong) in https://github.com/open-mmlab/mmengine/pull/1506
- Fix typo by [@hiramf](https://github.com/hiramf) in https://github.com/open-mmlab/mmengine/pull/1532
## v0.10.3 (24/1/2024)
### New Features & Enhancements
- Add the support for musa device support by [@hanhaowen-mt](https://github.com/hanhaowen-mt) in https://github.com/open-mmlab/mmengine/pull/1453
- Support `save_optimizer=False` for DeepSpeed by [@LZHgrla](https://github.com/LZHgrla) in https://github.com/open-mmlab/mmengine/pull/1474
- Update visualizer.py by [@Anm-pinellia](https://github.com/Anm-pinellia) in https://github.com/open-mmlab/mmengine/pull/1476
### Bug Fixes
- Fix `Config.to_dict` by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/1465
- Fix the resume of iteration by [@LZHgrla](https://github.com/LZHgrla) in https://github.com/open-mmlab/mmengine/pull/1471
- Fix `dist.collect_results` to keep all ranks' elements by [@LZHgrla](https://github.com/LZHgrla) in https://github.com/open-mmlab/mmengine/pull/1469
### Docs
- Add the usage of ProfilerHook by [@zhouzaida](https://github.com/zhouzaida) in https://github.com/open-mmlab/mmengine/pull/1466
- Fix the nnodes in the doc of ddp training by [@XiwuChen](https://github.com/XiwuChen) in https://github.com/open-mmlab/mmengine/pull/1462
## v0.10.2 (26/12/2023)
### New Features & Enhancements
- Support multi-node distributed training with NPU backend by [@shun001](https://github.com/shun001) in https://github.com/open-mmlab/mmengine/pull/1459
- Use `ImportError` to cover `ModuleNotFoundError` by [@del-zhenwu](https://github.com/del-zhenwu) in https://github.com/open-mmlab/mmengine/pull/1438
### Bug Fixes
- Fix bug in `load_model_state_dict` of `BaseStrategy` by [@SCZwangxiao](https://github.com/SCZwangxiao) in https://github.com/open-mmlab/mmengine/pull/1447
- Fix placement policy in ColossalAIStrategy by [@fanqiNO1](https://github.com/fanqiNO1) in https://github.com/open-mmlab/mmengine/pull/1440
### Contributors
A total of 4 developers contributed to this release. Thanks [@shun001](https://github.com/shun001), [@del-zhenwu](https://github.com/del-zhenwu), [@SCZwangxiao](https://github.com/SCZwangxiao), [@fanqiNO1](https://github.com/fanqiNO1)
## v0.10.1 (22/11/2023)
### Bug Fixes
@ -652,7 +715,7 @@ A total of 16 developers contributed to this release. Thanks [@BayMaxBHL](https:
### Bug Fixes
- Fix error calculation of `eta_min` in `CosineRestartParamScheduler` by [@Z-Fran](https://github.com/Z-Fran) in https://github.com/open-mmlab/mmengine/pull/639
- Fix `BaseDataPreprocessor.cast_data` could not handle string data by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/602
- Fix `BaseDataPreprocessor.cast_data` could not handle string data by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/602
- Make `autocast` compatible with mps by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/587
- Fix error format of log message by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/508
- Fix error implementation of `is_model_wrapper` by [@HAOCHENYE](https://github.com/HAOCHENYE) in https://github.com/open-mmlab/mmengine/pull/640

View File

@ -31,11 +31,12 @@ Each hook has a corresponding priority. At each mount point, hooks with higher p
**custom hooks**
| Name | Function | Priority |
| :---------------------------------: | :----------------------------------------------------------------------: | :---------: |
| [EMAHook](#emahook) | apply Exponential Moving Average (EMA) on the model during training | NORMAL (50) |
| [EmptyCacheHook](#emptycachehook) | Releases all unoccupied cached GPU memory during the process of training | NORMAL (50) |
| [SyncBuffersHook](#syncbuffershook) | Synchronize model buffers at the end of each epoch | NORMAL (50) |
| Name | Function | Priority |
| :---------------------------------: | :----------------------------------------------------------------------: | :-----------: |
| [EMAHook](#emahook) | Apply Exponential Moving Average (EMA) on the model during training | NORMAL (50) |
| [EmptyCacheHook](#emptycachehook) | Releases all unoccupied cached GPU memory during the process of training | NORMAL (50) |
| [SyncBuffersHook](#syncbuffershook) | Synchronize model buffers at the end of each epoch | NORMAL (50) |
| [ProfilerHook](#profilerhook) | Analyze the execution time and GPU memory usage of model operators | VERY_LOW (90) |
```{note}
It is not recommended to modify the priority of the default hooks, as hooks with lower priority may depend on hooks with higher priority. For example, `CheckpointHook` needs to have a lower priority than ParamSchedulerHook so that the saved optimizer state is correct. Also, the priority of custom hooks defaults to `NORMAL (50)`.
@ -211,6 +212,20 @@ runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()
```
### ProfilerHook
The [ProfilerHook](mmengine.hooks.ProfilerHook) is used to analyze the execution time and GPU memory occupancy of model operators.
```python
custom_hooks = [dict(type='ProfilerHook', on_trace_ready=dict(type='tb_trace'))]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()
```
The profiling results will be saved in the tf_tracing_logs directory under `work_dirs/{timestamp}`, and can be visualized using TensorBoard with the command `tensorboard --logdir work_dirs/{timestamp}/tf_tracing_logs`.
For more information on the usage of the ProfilerHook, please refer to the [ProfilerHook](mmengine.hooks.ProfilerHook) documentation.
## Customize Your Hooks
If the built-in hooks provided by MMEngine do not cover your demands, you are encouraged to customize your own hooks by simply inheriting the base [hook](mmengine.hooks.Hook) class and overriding the corresponding mount point methods.

View File

@ -43,7 +43,7 @@ Usually, we should define a model to implement the body of the algorithm. In MME
Benefits from the `BaseModel`, we only need to make the model inherit from `BaseModel`, and implement the `forward` function to perform the training, testing, and validation process.
```{note}
BaseModel inherits from [BaseModule](../advanced_tutorials/initialize.md)which can be used to initialize the model parameters dynamically.
BaseModel inherits from [BaseModule](../advanced_tutorials/initialize.md), which can be used to initialize the model parameters dynamically.
```
[**forward**](mmengine.model.BaseModel.forward): The arguments of `forward` need to match with the data given by [DataLoader](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). If the DataLoader samples a tuple `data`, `forward` needs to accept the value of unpacked `*data`. If DataLoader returns a dict `data`, `forward` needs to accept the key-value of unpacked `**data`. `forward` also accepts `mode` parameter, which is used to control the running branch:

View File

@ -40,7 +40,7 @@ for epoch in range(10):
`mmengine.optim.scheduler` supports most of PyTorch's learning rate schedulers such as `ExponentialLR`, `LinearLR`, `StepLR`, `MultiStepLR`, etc. Please refer to [parameter scheduler API documentation](https://mmengine.readthedocs.io/en/latest/api/optim.html#scheduler) for all of the supported schedulers.
MMEngine also supports adjusting momentum with parameter schedulers. To use momentum schedulers, replace `LR` in the class name to `Momentum`, such as `ExponentialMomentum``LinearMomentum`. Further, we implement the general parameter scheduler ParamScheduler, which is used to adjust the specified hyperparameters in the optimizer, such as weight_decay, etc. This feature makes it easier to apply some complex hyperparameter tuning strategies.
MMEngine also supports adjusting momentum with parameter schedulers. To use momentum schedulers, replace `LR` in the class name to `Momentum`, such as `ExponentialMomentum`, `LinearMomentum`. Further, we implement the general parameter scheduler ParamScheduler, which is used to adjust the specified hyperparameters in the optimizer, such as weight_decay, etc. This feature makes it easier to apply some complex hyperparameter tuning strategies.
Different from the above example, MMEngine usually does not need to manually implement the training loop and call `optimizer.step()`. The runner will automatically manage the training progress and control the execution of the parameter scheduler through `ParamSchedulerHook`.

View File

@ -20,7 +20,7 @@ Pros and cons lie in both approaches. For the former one, beginners may be lost
We argue that the key to learning runner is using it as a memo. You should remember its most commonly used arguments and only focus on those less used when in need, since default values usually work fine. In the following, we will provide a beginner-friendly example to illustrate the most commonly used arguments of the runner, along with advanced guidelines for those less used.
### A beginer-friendly example
### A beginner-friendly example
```{hint}
In this tutorial, we hope you can focus more on overall architecture instead of implementation details. This "top-down" way of thinking is exactly what we advocate. Don't worry, you will definitely have plenty of opportunities and guidance afterward to focus on modules you want to improve.

View File

@ -0,0 +1,19 @@
version: 2
formats:
- epub
sphinx:
configuration: docs/zh_cn/conf.py
# Set the version of Python and other tools you might need
build:
os: ubuntu-22.04
tools:
python: "3.10"
python:
install:
- requirements: requirements/runtime.txt
- requirements: requirements/docs.txt
- requirements: requirements/docs_extra.txt

View File

@ -26,7 +26,7 @@ CUDA_VISIBLE_DEVICES=0,3 python -m torch.distributed.launch --nproc_per_node=2 e
```bash
python -m torch.distributed.launch \
--nnodes 8 \
--nnodes 2 \
--node_rank 0 \
--master_addr 127.0.0.1 \
--master_port 29500 \
@ -38,9 +38,9 @@ python -m torch.distributed.launch \
```bash
python -m torch.distributed.launch \
--nnodes 8 \
--nnodes 2 \
--node_rank 1 \
--master_addr 127.0.0.1 \
--master_addr "ip_of_the_first_machine" \
--master_port 29500 \
--nproc_per_node=8 \
examples/distributed_training.py --launcher pytorch

View File

@ -1,22 +1,33 @@
# 介绍
MMEngine 是一个基于 PyTorch 实现的,用于训练深度学习模型的基础库,支持在 Linux、Windows、macOS 上运行。它具有如下三个特性
MMEngine 是一个基于 PyTorch 实现的,用于训练深度学习模型的基础库,支持在 Linux、Windows、macOS 上运行。它的亮点如下
1. **通用且强大的执行器**
**集成主流的大模型训练框架**
- 支持用少量代码训练不同的任务,例如仅使用 80 行代码就可以训练 ImageNet原始 PyTorch 示例需要 400 行)。
- 轻松兼容流行的算法库(如 TIMM、TorchVision 和 Detectron2中的模型。
- [ColossalAI](../common_usage/large_model_training.md#colossalai)
- [DeepSpeed](../common_usage/large_model_training.md#deepspeed)
- [FSDP](../common_usage/large_model_training.md#fullyshardeddataparallel-fsdp)
2. **接口统一的开放架构**
**支持丰富的训练策略**
- 使用统一的接口处理不同的算法任务,例如,实现一个方法并应用于所有的兼容性模型。
- 上下游的对接更加统一便捷,在为上层算法库提供统一抽象的同时,支持多种后端设备。目前 MMEngine 支持 Nvidia CUDA、Mac MPS、AMD、MLU 等设备进行模型训练。
- [混合精度训练Mixed Precision Training](../common_usage/speed_up_training.md#混合精度训练)
- [梯度累积Gradient Accumulation](../common_usage/save_gpu_memory.md#梯度累加)
- [梯度检查点Gradient Checkpointing](../common_usage/save_gpu_memory.md#梯度检查点)
3. **可定制的训练流程**
**提供易用的配置系统**
- 定义了“乐高”式的训练流程。
- 提供了丰富的组件和策略。
- 使用不同等级的 API 控制训练过程。
- [纯 Python 风格的配置文件,易于跳转](../advanced_tutorials/config.md#纯-python-风格的配置文件beta)
- [纯文本风格的配置文件,支持 JSON 和 YAML](../advanced_tutorials/config.md)
**覆盖主流的训练监测平台**
- [TensorBoard](../common_usage/visualize_training_log.md#tensorboard) | [WandB](../common_usage/visualize_training_log.md#wandb) | [MLflow](../common_usage/visualize_training_log.md#mlflow-wip)
- [ClearML](../common_usage/visualize_training_log.md#clearml) | [Neptune](../common_usage/visualize_training_log.md#neptune) | [DVCLive](../common_usage/visualize_training_log.md#dvclive) | [Aim](../common_usage/visualize_training_log.md#aim)
**兼容主流的训练芯片**
- 英伟达 CUDA | 苹果 MPS
- 华为 Ascend | 寒武纪 MLU | 摩尔线程 MUSA
## 架构

View File

@ -31,11 +31,12 @@ MMEngine 提供了很多内置的钩子,将钩子分为两类,分别是默
**自定义钩子**
| 名称 | 用途 | 优先级 |
| :---------------------------------: | :-------------------: | :---------: |
| [EMAHook](#emahook) | 模型参数指数滑动平均 | NORMAL (50) |
| [EmptyCacheHook](#emptycachehook) | PyTorch CUDA 缓存清理 | NORMAL (50) |
| [SyncBuffersHook](#syncbuffershook) | 同步模型的 buffer | NORMAL (50) |
| 名称 | 用途 | 优先级 |
| :---------------------------------: | :--------------------------------: | :-----------: |
| [EMAHook](#emahook) | 模型参数指数滑动平均 | NORMAL (50) |
| [EmptyCacheHook](#emptycachehook) | PyTorch CUDA 缓存清理 | NORMAL (50) |
| [SyncBuffersHook](#syncbuffershook) | 同步模型的 buffer | NORMAL (50) |
| [ProfilerHook](#profilerhook) | 分析算子的执行时间以及显存占用情况 | VERY_LOW (90) |
```{note}
不建议修改默认钩子的优先级,因为优先级低的钩子可能会依赖优先级高的钩子。例如 CheckpointHook 的优先级需要比 ParamSchedulerHook 低,这样保存的优化器状态才是正确的状态。另外,自定义钩子的优先级默认为 `NORMAL (50)`
@ -206,6 +207,20 @@ runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()
```
### ProfilerHook
[ProfilerHook](mmengine.hooks.ProfilerHook) 用于分析模型算子的执行时间以及显存占用情况。
```python
custom_hooks = [dict(type='ProfilerHook', on_trace_ready=dict(type='tb_trace'))]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()
```
profile 的结果会保存在 `work_dirs/{timestamp}` 下的 `tf_tracing_logs` 目录,通过 `tensorboard --logdir work_dirs/{timestamp}tf_tracing_logs`
更多关于 ProfilerHook 的用法请阅读 [ProfilerHook](mmengine.hooks.ProfilerHook) 文档。
## 自定义钩子
如果 MMEngine 提供的默认钩子不能满足需求,用户可以自定义钩子,只需继承钩子基类并重写相应的位点方法。

View File

@ -499,7 +499,7 @@ class BaseStrategy(metaclass=ABCMeta):
'"type" and "constructor" are not in '
f'optimizer, but got {name}={optim}')
optim_wrappers[name] = optim
return OptimWrapperDict(**optim_wrappers)
return OptimWrapperDict(**optim_wrappers) # type: ignore
else:
raise TypeError('optimizer wrapper should be an OptimWrapper '
f'object or dict, but got {optim_wrapper}')
@ -799,7 +799,8 @@ class BaseStrategy(metaclass=ABCMeta):
else:
model = self.model
_load_checkpoint_to_model(model, state_dict, strict, revise_keys)
_load_checkpoint_to_model(
model, state_dict, strict=strict, revise_keys=revise_keys)
def load_optim_state_dict(self, state_dict: dict) -> None:
"""Load optimizer state from dict."""

View File

@ -120,8 +120,9 @@ class ColossalAIOptimWrapper(OptimWrapper):
self.optimizer.backward(loss, **kwargs)
@MODEL_WRAPPERS.register_module()
class CollosalAIModelWrapper:
@MODEL_WRAPPERS.register_module(
name=['ColossalAIModelWrapper', 'CollosalAIModelWrapper'])
class ColossalAIModelWrapper:
def __init__(self, model_wrapper: ModelWrapper, model: nn.Module):
self.model_wrapper = model_wrapper
@ -238,7 +239,7 @@ class ColossalAIStrategy(BaseStrategy):
OPTIMIZER_DIR = 'optimizer' # directory to save optimizer state.
MODEL_DIR = 'model' # directory to save model
SCHEDULER_DIR = 'scheduler' # directory to save scheduelrs
model: CollosalAIModelWrapper # type: ignore
model: ColossalAIModelWrapper # type: ignore
optim_wrapper: ColossalAIOptimWrapper # type: ignore
def __init__(
@ -360,7 +361,7 @@ class ColossalAIStrategy(BaseStrategy):
map_location: Union[str, Callable] = 'default',
callback: Optional[Callable] = None,
) -> dict:
"""override this method since colossalai resume optimizer from filename
"""Override this method since colossalai resume optimizer from filename
directly."""
self.logger.info(f'Resume checkpoint from {filename}')
@ -468,8 +469,14 @@ class ColossalAIStrategy(BaseStrategy):
def _build_plugin(self, plugin: Union[str, dict]):
if isinstance(plugin, str):
if plugin == 'gemini':
plugin = colo_plugin.GeminiPlugin(
precision='bf16', placement_policy='cuda')
try:
plugin = colo_plugin.GeminiPlugin(
precision='bf16', placement_policy='auto')
except AssertionError:
from colossalai.zero.gemini.placement_policy import \
PlacementPolicyFactory as colo_placement
raise ValueError('placement policy must be one of ' +
f'{list(colo_placement.policies.keys())}')
elif plugin == 'lowlevel-zero':
plugin = colo_plugin.LowLevelZeroPlugin()
else:
@ -508,11 +515,11 @@ class ColossalAIStrategy(BaseStrategy):
self,
model: nn.Module,
optim_wrapper: Optional[OptimWrapper] = None,
) -> Union[Tuple[CollosalAIModelWrapper, ColossalAIOptimWrapper],
CollosalAIModelWrapper]: # type: ignore
) -> Union[Tuple[ColossalAIModelWrapper, ColossalAIOptimWrapper],
ColossalAIModelWrapper]: # type: ignore
"""Wrap model with :class:`ModelWrapper`."""
if self.model_wrapper is None:
self.model_wrapper = {'type': 'CollosalAIModelWrapper'}
self.model_wrapper = {'type': 'ColossalAIModelWrapper'}
# For zero series parallel, move `data_preprocessor` to current device
# is reasonable. We need to `BaseDataPreprocessor.to` manually since

View File

@ -6,18 +6,23 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch
from mmengine.logging import print_log
try:
import deepspeed
except ImportError:
deepspeed = None
import logging
import torch.nn as nn
import mmengine
from mmengine.dist import init_dist
from mmengine.dist import init_dist, is_main_process
from mmengine.optim import BaseOptimWrapper, _ParamScheduler
from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS, OPTIMIZERS,
STRATEGIES)
from mmengine.runner.checkpoint import save_checkpoint, weights_to_cpu
from mmengine.utils import apply_to, digit_version, get_git_hash
from .base import BaseStrategy
@ -306,8 +311,8 @@ class DeepSpeedStrategy(BaseStrategy):
self.config['steps_per_print'] = steps_per_print
self._inputs_to_half = inputs_to_half
assert (exclude_frozen_parameters is None or
digit_version(deepspeed.__version__) >= digit_version('0.10.1')
), ('DeepSpeed >= 0.10.1 is required to enable '
digit_version(deepspeed.__version__) >= digit_version('0.13.2')
), ('DeepSpeed >= 0.13.2 is required to enable '
'exclude_frozen_parameters')
self.exclude_frozen_parameters = exclude_frozen_parameters
@ -425,7 +430,7 @@ class DeepSpeedStrategy(BaseStrategy):
self.logger.info(f'Load checkpoint from {filename}')
dirname, basename = osp.split(filename)
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
if digit_version(deepspeed.__version__) >= digit_version('0.13.2'):
_, extra_ckpt = self.model.load_checkpoint(
dirname,
tag=basename,
@ -463,7 +468,7 @@ class DeepSpeedStrategy(BaseStrategy):
self.logger.info(f'Resume checkpoint from {filename}')
dirname, basename = osp.split(filename)
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
if digit_version(deepspeed.__version__) >= digit_version('0.13.2'):
_, extra_ckpt = self.model.load_checkpoint(
dirname,
tag=basename,
@ -506,7 +511,7 @@ class DeepSpeedStrategy(BaseStrategy):
"""Save checkpoint to given ``filename``.
Warning:
`save_optimizer` and `callback` parameters are not supported yet.
`callback` parameter is not supported yet.
Args:
filename (str): Filename to save checkpoint.
@ -527,25 +532,50 @@ class DeepSpeedStrategy(BaseStrategy):
mmengine=mmengine.__version__ + get_git_hash(),
)
if save_optimizer and hasattr(self, 'optim_wrapper'):
# The key can not be 'optimizer', otherwise error will be thrown
# when loading or resuming checkpoint.
extra_ckpt['optim_wrapper'] = self.optim_state_dict()
if save_param_scheduler and hasattr(self, 'param_schedulers'):
extra_ckpt['param_schedulers'] = self.scheduler_state_dict()
dirname, basename = osp.split(filename)
if digit_version(deepspeed.__version__) >= digit_version('0.10.1'):
if (not save_optimizer
and self.model.zero_optimization_partition_weights()
and not self.model.zero_gather_16bit_weights_on_model_save()):
print_log(
'Configured to `save_optimizer=False`, but currently using '
"DeepSpeed's ZeRO stage 3 with "
'`gather_16bit_weights_on_model_save=False`. In '
'this configuration, the model cannot be saved properly '
'and will be saved with the optimizer state. '
'To support `save_optimizer=False`, please set '
'`gather_16bit_weights_on_model_save=True` in your '
'DeepSpeed config.',
logger='current',
level=logging.WARNING)
save_optimizer = True
state_dict_kwargs = {}
if digit_version(deepspeed.__version__) >= digit_version('0.13.2'):
state_dict_kwargs[
'exclude_frozen_parameters'] = self.exclude_frozen_parameters
if save_optimizer:
if hasattr(self, 'optim_wrapper'):
# The key can not be 'optimizer', otherwise error will be
# thrown when loading or resuming checkpoint.
extra_ckpt['optim_wrapper'] = self.optim_state_dict()
dirname, basename = osp.split(filename)
self.model.save_checkpoint(
dirname,
tag=basename,
client_state=extra_ckpt,
save_latest=False,
exclude_frozen_parameters=self.exclude_frozen_parameters)
**state_dict_kwargs)
else:
self.model.save_checkpoint(
dirname,
tag=basename,
client_state=extra_ckpt,
save_latest=False)
if self.model.zero_optimization_partition_weights():
state_dict = self.model._zero3_consolidated_16bit_state_dict(
**state_dict_kwargs)
else:
state_dict = self.model.module_state_dict(**state_dict_kwargs)
if is_main_process():
ckpt = {'state_dict': weights_to_cpu(state_dict), **extra_ckpt}
save_checkpoint(ckpt, filename)

View File

@ -53,7 +53,7 @@ class DDPStrategy(SingleDeviceStrategy):
init_dist(launcher, backend, **kwargs)
def convert_model(self, model: nn.Module) -> nn.Module:
"""convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm``
"""Convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm``
(SyncBN) or ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers.
Args:

View File

@ -48,9 +48,11 @@ else:
def _lazy2string(cfg_dict, dict_type=None):
if isinstance(cfg_dict, dict):
dict_type = dict_type or type(cfg_dict)
return dict_type({k: _lazy2string(v) for k, v in dict.items(cfg_dict)})
return dict_type(
{k: _lazy2string(v, dict_type)
for k, v in dict.items(cfg_dict)})
elif isinstance(cfg_dict, (tuple, list)):
return type(cfg_dict)(_lazy2string(v) for v in cfg_dict)
return type(cfg_dict)(_lazy2string(v, dict_type) for v in cfg_dict)
elif isinstance(cfg_dict, (LazyAttr, LazyObject)):
return f'{cfg_dict.module}.{str(cfg_dict)}'
else:
@ -391,7 +393,7 @@ class Config:
def __init__(
self,
cfg_dict: dict = None,
cfg_dict: Optional[dict] = None,
cfg_text: Optional[str] = None,
filename: Optional[Union[str, Path]] = None,
env_variables: Optional[dict] = None,
@ -442,6 +444,8 @@ class Config:
predefined variables. Defaults to True.
import_custom_modules (bool, optional): Whether to support
importing custom modules in config. Defaults to None.
use_environment_variables (bool, optional): Whether to use
environment variables. Defaults to True.
lazy_import (bool): Whether to load config in `lazy_import` mode.
If it is `None`, it will be deduced by the content of the
config file. Defaults to None.
@ -829,6 +833,8 @@ class Config:
filename (str): Name of config file.
use_predefined_variables (bool, optional): Whether to use
predefined variables. Defaults to True.
use_environment_variables (bool, optional): Whether to use
environment variables. Defaults to True.
lazy_import (bool): Whether to load config in `lazy_import` mode.
If it is `None`, it will be deduced by the content of the
config file. Defaults to None.
@ -899,7 +905,7 @@ class Config:
# 2. Set `_scope_` for the outer dict variable for the base
# config.
# 3. Set `scope` attribute for each base variable.
# Different from `_scope_` `scope` is not a key of base
# Different from `_scope_`, `scope` is not a key of base
# dict, `scope` attribute will be parsed to key `_scope_`
# by function `_parse_scope` only if the base variable is
# accessed by the current config.
@ -1221,7 +1227,8 @@ class Config:
if base_code is not None:
base_code = ast.Expression( # type: ignore
body=base_code.value) # type: ignore
base_files = eval(compile(base_code, '', mode='eval'))
base_files = eval(compile(base_code, '',
mode='eval')) # type: ignore
else:
base_files = []
elif file_format in ('.yml', '.yaml', '.json'):
@ -1282,7 +1289,7 @@ class Config:
def _merge_a_into_b(a: dict,
b: dict,
allow_list_keys: bool = False) -> dict:
"""merge dict ``a`` into dict ``b`` (non-inplace).
"""Merge dict ``a`` into dict ``b`` (non-inplace).
Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
in-place modifications.
@ -1352,22 +1359,22 @@ class Config:
@property
def filename(self) -> str:
"""get file name of config."""
"""Get file name of config."""
return self._filename
@property
def text(self) -> str:
"""get config text."""
"""Get config text."""
return self._text
@property
def env_variables(self) -> dict:
"""get used environment variables."""
"""Get used environment variables."""
return self._env_variables
@property
def pretty_text(self) -> str:
"""get formatted python config text."""
"""Get formatted python config text."""
indent = 4
@ -1721,17 +1728,17 @@ class Config:
class DictAction(Action):
"""
argparse action to split an argument into KEY=VALUE form
on the first = and append to a dictionary. List options can
be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
"""Argparse action to split an argument into KEY=VALUE form on the first =
and append to a dictionary.
List options can be passed as comma separated values, i.e 'KEY=V1,V2,V3',
or with explicit brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested
brackets to build list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
"""
@staticmethod
def _parse_int_float_bool(val: str) -> Union[int, float, bool, Any]:
"""parse int/float/bool value in the string."""
"""Parse int/float/bool value in the string."""
try:
return int(val)
except ValueError:
@ -1816,7 +1823,7 @@ class DictAction(Action):
parser: ArgumentParser,
namespace: Namespace,
values: Union[str, Sequence[Any], None],
option_string: str = None):
option_string: str = None): # type: ignore
"""Parse Variables in string and add them into argparser.
Args:

View File

@ -12,6 +12,7 @@ from mmengine.fileio import load
from mmengine.utils import check_file_exist
PYTHON_ROOT_DIR = osp.dirname(osp.dirname(sys.executable))
SYSTEM_PYTHON_PREFIX = '/usr/lib/python'
MODULE2PACKAGE = {
'mmcls': 'mmcls',
@ -176,7 +177,8 @@ def _is_builtin_module(module_name: str) -> bool:
return False
origin_path = osp.abspath(origin_path)
if ('site-package' in origin_path or 'dist-package' in origin_path
or not origin_path.startswith(PYTHON_ROOT_DIR)):
or not origin_path.startswith(
(PYTHON_ROOT_DIR, SYSTEM_PYTHON_PREFIX))):
return False
else:
return True

View File

@ -1,10 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
is_dipu_available, is_mlu_available, is_mps_available,
is_npu_available, is_npu_support_full_precision)
from .utils import (get_device, get_max_cuda_memory, get_max_musa_memory,
is_cuda_available, is_dipu_available, is_mlu_available,
is_mps_available, is_musa_available, is_npu_available,
is_npu_support_full_precision)
__all__ = [
'get_max_cuda_memory', 'get_device', 'is_cuda_available',
'is_mlu_available', 'is_mps_available', 'is_npu_available',
'is_dipu_available', 'is_npu_support_full_precision'
'is_dipu_available', 'get_max_musa_memory', 'is_musa_available',
'is_npu_support_full_precision'
]

View File

@ -16,12 +16,24 @@ try:
except Exception:
IS_NPU_AVAILABLE = False
try:
import torch_mlu # noqa: F401
IS_MLU_AVAILABLE = hasattr(torch, 'mlu') and torch.mlu.is_available()
except Exception:
IS_MLU_AVAILABLE = False
try:
import torch_dipu # noqa: F401
IS_DIPU_AVAILABLE = True
except Exception:
IS_DIPU_AVAILABLE = False
try:
import torch_musa # noqa: F401
IS_MUSA_AVAILABLE = True
except Exception:
IS_MUSA_AVAILABLE = False
def get_max_cuda_memory(device: Optional[torch.device] = None) -> int:
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB) for
@ -58,7 +70,7 @@ def is_npu_available() -> bool:
def is_mlu_available() -> bool:
"""Returns True if Cambricon PyTorch and mlu devices exist."""
return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
return IS_MLU_AVAILABLE
def is_mps_available() -> bool:
@ -73,6 +85,34 @@ def is_dipu_available() -> bool:
return IS_DIPU_AVAILABLE
def get_max_musa_memory(device: Optional[torch.device] = None) -> int:
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB) for
a given device. By default, this returns the peak allocated memory since
the beginning of this program.
Args:
device (torch.device, optional): selected device. Returns
statistic for the current device, given by
:func:`~torch.musa.current_device`, if ``device`` is None.
Defaults to None.
Returns:
int: The maximum GPU memory occupied by tensors in megabytes
for a given device.
"""
mem = torch.musa.max_memory_allocated(device=device)
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
dtype=torch.int,
device=device)
# TODO:haowen.han@mthreads.com: This function is not supported by musa yet.
# torch.musa.reset_peak_memory_stats()
return int(mem_mb.item())
def is_musa_available() -> bool:
return IS_MUSA_AVAILABLE
def is_npu_support_full_precision() -> bool:
"""Returns True if npu devices support full precision training."""
version_of_support_full_precision = 220
@ -91,12 +131,14 @@ elif is_mps_available():
DEVICE = 'mps'
elif is_dipu_available():
DEVICE = 'dipu'
elif is_musa_available():
DEVICE = 'musa'
def get_device() -> str:
"""Returns the currently existing device type.
Returns:
str: cuda | npu | mlu | mps | cpu.
str: cuda | npu | mlu | mps | musa | cpu.
"""
return DEVICE

31
mmengine/dist/dist.py vendored
View File

@ -13,7 +13,7 @@ from torch import distributed as torch_dist
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)
from torch.distributed import ProcessGroup
from itertools import zip_longest, chain
import mmengine
from .utils import (get_world_size, get_rank, get_backend, get_dist_info,
get_default_group, barrier, get_data_device,
@ -415,12 +415,16 @@ def _broadcast_object_list(object_list: List[Any],
current_device = torch.device('cpu')
is_hccl_backend = group_backend == 'hccl'
is_cncl_backend = group_backend == 'cncl'
is_mccl_backend = group_backend == 'mccl'
if is_hccl_backend:
current_device = torch.device('npu', torch.npu.current_device())
object_sizes_tensor = object_sizes_tensor.to(current_device)
elif is_cncl_backend:
current_device = torch.device('mlu', torch.mlu.current_device())
object_sizes_tensor = object_sizes_tensor.to(current_device)
elif is_mccl_backend:
current_device = torch.device('musa', torch.musa.current_device())
object_sizes_tensor = object_sizes_tensor.to(current_device)
elif is_nccl_backend:
# See note about using torch.cuda.current_device() here in
# docstring. We cannot simply use my_rank since rank == device is
@ -624,6 +628,7 @@ def _all_gather_object(object_list: List[Any],
group_backend = get_backend(group)
current_device = torch.device('cpu')
is_nccl_backend = group_backend == torch_dist.Backend.NCCL
is_mccl_backend = group_backend == 'mccl'
if is_nccl_backend:
# See note about using torch.cuda.current_device() here in docstring.
# We cannot simply use my_rank since rank == device is not necessarily
@ -631,6 +636,13 @@ def _all_gather_object(object_list: List[Any],
current_device = torch.device('cuda', torch.cuda.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
elif is_mccl_backend:
# See note about using torch.musa.current_device() here in docstring.
# We cannot simply use my_rank since rank == device is not necessarily
# true.
current_device = torch.device('musa', torch.musa.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
# Gather all local sizes. This is so that we can find the max size, and
# index until the correct size when deserializing the tensors.
group_size = get_world_size(group=group)
@ -776,10 +788,15 @@ def _gather_object(obj: Any,
group_backend = get_backend(group)
current_device = torch.device('cpu')
is_nccl_backend = group_backend == torch_dist.Backend.NCCL
is_mccl_backend = group_backend == 'mccl'
if is_nccl_backend:
current_device = torch.device('cuda', torch.cuda.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
elif is_mccl_backend:
current_device = torch.device('musa', torch.musa.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
# Gather all local sizes. This is so that we can find the max size, and
# index until the correct size when deserializing the tensors.
group_size = get_world_size(group=group)
@ -1010,8 +1027,10 @@ def collect_results_cpu(result_part: list,
part_list.append(pickle.load(f))
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
zipped_results = zip_longest(*part_list)
ordered_results = [
i for i in chain.from_iterable(zipped_results) if i is not None
]
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
# remove tmp dir
@ -1032,8 +1051,10 @@ def _collect_results_device(result_part: list, size: int) -> Optional[list]:
if rank == 0:
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
zipped_results = zip_longest(*part_list)
ordered_results = [
i for i in chain.from_iterable(zipped_results) if i is not None
]
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
return ordered_results

View File

@ -11,7 +11,8 @@ import torch.multiprocessing as mp
from torch import Tensor
from torch import distributed as torch_dist
from torch.distributed import ProcessGroup
from mmengine.device import is_mlu_available, is_npu_available
from mmengine.device import (is_mlu_available, is_npu_available,
is_musa_available)
from collections.abc import Iterable, Mapping
@ -99,9 +100,10 @@ def _init_dist_pytorch(backend, init_backend='torch', **kwargs) -> None:
**kwargs: keyword arguments are passed to ``init_process_group``.
"""
rank = int(os.environ['RANK'])
# LOCAL_RANK is set by `torch.distributed.launch` since PyTorch 1.1
local_rank = int(os.environ['LOCAL_RANK'])
if is_mlu_available():
import torch_mlu # noqa: F401
local_rank = int(os.environ['LOCAL_RANK'])
torch.mlu.set_device(local_rank)
torch_dist.init_process_group(
backend='cncl',
@ -110,15 +112,21 @@ def _init_dist_pytorch(backend, init_backend='torch', **kwargs) -> None:
**kwargs)
elif is_npu_available():
import torch_npu # noqa: F401
torch.npu.set_device(rank)
torch.npu.set_device(local_rank)
torch_dist.init_process_group(
backend='hccl',
rank=rank,
world_size=int(os.environ['WORLD_SIZE']),
**kwargs)
elif is_musa_available():
import torch_musa # noqa: F401
torch.musa.set_device(rank)
torch_dist.init_process_group(
backend='mccl',
rank=rank,
world_size=int(os.environ['WORLD_SIZE']),
**kwargs)
else:
# LOCAL_RANK is set by `torch.distributed.launch` since PyTorch 1.1
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
if init_backend == 'torch':
@ -528,6 +536,9 @@ def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device:
return torch.device('mlu', torch.mlu.current_device())
elif backend == 'smddp':
return torch.device('cuda', torch.cuda.current_device())
elif backend == 'mccl':
import torch_musa
return torch.device('musa', torch_musa.current_device())
else:
# GLOO and MPI backends use cpu device by default
return torch.device('cpu')
@ -552,7 +563,7 @@ def cast_data_device(
Tensor or list or dict: ``data`` was casted to ``device``.
"""
if out is not None:
if type(data) != type(out):
if type(data) is not type(out):
raise TypeError(
'out should be the same type with data, but got data is '
f'{type(data)} and out is {type(data)}')

View File

@ -175,11 +175,11 @@ class DumpResults(BaseMetric):
self.out_file_path = out_file_path
def process(self, data_batch: Any, predictions: Sequence[dict]) -> None:
"""transfer tensors in predictions to CPU."""
"""Transfer tensors in predictions to CPU."""
self.results.extend(_to_cpu(predictions))
def compute_metrics(self, results: list) -> dict:
"""dump the prediction results to a pickle file."""
"""Dump the prediction results to a pickle file."""
dump(results, self.out_file_path)
print_log(
f'Results has been saved to {self.out_file_path}.',
@ -188,7 +188,7 @@ class DumpResults(BaseMetric):
def _to_cpu(data: Any) -> Any:
"""transfer all tensors and BaseDataElement to cpu."""
"""Transfer all tensors and BaseDataElement to cpu."""
if isinstance(data, (Tensor, BaseDataElement)):
return data.to('cpu')
elif isinstance(data, list):

View File

@ -4,6 +4,7 @@ from typing import Optional, Sequence, Union
import torch
from mmengine.registry import HOOKS
from ..device import is_cuda_available, is_musa_available
from .hook import Hook
DATA_BATCH = Optional[Union[dict, tuple, list]]
@ -49,7 +50,10 @@ class EmptyCacheHook(Hook):
mode (str): Current mode of runner. Defaults to 'train'.
"""
if self._do_after_iter:
torch.cuda.empty_cache()
if is_cuda_available():
torch.cuda.empty_cache()
elif is_musa_available():
torch.musa.empty_cache()
def _before_epoch(self, runner, mode: str = 'train') -> None:
"""Empty cache before an epoch.
@ -59,7 +63,10 @@ class EmptyCacheHook(Hook):
mode (str): Current mode of runner. Defaults to 'train'.
"""
if self._do_before_epoch:
torch.cuda.empty_cache()
if is_cuda_available():
torch.cuda.empty_cache()
elif is_musa_available():
torch.musa.empty_cache()
def _after_epoch(self, runner, mode: str = 'train') -> None:
"""Empty cache after an epoch.
@ -69,4 +76,7 @@ class EmptyCacheHook(Hook):
mode (str): Current mode of runner. Defaults to 'train'.
"""
if self._do_after_epoch:
torch.cuda.empty_cache()
if is_cuda_available():
torch.cuda.empty_cache()
elif is_musa_available():
torch.musa.empty_cache()

View File

@ -233,7 +233,7 @@ class ProfilerHook(Hook):
self._export_chrome_trace(runner)
def after_train_iter(self, runner, batch_idx, data_batch, outputs):
"""profiler will call `step` method if it is not closed."""
"""Profiler will call `step` method if it is not closed."""
if not self._closed:
self.profiler.step()
if runner.iter == self.profile_times - 1 and not self.by_epoch:

View File

@ -58,7 +58,7 @@ class HistoryBuffer:
self._statistics_methods.setdefault('mean', HistoryBuffer.mean)
def update(self, log_val: Union[int, float], count: int = 1) -> None:
"""update the log history.
"""Update the log history.
If the length of the buffer exceeds ``self._max_length``, the oldest
element will be removed from the buffer.

View File

@ -398,22 +398,38 @@ def _get_device_id():
except ImportError:
return 0
else:
local_rank = int(os.getenv('LOCAL_RANK', '0'))
# TODO: return device id of npu and mlu.
if not torch.cuda.is_available():
return local_rank
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
if cuda_visible_devices is None:
num_device = torch.cuda.device_count()
cuda_visible_devices = list(range(num_device))
else:
cuda_visible_devices = cuda_visible_devices.split(',')
MUSA_AVAILABLE = False
try:
return int(cuda_visible_devices[local_rank])
except ValueError:
# handle case for Multi-Instance GPUs
# see #1148 for details
return cuda_visible_devices[local_rank]
import torch_musa
MUSA_AVAILABLE = True
except ImportError:
pass
if MUSA_AVAILABLE:
local_rank = int(os.getenv('LOCAL_RANK', '0'))
musa_visible_devices = os.getenv('MUSA_VISIBLE_DEVICES', None)
if musa_visible_devices is None:
num_device = torch_musa.device_count()
musa_visible_devices = list(range(num_device))
else:
musa_visible_devices = musa_visible_devices.split(',')
return int(musa_visible_devices[local_rank])
else:
local_rank = int(os.getenv('LOCAL_RANK', '0'))
# TODO: return device id of npu and mlu.
if not torch.cuda.is_available():
return local_rank
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
if cuda_visible_devices is None:
num_device = torch.cuda.device_count()
cuda_visible_devices = list(range(num_device))
else:
cuda_visible_devices = cuda_visible_devices.split(',')
try:
return int(cuda_visible_devices[local_rank])
except ValueError:
# handle case for Multi-Instance GPUs
# see #1148 for details
return cuda_visible_devices[local_rank]
def _get_host_info() -> str:
@ -427,8 +443,7 @@ def _get_host_info() -> str:
host = f'{getuser()}@{gethostname()}'
except Exception as e:
warnings.warn(f'Host or user not found: {str(e)}')
finally:
return host
return host
def _get_logging_file_handlers() -> Dict:

View File

@ -317,7 +317,7 @@ class MessageHub(ManagerMixin):
if key not in self.runtime_info:
return default
else:
# TODO There are restrictions on objects that can be saved
# TODO: There are restrictions on objects that can be saved
# return copy.deepcopy(self._runtime_info[key])
return self._runtime_info[key]

View File

@ -222,6 +222,21 @@ class BaseModel(BaseModule):
self._set_device(torch.device(device))
return super().cuda(device)
def musa(
self,
device: Optional[Union[int, str, torch.device]] = None,
) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.musa`
additionally.
Returns:
nn.Module: The model itself.
"""
if device is None or isinstance(device, int):
device = torch.device('musa', index=device)
self._set_device(torch.device(device))
return super().musa(device)
def mlu(
self,
device: Union[int, str, torch.device, None] = None,

View File

@ -113,6 +113,15 @@ class BaseDataPreprocessor(nn.Module):
self._device = torch.device(torch.cuda.current_device())
return super().cuda()
def musa(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to set the :attr:`device`
Returns:
nn.Module: The model itself.
"""
self._device = torch.device(torch.musa.current_device())
return super().musa()
def npu(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to set the :attr:`device`
@ -226,7 +235,7 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
self.pad_value = pad_value
def forward(self, data: dict, training: bool = False) -> Union[dict, list]:
"""Performs normalizationpadding and bgr2rgb conversion based on
"""Performs normalization, padding and bgr2rgb conversion based on
``BaseDataPreprocessor``.
Args:
@ -244,17 +253,17 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
dict or list: Data in the same format as the model input.
"""
data = self.cast_data(data) # type: ignore
_batch_inputs = data['inputs']
_batch_inputs = data['inputs'] # type: ignore
# Process data with `pseudo_collate`.
if is_seq_of(_batch_inputs, torch.Tensor):
batch_inputs = []
for _batch_input in _batch_inputs:
# channel transform
if self._channel_conversion:
_batch_input = _batch_input[[2, 1, 0], ...]
_batch_input = _batch_input[[2, 1, 0], ...] # type: ignore
# Convert to float after channel conversion to ensure
# efficiency
_batch_input = _batch_input.float()
_batch_input = _batch_input.float() # type: ignore
# Normalization.
if self._enable_normalize:
if self.mean.shape[0] == 3:
@ -293,7 +302,7 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
else:
raise TypeError('Output of `cast_data` should be a dict of '
'list/tuple with inputs and data_samples, '
f'but got {type(data)}: {data}')
data['inputs'] = batch_inputs
data.setdefault('data_samples', None)
return data
f'but got {type(data)}: {data}') # type: ignore
data['inputs'] = batch_inputs # type: ignore
data.setdefault('data_samples', None) # type: ignore
return data # type: ignore

View File

@ -119,7 +119,7 @@ def caffe2_xavier_init(module, bias=0):
def bias_init_with_prob(prior_prob):
"""initialize conv/fc bias value according to a given probability value."""
"""Initialize conv/fc bias value according to a given probability value."""
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
return bias_init
@ -662,12 +662,12 @@ def trunc_normal_(tensor: Tensor,
std: float = 1.,
a: float = -2.,
b: float = 2.) -> Tensor:
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
r"""Fills the input Tensor with values drawn from a truncated normal
distribution. The values are effectively drawn from the normal distribution
:math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside
:math:`[a, b]` redrawn until they are within the bounds. The method used
for generating the random values works best when :math:`a \leq \text{mean}
\leq b`.
Modified from
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py

View File

@ -127,7 +127,8 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
auto_wrap_policy: Union[str, Callable, None] = None,
backward_prefetch: Union[str, BackwardPrefetch, None] = None,
mixed_precision: Union[dict, MixedPrecision, None] = None,
param_init_fn: Union[str, Callable[[nn.Module], None]] = None,
param_init_fn: Union[str, Callable[
[nn.Module], None]] = None, # type: ignore # noqa: E501
use_orig_params: bool = True,
**kwargs,
):
@ -362,7 +363,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
optim: torch.optim.Optimizer,
group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
"""copied from pytorch 2.0.1 which has fixed some bugs."""
"""Copied from pytorch 2.0.1 which has fixed some bugs."""
state_dict_settings = FullyShardedDataParallel.get_state_dict_type(
model)
return FullyShardedDataParallel._optim_state_dict_impl(
@ -384,7 +385,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
state_dict_config: Optional[StateDictConfig] = None,
optim_state_dict_config: Optional[OptimStateDictConfig] = None,
) -> StateDictSettings:
"""copied from pytorch 2.0.1 which has fixed some bugs."""
"""Copied from pytorch 2.0.1 which has fixed some bugs."""
import torch.distributed.fsdp._traversal_utils as traversal_utils
_state_dict_type_to_config = {
StateDictType.FULL_STATE_DICT: FullStateDictConfig,

View File

@ -6,7 +6,7 @@ import torch
import torch.nn as nn
from mmengine.device import (is_cuda_available, is_mlu_available,
is_npu_available)
is_musa_available, is_npu_available)
from mmengine.registry import OPTIM_WRAPPERS
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
@ -74,8 +74,9 @@ class AmpOptimWrapper(OptimWrapper):
assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), (
'`torch.cuda.amp` is only available when pytorch version >= 1.6')
assert is_cuda_available() or is_npu_available() or is_mlu_available(
), ('``AmpOptimizerWrapper`` is only available training '
'on gpu, npu or mlu')
) or is_musa_available(), (
'``AmpOptimizerWrapper`` is only available training '
'on gpu, npu, mlu or musa')
super().__init__(**kwargs)
self._scale_update_param = None

View File

@ -123,8 +123,7 @@ class ApexOptimWrapper(OptimWrapper):
self._inner_count += 1
def state_dict(self) -> dict:
"""Get the state dictionary of :attr:`optimizer` and
:attr:`apex_amp`.
"""Get the state dictionary of :attr:`optimizer` and :attr:`apex_amp`.
Based on the state dictionary of the optimizer, the returned state
dictionary will add a key named "apex_amp".

View File

@ -25,7 +25,11 @@ def register_torch_optimizers() -> List[str]:
_optim = getattr(torch.optim, module_name)
if inspect.isclass(_optim) and issubclass(_optim,
torch.optim.Optimizer):
OPTIMIZERS.register_module(module=_optim)
if module_name == 'Adafactor':
OPTIMIZERS.register_module(
name='TorchAdafactor', module=_optim)
else:
OPTIMIZERS.register_module(module=_optim)
torch_optimizers.append(module_name)
return torch_optimizers

View File

@ -131,7 +131,7 @@ class DefaultOptimWrapperConstructor:
self._validate_cfg()
def _validate_cfg(self) -> None:
"""verify the correctness of the config."""
"""Verify the correctness of the config."""
if not isinstance(self.paramwise_cfg, dict):
raise TypeError('paramwise_cfg should be None or a dict, '
f'but got {type(self.paramwise_cfg)}')
@ -155,7 +155,7 @@ class DefaultOptimWrapperConstructor:
raise ValueError('base_wd should not be None')
def _is_in(self, param_group: dict, param_group_list: list) -> bool:
"""check whether the `param_group` is in the`param_group_list`"""
"""Check whether the `param_group` is in the`param_group_list`"""
assert is_list_of(param_group_list, dict)
param = set(param_group['params'])
param_set = set()
@ -213,7 +213,10 @@ class DefaultOptimWrapperConstructor:
level=logging.WARNING)
continue
if not param.requires_grad:
params.append(param_group)
print_log((f'{prefix}.{name} is skipped since its '
f'requires_grad={param.requires_grad}'),
logger='current',
level=logging.WARNING)
continue
# if the parameter match one of the custom keys, ignore other rules

View File

@ -161,8 +161,7 @@ class OptimWrapperDict(OptimWrapper):
self.optim_wrappers[name].load_state_dict(_state_dict)
def items(self) -> Iterator[Tuple[str, OptimWrapper]]:
"""A generator to get the name and corresponding
:obj:`OptimWrapper`"""
"""A generator to get the name and corresponding :obj:`OptimWrapper`"""
yield from self.optim_wrappers.items()
def values(self) -> Iterator[OptimWrapper]:

View File

@ -223,13 +223,13 @@ class PolyLR(LRSchedulerMixin, PolyParamScheduler):
@PARAM_SCHEDULERS.register_module()
class OneCycleLR(LRSchedulerMixin, OneCycleParamScheduler):
r"""Sets the learning rate of each parameter group according to the
1cycle learning rate policy. The 1cycle policy anneals the learning
rate from an initial learning rate to some maximum learning rate and then
from that maximum learning rate to some minimum learning rate much lower
than the initial learning rate.
This policy was initially described in the paper `Super-Convergence:
Very Fast Training of Neural Networks Using Large Learning Rates`_.
r"""Sets the learning rate of each parameter group according to the 1cycle
learning rate policy. The 1cycle policy anneals the learning rate from an
initial learning rate to some maximum learning rate and then from that
maximum learning rate to some minimum learning rate much lower than the
initial learning rate. This policy was initially described in the paper
`Super-Convergence: Very Fast Training of Neural Networks Using Large
Learning Rates`_.
The 1cycle learning rate policy changes the learning rate after every
batch. `step` should be called after a batch has been used for training.

View File

@ -565,9 +565,9 @@ class ExponentialParamScheduler(_ParamScheduler):
@PARAM_SCHEDULERS.register_module()
class CosineAnnealingParamScheduler(_ParamScheduler):
r"""Set the parameter value of each parameter group using a cosine
annealing schedule, where :math:`\eta_{max}` is set to the initial value
and :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
r"""Set the parameter value of each parameter group using a cosine annealing
schedule, where :math:`\eta_{max}` is set to the initial value and
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
.. math::
\begin{aligned}
@ -617,7 +617,7 @@ class CosineAnnealingParamScheduler(_ParamScheduler):
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
"""
""" # noqa: E501
def __init__(self,
optimizer: Union[Optimizer, BaseOptimWrapper],
@ -890,13 +890,13 @@ class PolyParamScheduler(_ParamScheduler):
@PARAM_SCHEDULERS.register_module()
class OneCycleParamScheduler(_ParamScheduler):
r"""Sets the parameters of each parameter group according to the
1cycle learning rate policy. The 1cycle policy anneals the learning
rate from an initial learning rate to some maximum learning rate and then
from that maximum learning rate to some minimum learning rate much lower
than the initial learning rate.
This policy was initially described in the paper `Super-Convergence:
Very Fast Training of Neural Networks Using Large Learning Rates`_.
r"""Sets the parameters of each parameter group according to the 1cycle
learning rate policy. The 1cycle policy anneals the learning rate from an
initial learning rate to some maximum learning rate and then from that
maximum learning rate to some minimum learning rate much lower than the
initial learning rate. This policy was initially described in the paper
`Super-Convergence: Very Fast Training of Neural Networks Using Large
Learning Rates`_.
The 1cycle learning rate policy changes the learning rate after every
batch. `step` should be called after a batch has been used for training.

View File

@ -4,7 +4,7 @@ import logging
from typing import TYPE_CHECKING, Any, Optional, Union
from mmengine.config import Config, ConfigDict
from mmengine.utils import ManagerMixin
from mmengine.utils import ManagerMixin, digit_version
from .registry import Registry
if TYPE_CHECKING:
@ -232,6 +232,21 @@ def build_model_from_cfg(
return build_from_cfg(cfg, registry, default_args)
def build_optimizer_from_cfg(
cfg: Union[dict, ConfigDict, Config],
registry: Registry,
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any:
import torch
from ..logging import print_log
if 'type' in cfg \
and 'Adafactor' == cfg['type'] \
and digit_version(torch.__version__) >= digit_version('2.5.0'):
print_log(
'the torch version of Adafactor is registered as TorchAdafactor')
return build_from_cfg(cfg, registry, default_args)
def build_scheduler_from_cfg(
cfg: Union[dict, ConfigDict, Config],
registry: Registry,

View File

@ -81,7 +81,7 @@ class DefaultScope(ManagerMixin):
@classmethod
@contextmanager
def overwrite_default_scope(cls, scope_name: Optional[str]) -> Generator:
"""overwrite the current default scope with `scope_name`"""
"""Overwrite the current default scope with `scope_name`"""
if scope_name is None:
yield
else:

View File

@ -332,7 +332,7 @@ class Registry:
return root
def import_from_location(self) -> None:
"""import modules from the pre-defined locations in self._location."""
"""Import modules from the pre-defined locations in self._location."""
if not self._imported:
# Avoid circular import
from ..logging import print_log

View File

@ -6,8 +6,8 @@ More datails can be found at
https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html.
"""
from .build_functions import (build_model_from_cfg, build_runner_from_cfg,
build_scheduler_from_cfg)
from .build_functions import (build_model_from_cfg, build_optimizer_from_cfg,
build_runner_from_cfg, build_scheduler_from_cfg)
from .registry import Registry
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
@ -35,7 +35,7 @@ MODEL_WRAPPERS = Registry('model_wrapper')
WEIGHT_INITIALIZERS = Registry('weight initializer')
# mangage all kinds of optimizers like `SGD` and `Adam`
OPTIMIZERS = Registry('optimizer')
OPTIMIZERS = Registry('optimizer', build_func=build_optimizer_from_cfg)
# manage optimizer wrapper
OPTIM_WRAPPERS = Registry('optim_wrapper')
# manage constructors that customize the optimization hyperparameters.

View File

@ -109,7 +109,7 @@ def init_default_scope(scope: str) -> None:
if current_scope.scope_name != scope: # type: ignore
print_log(
'The current default scope ' # type: ignore
f'"{current_scope.scope_name}" is not "{scope}", '
f'"{current_scope.scope_name}" is not "{scope}", ' # type: ignore
'`init_default_scope` will force set the current'
f'default scope to "{scope}".',
logger='current',

View File

@ -540,7 +540,7 @@ class FlexibleRunner:
@property
def hooks(self):
"""list[:obj:`Hook`]: A list of registered hooks."""
"""List[:obj:`Hook`]: A list of registered hooks."""
return self._hooks
@property
@ -1117,7 +1117,7 @@ class FlexibleRunner:
return '\n'.join(stage_hook_infos)
def load_or_resume(self):
"""load or resume checkpoint."""
"""Load or resume checkpoint."""
if self._has_loaded:
return None
@ -1539,7 +1539,7 @@ class FlexibleRunner:
file_client_args: Optional[dict] = None,
save_optimizer: bool = True,
save_param_scheduler: bool = True,
meta: dict = None,
meta: Optional[dict] = None,
by_epoch: bool = True,
backend_args: Optional[dict] = None,
):

View File

@ -135,7 +135,13 @@ def autocast(device_type: Optional[str] = None,
elif device_type == 'npu':
pass
elif device_type == 'musa':
if dtype is None:
dtype = torch.get_autocast_gpu_dtype()
with torch.musa.amp.autocast(
enabled=enabled, dtype=dtype, cache_enabled=cache_enabled):
yield
return
else:
# Device like MPS does not support fp16 training or testing.
# If an inappropriate device is set and fp16 is enabled, an error

View File

@ -309,7 +309,7 @@ class CheckpointLoader:
@classmethod
def load_checkpoint(cls, filename, map_location=None, logger='current'):
"""load checkpoint through URL scheme path.
"""Load checkpoint through URL scheme path.
Args:
filename (str): checkpoint file name with given prefix
@ -332,7 +332,7 @@ class CheckpointLoader:
@CheckpointLoader.register_scheme(prefixes='')
def load_from_local(filename, map_location):
"""load checkpoint by local file path.
"""Load checkpoint by local file path.
Args:
filename (str): local checkpoint file path
@ -353,7 +353,7 @@ def load_from_http(filename,
map_location=None,
model_dir=None,
progress=os.isatty(0)):
"""load checkpoint through HTTP or HTTPS scheme path. In distributed
"""Load checkpoint through HTTP or HTTPS scheme path. In distributed
setting, this function only download checkpoint at local rank 0.
Args:
@ -386,7 +386,7 @@ def load_from_http(filename,
@CheckpointLoader.register_scheme(prefixes='pavi://')
def load_from_pavi(filename, map_location=None):
"""load checkpoint through the file path prefixed with pavi. In distributed
"""Load checkpoint through the file path prefixed with pavi. In distributed
setting, this function download ckpt at all ranks to different temporary
directories.
@ -419,7 +419,7 @@ def load_from_pavi(filename, map_location=None):
@CheckpointLoader.register_scheme(
prefixes=[r'(\S+\:)?s3://', r'(\S+\:)?petrel://'])
def load_from_ceph(filename, map_location=None, backend='petrel'):
"""load checkpoint through the file path prefixed with s3. In distributed
"""Load checkpoint through the file path prefixed with s3. In distributed
setting, this function download ckpt at all ranks to different temporary
directories.
@ -441,7 +441,7 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
def load_from_torchvision(filename, map_location=None):
"""load checkpoint through the file path prefixed with modelzoo or
"""Load checkpoint through the file path prefixed with modelzoo or
torchvision.
Args:
@ -467,7 +467,7 @@ def load_from_torchvision(filename, map_location=None):
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
def load_from_openmmlab(filename, map_location=None):
"""load checkpoint through the file path prefixed with open-mmlab or
"""Load checkpoint through the file path prefixed with open-mmlab or
openmmlab.
Args:
@ -510,7 +510,7 @@ def load_from_openmmlab(filename, map_location=None):
@CheckpointLoader.register_scheme(prefixes='mmcls://')
def load_from_mmcls(filename, map_location=None):
"""load checkpoint through the file path prefixed with mmcls.
"""Load checkpoint through the file path prefixed with mmcls.
Args:
filename (str): checkpoint file path with mmcls prefix

View File

@ -9,7 +9,8 @@ from typing import List, Optional, Tuple
import numpy as np
import torch
from mmengine.device import get_max_cuda_memory, is_cuda_available
from mmengine.device import (get_max_cuda_memory, get_max_musa_memory,
is_cuda_available, is_musa_available)
from mmengine.registry import LOG_PROCESSORS
@ -226,11 +227,13 @@ class LogProcessor:
log_tag.pop('time')
log_tag.pop('data_time')
# If cuda is available, the max memory occupied should be calculated.
if is_cuda_available():
# If cuda/musa is available,
# the max memory occupied should be calculated.
if is_cuda_available() or is_musa_available():
max_memory = self._get_max_memory(runner)
log_str += f'memory: {max_memory} '
tag['memory'] = max_memory
# Loop left keys to fill `log_str`.
if mode in ('train', 'val'):
log_items = []
@ -498,6 +501,9 @@ class LogProcessor:
"""
device = getattr(runner.model, 'output_device', None)
if is_musa_available():
return get_max_musa_memory(device)
return get_max_cuda_memory(device)
def _get_iter(self, runner, batch_idx: int) -> int:

View File

@ -8,8 +8,10 @@ import torch
from torch.utils.data import DataLoader
from mmengine.evaluator import Evaluator
from mmengine.logging import print_log
from mmengine.logging import HistoryBuffer, print_log
from mmengine.registry import LOOPS
from mmengine.structures import BaseDataElement
from mmengine.utils import is_list_of
from .amp import autocast
from .base_loop import BaseLoop
from .utils import calc_dynamic_intervals
@ -98,7 +100,8 @@ class EpochBasedTrainLoop(BaseLoop):
self._decide_current_val_interval()
if (self.runner.val_loop is not None
and self._epoch >= self.val_begin
and self._epoch % self.val_interval == 0):
and (self._epoch % self.val_interval == 0
or self._epoch == self._max_epochs)):
self.runner.val_loop.run()
self.runner.call_hook('after_train')
@ -271,6 +274,14 @@ class IterBasedTrainLoop(BaseLoop):
# In iteration-based training loop, we treat the whole training process
# as a big epoch and execute the corresponding hook.
self.runner.call_hook('before_train_epoch')
if self._iter > 0:
print_log(
f'Advance dataloader {self._iter} steps to skip data '
'that has already been trained',
logger='current',
level=logging.WARNING)
for _ in range(self._iter):
next(self.dataloader_iterator)
while self._iter < self._max_iters and not self.stop_training:
self.runner.model.train()
@ -280,7 +291,8 @@ class IterBasedTrainLoop(BaseLoop):
self._decide_current_val_interval()
if (self.runner.val_loop is not None
and self._iter >= self.val_begin
and self._iter % self.val_interval == 0):
and (self._iter % self.val_interval == 0
or self._iter == self._max_iters)):
self.runner.val_loop.run()
self.runner.call_hook('after_train_epoch')
@ -353,17 +365,26 @@ class ValLoop(BaseLoop):
logger='current',
level=logging.WARNING)
self.fp16 = fp16
self.val_loss: Dict[str, HistoryBuffer] = dict()
def run(self) -> dict:
"""Launch validation."""
self.runner.call_hook('before_val')
self.runner.call_hook('before_val_epoch')
self.runner.model.eval()
# clear val loss
self.val_loss.clear()
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)
# compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
if self.val_loss:
loss_dict = _parse_losses(self.val_loss, 'val')
metrics.update(loss_dict)
self.runner.call_hook('after_val_epoch', metrics=metrics)
self.runner.call_hook('after_val')
return metrics
@ -381,6 +402,9 @@ class ValLoop(BaseLoop):
# outputs should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
outputs = self.runner.model.val_step(data_batch)
outputs, self.val_loss = _update_losses(outputs, self.val_loss)
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
'after_val_iter',
@ -425,17 +449,26 @@ class TestLoop(BaseLoop):
logger='current',
level=logging.WARNING)
self.fp16 = fp16
self.test_loss: Dict[str, HistoryBuffer] = dict()
def run(self) -> dict:
"""Launch test."""
self.runner.call_hook('before_test')
self.runner.call_hook('before_test_epoch')
self.runner.model.eval()
# clear test loss
self.test_loss.clear()
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)
# compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
if self.test_loss:
loss_dict = _parse_losses(self.test_loss, 'test')
metrics.update(loss_dict)
self.runner.call_hook('after_test_epoch', metrics=metrics)
self.runner.call_hook('after_test')
return metrics
@ -452,9 +485,66 @@ class TestLoop(BaseLoop):
# predictions should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
outputs = self.runner.model.test_step(data_batch)
outputs, self.test_loss = _update_losses(outputs, self.test_loss)
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
'after_test_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=outputs)
def _parse_losses(losses: Dict[str, HistoryBuffer],
stage: str) -> Dict[str, float]:
"""Parses the raw losses of the network.
Args:
losses (dict): raw losses of the network.
stage (str): The stage of loss, e.g., 'val' or 'test'.
Returns:
dict[str, float]: The key is the loss name, and the value is the
average loss.
"""
all_loss = 0
loss_dict: Dict[str, float] = dict()
for loss_name, loss_value in losses.items():
avg_loss = loss_value.mean()
loss_dict[loss_name] = avg_loss
if 'loss' in loss_name:
all_loss += avg_loss
loss_dict[f'{stage}_loss'] = all_loss
return loss_dict
def _update_losses(outputs: list, losses: dict) -> Tuple[list, dict]:
"""Update and record the losses of the network.
Args:
outputs (list): The outputs of the network.
losses (dict): The losses of the network.
Returns:
list: The updated outputs of the network.
dict: The updated losses of the network.
"""
if isinstance(outputs[-1],
BaseDataElement) and outputs[-1].keys() == ['loss']:
loss = outputs[-1].loss # type: ignore
outputs = outputs[:-1]
else:
loss = dict()
for loss_name, loss_value in loss.items():
if loss_name not in losses:
losses[loss_name] = HistoryBuffer()
if isinstance(loss_value, torch.Tensor):
losses[loss_name].update(loss_value.item())
elif is_list_of(loss_value, torch.Tensor):
for loss_value_i in loss_value:
losses[loss_name].update(loss_value_i.item())
return outputs, losses

View File

@ -579,7 +579,7 @@ class Runner:
@property
def hooks(self):
"""list[:obj:`Hook`]: A list of registered hooks."""
"""List[:obj:`Hook`]: A list of registered hooks."""
return self._hooks
@property
@ -720,7 +720,7 @@ class Runner:
def build_logger(self,
log_level: Union[int, str] = 'INFO',
log_file: str = None,
log_file: Optional[str] = None,
**kwargs) -> MMLogger:
"""Build a global asscessable MMLogger.
@ -1677,7 +1677,7 @@ class Runner:
return '\n'.join(stage_hook_infos)
def load_or_resume(self) -> None:
"""load or resume checkpoint."""
"""Load or resume checkpoint."""
if self._has_loaded:
return None

View File

@ -7,6 +7,7 @@ import numpy as np
import torch
from torch.utils.data import DataLoader
from mmengine.device import is_cuda_available, is_musa_available
from mmengine.dist import get_rank, sync_random_seed
from mmengine.logging import print_log
from mmengine.utils import digit_version, is_list_of
@ -69,7 +70,10 @@ def set_random_seed(seed: Optional[int] = None,
np.random.seed(seed)
torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if is_cuda_available():
torch.cuda.manual_seed_all(seed)
elif is_musa_available():
torch.musa.manual_seed_all(seed)
# os.environ['PYTHONHASHSEED'] = str(seed)
if deterministic:
if torch.backends.cudnn.benchmark:

View File

@ -387,7 +387,7 @@ class BaseDataElement:
return dict(self.metainfo_items())
def __setattr__(self, name: str, value: Any):
"""setattr is only used to set data."""
"""Setattr is only used to set data."""
if name in ('_metainfo_fields', '_data_fields'):
if not hasattr(self, name):
super().__setattr__(name, value)
@ -510,6 +510,17 @@ class BaseDataElement:
new_data.set_data(data)
return new_data
# Tensor-like methods
def musa(self) -> 'BaseDataElement':
"""Convert all tensors to musa in data."""
new_data = self.new()
for k, v in self.items():
if isinstance(v, (torch.Tensor, BaseDataElement)):
v = v.musa()
data = {k: v}
new_data.set_data(data)
return new_data
# Tensor-like methods
def npu(self) -> 'BaseDataElement':
"""Convert all tensors to NPU in data."""

View File

@ -18,6 +18,9 @@ if get_device() == 'npu':
elif get_device() == 'mlu':
BoolTypeTensor = Union[torch.BoolTensor, torch.mlu.BoolTensor]
LongTypeTensor = Union[torch.LongTensor, torch.mlu.LongTensor]
elif get_device() == 'musa':
BoolTypeTensor = Union[torch.BoolTensor, torch.musa.BoolTensor]
LongTypeTensor = Union[torch.LongTensor, torch.musa.LongTensor]
else:
BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor]
LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor]
@ -132,7 +135,7 @@ class InstanceData(BaseDataElement):
"""
def __setattr__(self, name: str, value: Sized):
"""setattr is only used to set data.
"""Setattr is only used to set data.
The value must have the attribute of `__len__` and have the same length
of `InstanceData`.

View File

@ -92,10 +92,25 @@ class MultiProcessTestCase(TestCase):
# Constructor patches current instance test method to
# assume the role of the main process and join its subprocesses,
# or run the underlying test function.
def __init__(self, method_name: str = 'runTest') -> None:
def __init__(self,
method_name: str = 'runTest',
methodName: str = 'runTest') -> None:
# methodName is the correct naming in unittest
# and testslide uses keyword arguments.
# So we need to use both to 1) not break BC and, 2) support testslide.
if methodName != 'runTest':
method_name = methodName
super().__init__(method_name)
fn = getattr(self, method_name)
setattr(self, method_name, self.join_or_run(fn))
try:
fn = getattr(self, method_name)
setattr(self, method_name, self.join_or_run(fn))
except AttributeError as e:
if methodName != 'runTest':
# we allow instantiation with no explicit method name
# but not an *incorrect* or missing method name
raise ValueError(
f'no such test method in {self.__class__}: {methodName}'
) from e
def setUp(self) -> None:
super().setUp()

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""This file holding some environment constant for sharing by other files."""
import os
import os.path as osp
import subprocess
import sys
@ -9,6 +10,7 @@ import numpy as np
import torch
import mmengine
from mmengine.device import is_cuda_available, is_musa_available
from .parrots_wrapper import TORCH_VERSION, get_build_config, is_rocm_pytorch
@ -24,6 +26,10 @@ def _get_cuda_home():
return CUDA_HOME
def _get_musa_home():
return os.environ.get('MUSA_HOME')
def collect_env():
"""Collect the information of the running environments.
@ -51,9 +57,10 @@ def collect_env():
env_info['sys.platform'] = sys.platform
env_info['Python'] = sys.version.replace('\n', '')
cuda_available = torch.cuda.is_available()
cuda_available = is_cuda_available()
musa_available = is_musa_available()
env_info['CUDA available'] = cuda_available
env_info['MUSA available'] = musa_available
env_info['numpy_random_seed'] = np.random.get_state()[1][0]
if cuda_available:
@ -89,7 +96,23 @@ def collect_env():
except subprocess.SubprocessError:
nvcc = 'Not Available'
env_info['NVCC'] = nvcc
elif musa_available:
devices = defaultdict(list)
for k in range(torch.musa.device_count()):
devices[torch.musa.get_device_name(k)].append(str(k))
for name, device_ids in devices.items():
env_info['GPU ' + ','.join(device_ids)] = name
MUSA_HOME = _get_musa_home()
env_info['MUSA_HOME'] = MUSA_HOME
if MUSA_HOME is not None and osp.isdir(MUSA_HOME):
try:
mcc = osp.join(MUSA_HOME, 'bin/mcc')
subprocess.check_output(f'"{mcc}" -v', shell=True)
except subprocess.SubprocessError:
mcc = 'Not Available'
env_info['mcc'] = mcc
try:
# Check C++ Compiler.
# For Unix-like, sysconfig has 'CC' variable like 'gcc -pthread ...',
@ -138,7 +161,7 @@ def collect_env():
try:
import cv2
env_info['OpenCV'] = cv2.__version__
except ModuleNotFoundError:
except ImportError:
pass
env_info['MMEngine'] = mmengine.__version__

View File

@ -57,6 +57,7 @@ if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version(
check_hash=False,
file_name=None):
r"""Loads the Torch serialized object at the given URL.
If downloaded file is a zip file, it will be automatically decompressed
If the object is already present in `model_dir`, it's deserialized and
returned.

View File

@ -4,6 +4,7 @@ from typing import Optional, Union
import torch
from mmengine.device import is_cuda_available, is_musa_available
from mmengine.dist.utils import master_only
from mmengine.logging import MMLogger, print_log
@ -66,7 +67,7 @@ class TimeCounter:
instance.log_interval = log_interval
instance.warmup_interval = warmup_interval
instance.with_sync = with_sync
instance.with_sync = with_sync # type: ignore
instance.tag = tag
instance.logger = logger
@ -84,15 +85,20 @@ class TimeCounter:
def wrapper(*args, **kwargs):
self.__count += 1
if self.with_sync and torch.cuda.is_available():
torch.cuda.synchronize()
if self.with_sync:
if is_cuda_available():
torch.cuda.synchronize()
elif is_musa_available():
torch.musa.synchronize()
start_time = time.perf_counter()
result = fn(*args, **kwargs)
if self.with_sync and torch.cuda.is_available():
torch.cuda.synchronize()
if self.with_sync:
if is_cuda_available():
torch.cuda.synchronize()
elif is_musa_available():
torch.musa.synchronize()
elapsed = time.perf_counter() - start_time
self.print_time(elapsed)
@ -121,7 +127,7 @@ class TimeCounter:
self.print_time(elapsed)
def print_time(self, elapsed: Union[int, float]) -> None:
"""print times per count."""
"""Print times per count."""
if self.__count >= self.warmup_interval:
self.__pure_inf_time += elapsed

View File

@ -131,7 +131,7 @@ def tuple_cast(inputs, dst_type):
def is_seq_of(seq: Any,
expected_type: Union[Type, tuple],
seq_type: Type = None) -> bool:
seq_type: Optional[Type] = None) -> bool:
"""Check whether it is a sequence of some type.
Args:

View File

@ -69,11 +69,11 @@ def get_installed_path(package: str) -> str:
else:
raise e
possible_path = osp.join(pkg.location, package)
possible_path = osp.join(pkg.location, package) # type: ignore
if osp.exists(possible_path):
return possible_path
else:
return osp.join(pkg.location, package2module(package))
return osp.join(pkg.location, package2module(package)) # type: ignore
def package2module(package: str):

View File

@ -3,7 +3,7 @@ import sys
from collections.abc import Iterable
from multiprocessing import Pool
from shutil import get_terminal_size
from typing import Callable, Sequence
from typing import Callable, Optional, Sequence
from .timer import Timer
@ -54,7 +54,7 @@ class ProgressBar:
self.timer = Timer()
def update(self, num_tasks: int = 1):
"""update progressbar.
"""Update progressbar.
Args:
num_tasks (int): Update step size.
@ -142,8 +142,8 @@ def init_pool(process_num, initializer=None, initargs=None):
def track_parallel_progress(func: Callable,
tasks: Sequence,
nproc: int,
initializer: Callable = None,
initargs: tuple = None,
initializer: Optional[Callable] = None,
initargs: Optional[tuple] = None,
bar_width: int = 50,
chunksize: int = 1,
skip_first: bool = False,

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from multiprocessing import Pool
from typing import Callable, Iterable, Sized
from typing import Callable, Iterable, Optional, Sized
from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task,
TaskProgressColumn, TextColumn, TimeRemainingColumn)
@ -47,7 +47,7 @@ def _tasks_with_index(tasks):
def track_progress_rich(func: Callable,
tasks: Iterable = tuple(),
task_num: int = None,
task_num: Optional[int] = None,
nproc: int = 1,
chunksize: int = 1,
description: str = 'Processing',

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
__version__ = '0.10.1'
__version__ = '0.10.7'
def parse_version_info(version_str):

View File

@ -161,7 +161,7 @@ class BaseVisBackend(metaclass=ABCMeta):
pass
def close(self) -> None:
"""close an opened object."""
"""Close an opened object."""
pass
@ -314,7 +314,7 @@ class LocalVisBackend(BaseVisBackend):
def _dump(self, value_dict: dict, file_path: str,
file_format: str) -> None:
"""dump dict to file.
"""Dump dict to file.
Args:
value_dict (dict) : The dict data to saved.
@ -505,7 +505,7 @@ class WandbVisBackend(BaseVisBackend):
self._wandb.log(scalar_dict, commit=self._commit)
def close(self) -> None:
"""close an opened wandb object."""
"""Close an opened wandb object."""
if hasattr(self, '_wandb'):
self._wandb.join()
@ -629,7 +629,7 @@ class TensorboardVisBackend(BaseVisBackend):
self.add_scalar(key, value, step)
def close(self):
"""close an opened tensorboard object."""
"""Close an opened tensorboard object."""
if hasattr(self, '_tensorboard'):
self._tensorboard.close()
@ -669,6 +669,10 @@ class MLflowVisBackend(BaseVisBackend):
will be added to the experiment. If it is None, which means all
the config will be added. Defaults to None.
`New in version 0.7.4.`
artifact_location (str, optional): The location to store run artifacts.
If None, the server picks an appropriate default.
Defaults to None.
`New in version 0.10.4.`
"""
def __init__(self,
@ -680,7 +684,8 @@ class MLflowVisBackend(BaseVisBackend):
tracking_uri: Optional[str] = None,
artifact_suffix: SUFFIX_TYPE = ('.json', '.log', '.py',
'yaml'),
tracked_config_keys: Optional[dict] = None):
tracked_config_keys: Optional[dict] = None,
artifact_location: Optional[str] = None):
super().__init__(save_dir)
self._exp_name = exp_name
self._run_name = run_name
@ -689,6 +694,7 @@ class MLflowVisBackend(BaseVisBackend):
self._tracking_uri = tracking_uri
self._artifact_suffix = artifact_suffix
self._tracked_config_keys = tracked_config_keys
self._artifact_location = artifact_location
def _init_env(self):
"""Setup env for MLflow."""
@ -726,7 +732,8 @@ class MLflowVisBackend(BaseVisBackend):
self._exp_name = self._exp_name or 'Default'
if self._mlflow.get_experiment_by_name(self._exp_name) is None:
self._mlflow.create_experiment(self._exp_name)
self._mlflow.create_experiment(
self._exp_name, artifact_location=self._artifact_location)
self._mlflow.set_experiment(self._exp_name)
@ -1128,7 +1135,7 @@ class NeptuneVisBackend(BaseVisBackend):
self._neptune[k].append(v, step=step)
def close(self) -> None:
"""close an opened object."""
"""Close an opened object."""
if hasattr(self, '_neptune'):
self._neptune.stop()
@ -1275,7 +1282,7 @@ class DVCLiveVisBackend(BaseVisBackend):
self.add_scalar(key, value, step, **kwargs)
def close(self) -> None:
"""close an opened dvclive object."""
"""Close an opened dvclive object."""
if not hasattr(self, '_dvclive'):
return

View File

@ -356,7 +356,7 @@ class Visualizer(ManagerMixin):
@master_only
def get_backend(self, name) -> 'BaseVisBackend':
"""get vis backend by name.
"""Get vis backend by name.
Args:
name (str): The name of vis backend
@ -879,7 +879,7 @@ class Visualizer(ManagerMixin):
if binary_masks.ndim == 2:
binary_masks = binary_masks[None]
assert img.shape[:2] == binary_masks.shape[
1:], '`binary_marks` must have ' \
1:], '`binary_masks` must have ' \
'the same shape with image'
binary_mask_len = binary_masks.shape[0]
@ -961,7 +961,7 @@ class Visualizer(ManagerMixin):
if topk <= 0, tensor_chw is assert to be one or three.
Defaults to 20.
arrangement (Tuple[int, int]): The arrangement of featmap when
channel_reduction is not None and topk > 0. Defaults to (4, 5).
channel_reduction is None and topk > 0. Defaults to (4, 5).
resize_shape (tuple, optional): The shape to scale the feature map.
Defaults to None.
alpha (Union[int, List[int]]): The transparency of featmap.
@ -989,7 +989,7 @@ class Visualizer(ManagerMixin):
f'overlaid_image: {overlaid_image.shape[:2]} and '
f'featmap: {featmap.shape[1:]} are not same, '
f'the feature map will be interpolated. '
f'This may cause mismatch problems ')
f'This may cause mismatch problems !')
if resize_shape is None:
featmap = F.interpolate(
featmap[None],
@ -1145,7 +1145,7 @@ class Visualizer(ManagerMixin):
pass
def close(self) -> None:
"""close an opened object."""
"""Close an opened object."""
for vis_backend in self._vis_backends.values():
vis_backend.close()

View File

@ -1,8 +1,8 @@
docutils==0.17.1
docutils==0.18.1
myst-parser
opencv-python
-e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
sphinx==4.5.0
sphinx==6.2.1
sphinx-copybutton
sphinx-tabs
sphinx_markdown_tables

View File

@ -843,8 +843,8 @@ class TestConfig:
assert cfg_dict['item4'] == 'test'
assert '_delete_' not in cfg_dict['item1']
assert type(cfg_dict['item1']) == ConfigDict
assert type(cfg_dict['item2']) == ConfigDict
assert type(cfg_dict['item1']) is ConfigDict
assert type(cfg_dict['item2']) is ConfigDict
def _merge_intermediate_variable(self):

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.device import (get_device, is_cuda_available, is_mlu_available,
is_mps_available, is_npu_available)
is_mps_available, is_musa_available,
is_npu_available)
def test_get_device():
@ -13,5 +14,7 @@ def test_get_device():
assert device == 'mlu'
elif is_mps_available():
assert device == 'mps'
elif is_musa_available():
assert device == 'musa'
else:
assert device == 'cpu'

View File

@ -11,6 +11,7 @@ import torch
import torch.distributed as torch_dist
import mmengine.dist as dist
from mmengine.device import is_musa_available
from mmengine.dist.dist import sync_random_seed
from mmengine.testing._internal import MultiProcessTestCase
from mmengine.utils import digit_version
@ -117,6 +118,7 @@ class TestDist(TestCase):
self.assertTrue(torch.allclose(item1, item2))
@unittest.skipIf(is_musa_available(), reason='musa do not support gloo yet')
class TestDistWithGLOOBackend(MultiProcessTestCase):
def _init_dist_env(self, rank, world_size):

View File

@ -300,8 +300,8 @@ except ImportError:
get_inputs.append(filepath)
with build_temporary_directory() as tmp_dir, \
patch.object(backend, 'put', side_effect=put),\
patch.object(backend, 'get', side_effect=get),\
patch.object(backend, 'put', side_effect=put), \
patch.object(backend, 'get', side_effect=get), \
patch.object(backend, 'exists', return_value=False):
tmp_dir = tmp_dir.replace('\\', '/')
dst = f'{tmp_dir}/dir'
@ -351,7 +351,7 @@ except ImportError:
with build_temporary_directory() as tmp_dir, \
patch.object(backend, 'copyfile_from_local',
side_effect=copyfile_from_local),\
side_effect=copyfile_from_local), \
patch.object(backend, 'exists', return_value=False):
backend.copytree_from_local(tmp_dir, self.petrel_dir)
@ -427,7 +427,7 @@ except ImportError:
def remove(filepath):
inputs.append(filepath)
with build_temporary_directory() as tmp_dir,\
with build_temporary_directory() as tmp_dir, \
patch.object(backend, 'remove', side_effect=remove):
backend.rmtree(tmp_dir)

View File

@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
import unittest
import torch
import torch.nn as nn
from mmengine.config import ConfigDict
from mmengine.device import is_musa_available
from mmengine.hooks import EMAHook
from mmengine.model import BaseModel, ExponentialMovingAverage
from mmengine.registry import MODELS
@ -45,6 +47,9 @@ class ToyModel3(ToyModel):
return super().forward(*args, **kwargs)
# TODO:haowen.han@mtheads.com
@unittest.skipIf(is_musa_available(),
"musa backend do not support 'aten::lerp.Scalar_out'")
class TestEMAHook(RunnerTestCase):
def setUp(self) -> None:

View File

@ -1,11 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import patch
import pytest
from mmengine.device import is_cuda_available
from mmengine.testing import RunnerTestCase
class TestEmptyCacheHook(RunnerTestCase):
@pytest.mark.skipif(
not is_cuda_available(), reason='cuda should be available')
def test_with_runner(self):
with patch('torch.cuda.empty_cache') as mock_empty_cache:
cfg = self.epoch_based_cfg

View File

@ -13,7 +13,8 @@ from mmengine.testing import assert_allclose
class TestAveragedModel(TestCase):
"""Test the AveragedModel class.
Some test cases are referenced from https://github.com/pytorch/pytorch/blob/master/test/test_optim.py
Some test cases are referenced from
https://github.com/pytorch/pytorch/blob/master/test/test_optim.py
""" # noqa: E501
def _test_swa_model(self, net_device, avg_device):

View File

@ -549,7 +549,8 @@ class TestBuilder(TestCase):
weight_decay=self.base_wd,
momentum=self.momentum))
paramwise_cfg = dict()
optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
optim_wrapper = optim_constructor(model)
self._check_default_optimizer(optim_wrapper.optimizer, model)
@ -591,23 +592,16 @@ class TestBuilder(TestCase):
dwconv_decay_mult=0.1,
dcn_offset_lr_mult=0.1)
for param in self.model.parameters():
param.requires_grad = False
self.model.conv1.requires_grad_(False)
optim_constructor = DefaultOptimWrapperConstructor(
optim_wrapper_cfg, paramwise_cfg)
optim_wrapper = optim_constructor(self.model)
optimizer = optim_wrapper.optimizer
param_groups = optimizer.param_groups
assert isinstance(optim_wrapper.optimizer, torch.optim.SGD)
assert optimizer.defaults['lr'] == self.base_lr
assert optimizer.defaults['momentum'] == self.momentum
assert optimizer.defaults['weight_decay'] == self.base_wd
for i, (name, param) in enumerate(self.model.named_parameters()):
param_group = param_groups[i]
assert torch.equal(param_group['params'][0], param)
assert param_group['momentum'] == self.momentum
assert param_group['lr'] == self.base_lr
assert param_group['weight_decay'] == self.base_wd
all_params = []
for pg in optim_wrapper.param_groups:
all_params.extend(map(id, pg['params']))
self.assertNotIn(id(self.model.conv1.weight), all_params)
self.assertIn(id(self.model.conv2.weight), all_params)
def test_default_optimizer_constructor_bypass_duplicate(self):
# paramwise_cfg with bypass_duplicate option
@ -663,10 +657,8 @@ class TestBuilder(TestCase):
optim_wrapper = optim_constructor(model)
model_parameters = list(model.parameters())
num_params = 14 if MMCV_FULL_AVAILABLE else 11
assert len(optim_wrapper.optimizer.param_groups) == len(
model_parameters) == num_params
self._check_sgd_optimizer(optim_wrapper.optimizer, model,
**paramwise_cfg)
assert len(optim_wrapper.optimizer.param_groups
) == len(model_parameters) - 1 == num_params - 1
def test_default_optimizer_constructor_custom_key(self):
# test DefaultOptimWrapperConstructor with custom_keys and

View File

@ -102,7 +102,7 @@ class TestLRScheduler(TestCase):
rtol=0)
def test_scheduler_before_optim_warning(self):
"""warns if scheduler is used before optimizer."""
"""Warns if scheduler is used before optimizer."""
def call_sch_before_optim():
scheduler = StepLR(self.optimizer, gamma=0.1, step_size=3)

View File

@ -120,7 +120,7 @@ class TestMomentumScheduler(TestCase):
rtol=0)
def test_scheduler_before_optim_warning(self):
"""warns if scheduler is used before optimizer."""
"""Warns if scheduler is used before optimizer."""
def call_sch_before_optim():
scheduler = StepMomentum(self.optimizer, gamma=0.1, step_size=3)

View File

@ -127,7 +127,7 @@ class TestParameterScheduler(TestCase):
rtol=0)
def test_scheduler_before_optim_warning(self):
"""warns if scheduler is used before optimizer."""
"""Warns if scheduler is used before optimizer."""
def call_sch_before_optim():
scheduler = StepParamScheduler(

View File

@ -5,7 +5,8 @@ import torch
import torch.nn as nn
import mmengine
from mmengine.device import get_device, is_mlu_available, is_npu_available
from mmengine.device import (get_device, is_mlu_available, is_musa_available,
is_npu_available)
from mmengine.runner import autocast
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
@ -44,6 +45,21 @@ class TestAmp(unittest.TestCase):
layer = nn.Conv2d(1, 1, 1).to(device)
res = layer(torch.randn(1, 1, 1, 1).to(device))
self.assertEqual(res.dtype, torch.float32)
elif is_musa_available():
device = 'musa'
with autocast(device_type=device):
# torch.autocast support mlu mode.
layer = nn.Conv2d(1, 1, 1).to(device)
res = layer(torch.randn(1, 1, 1, 1).to(device))
self.assertIn(res.dtype, (torch.bfloat16, torch.float16))
with autocast(enabled=False, device_type=device):
res = layer(torch.randn(1, 1, 1, 1).to(device))
self.assertEqual(res.dtype, torch.float32)
# Test with fp32_enabled
with autocast(enabled=False, device_type=device):
layer = nn.Conv2d(1, 1, 1).to(device)
res = layer(torch.randn(1, 1, 1, 1).to(device))
self.assertEqual(res.dtype, torch.float32)
elif not torch.cuda.is_available():
if digit_version(TORCH_VERSION) < digit_version('1.10.0'):
# `torch.cuda.amp.autocast` is only support in gpu mode, if

View File

@ -251,7 +251,7 @@ def test_load_checkpoint_metadata():
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
*args, **kwargs):
"""load checkpoints."""
"""Load checkpoints."""
# Names of some parameters in has been changed.
version = local_metadata.get('version', None)

View File

@ -7,6 +7,7 @@ import pytest
import torch
from parameterized import parameterized
from mmengine.device import is_cuda_available, is_musa_available
from mmengine.logging import HistoryBuffer, MessageHub, MMLogger
from mmengine.runner import LogProcessor
from mmengine.testing import RunnerTestCase
@ -113,7 +114,7 @@ class TestLogProcessor(RunnerTestCase):
f"time: {train_logs['time']:.4f} "
f"data_time: {train_logs['data_time']:.4f} ")
if torch.cuda.is_available():
if is_cuda_available() or is_musa_available():
log_str += 'memory: 100 '
if mode == 'train':
log_str += f"loss_cls: {train_logs['loss_cls']:.4f}"
@ -141,7 +142,7 @@ class TestLogProcessor(RunnerTestCase):
f"time: {train_logs['time']:.4f} "
f"data_time: {train_logs['data_time']:.4f} ")
if torch.cuda.is_available():
if is_cuda_available() or is_musa_available():
log_str += 'memory: 100 '
if mode == 'train':
@ -249,6 +250,7 @@ class TestLogProcessor(RunnerTestCase):
assert tag['metric1'] is metric1
assert tag['metric2'] is metric2
# TODO:haowen.han@mtheads.com MUSA does not support it yet!
@patch('torch.cuda.max_memory_allocated', MagicMock())
@patch('torch.cuda.reset_peak_memory_stats', MagicMock())
def test_get_max_memory(self):

View File

@ -2226,7 +2226,7 @@ class TestRunner(TestCase):
@HOOKS.register_module(force=True)
class TestWarmupHook(Hook):
"""test custom train loop."""
"""Test custom train loop."""
def before_warmup_iter(self, runner, data_batch=None):
before_warmup_iter_results.append('before')

View File

@ -64,7 +64,7 @@ class TestBaseDataElement(TestCase):
return metainfo, data
def is_equal(self, x, y):
assert type(x) == type(y)
assert type(x) is type(y)
if isinstance(
x, (int, float, str, list, tuple, dict, set, BaseDataElement)):
return x == y
@ -141,7 +141,7 @@ class TestBaseDataElement(TestCase):
# test new() with no arguments
new_instances = instances.new()
assert type(new_instances) == type(instances)
assert type(new_instances) is type(instances)
# After deepcopy, the address of new data'element will be same as
# origin, but when change new data' element will not effect the origin
# element and will have new address
@ -154,7 +154,7 @@ class TestBaseDataElement(TestCase):
# test new() with arguments
metainfo, data = self.setup_data()
new_instances = instances.new(metainfo=metainfo, **data)
assert type(new_instances) == type(instances)
assert type(new_instances) is type(instances)
assert id(new_instances.gt_instances) != id(instances.gt_instances)
_, new_data = self.setup_data()
new_instances.set_data(new_data)
@ -168,7 +168,7 @@ class TestBaseDataElement(TestCase):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, **data)
new_instances = instances.clone()
assert type(new_instances) == type(instances)
assert type(new_instances) is type(instances)
def test_set_metainfo(self):
metainfo, _ = self.setup_data()

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import warnings
import torch
from mmengine.utils.dl_utils import torch_meshgrid
@ -7,9 +8,8 @@ from mmengine.utils.dl_utils import torch_meshgrid
def test_torch_meshgrid():
# torch_meshgrid should not throw warning
with pytest.warns(None) as record:
with warnings.catch_warnings():
warnings.simplefilter('error')
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
grid_x, grid_y = torch_meshgrid(x, y)
assert len(record) == 0

View File

@ -282,6 +282,14 @@ class TestMLflowVisBackend:
mlflow_vis_backend = MLflowVisBackend('temp_dir')
assert mlflow_vis_backend.experiment == mlflow_vis_backend._mlflow
def test_create_experiment(self):
with patch('mlflow.create_experiment') as mock_create_experiment:
MLflowVisBackend(
'temp_dir', exp_name='test',
artifact_location='foo')._init_env()
mock_create_experiment.assert_any_call(
'test', artifact_location='foo')
def test_add_config(self):
cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
mlflow_vis_backend = MLflowVisBackend('temp_dir')

Some files were not shown because too many files have changed in this diff Show More