Merge branch 'open-mmlab:master' into custom/face_occlusion

pull/2194/head
whooray 2022-10-11 14:58:01 +09:00 committed by GitHub
commit d5e79fa0ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 406 additions and 46 deletions

View File

@ -61,17 +61,15 @@ jobs:
name: Install mmseg dependencies
command: |
python -m pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/torch<< parameters.torch >>/index.html
python -m pip install mmdet
python -m pip install -r requirements.txt
- run:
name: Build and install
command: |
python -m pip install -e .
- run:
name: Run unittests
name: Run unittests but skip timm unittests
command: |
python -m pip install timm
python -m coverage run --branch --source mmseg -m pytest tests/
python -m coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
python -m coverage xml
python -m coverage report -m
@ -102,7 +100,6 @@ jobs:
# python -m pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/torch${{matrix.torch_version}}/index.html
command: |
python -m pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.6.0/index.html
python -m pip install mmdet
python -m pip install -r requirements.txt
- run:
name: Build and install
@ -110,10 +107,9 @@ jobs:
python setup.py check -m -s
TORCH_CUDA_ARCH_LIST=7.0 python -m pip install -e .
- run:
name: Run unittests
name: Run unittests but skip timm unittests
command: |
python -m pip install timm
python -m pytest tests/
python -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
workflows:
unit_tests:

View File

@ -70,13 +70,14 @@ jobs:
coverage run --branch --source mmseg -m pytest tests/
coverage xml
coverage report -m
if: ${{matrix.torch >= '1.5.0'}}
# timm from v0.6.11 requires torch>=1.7
if: ${{matrix.torch >= '1.7.0'}}
- name: Skip timm unittests and generate coverage report
run: |
coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
coverage xml
coverage report -m
if: ${{matrix.torch < '1.5.0'}}
if: ${{matrix.torch < '1.7.0'}}
build_cuda101:
runs-on: ubuntu-18.04
@ -144,13 +145,14 @@ jobs:
coverage run --branch --source mmseg -m pytest tests/
coverage xml
coverage report -m
if: ${{matrix.torch >= '1.5.0'}}
# timm from v0.6.11 requires torch>=1.7
if: ${{matrix.torch >= '1.7.0'}}
- name: Skip timm unittests and generate coverage report
run: |
coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
coverage xml
coverage report -m
if: ${{matrix.torch < '1.5.0'}}
if: ${{matrix.torch < '1.7.0'}}
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1.0.10
with:
@ -249,7 +251,7 @@ jobs:
run: pip install -e .
- name: Run unittests
run: |
python -m pip install timm
python -m pip install 'timm<0.6.11'
coverage run --branch --source mmseg -m pytest tests/
- name: Generate coverage report
run: |

View File

@ -1,6 +1,6 @@
repos:
- repo: https://gitlab.com/pycqa/flake8.git
rev: 3.8.3
rev: 5.0.4
hooks:
- id: flake8
- repo: https://github.com/PyCQA/isort
@ -8,11 +8,11 @@ repos:
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.30.0
rev: v0.32.0
hooks:
- id: yapf
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.1.0
rev: v4.3.0
hooks:
- id: trailing-whitespace
- id: check-yaml
@ -34,7 +34,7 @@ repos:
- mdformat_frontmatter
- linkify-it-py
- repo: https://github.com/codespell-project/codespell
rev: v2.1.0
rev: v2.2.1
hooks:
- id: codespell
- repo: https://github.com/myint/docformatter

View File

@ -77,10 +77,9 @@ The master branch works with **PyTorch 1.5+**.
### 💎 Stable version
v0.28.0 was released in 9/08/2022:
v0.29.0 was released on 10/10/2022:
- Support Tversky Loss
- Fix binary segmentation
- Support PoolFormer (CVPR'2022)
Please refer to [changelog.md](docs/en/changelog.md) for details and release history.
@ -130,6 +129,7 @@ Supported backbones:
- [x] [BEiT (ICLR'2022)](configs/beit)
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
- [x] [MAE (CVPR'2022)](configs/mae)
- [x] [PoolFormer (CVPR'2022)](configs/poolformer)
Supported methods:

View File

@ -74,10 +74,9 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
### 💎 稳定版本
最新版本 v0.28.0 在 2022.9.8 发布:
最新版本 v0.29.0 在 2022.10.10 发布:
- 支持 Tversky Loss
- 修复二值分割
- 支持 PoolFormer (CVPR'2022)
如果想了解更多版本更新细节和历史信息,请阅读[更新日志](docs/en/changelog.md)。
@ -127,6 +126,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [BEiT (ICLR'2022)](configs/beit)
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
- [x] [MAE (CVPR'2022)](configs/mae)
- [x] [PoolFormer (CVPR'2022)](configs/poolformer)
已支持的算法:

View File

@ -0,0 +1,42 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-s12_3rdparty_32xb128_in1k_20220414-f8d83051.pth' # noqa
custom_imports = dict(imports='mmcls.models', allow_failed_imports=False)
model = dict(
type='EncoderDecoder',
backbone=dict(
type='mmcls.PoolFormer',
arch='s12',
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint_file, prefix='backbone.'),
in_patch_size=7,
in_stride=4,
in_pad=2,
down_patch_size=3,
down_stride=2,
down_pad=1,
drop_rate=0.,
drop_path_rate=0.,
out_indices=(0, 2, 4, 6),
frozen_stages=0,
),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=4),
decode_head=dict(
type='FPNHead',
in_channels=[256, 256, 256, 256],
in_index=[0, 1, 2, 3],
feature_strides=[4, 8, 16, 32],
channels=128,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

View File

@ -0,0 +1,63 @@
# PoolFormer
[MetaFormer is Actually What You Need for Vision](https://arxiv.org/abs/2111.11418)
## Introduction
<!-- [BACKBONE] -->
<a href="https://github.com/sail-sg/poolformer/tree/main/segmentation">Official Repo</a>
<a href="https://github.com/open-mmlab/mmclassification/blob/v0.23.0/mmcls/models/backbones/poolformer.py#L198">Code Snippet</a>
## Abstract
<!-- [ABSTRACT] -->
Transformers have shown great potential in computer vision tasks. A common belief is their attention-based token mixer module contributes most to their competence. However, recent works show the attention-based module in transformers can be replaced by spatial MLPs and the resulted models still perform quite well. Based on this observation, we hypothesize that the general architecture of the transformers, instead of the specific token mixer module, is more essential to the model's performance. To verify this, we deliberately replace the attention module in transformers with an embarrassingly simple spatial pooling operator to conduct only the most basic token mixing. Surprisingly, we observe that the derived model, termed as PoolFormer, achieves competitive performance on multiple computer vision tasks. For example, on ImageNet-1K, PoolFormer achieves 82.1% top-1 accuracy, surpassing well-tuned vision transformer/MLP-like baselines DeiT-B/ResMLP-B24 by 0.3%/1.1% accuracy with 35%/52% fewer parameters and 48%/60% fewer MACs. The effectiveness of PoolFormer verifies our hypothesis and urges us to initiate the concept of "MetaFormer", a general architecture abstracted from transformers without specifying the token mixer. Based on the extensive experiments, we argue that MetaFormer is the key player in achieving superior results for recent transformer and MLP-like models on vision tasks. This work calls for more future research dedicated to improving MetaFormer instead of focusing on the token mixer modules. Additionally, our proposed PoolFormer could serve as a starting baseline for future MetaFormer architecture design. Code is available at [this https URL](https://github.com/sail-sg/poolformer)
<!-- [IMAGE] -->
<div align=center>
<img src="https://user-images.githubusercontent.com/15921929/144710761-1635f59a-abde-4946-984c-a2c3f22a19d2.png" width="70%"/>
</div>
## Citation
```bibtex
@inproceedings{yu2022metaformer,
title={Metaformer is actually what you need for vision},
author={Yu, Weihao and Luo, Mi and Zhou, Pan and Si, Chenyang and Zhou, Yichen and Wang, Xinchao and Feng, Jiashi and Yan, Shuicheng},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={10819--10829},
year={2022}
}
```
### Usage
- PoolFormer backbone needs to install [MMClassification](https://github.com/open-mmlab/mmclassification) first, which has abundant backbones for downstream tasks.
```shell
pip install mmcls>=0.23.0
```
- The pretrained models could also be downloaded from [PoolFormer config of MMClassification](https://github.com/open-mmlab/mmclassification/tree/master/configs/poolformer).
## Results and models
### ADE20K
| Method | Backbone | Crop Size | pretrain | Batch Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | mIoU\* | mIoU\*(ms+flip) | config | download |
| ------ | -------------- | --------- | ----------- | ---------- | ------- | -------- | -------------- | ----- | ------------: | ------ | --------------: | ---------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| FPN | PoolFormer-S12 | 512x512 | ImageNet-1K | 32 | 40000 | 4.17 | 23.48 | 36.0 | 36.42 | 37.07 | 38.44 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/poolformer/fpn_poolformer_s12_8x4_512x512_40k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s12_8x4_512x512_40k_ade20k/fpn_poolformer_s12_8x4_512x512_40k_ade20k_20220501_115154-b5aa2f49.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s12_8x4_512x512_40k_ade20k/fpn_poolformer_s12_8x4_512x512_40k_ade20k_20220501_115154.log.json) |
| FPN | PoolFormer-S24 | 512x512 | ImageNet-1K | 32 | 40000 | 5.47 | 15.74 | 39.35 | 39.73 | 40.36 | 41.08 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/poolformer/fpn_poolformer_s24_8x4_512x512_40k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s24_8x4_512x512_40k_ade20k/fpn_poolformer_s24_8x4_512x512_40k_ade20k_20220503_222049-394a7cf7.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s24_8x4_512x512_40k_ade20k/fpn_poolformer_s24_8x4_512x512_40k_ade20k_20220503_222049.log.json) |
| FPN | PoolFormer-S36 | 512x512 | ImageNet-1K | 32 | 40000 | 6.77 | 11.34 | 40.64 | 40.99 | 41.81 | 42.72 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/poolformer/fpn_poolformer_s36_8x4_512x512_40k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s36_8x4_512x512_40k_ade20k/fpn_poolformer_s36_8x4_512x512_40k_ade20k_20220501_151122-b47e607d.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s36_8x4_512x512_40k_ade20k/fpn_poolformer_s36_8x4_512x512_40k_ade20k_20220501_151122.log.json) |
| FPN | PoolFormer-M36 | 512x512 | ImageNet-1K | 32 | 40000 | 8.59 | 8.97 | 40.91 | 41.28 | 42.35 | 43.34 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/poolformer/fpn_poolformer_m36_8x4_512x512_40k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_m36_8x4_512x512_40k_ade20k/fpn_poolformer_m36_8x4_512x512_40k_ade20k_20220501_164230-3dc83921.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_m36_8x4_512x512_40k_ade20k/fpn_poolformer_m36_8x4_512x512_40k_ade20k_20220501_164230.log.json) |
| FPN | PoolFormer-M48 | 512x512 | ImageNet-1K | 32 | 40000 | 10.48 | 6.69 | 41.82 | 42.2 | 42.76 | 43.57 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/poolformer/fpn_poolformer_m48_8x4_512x512_40k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_m48_8x4_512x512_40k_ade20k/fpn_poolformer_m48_8x4_512x512_40k_ade20k_20220504_003923-64168d3b.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_m48_8x4_512x512_40k_ade20k/fpn_poolformer_m48_8x4_512x512_40k_ade20k_20220504_003923.log.json) |
Note:
- We replace `AlignedResize` in original PoolFormer implementation to `Resize + ResizeToMultiple`.
- `mIoU` with * is collected when `Resize + ResizeToMultiple` is adopted in `test_pipeline`, so do `mIoU` in logs.

View File

@ -0,0 +1,11 @@
_base_ = './fpn_poolformer_s12_8x4_512x512_40k_ade20k.py'
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-m36_3rdparty_32xb128_in1k_20220414-c55e0949.pth' # noqa
# model settings
model = dict(
backbone=dict(
arch='m36',
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint_file,
prefix='backbone.')),
neck=dict(in_channels=[96, 192, 384, 768]))

View File

@ -0,0 +1,11 @@
_base_ = './fpn_poolformer_s12_8x4_512x512_40k_ade20k.py'
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-m48_3rdparty_32xb128_in1k_20220414-9378f3eb.pth' # noqa
# model settings
model = dict(
backbone=dict(
arch='m48',
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint_file,
prefix='backbone.')),
neck=dict(in_channels=[96, 192, 384, 768]))

View File

@ -0,0 +1,74 @@
_base_ = [
'../_base_/models/fpn_poolformer_s12.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_40k.py'
]
# model settings
model = dict(
neck=dict(in_channels=[64, 128, 320, 512]),
decode_head=dict(num_classes=150))
# optimizer
optimizer = dict(_delete_=True, type='AdamW', lr=0.0002, weight_decay=0.0001)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=0.0, by_epoch=False)
# dataset settings
dataset_type = 'ADE20KDataset'
data_root = 'data/ade/ADEChallengeData2016'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='ResizeToMultiple', size_divisor=32),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type='RepeatDataset',
times=50,
dataset=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/training',
ann_dir='annotations/training',
pipeline=train_pipeline)),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline))

View File

@ -0,0 +1,9 @@
_base_ = './fpn_poolformer_s12_8x4_512x512_40k_ade20k.py'
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-s24_3rdparty_32xb128_in1k_20220414-d7055904.pth' # noqa
# model settings
model = dict(
backbone=dict(
arch='s24',
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint_file,
prefix='backbone.')))

View File

@ -0,0 +1,10 @@
_base_ = './fpn_poolformer_s12_8x4_512x512_40k_ade20k.py'
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-s36_3rdparty_32xb128_in1k_20220414-d78ff3e8.pth' # noqa
# model settings
model = dict(
backbone=dict(
arch='s36',
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint_file,
prefix='backbone.')))

View File

@ -0,0 +1,111 @@
Models:
- Name: fpn_poolformer_s12_8x4_512x512_40k_ade20k
In Collection: FPN
Metadata:
backbone: PoolFormer-S12
crop size: (512,512)
lr schd: 40000
inference time (ms/im):
- value: 42.59
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 4.17
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 36.0
mIoU(ms+flip): 36.42
Config: configs/poolformer/fpn_poolformer_s12_8x4_512x512_40k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s12_8x4_512x512_40k_ade20k/fpn_poolformer_s12_8x4_512x512_40k_ade20k_20220501_115154-b5aa2f49.pth
- Name: fpn_poolformer_s24_8x4_512x512_40k_ade20k
In Collection: FPN
Metadata:
backbone: PoolFormer-S24
crop size: (512,512)
lr schd: 40000
inference time (ms/im):
- value: 63.53
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 5.47
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 39.35
mIoU(ms+flip): 39.73
Config: configs/poolformer/fpn_poolformer_s24_8x4_512x512_40k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s24_8x4_512x512_40k_ade20k/fpn_poolformer_s24_8x4_512x512_40k_ade20k_20220503_222049-394a7cf7.pth
- Name: fpn_poolformer_s36_8x4_512x512_40k_ade20k
In Collection: FPN
Metadata:
backbone: PoolFormer-S36
crop size: (512,512)
lr schd: 40000
inference time (ms/im):
- value: 88.18
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 6.77
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 40.64
mIoU(ms+flip): 40.99
Config: configs/poolformer/fpn_poolformer_s36_8x4_512x512_40k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_s36_8x4_512x512_40k_ade20k/fpn_poolformer_s36_8x4_512x512_40k_ade20k_20220501_151122-b47e607d.pth
- Name: fpn_poolformer_m36_8x4_512x512_40k_ade20k
In Collection: FPN
Metadata:
backbone: PoolFormer-M36
crop size: (512,512)
lr schd: 40000
inference time (ms/im):
- value: 111.48
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 8.59
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 40.91
mIoU(ms+flip): 41.28
Config: configs/poolformer/fpn_poolformer_m36_8x4_512x512_40k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_m36_8x4_512x512_40k_ade20k/fpn_poolformer_m36_8x4_512x512_40k_ade20k_20220501_164230-3dc83921.pth
- Name: fpn_poolformer_m48_8x4_512x512_40k_ade20k
In Collection: FPN
Metadata:
backbone: PoolFormer-M48
crop size: (512,512)
lr schd: 40000
inference time (ms/im):
- value: 149.48
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 10.48
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 41.82
mIoU(ms+flip): 42.2
Config: configs/poolformer/fpn_poolformer_m48_8x4_512x512_40k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/poolformer/fpn_poolformer_m48_8x4_512x512_40k_ade20k/fpn_poolformer_m48_8x4_512x512_40k_ade20k_20220504_003923-64168d3b.pth

View File

@ -4,10 +4,14 @@ ARG CUDNN="8"
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
ARG MMCV="1.5.0"
ARG MMSEG="0.28.0"
ARG MMSEG="0.29.0"
ENV PYTHONUNBUFFERED TRUE
# NVIDIA APT KEYS
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
RUN apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \
ca-certificates \
@ -21,7 +25,7 @@ ENV PATH="/opt/conda/bin:$PATH"
RUN export FORCE_CUDA=1
# TORCHSEVER
RUN pip install torchserve torch-model-archiver
RUN pip install torchserve torch-model-archiver nvgpu
# MMLAB
ARG PYTORCH

View File

@ -1,5 +1,32 @@
## Changelog
### V0.29.0 (10/10/2022)
**New Features**
- Support PoolFormer (CVPR'2022) ([#1537](https://github.com/open-mmlab/mmsegmentation/pull/1537))
**Enhancement**
- Improve structure and readability for FCNHead ([#2142](https://github.com/open-mmlab/mmsegmentation/pull/2142))
- Support IterableDataset in distributed training ([#2151](https://github.com/open-mmlab/mmsegmentation/pull/2151))
- Upgrade .dev scripts ([#2020](https://github.com/open-mmlab/mmsegmentation/pull/2020))
- Upgrade pre-commit hooks ([#2155](https://github.com/open-mmlab/mmsegmentation/pull/2155))
**Bug Fixes**
- Fix mmseg.api.inference inference_segmentor ([#1849](https://github.com/open-mmlab/mmsegmentation/pull/1849))
- fix bug about label_map in evaluation part ([#2075](https://github.com/open-mmlab/mmsegmentation/pull/2075))
- Add missing dependencies to torchserve docker file ([#2133](https://github.com/open-mmlab/mmsegmentation/pull/2133))
- Fix ddp unittest ([#2060](https://github.com/open-mmlab/mmsegmentation/pull/2060))
**Contributors**
- @jinwonkim93 made their first contribution in https://github.com/open-mmlab/mmsegmentation/pull/1849
- @rlatjcj made their first contribution in https://github.com/open-mmlab/mmsegmentation/pull/2075
- @ShirleyWangCVR made their first contribution in https://github.com/open-mmlab/mmsegmentation/pull/2151
- @mangelroman made their first contribution in https://github.com/open-mmlab/mmsegmentation/pull/2133
### V0.28.0 (9/8/2022)
**New Features**

View File

@ -9,6 +9,7 @@ The compatible MMSegmentation and MMCV versions are as below. Please install the
| MMSegmentation version | MMCV version | MMClassification version |
| :--------------------: | :-------------------------: | :----------------------: |
| master | mmcv-full>=1.5.0, \<1.7.0 | mmcls>=0.20.1, \<=1.0.0 |
| 0.29.0 | mmcv-full>=1.5.0, \<1.7.0 | mmcls>=0.20.1, \<=1.0.0 |
| 0.28.0 | mmcv-full>=1.5.0, \<1.7.0 | mmcls>=0.20.1, \<=1.0.0 |
| 0.27.0 | mmcv-full>=1.5.0, \<1.7.0 | mmcls>=0.20.1, \<=1.0.0 |
| 0.26.0 | mmcv-full>=1.5.0, \<=1.6.0 | mmcls>=0.20.1, \<=1.0.0 |
@ -53,7 +54,7 @@ Briefly, it is a deep supervision trick to improve the accuracy. In the training
## Why is the log file not created
In the train script, we call `get_root_logger`at Line 167, and `get_root_logger` in mmseg calls `get_logger` in mmcv, mmcv will return the same logger which has beed initialized in 'mmsegmentation/tools/train.py' with the parameter `log_file`. There is only one logger (initialized with `log_file`) during training.
In the train script, we call `get_root_logger`at Line 167, and `get_root_logger` in mmseg calls `get_logger` in mmcv, mmcv will return the same logger which has been initialized in 'mmsegmentation/tools/train.py' with the parameter `log_file`. There is only one logger (initialized with `log_file`) during training.
Ref: [https://github.com/open-mmlab/mmcv/blob/21bada32560c7ed7b15b017dc763d862789e29a8/mmcv/utils/logging.py#L9-L16](https://github.com/open-mmlab/mmcv/blob/21bada32560c7ed7b15b017dc763d862789e29a8/mmcv/utils/logging.py#L9-L16)
If you find the log file not been created, you might check if `mmcv.utils.get_logger` is called elsewhere.

View File

@ -33,7 +33,7 @@ data = dict(
- `train`, `val` and `test`: The [`config`](https://github.com/open-mmlab/mmcv/blob/master/docs/en/understand_mmcv/config.md)s to build dataset instances for model training, validation and testing by
using [`build and registry`](https://github.com/open-mmlab/mmcv/blob/master/docs/en/understand_mmcv/registry.md) mechanism.
- `samples_per_gpu`: How many samples per batch and per gpu to load during model training, and the `batch_size` of training is equal to `samples_per_gpu` times gpu number, e.g. when using 8 gpus for distributed data parallel trainig and `samples_per_gpu=4`, the `batch_size` is `8*4=32`.
- `samples_per_gpu`: How many samples per batch and per gpu to load during model training, and the `batch_size` of training is equal to `samples_per_gpu` times gpu number, e.g. when using 8 gpus for distributed data parallel training and `samples_per_gpu=4`, the `batch_size` is `8*4=32`.
If you would like to define `batch_size` for testing and validation, please use `test_dataloaser` and
`val_dataloader` with mmseg >=0.24.1.

View File

@ -9,6 +9,7 @@
| MMSegmentation version | MMCV version | MMClassification version |
| :--------------------: | :-------------------------: | :----------------------: |
| master | mmcv-full>=1.5.0, \<1.7.0 | mmcls>=0.20.1, \<=1.0.0 |
| 0.29.0 | mmcv-full>=1.5.0, \<1.7.0 | mmcls>=0.20.1, \<=1.0.0 |
| 0.28.0 | mmcv-full>=1.5.0, \<1.7.0 | mmcls>=0.20.1, \<=1.0.0 |
| 0.27.0 | mmcv-full>=1.5.0, \<1.7.0 | mmcls>=0.20.1, \<=1.0.0 |
| 0.26.0 | mmcv-full>=1.5.0, \<=1.6.0 | mmcls>=0.20.1, \<=1.0.0 |

View File

@ -9,7 +9,7 @@ import torch
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import Registry, build_from_cfg, digit_version
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, IterableDataset
from .samplers import DistributedSampler
@ -129,12 +129,17 @@ def build_dataloader(dataset,
DataLoader: A PyTorch dataloader.
"""
rank, world_size = get_dist_info()
if dist:
if dist and not isinstance(dataset, IterableDataset):
sampler = DistributedSampler(
dataset, world_size, rank, shuffle=shuffle, seed=seed)
shuffle = False
batch_size = samples_per_gpu
num_workers = workers_per_gpu
elif dist:
sampler = None
shuffle = False
batch_size = samples_per_gpu
num_workers = workers_per_gpu
else:
sampler = None
batch_size = num_gpus * samples_per_gpu

View File

@ -337,7 +337,7 @@ class VisionTransformer(BaseModule):
constant_init(m, val=1.0, bias=0.)
def _pos_embeding(self, patched_img, hw_shape, pos_embed):
"""Positiong embeding method.
"""Positioning embeding method.
Resize the pos_embed, if the input image size doesn't match
the training size.

View File

@ -37,20 +37,11 @@ class FCNHead(BaseDecodeHead):
conv_padding = (kernel_size // 2) * dilation
convs = []
convs.append(
ConvModule(
self.in_channels,
self.channels,
kernel_size=kernel_size,
padding=conv_padding,
dilation=dilation,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
for i in range(num_convs - 1):
for i in range(num_convs):
_in_channels = self.in_channels if i == 0 else self.channels
convs.append(
ConvModule(
self.channels,
_in_channels,
self.channels,
kernel_size=kernel_size,
padding=conv_padding,
@ -58,7 +49,8 @@ class FCNHead(BaseDecodeHead):
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
if num_convs == 0:
if len(convs) == 0:
self.convs = nn.Identity()
else:
self.convs = nn.Sequential(*convs)

View File

@ -78,7 +78,7 @@ def sigmoid_focal_loss(pred,
valid_mask=None,
reduction='mean',
avg_factor=None):
r"""A warpper of cuda version `Focal Loss
r"""A wrapper of cuda version `Focal Loss
<https://arxiv.org/abs/1708.02002>`_.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number

View File

@ -1,6 +1,6 @@
# Copyright (c) Open-MMLab. All rights reserved.
__version__ = '0.28.0'
__version__ = '0.29.0'
def parse_version_info(version_str):

View File

@ -30,6 +30,7 @@ Import:
- configs/nonlocal_net/nonlocal_net.yml
- configs/ocrnet/ocrnet.yml
- configs/point_rend/point_rend.yml
- configs/poolformer/poolformer.yml
- configs/psanet/psanet.yml
- configs/pspnet/pspnet.yml
- configs/resnest/resnest.yml

View File

@ -19,4 +19,4 @@ default_section = THIRDPARTY
skip = *.po,*.ts,*.ipynb
count =
quiet-level = 3
ignore-words-list = formating,sur,hist,dota,ba
ignore-words-list = formating,sur,hist,dota,ba,warmup