Merge branch 'dev'

pull/1249/head v0.25.0
mzr1996 2022-12-06 18:25:47 +08:00
commit 2495400a98
17 changed files with 157 additions and 81 deletions

View File

@ -29,9 +29,9 @@ repos:
rev: 0.7.9 rev: 0.7.9
hooks: hooks:
- id: mdformat - id: mdformat
args: ["--number", "--table-width", "200"] args: ["--number", "--table-width", "200", '--disable-escape', 'backslash', '--disable-escape', 'link-enclosure']
additional_dependencies: additional_dependencies:
- mdformat-openmmlab - "mdformat-openmmlab>=0.0.4"
- mdformat_frontmatter - mdformat_frontmatter
- linkify-it-py - linkify-it-py
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell

View File

@ -64,6 +64,12 @@ The MMClassification 1.0 has released! It's still unstable and in release candid
to [the 1.x branch](https://github.com/open-mmlab/mmclassification/tree/1.x) and discuss it with us in to [the 1.x branch](https://github.com/open-mmlab/mmclassification/tree/1.x) and discuss it with us in
[the discussion](https://github.com/open-mmlab/mmclassification/discussions). [the discussion](https://github.com/open-mmlab/mmclassification/discussions).
v0.25.0 was released in 06/12/2022.
Highlights of the new version:
- Support MLU backend.
- Add `dist_train_arm.sh` for ARM device.
v0.24.1 was released in 31/10/2022. v0.24.1 was released in 31/10/2022.
Highlights of the new version: Highlights of the new version:
@ -75,13 +81,6 @@ Highlights of the new version:
- Support **HorNet**, **EfficientFormerm**, **SwinTransformer V2** and **MViT** backbones. - Support **HorNet**, **EfficientFormerm**, **SwinTransformer V2** and **MViT** backbones.
- Support Standford Cars dataset. - Support Standford Cars dataset.
v0.23.0 was released in 1/5/2022.
Highlights of the new version:
- Support **DenseNet**, **VAN** and **PoolFormer**, and provide pre-trained models.
- Support training on IPU.
- New style API docs, welcome [view it](https://mmclassification.readthedocs.io/en/master/api/models.html).
Please refer to [changelog.md](docs/en/changelog.md) for more details and other release history. Please refer to [changelog.md](docs/en/changelog.md) for more details and other release history.
## Installation ## Installation

View File

@ -63,6 +63,11 @@ MMClassification 是一款基于 PyTorch 的开源图像分类工具箱,是 [O
MMClassification 1.0 已经发布!目前仍在公测中,如果希望试用,请切换到 [1.x 分支](https://github.com/open-mmlab/mmclassification/tree/1.x),并在[讨论版](https://github.com/open-mmlab/mmclassification/discussions) 参加开发讨论! MMClassification 1.0 已经发布!目前仍在公测中,如果希望试用,请切换到 [1.x 分支](https://github.com/open-mmlab/mmclassification/tree/1.x),并在[讨论版](https://github.com/open-mmlab/mmclassification/discussions) 参加开发讨论!
2022/12/06 发布了 v0.25.0 版本
- 支持 MLU 设备
- 添加了用于 ARM 设备训练的 `dist_train_arm.sh`
2022/10/31 发布了 v0.24.1 版本 2022/10/31 发布了 v0.24.1 版本
- 支持了华为昇腾 NPU 设备。 - 支持了华为昇腾 NPU 设备。

View File

@ -6,7 +6,7 @@
## Abstract ## Abstract
Transformers, which are popular for language modeling, have been explored for solving vision tasks recently, \\eg, the Vision Transformer (ViT) for image classification. The ViT model splits each image into a sequence of tokens with fixed length and then applies multiple Transformer layers to model their global relation for classification. However, ViT achieves inferior performance to CNNs when trained from scratch on a midsize dataset like ImageNet. We find it is because: 1) the simple tokenization of input images fails to model the important local structure such as edges and lines among neighboring pixels, leading to low training sample efficiency; 2) the redundant attention backbone design of ViT leads to limited feature richness for fixed computation budgets and limited training samples. To overcome such limitations, we propose a new Tokens-To-Token Vision Transformer (T2T-ViT), which incorporates 1) a layer-wise Tokens-to-Token (T2T) transformation to progressively structurize the image to tokens by recursively aggregating neighboring Tokens into one Token (Tokens-to-Token), such that local structure represented by surrounding tokens can be modeled and tokens length can be reduced; 2) an efficient backbone with a deep-narrow structure for vision transformer motivated by CNN architecture design after empirical study. Notably, T2T-ViT reduces the parameter count and MACs of vanilla ViT by half, while achieving more than 3.0% improvement when trained from scratch on ImageNet. It also outperforms ResNets and achieves comparable performance with MobileNets by directly training on ImageNet. For example, T2T-ViT with comparable size to ResNet50 (21.5M parameters) can achieve 83.3% top1 accuracy in image resolution 384×384 on ImageNet. Transformers, which are popular for language modeling, have been explored for solving vision tasks recently, e.g., the Vision Transformer (ViT) for image classification. The ViT model splits each image into a sequence of tokens with fixed length and then applies multiple Transformer layers to model their global relation for classification. However, ViT achieves inferior performance to CNNs when trained from scratch on a midsize dataset like ImageNet. We find it is because: 1) the simple tokenization of input images fails to model the important local structure such as edges and lines among neighboring pixels, leading to low training sample efficiency; 2) the redundant attention backbone design of ViT leads to limited feature richness for fixed computation budgets and limited training samples. To overcome such limitations, we propose a new Tokens-To-Token Vision Transformer (T2T-ViT), which incorporates 1) a layer-wise Tokens-to-Token (T2T) transformation to progressively structurize the image to tokens by recursively aggregating neighboring Tokens into one Token (Tokens-to-Token), such that local structure represented by surrounding tokens can be modeled and tokens length can be reduced; 2) an efficient backbone with a deep-narrow structure for vision transformer motivated by CNN architecture design after empirical study. Notably, T2T-ViT reduces the parameter count and MACs of vanilla ViT by half, while achieving more than 3.0% improvement when trained from scratch on ImageNet. It also outperforms ResNets and achieves comparable performance with MobileNets by directly training on ImageNet. For example, T2T-ViT with comparable size to ResNet50 (21.5M parameters) can achieve 83.3% top1 accuracy in image resolution 384×384 on ImageNet.
<div align=center> <div align=center>
<img src="https://user-images.githubusercontent.com/26739999/142578381-e9040610-05d9-457c-8bf5-01c2fa94add2.png" width="60%"/> <img src="https://user-images.githubusercontent.com/26739999/142578381-e9040610-05d9-457c-8bf5-01c2fa94add2.png" width="60%"/>

View File

@ -4,7 +4,7 @@ ARG CUDNN="7"
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
ARG MMCV="1.7.0" ARG MMCV="1.7.0"
ARG MMCLS="0.24.1" ARG MMCLS="0.25.0"
ENV PYTHONUNBUFFERED TRUE ENV PYTHONUNBUFFERED TRUE

View File

@ -1,5 +1,33 @@
# Changelog # Changelog
## v0.25.0(06/12/2022)
### Highlights
- Support MLU backend.
### New Features
- Support MLU backend. ([#1159](https://github.com/open-mmlab/mmclassification/pull/1159))
- Support Activation Checkpointing for ConvNeXt. ([#1152](https://github.com/open-mmlab/mmclassification/pull/1152))
### Improvements
- Add `dist_train_arm.sh` for ARM device and update NPU results. ([#1218](https://github.com/open-mmlab/mmclassification/pull/1218))
### Bug Fixes
- Fix a bug caused `MMClsWandbHook` stuck. ([#1242](https://github.com/open-mmlab/mmclassification/pull/1242))
- Fix the redundant `device_ids` in `tools/test.py`. ([#1215](https://github.com/open-mmlab/mmclassification/pull/1215))
### Docs Update
- Add version banner and version warning in master docs. ([#1216](https://github.com/open-mmlab/mmclassification/pull/1216))
- Update NPU support doc. ([#1198](https://github.com/open-mmlab/mmclassification/pull/1198))
- Fixed typo in `pytorch2torchscript.md`. ([#1173](https://github.com/open-mmlab/mmclassification/pull/1173))
- Fix typo in `miscellaneous.md`. ([#1137](https://github.com/open-mmlab/mmclassification/pull/1137))
- further detail for the doc for `ClassBalancedDataset`. ([#901](https://github.com/open-mmlab/mmclassification/pull/901))
## v0.24.1(31/10/2022) ## v0.24.1(31/10/2022)
### New Features ### New Features
@ -28,14 +56,14 @@
### Improvements ### Improvements
- \[Improve\] replace loop of progressbar in api/test. ([#878](https://github.com/open-mmlab/mmclassification/pull/878)) - [Improve] replace loop of progressbar in api/test. ([#878](https://github.com/open-mmlab/mmclassification/pull/878))
- \[Enhance\] RepVGG for YOLOX-PAI. ([#1025](https://github.com/open-mmlab/mmclassification/pull/1025)) - [Enhance] RepVGG for YOLOX-PAI. ([#1025](https://github.com/open-mmlab/mmclassification/pull/1025))
- \[Enhancement\] Update VAN. ([#1017](https://github.com/open-mmlab/mmclassification/pull/1017)) - [Enhancement] Update VAN. ([#1017](https://github.com/open-mmlab/mmclassification/pull/1017))
- \[Refactor\] Re-write `get_sinusoid_encoding` from third-party implementation. ([#965](https://github.com/open-mmlab/mmclassification/pull/965)) - [Refactor] Re-write `get_sinusoid_encoding` from third-party implementation. ([#965](https://github.com/open-mmlab/mmclassification/pull/965))
- \[Improve\] Upgrade onnxsim to v0.4.0. ([#915](https://github.com/open-mmlab/mmclassification/pull/915)) - [Improve] Upgrade onnxsim to v0.4.0. ([#915](https://github.com/open-mmlab/mmclassification/pull/915))
- \[Improve\] Fixed typo in `RepVGG`. ([#985](https://github.com/open-mmlab/mmclassification/pull/985)) - [Improve] Fixed typo in `RepVGG`. ([#985](https://github.com/open-mmlab/mmclassification/pull/985))
- \[Improve\] Using `train_step` instead of `forward` in PreciseBNHook ([#964](https://github.com/open-mmlab/mmclassification/pull/964)) - [Improve] Using `train_step` instead of `forward` in PreciseBNHook ([#964](https://github.com/open-mmlab/mmclassification/pull/964))
- \[Improve\] Use `forward_dummy` to calculate FLOPS. ([#953](https://github.com/open-mmlab/mmclassification/pull/953)) - [Improve] Use `forward_dummy` to calculate FLOPS. ([#953](https://github.com/open-mmlab/mmclassification/pull/953))
### Bug Fixes ### Bug Fixes
@ -102,13 +130,13 @@
### New Features ### New Features
- \[Feature\] Support resize relative position embedding in `SwinTransformer`. ([#749](https://github.com/open-mmlab/mmclassification/pull/749)) - [Feature] Support resize relative position embedding in `SwinTransformer`. ([#749](https://github.com/open-mmlab/mmclassification/pull/749))
- \[Feature\] Add PoolFormer backbone and checkpoints. ([#746](https://github.com/open-mmlab/mmclassification/pull/746)) - [Feature] Add PoolFormer backbone and checkpoints. ([#746](https://github.com/open-mmlab/mmclassification/pull/746))
### Improvements ### Improvements
- \[Enhance\] Improve CPE performance by reduce memory copy. ([#762](https://github.com/open-mmlab/mmclassification/pull/762)) - [Enhance] Improve CPE performance by reduce memory copy. ([#762](https://github.com/open-mmlab/mmclassification/pull/762))
- \[Enhance\] Add extra dataloader settings in configs. ([#752](https://github.com/open-mmlab/mmclassification/pull/752)) - [Enhance] Add extra dataloader settings in configs. ([#752](https://github.com/open-mmlab/mmclassification/pull/752))
## v0.22.0(30/3/2022) ## v0.22.0(30/3/2022)
@ -120,29 +148,29 @@
### New Features ### New Features
- \[Feature\] Add CSPNet and backbone and checkpoints ([#735](https://github.com/open-mmlab/mmclassification/pull/735)) - [Feature] Add CSPNet and backbone and checkpoints ([#735](https://github.com/open-mmlab/mmclassification/pull/735))
- \[Feature\] Add `CustomDataset`. ([#738](https://github.com/open-mmlab/mmclassification/pull/738)) - [Feature] Add `CustomDataset`. ([#738](https://github.com/open-mmlab/mmclassification/pull/738))
- \[Feature\] Add diff seeds to diff ranks. ([#744](https://github.com/open-mmlab/mmclassification/pull/744)) - [Feature] Add diff seeds to diff ranks. ([#744](https://github.com/open-mmlab/mmclassification/pull/744))
- \[Feature\] Support ConvMixer. ([#716](https://github.com/open-mmlab/mmclassification/pull/716)) - [Feature] Support ConvMixer. ([#716](https://github.com/open-mmlab/mmclassification/pull/716))
- \[Feature\] Our `dist_train` & `dist_test` tools support distributed training on multiple machines. ([#734](https://github.com/open-mmlab/mmclassification/pull/734)) - [Feature] Our `dist_train` & `dist_test` tools support distributed training on multiple machines. ([#734](https://github.com/open-mmlab/mmclassification/pull/734))
- \[Feature\] Add RepMLP backbone and checkpoints. ([#709](https://github.com/open-mmlab/mmclassification/pull/709)) - [Feature] Add RepMLP backbone and checkpoints. ([#709](https://github.com/open-mmlab/mmclassification/pull/709))
- \[Feature\] Support CUB dataset. ([#703](https://github.com/open-mmlab/mmclassification/pull/703)) - [Feature] Support CUB dataset. ([#703](https://github.com/open-mmlab/mmclassification/pull/703))
- \[Feature\] Support ResizeMix. ([#676](https://github.com/open-mmlab/mmclassification/pull/676)) - [Feature] Support ResizeMix. ([#676](https://github.com/open-mmlab/mmclassification/pull/676))
### Improvements ### Improvements
- \[Enhance\] Use `--a-b` instead of `--a_b` in arguments. ([#754](https://github.com/open-mmlab/mmclassification/pull/754)) - [Enhance] Use `--a-b` instead of `--a_b` in arguments. ([#754](https://github.com/open-mmlab/mmclassification/pull/754))
- \[Enhance\] Add `get_cat_ids` and `get_gt_labels` to KFoldDataset. ([#721](https://github.com/open-mmlab/mmclassification/pull/721)) - [Enhance] Add `get_cat_ids` and `get_gt_labels` to KFoldDataset. ([#721](https://github.com/open-mmlab/mmclassification/pull/721))
- \[Enhance\] Set torch seed in `worker_init_fn`. ([#733](https://github.com/open-mmlab/mmclassification/pull/733)) - [Enhance] Set torch seed in `worker_init_fn`. ([#733](https://github.com/open-mmlab/mmclassification/pull/733))
### Bug Fixes ### Bug Fixes
- \[Fix\] Fix the discontiguous output feature map of ConvNeXt. ([#743](https://github.com/open-mmlab/mmclassification/pull/743)) - [Fix] Fix the discontiguous output feature map of ConvNeXt. ([#743](https://github.com/open-mmlab/mmclassification/pull/743))
### Docs Update ### Docs Update
- \[Docs\] Add brief installation steps in README for copy&paste. ([#755](https://github.com/open-mmlab/mmclassification/pull/755)) - [Docs] Add brief installation steps in README for copy&paste. ([#755](https://github.com/open-mmlab/mmclassification/pull/755))
- \[Docs\] fix logo url link from mmocr to mmcls. ([#732](https://github.com/open-mmlab/mmclassification/pull/732)) - [Docs] fix logo url link from mmocr to mmcls. ([#732](https://github.com/open-mmlab/mmclassification/pull/732))
## v0.21.0(04/03/2022) ## v0.21.0(04/03/2022)
@ -245,18 +273,18 @@
### Improvements ### Improvements
- \[Reproduction\] Reproduce RegNetX training accuracy. ([#587](https://github.com/open-mmlab/mmclassification/pull/587)) - [Reproduction] Reproduce RegNetX training accuracy. ([#587](https://github.com/open-mmlab/mmclassification/pull/587))
- \[Reproduction\] Reproduce training results of T2T-ViT. ([#610](https://github.com/open-mmlab/mmclassification/pull/610)) - [Reproduction] Reproduce training results of T2T-ViT. ([#610](https://github.com/open-mmlab/mmclassification/pull/610))
- \[Enhance\] Provide high-acc training settings of ResNet. ([#572](https://github.com/open-mmlab/mmclassification/pull/572)) - [Enhance] Provide high-acc training settings of ResNet. ([#572](https://github.com/open-mmlab/mmclassification/pull/572))
- \[Enhance\] Set a random seed when the user does not set a seed. ([#554](https://github.com/open-mmlab/mmclassification/pull/554)) - [Enhance] Set a random seed when the user does not set a seed. ([#554](https://github.com/open-mmlab/mmclassification/pull/554))
- \[Enhance\] Added `NumClassCheckHook` and unit tests. ([#559](https://github.com/open-mmlab/mmclassification/pull/559)) - [Enhance] Added `NumClassCheckHook` and unit tests. ([#559](https://github.com/open-mmlab/mmclassification/pull/559))
- \[Enhance\] Enhance feature extraction function. ([#593](https://github.com/open-mmlab/mmclassification/pull/593)) - [Enhance] Enhance feature extraction function. ([#593](https://github.com/open-mmlab/mmclassification/pull/593))
- \[Enhance\] Improve efficiency of precision, recall, f1_score and support. ([#595](https://github.com/open-mmlab/mmclassification/pull/595)) - [Enhance] Improve efficiency of precision, recall, f1_score and support. ([#595](https://github.com/open-mmlab/mmclassification/pull/595))
- \[Enhance\] Improve accuracy calculation performance. ([#592](https://github.com/open-mmlab/mmclassification/pull/592)) - [Enhance] Improve accuracy calculation performance. ([#592](https://github.com/open-mmlab/mmclassification/pull/592))
- \[Refactor\] Refactor `analysis_log.py`. ([#529](https://github.com/open-mmlab/mmclassification/pull/529)) - [Refactor] Refactor `analysis_log.py`. ([#529](https://github.com/open-mmlab/mmclassification/pull/529))
- \[Refactor\] Use new API of matplotlib to handle blocking input in visualization. ([#568](https://github.com/open-mmlab/mmclassification/pull/568)) - [Refactor] Use new API of matplotlib to handle blocking input in visualization. ([#568](https://github.com/open-mmlab/mmclassification/pull/568))
- \[CI\] Cancel previous runs that are not completed. ([#583](https://github.com/open-mmlab/mmclassification/pull/583)) - [CI] Cancel previous runs that are not completed. ([#583](https://github.com/open-mmlab/mmclassification/pull/583))
- \[CI\] Skip build CI if only configs or docs modification. ([#575](https://github.com/open-mmlab/mmclassification/pull/575)) - [CI] Skip build CI if only configs or docs modification. ([#575](https://github.com/open-mmlab/mmclassification/pull/575))
### Bug Fixes ### Bug Fixes

View File

@ -18,7 +18,8 @@ and make sure you fill in all required information in the template.
| MMClassification version | MMCV version | | MMClassification version | MMCV version |
| :----------------------: | :--------------------: | | :----------------------: | :--------------------: |
| dev | mmcv>=1.7.0, \<1.9.0 | | dev | mmcv>=1.7.0, \<1.9.0 |
| 0.24.1 (master) | mmcv>=1.4.2, \<1.9.0 | | 0.25.0 (master) | mmcv>=1.4.2, \<1.9.0 |
| 0.24.1 | mmcv>=1.4.2, \<1.9.0 |
| 0.23.2 | mmcv>=1.4.2, \<1.7.0 | | 0.23.2 | mmcv>=1.4.2, \<1.7.0 |
| 0.22.1 | mmcv>=1.4.2, \<1.6.0 | | 0.22.1 | mmcv>=1.4.2, \<1.6.0 |
| 0.21.0 | mmcv>=1.4.2, \<=1.5.0 | | 0.21.0 | mmcv>=1.4.2, \<=1.5.0 |

View File

@ -16,7 +16,8 @@
| MMClassification version | MMCV version | | MMClassification version | MMCV version |
| :----------------------: | :--------------------: | | :----------------------: | :--------------------: |
| dev | mmcv>=1.7.0, \<1.9.0 | | dev | mmcv>=1.7.0, \<1.9.0 |
| 0.24.1 (master) | mmcv>=1.4.2, \<1.9.0 | | 0.25.0 (master) | mmcv>=1.4.2, \<1.9.0 |
| 0.24.1 | mmcv>=1.4.2, \<1.9.0 |
| 0.23.2 | mmcv>=1.4.2, \<1.7.0 | | 0.23.2 | mmcv>=1.4.2, \<1.7.0 |
| 0.22.1 | mmcv>=1.4.2, \<1.6.0 | | 0.22.1 | mmcv>=1.4.2, \<1.6.0 |
| 0.21.0 | mmcv>=1.4.2, \<=1.5.0 | | 0.21.0 | mmcv>=1.4.2, \<=1.5.0 |

View File

@ -10,11 +10,11 @@ from mmcv.runner import (DistSamplerSeedHook, Fp16OptimizerHook,
from mmcls.core import DistEvalHook, DistOptimizerHook, EvalHook from mmcls.core import DistEvalHook, DistOptimizerHook, EvalHook
from mmcls.datasets import build_dataloader, build_dataset from mmcls.datasets import build_dataloader, build_dataset
from mmcls.utils import (get_root_logger, wrap_distributed_model, from mmcls.utils import (auto_select_device, get_root_logger,
wrap_non_distributed_model) wrap_distributed_model, wrap_non_distributed_model)
def init_random_seed(seed=None, device='cuda'): def init_random_seed(seed=None, device=None):
"""Initialize random seed. """Initialize random seed.
If the seed is not set, the seed will be automatically randomized, If the seed is not set, the seed will be automatically randomized,
@ -30,7 +30,8 @@ def init_random_seed(seed=None, device='cuda'):
""" """
if seed is not None: if seed is not None:
return seed return seed
if device is None:
device = auto_select_device()
# Make sure all ranks share the same random seed to prevent # Make sure all ranks share the same random seed to prevent
# some potential bugs. Please refer to # some potential bugs. Please refer to
# https://github.com/open-mmlab/mmdetection/issues/6339 # https://github.com/open-mmlab/mmdetection/issues/6339

View File

@ -3,7 +3,7 @@ import os.path as osp
import numpy as np import numpy as np
from mmcv.runner import HOOKS, BaseRunner from mmcv.runner import HOOKS, BaseRunner
from mmcv.runner.dist_utils import master_only from mmcv.runner.dist_utils import get_dist_info, master_only
from mmcv.runner.hooks.checkpoint import CheckpointHook from mmcv.runner.hooks.checkpoint import CheckpointHook
from mmcv.runner.hooks.evaluation import DistEvalHook, EvalHook from mmcv.runner.hooks.evaluation import DistEvalHook, EvalHook
from mmcv.runner.hooks.logger.wandb import WandbLoggerHook from mmcv.runner.hooks.logger.wandb import WandbLoggerHook
@ -190,7 +190,6 @@ class MMClsWandbHook(WandbLoggerHook):
# Log the evaluation table # Log the evaluation table
self._log_eval_table(runner.epoch + 1) self._log_eval_table(runner.epoch + 1)
@master_only
def after_train_iter(self, runner): def after_train_iter(self, runner):
if self.get_mode(runner) == 'train': if self.get_mode(runner) == 'train':
# An ugly patch. The iter-based eval hook will call the # An ugly patch. The iter-based eval hook will call the
@ -201,6 +200,10 @@ class MMClsWandbHook(WandbLoggerHook):
else: else:
super(MMClsWandbHook, self).after_train_iter(runner) super(MMClsWandbHook, self).after_train_iter(runner)
rank, _ = get_dist_info()
if rank != 0:
return
if self.by_epoch: if self.by_epoch:
return return

View File

@ -8,6 +8,8 @@ from mmcv.runner import OptimizerHook, get_dist_info
from torch._utils import (_flatten_dense_tensors, _take_tensors, from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors) _unflatten_dense_tensors)
from mmcls.utils import auto_select_device
def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
if bucket_size_mb > 0: if bucket_size_mb > 0:
@ -59,7 +61,7 @@ class DistOptimizerHook(OptimizerHook):
runner.optimizer.step() runner.optimizer.step()
def sync_random_seed(seed=None, device='cuda'): def sync_random_seed(seed=None, device=None):
"""Make sure different ranks share the same seed. """Make sure different ranks share the same seed.
All workers must call this function, otherwise it will deadlock. All workers must call this function, otherwise it will deadlock.
@ -81,6 +83,8 @@ def sync_random_seed(seed=None, device='cuda'):
Returns: Returns:
int: Seed to be used. int: Seed to be used.
""" """
if device is None:
device = auto_select_device()
if seed is None: if seed is None:
seed = np.random.randint(2**31) seed = np.random.randint(2**31)
assert isinstance(seed, int) assert isinstance(seed, int)

View File

@ -4,7 +4,6 @@ from torch.utils.data import DistributedSampler as _DistributedSampler
from mmcls.core.utils import sync_random_seed from mmcls.core.utils import sync_random_seed
from mmcls.datasets import SAMPLERS from mmcls.datasets import SAMPLERS
from mmcls.utils import auto_select_device
@SAMPLERS.register_module() @SAMPLERS.register_module()
@ -31,7 +30,7 @@ class DistributedSampler(_DistributedSampler):
# in the same order based on the same seed. Then different ranks # in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the # could use different indices to select non-overlapped data from the
# same data list. # same data list.
self.seed = sync_random_seed(seed, device=auto_select_device()) self.seed = sync_random_seed(seed)
def __iter__(self): def __iter__(self):
# deterministically shuffle based on epoch # deterministically shuffle based on epoch

View File

@ -6,6 +6,7 @@ from typing import Sequence
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn.bricks import (NORM_LAYERS, DropPath, build_activation_layer, from mmcv.cnn.bricks import (NORM_LAYERS, DropPath, build_activation_layer,
build_norm_layer) build_norm_layer)
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
@ -77,8 +78,11 @@ class ConvNeXtBlock(BaseModule):
mlp_ratio=4., mlp_ratio=4.,
linear_pw_conv=True, linear_pw_conv=True,
drop_path_rate=0., drop_path_rate=0.,
layer_scale_init_value=1e-6): layer_scale_init_value=1e-6,
with_cp=False):
super().__init__() super().__init__()
self.with_cp = with_cp
self.depthwise_conv = nn.Conv2d( self.depthwise_conv = nn.Conv2d(
in_channels, in_channels,
in_channels, in_channels,
@ -108,24 +112,33 @@ class ConvNeXtBlock(BaseModule):
drop_path_rate) if drop_path_rate > 0. else nn.Identity() drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def forward(self, x): def forward(self, x):
shortcut = x
x = self.depthwise_conv(x)
x = self.norm(x)
if self.linear_pw_conv: def _inner_forward(x):
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) shortcut = x
x = self.depthwise_conv(x)
x = self.norm(x)
x = self.pointwise_conv1(x) if self.linear_pw_conv:
x = self.act(x) x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.pointwise_conv2(x)
if self.linear_pw_conv: x = self.pointwise_conv1(x)
x = x.permute(0, 3, 1, 2) # permute back x = self.act(x)
x = self.pointwise_conv2(x)
if self.gamma is not None: if self.linear_pw_conv:
x = x.mul(self.gamma.view(1, -1, 1, 1)) x = x.permute(0, 3, 1, 2) # permute back
if self.gamma is not None:
x = x.mul(self.gamma.view(1, -1, 1, 1))
x = shortcut + self.drop_path(x)
return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
x = shortcut + self.drop_path(x)
return x return x
@ -169,6 +182,8 @@ class ConvNeXt(BaseBackbone):
gap_before_final_norm (bool): Whether to globally average the feature gap_before_final_norm (bool): Whether to globally average the feature
map before the final norm layer. In the official repo, it's only map before the final norm layer. In the official repo, it's only
used in classification task. Defaults to True. used in classification task. Defaults to True.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
init_cfg (dict, optional): Initialization config dict init_cfg (dict, optional): Initialization config dict
""" # noqa: E501 """ # noqa: E501
arch_settings = { arch_settings = {
@ -206,6 +221,7 @@ class ConvNeXt(BaseBackbone):
out_indices=-1, out_indices=-1,
frozen_stages=0, frozen_stages=0,
gap_before_final_norm=True, gap_before_final_norm=True,
with_cp=False,
init_cfg=None): init_cfg=None):
super().__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
@ -288,8 +304,8 @@ class ConvNeXt(BaseBackbone):
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
act_cfg=act_cfg, act_cfg=act_cfg,
linear_pw_conv=linear_pw_conv, linear_pw_conv=linear_pw_conv,
layer_scale_init_value=layer_scale_init_value) layer_scale_init_value=layer_scale_init_value,
for j in range(depth) with_cp=with_cp) for j in range(depth)
]) ])
block_idx += depth block_idx += depth

View File

@ -19,6 +19,9 @@ def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs):
if device == 'npu': if device == 'npu':
from mmcv.device.npu import NPUDataParallel from mmcv.device.npu import NPUDataParallel
model = NPUDataParallel(model.npu(), dim=dim, *args, **kwargs) model = NPUDataParallel(model.npu(), dim=dim, *args, **kwargs)
elif device == 'mlu':
from mmcv.device.mlu import MLUDataParallel
model = MLUDataParallel(model.mlu(), dim=dim, *args, **kwargs)
elif device == 'cuda': elif device == 'cuda':
from mmcv.parallel import MMDataParallel from mmcv.parallel import MMDataParallel
model = MMDataParallel(model.cuda(), dim=dim, *args, **kwargs) model = MMDataParallel(model.cuda(), dim=dim, *args, **kwargs)
@ -57,6 +60,15 @@ def wrap_distributed_model(model, device='cuda', *args, **kwargs):
from torch.npu import current_device from torch.npu import current_device
model = NPUDistributedDataParallel( model = NPUDistributedDataParallel(
model.npu(), *args, device_ids=[current_device()], **kwargs) model.npu(), *args, device_ids=[current_device()], **kwargs)
elif device == 'mlu':
import os
from mmcv.device.mlu import MLUDistributedDataParallel
model = MLUDistributedDataParallel(
model.mlu(),
*args,
device_ids=[int(os.environ['LOCAL_RANK'])],
**kwargs)
elif device == 'cuda': elif device == 'cuda':
from mmcv.parallel import MMDistributedDataParallel from mmcv.parallel import MMDistributedDataParallel
from torch.cuda import current_device from torch.cuda import current_device

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved # Copyright (c) OpenMMLab. All rights reserved
__version__ = '0.24.1' __version__ = '0.25.0'
def parse_version_info(version_str): def parse_version_info(version_str):

View File

@ -84,3 +84,13 @@ def test_convnext():
for i in range(2, 4): for i in range(2, 4):
assert model.downsample_layers[i].training assert model.downsample_layers[i].training
assert model.stages[i].training assert model.stages[i].training
# Test Activation Checkpointing
model = ConvNeXt(arch='tiny', out_indices=-1, with_cp=True)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([1, 768])

View File

@ -195,10 +195,7 @@ def main():
**show_kwargs) **show_kwargs)
else: else:
model = wrap_distributed_model( model = wrap_distributed_model(
model, model, device=cfg.device, broadcast_buffers=False)
device=cfg.device,
device_ids=[int(os.environ['LOCAL_RANK'])],
broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader, args.tmpdir, outputs = multi_gpu_test(model, data_loader, args.tmpdir,
args.gpu_collect) args.gpu_collect)