From bd29c20778697ca00ad8a8b575ac595c9492fa0d Mon Sep 17 00:00:00 2001 From: unrealMJ <45420156+unrealMJ@users.noreply.github.com> Date: Wed, 4 Jan 2023 20:39:03 +0800 Subject: [PATCH] CodeCamp #150 [Feature] Add ISNet (#2400) ## Motivation Support ISNet. paper link: [ISNet: Integrate Image-Level and Semantic-Level Context for Semantic Segmentation](https://openaccess.thecvf.com/content/ICCV2021/papers/Jin_ISNet_Integrate_Image-Level_and_Semantic-Level_Context_for_Semantic_Segmentation_ICCV_2021_paper.pdf) ## Modification Add ISNet decoder head. Add ISNet config. --- projects/isnet/README.md | 57 +++ ...et_r50-d8_8xb2-160k_cityscapes-512x1024.py | 80 +++++ projects/isnet/decode_heads/__init__.py | 3 + projects/isnet/decode_heads/isnet_head.py | 337 ++++++++++++++++++ 4 files changed, 477 insertions(+) create mode 100644 projects/isnet/README.md create mode 100644 projects/isnet/configs/isnet_r50-d8_8xb2-160k_cityscapes-512x1024.py create mode 100644 projects/isnet/decode_heads/__init__.py create mode 100644 projects/isnet/decode_heads/isnet_head.py diff --git a/projects/isnet/README.md b/projects/isnet/README.md new file mode 100644 index 000000000..b2623e39f --- /dev/null +++ b/projects/isnet/README.md @@ -0,0 +1,57 @@ +# ISNet + +[ISNet: Integrate Image-Level and Semantic-Level Context for Semantic Segmentation](https://arxiv.org/pdf/2108.12382.pdf) + +## Description + +This is an implementation of [ISNet](https://arxiv.org/pdf/2108.12382.pdf). +[Official Repo](https://github.com/SegmentationBLWX/sssegmentation) + +## Usage + +### Prerequisites + +- Python 3.7 +- PyTorch 1.6 or higher +- [MIM](https://github.com/open-mmlab/mim) v0.33 or higher +- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) v1.0.0rc2 or higher + +All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `isnet/` root directory, run the following line to add the current directory to `PYTHONPATH`: + +```shell +export PYTHONPATH=`pwd`:$PYTHONPATH +``` + +### Training commands + +```shell +mim train mmsegmentation configs/isnet_r50-d8_8xb2-160k_cityscapes-512x1024.py --work-dir work_dirs/isnet +``` + +To train on multiple GPUs, e.g. 8 GPUs, run the following command: + +```shell +mim train mmsegmentation configs/isnet_r50-d8_8xb2-160k_cityscapes-512x1024.py --work-dir work_dirs/isnet --launcher pytorch --gpus 8 +``` + +### Testing commands + +```shell +mim test mmsegmentation configs/isnet_r50-d8_8xb2-160k_cityscapes-512x1024.py --work-dir work_dirs/isnet --checkpoint ${CHECKPOINT_PATH} +``` + +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download | +| ------ | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | --------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------ | +| ISNet | R-50-D8 | 512x1024 | - | - | - | 79.32 | 80.88 | [config](configs/isnet_r50-d8_8xb2-160k_cityscapes-512x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/isnet/isnet_r50-d8_cityscapes-512x1024_20230104-a7a8ccf2.pth) | + +## Citation + +```bibtex +@article{Jin2021ISNetII, + title={ISNet: Integrate Image-Level and Semantic-Level Context for Semantic Segmentation}, + author={Zhenchao Jin and B. Liu and Qi Chu and Nenghai Yu}, + journal={2021 IEEE/CVF International Conference on Computer Vision (ICCV)}, + year={2021}, + pages={7169-7178} +} +``` diff --git a/projects/isnet/configs/isnet_r50-d8_8xb2-160k_cityscapes-512x1024.py b/projects/isnet/configs/isnet_r50-d8_8xb2-160k_cityscapes-512x1024.py new file mode 100644 index 000000000..a00d39237 --- /dev/null +++ b/projects/isnet/configs/isnet_r50-d8_8xb2-160k_cityscapes-512x1024.py @@ -0,0 +1,80 @@ +_base_ = [ + '../../../configs/_base_/datasets/cityscapes.py', + '../../../configs/_base_/default_runtime.py', + '../../../configs/_base_/schedules/schedule_80k.py' +] + +data_root = '../../data/cityscapes/' +train_dataloader = dict(dataset=dict(data_root=data_root)) +val_dataloader = dict(dataset=dict(data_root=data_root)) +test_dataloader = dict(dataset=dict(data_root=data_root)) + +custom_imports = dict(imports=['projects.isnet.decode_heads']) + +norm_cfg = dict(type='SyncBN', requires_grad=True) +data_preprocessor = dict( + type='SegDataPreProcessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_val=0, + seg_pad_val=255) + +model = dict( + type='EncoderDecoder', + data_preprocessor=data_preprocessor, + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='ISNetHead', + in_channels=(256, 512, 1024, 2048), + input_transform='multiple_select', + in_index=(0, 1, 2, 3), + channels=512, + dropout_ratio=0.1, + transform_channels=256, + concat_input=True, + with_shortcut=False, + shortcut_in_channels=256, + shortcut_feat_channels=48, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=[ + dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0, + loss_name='loss_o'), + dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=0.4, + loss_name='loss_d'), + ]), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=512, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + train_cfg=dict(), + # test_cfg=dict(mode='slide', crop_size=(769, 769), stride=(513, 513)) + test_cfg=dict(mode='whole')) diff --git a/projects/isnet/decode_heads/__init__.py b/projects/isnet/decode_heads/__init__.py new file mode 100644 index 000000000..a451629c4 --- /dev/null +++ b/projects/isnet/decode_heads/__init__.py @@ -0,0 +1,3 @@ +from .isnet_head import ISNetHead + +__all__ = ['ISNetHead'] diff --git a/projects/isnet/decode_heads/isnet_head.py b/projects/isnet/decode_heads/isnet_head.py new file mode 100644 index 000000000..9c8df540e --- /dev/null +++ b/projects/isnet/decode_heads/isnet_head.py @@ -0,0 +1,337 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from torch import Tensor + +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.models.losses import accuracy +from mmseg.models.utils import SelfAttentionBlock, resize +from mmseg.registry import MODELS +from mmseg.utils import SampleList + + +class ImageLevelContext(nn.Module): + """ Image-Level Context Module + Args: + feats_channels (int): Input channels of query/key feature. + transform_channels (int): Output channels of key/query transform. + concat_input (bool): whether to concat input feature. + align_corners (bool): align_corners argument of F.interpolate. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, + feats_channels, + transform_channels, + concat_input=False, + align_corners=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None): + super().__init__() + self.align_corners = align_corners + self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.correlate_net = SelfAttentionBlock( + key_in_channels=feats_channels * 2, + query_in_channels=feats_channels, + channels=transform_channels, + out_channels=feats_channels, + share_key_query=False, + query_downsample=None, + key_downsample=None, + key_query_num_convs=2, + value_out_num_convs=1, + key_query_norm=True, + value_out_norm=True, + matmul_norm=True, + with_out=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + if concat_input: + self.bottleneck = ConvModule( + feats_channels * 2, + feats_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + + '''forward''' + + def forward(self, x): + x_global = self.global_avgpool(x) + x_global = resize( + x_global, + size=x.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + feats_il = self.correlate_net(x, torch.cat([x_global, x], dim=1)) + if hasattr(self, 'bottleneck'): + feats_il = self.bottleneck(torch.cat([x, feats_il], dim=1)) + return feats_il + + +class SemanticLevelContext(nn.Module): + """ Semantic-Level Context Module + Args: + feats_channels (int): Input channels of query/key feature. + transform_channels (int): Output channels of key/query transform. + concat_input (bool): whether to concat input feature. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, + feats_channels, + transform_channels, + concat_input=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None): + super().__init__() + self.correlate_net = SelfAttentionBlock( + key_in_channels=feats_channels, + query_in_channels=feats_channels, + channels=transform_channels, + out_channels=feats_channels, + share_key_query=False, + query_downsample=None, + key_downsample=None, + key_query_num_convs=2, + value_out_num_convs=1, + key_query_norm=True, + value_out_norm=True, + matmul_norm=True, + with_out=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + if concat_input: + self.bottleneck = ConvModule( + feats_channels * 2, + feats_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + + '''forward''' + + def forward(self, x, preds, feats_il): + inputs = x + batch_size, num_channels, h, w = x.size() + num_classes = preds.size(1) + feats_sl = torch.zeros(batch_size, h * w, num_channels).type_as(x) + for batch_idx in range(batch_size): + # (C, H, W), (num_classes, H, W) --> (H*W, C), (H*W, num_classes) + feats_iter, preds_iter = x[batch_idx], preds[batch_idx] + feats_iter, preds_iter = feats_iter.reshape( + num_channels, -1), preds_iter.reshape(num_classes, -1) + feats_iter, preds_iter = feats_iter.permute(1, + 0), preds_iter.permute( + 1, 0) + # (H*W, ) + argmax = preds_iter.argmax(1) + for clsid in range(num_classes): + mask = (argmax == clsid) + if mask.sum() == 0: + continue + feats_iter_cls = feats_iter[mask] + preds_iter_cls = preds_iter[:, clsid][mask] + weight = torch.softmax(preds_iter_cls, dim=0) + feats_iter_cls = feats_iter_cls * weight.unsqueeze(-1) + feats_iter_cls = feats_iter_cls.sum(0) + feats_sl[batch_idx][mask] = feats_iter_cls + feats_sl = feats_sl.reshape(batch_size, h, w, num_channels) + feats_sl = feats_sl.permute(0, 3, 1, 2).contiguous() + feats_sl = self.correlate_net(inputs, feats_sl) + if hasattr(self, 'bottleneck'): + feats_sl = self.bottleneck(torch.cat([feats_il, feats_sl], dim=1)) + return feats_sl + + +@MODELS.register_module() +class ISNetHead(BaseDecodeHead): + """ISNet: Integrate Image-Level and Semantic-Level + Context for Semantic Segmentation + + This head is the implementation of `ISNet` + `_. + + Args: + transform_channels (int): Output channels of key/query transform. + concat_input (bool): whether to concat input feature. + with_shortcut (bool): whether to use shortcut connection. + shortcut_in_channels (int): Input channels of shortcut. + shortcut_feat_channels (int): Output channels of shortcut. + dropout_ratio (float): Ratio of dropout. + """ + + def __init__(self, transform_channels, concat_input, with_shortcut, + shortcut_in_channels, shortcut_feat_channels, dropout_ratio, + **kwargs): + super().__init__(**kwargs) + + self.in_channels = self.in_channels[-1] + + self.bottleneck = ConvModule( + self.in_channels, + self.channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.ilc_net = ImageLevelContext( + feats_channels=self.channels, + transform_channels=transform_channels, + concat_input=concat_input, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + self.slc_net = SemanticLevelContext( + feats_channels=self.channels, + transform_channels=transform_channels, + concat_input=concat_input, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.decoder_stage1 = nn.Sequential( + ConvModule( + self.channels, + self.channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + nn.Dropout2d(dropout_ratio), + nn.Conv2d( + self.channels, + self.num_classes, + kernel_size=1, + stride=1, + padding=0, + bias=True), + ) + + if with_shortcut: + self.shortcut = ConvModule( + shortcut_in_channels, + shortcut_feat_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.decoder_stage2 = nn.Sequential( + ConvModule( + self.channels + shortcut_feat_channels, + self.channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + nn.Dropout2d(dropout_ratio), + nn.Conv2d( + self.channels, + self.num_classes, + kernel_size=1, + stride=1, + padding=0, + bias=True), + ) + else: + self.decoder_stage2 = nn.Sequential( + nn.Dropout2d(dropout_ratio), + nn.Conv2d( + self.channels, + self.num_classes, + kernel_size=1, + stride=1, + padding=0, + bias=True), + ) + + self.conv_seg = None + self.dropout = None + + def forward(self, inputs): + x = self._transform_inputs(inputs) + feats = self.bottleneck(x[-1]) + + feats_il = self.ilc_net(feats) + + preds_stage1 = self.decoder_stage1(feats) + preds_stage1 = resize( + preds_stage1, + size=feats.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + + feats_sl = self.slc_net(feats, preds_stage1, feats_il) + + if hasattr(self, 'shortcut'): + shortcut_out = self.shortcut(x[0]) + feats_sl = resize( + feats_sl, + size=shortcut_out.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + feats_sl = torch.cat([feats_sl, shortcut_out], dim=1) + preds_stage2 = self.decoder_stage2(feats_sl) + + return preds_stage1, preds_stage2 + + def loss_by_feat(self, seg_logits: Tensor, + batch_data_samples: SampleList) -> dict: + seg_label = self._stack_batch_gt(batch_data_samples) + loss = dict() + + if self.sampler is not None: + seg_weight = self.sampler.sample(seg_logits[-1], seg_label) + else: + seg_weight = None + seg_label = seg_label.squeeze(1) + + for seg_logit, loss_decode in zip(seg_logits, self.loss_decode): + seg_logit = resize( + input=seg_logit, + size=seg_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + loss[loss_decode.name] = loss_decode( + seg_logit, + seg_label, + seg_weight, + ignore_index=self.ignore_index) + + loss['acc_seg'] = accuracy( + seg_logits[-1], seg_label, ignore_index=self.ignore_index) + return loss + + def predict_by_feat(self, seg_logits: Tensor, + batch_img_metas: List[dict]) -> Tensor: + _, seg_logits_stage2 = seg_logits + return super().predict_by_feat(seg_logits_stage2, batch_img_metas)