CodeCamp #150 [Feature] Add ISNet (#2400)

## Motivation

Support ISNet.
paper link: [ISNet: Integrate Image-Level and Semantic-Level Context for
Semantic
Segmentation](https://openaccess.thecvf.com/content/ICCV2021/papers/Jin_ISNet_Integrate_Image-Level_and_Semantic-Level_Context_for_Semantic_Segmentation_ICCV_2021_paper.pdf)

## Modification

Add ISNet decoder head.
Add ISNet config.
This commit is contained in:
unrealMJ 2023-01-04 20:39:03 +08:00 committed by GitHub
parent 6af2b8eab9
commit bd29c20778
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 477 additions and 0 deletions

57
projects/isnet/README.md Normal file
View File

@ -0,0 +1,57 @@
# ISNet
[ISNet: Integrate Image-Level and Semantic-Level Context for Semantic Segmentation](https://arxiv.org/pdf/2108.12382.pdf)
## Description
This is an implementation of [ISNet](https://arxiv.org/pdf/2108.12382.pdf).
[Official Repo](https://github.com/SegmentationBLWX/sssegmentation)
## Usage
### Prerequisites
- Python 3.7
- PyTorch 1.6 or higher
- [MIM](https://github.com/open-mmlab/mim) v0.33 or higher
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc2 or higher
All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `isnet/` root directory, run the following line to add the current directory to `PYTHONPATH`:
```shell
export PYTHONPATH=`pwd`:$PYTHONPATH
```
### Training commands
```shell
mim train mmsegmentation configs/isnet_r50-d8_8xb2-160k_cityscapes-512x1024.py --work-dir work_dirs/isnet
```
To train on multiple GPUs, e.g. 8 GPUs, run the following command:
```shell
mim train mmsegmentation configs/isnet_r50-d8_8xb2-160k_cityscapes-512x1024.py --work-dir work_dirs/isnet --launcher pytorch --gpus 8
```
### Testing commands
```shell
mim test mmsegmentation configs/isnet_r50-d8_8xb2-160k_cityscapes-512x1024.py --work-dir work_dirs/isnet --checkpoint ${CHECKPOINT_PATH}
```
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ------ | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | --------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------ |
| ISNet | R-50-D8 | 512x1024 | - | - | - | 79.32 | 80.88 | [config](configs/isnet_r50-d8_8xb2-160k_cityscapes-512x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/isnet/isnet_r50-d8_cityscapes-512x1024_20230104-a7a8ccf2.pth) |
## Citation
```bibtex
@article{Jin2021ISNetII,
title={ISNet: Integrate Image-Level and Semantic-Level Context for Semantic Segmentation},
author={Zhenchao Jin and B. Liu and Qi Chu and Nenghai Yu},
journal={2021 IEEE/CVF International Conference on Computer Vision (ICCV)},
year={2021},
pages={7169-7178}
}
```

View File

@ -0,0 +1,80 @@
_base_ = [
'../../../configs/_base_/datasets/cityscapes.py',
'../../../configs/_base_/default_runtime.py',
'../../../configs/_base_/schedules/schedule_80k.py'
]
data_root = '../../data/cityscapes/'
train_dataloader = dict(dataset=dict(data_root=data_root))
val_dataloader = dict(dataset=dict(data_root=data_root))
test_dataloader = dict(dataset=dict(data_root=data_root))
custom_imports = dict(imports=['projects.isnet.decode_heads'])
norm_cfg = dict(type='SyncBN', requires_grad=True)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255)
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='ISNetHead',
in_channels=(256, 512, 1024, 2048),
input_transform='multiple_select',
in_index=(0, 1, 2, 3),
channels=512,
dropout_ratio=0.1,
transform_channels=256,
concat_input=True,
with_shortcut=False,
shortcut_in_channels=256,
shortcut_feat_channels=48,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=[
dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0,
loss_name='loss_o'),
dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=0.4,
loss_name='loss_d'),
]),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=512,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
train_cfg=dict(),
# test_cfg=dict(mode='slide', crop_size=(769, 769), stride=(513, 513))
test_cfg=dict(mode='whole'))

View File

@ -0,0 +1,3 @@
from .isnet_head import ISNetHead
__all__ = ['ISNetHead']

View File

@ -0,0 +1,337 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from torch import Tensor
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.models.losses import accuracy
from mmseg.models.utils import SelfAttentionBlock, resize
from mmseg.registry import MODELS
from mmseg.utils import SampleList
class ImageLevelContext(nn.Module):
""" Image-Level Context Module
Args:
feats_channels (int): Input channels of query/key feature.
transform_channels (int): Output channels of key/query transform.
concat_input (bool): whether to concat input feature.
align_corners (bool): align_corners argument of F.interpolate.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict): Config of activation layers.
"""
def __init__(self,
feats_channels,
transform_channels,
concat_input=False,
align_corners=False,
conv_cfg=None,
norm_cfg=None,
act_cfg=None):
super().__init__()
self.align_corners = align_corners
self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.correlate_net = SelfAttentionBlock(
key_in_channels=feats_channels * 2,
query_in_channels=feats_channels,
channels=transform_channels,
out_channels=feats_channels,
share_key_query=False,
query_downsample=None,
key_downsample=None,
key_query_num_convs=2,
value_out_num_convs=1,
key_query_norm=True,
value_out_norm=True,
matmul_norm=True,
with_out=True,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
)
if concat_input:
self.bottleneck = ConvModule(
feats_channels * 2,
feats_channels,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
)
'''forward'''
def forward(self, x):
x_global = self.global_avgpool(x)
x_global = resize(
x_global,
size=x.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
feats_il = self.correlate_net(x, torch.cat([x_global, x], dim=1))
if hasattr(self, 'bottleneck'):
feats_il = self.bottleneck(torch.cat([x, feats_il], dim=1))
return feats_il
class SemanticLevelContext(nn.Module):
""" Semantic-Level Context Module
Args:
feats_channels (int): Input channels of query/key feature.
transform_channels (int): Output channels of key/query transform.
concat_input (bool): whether to concat input feature.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict): Config of activation layers.
"""
def __init__(self,
feats_channels,
transform_channels,
concat_input=False,
conv_cfg=None,
norm_cfg=None,
act_cfg=None):
super().__init__()
self.correlate_net = SelfAttentionBlock(
key_in_channels=feats_channels,
query_in_channels=feats_channels,
channels=transform_channels,
out_channels=feats_channels,
share_key_query=False,
query_downsample=None,
key_downsample=None,
key_query_num_convs=2,
value_out_num_convs=1,
key_query_norm=True,
value_out_norm=True,
matmul_norm=True,
with_out=True,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
)
if concat_input:
self.bottleneck = ConvModule(
feats_channels * 2,
feats_channels,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
)
'''forward'''
def forward(self, x, preds, feats_il):
inputs = x
batch_size, num_channels, h, w = x.size()
num_classes = preds.size(1)
feats_sl = torch.zeros(batch_size, h * w, num_channels).type_as(x)
for batch_idx in range(batch_size):
# (C, H, W), (num_classes, H, W) --> (H*W, C), (H*W, num_classes)
feats_iter, preds_iter = x[batch_idx], preds[batch_idx]
feats_iter, preds_iter = feats_iter.reshape(
num_channels, -1), preds_iter.reshape(num_classes, -1)
feats_iter, preds_iter = feats_iter.permute(1,
0), preds_iter.permute(
1, 0)
# (H*W, )
argmax = preds_iter.argmax(1)
for clsid in range(num_classes):
mask = (argmax == clsid)
if mask.sum() == 0:
continue
feats_iter_cls = feats_iter[mask]
preds_iter_cls = preds_iter[:, clsid][mask]
weight = torch.softmax(preds_iter_cls, dim=0)
feats_iter_cls = feats_iter_cls * weight.unsqueeze(-1)
feats_iter_cls = feats_iter_cls.sum(0)
feats_sl[batch_idx][mask] = feats_iter_cls
feats_sl = feats_sl.reshape(batch_size, h, w, num_channels)
feats_sl = feats_sl.permute(0, 3, 1, 2).contiguous()
feats_sl = self.correlate_net(inputs, feats_sl)
if hasattr(self, 'bottleneck'):
feats_sl = self.bottleneck(torch.cat([feats_il, feats_sl], dim=1))
return feats_sl
@MODELS.register_module()
class ISNetHead(BaseDecodeHead):
"""ISNet: Integrate Image-Level and Semantic-Level
Context for Semantic Segmentation
This head is the implementation of `ISNet`
<https://arxiv.org/pdf/2108.12382.pdf>`_.
Args:
transform_channels (int): Output channels of key/query transform.
concat_input (bool): whether to concat input feature.
with_shortcut (bool): whether to use shortcut connection.
shortcut_in_channels (int): Input channels of shortcut.
shortcut_feat_channels (int): Output channels of shortcut.
dropout_ratio (float): Ratio of dropout.
"""
def __init__(self, transform_channels, concat_input, with_shortcut,
shortcut_in_channels, shortcut_feat_channels, dropout_ratio,
**kwargs):
super().__init__(**kwargs)
self.in_channels = self.in_channels[-1]
self.bottleneck = ConvModule(
self.in_channels,
self.channels,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.ilc_net = ImageLevelContext(
feats_channels=self.channels,
transform_channels=transform_channels,
concat_input=concat_input,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
self.slc_net = SemanticLevelContext(
feats_channels=self.channels,
transform_channels=transform_channels,
concat_input=concat_input,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.decoder_stage1 = nn.Sequential(
ConvModule(
self.channels,
self.channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Dropout2d(dropout_ratio),
nn.Conv2d(
self.channels,
self.num_classes,
kernel_size=1,
stride=1,
padding=0,
bias=True),
)
if with_shortcut:
self.shortcut = ConvModule(
shortcut_in_channels,
shortcut_feat_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.decoder_stage2 = nn.Sequential(
ConvModule(
self.channels + shortcut_feat_channels,
self.channels,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Dropout2d(dropout_ratio),
nn.Conv2d(
self.channels,
self.num_classes,
kernel_size=1,
stride=1,
padding=0,
bias=True),
)
else:
self.decoder_stage2 = nn.Sequential(
nn.Dropout2d(dropout_ratio),
nn.Conv2d(
self.channels,
self.num_classes,
kernel_size=1,
stride=1,
padding=0,
bias=True),
)
self.conv_seg = None
self.dropout = None
def forward(self, inputs):
x = self._transform_inputs(inputs)
feats = self.bottleneck(x[-1])
feats_il = self.ilc_net(feats)
preds_stage1 = self.decoder_stage1(feats)
preds_stage1 = resize(
preds_stage1,
size=feats.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
feats_sl = self.slc_net(feats, preds_stage1, feats_il)
if hasattr(self, 'shortcut'):
shortcut_out = self.shortcut(x[0])
feats_sl = resize(
feats_sl,
size=shortcut_out.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
feats_sl = torch.cat([feats_sl, shortcut_out], dim=1)
preds_stage2 = self.decoder_stage2(feats_sl)
return preds_stage1, preds_stage2
def loss_by_feat(self, seg_logits: Tensor,
batch_data_samples: SampleList) -> dict:
seg_label = self._stack_batch_gt(batch_data_samples)
loss = dict()
if self.sampler is not None:
seg_weight = self.sampler.sample(seg_logits[-1], seg_label)
else:
seg_weight = None
seg_label = seg_label.squeeze(1)
for seg_logit, loss_decode in zip(seg_logits, self.loss_decode):
seg_logit = resize(
input=seg_logit,
size=seg_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
loss[loss_decode.name] = loss_decode(
seg_logit,
seg_label,
seg_weight,
ignore_index=self.ignore_index)
loss['acc_seg'] = accuracy(
seg_logits[-1], seg_label, ignore_index=self.ignore_index)
return loss
def predict_by_feat(self, seg_logits: Tensor,
batch_img_metas: List[dict]) -> Tensor:
_, seg_logits_stage2 = seg_logits
return super().predict_by_feat(seg_logits_stage2, batch_img_metas)