mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[New model] Support CGNet (#223)
* added cgnet * added testing for cgnet * git test * add cgnet * fix __init__ * rename FGlo with GlobalContextExtractor * add readme.md and rename bn with norm * delete cg_head * fix a language mistake * rename cgnet_m3n21.py to cgnet.py * modify README.md * modify list to tuple * add fcn_head test * add assert to fcn_head * blank * fix fcn_head assert bug * add * add cgnet to README.md and model_zoo.md * modify cgnet README.md Co-authored-by: KID <wps_@mail.nankai.edu.cn>
This commit is contained in:
parent
294a1f377a
commit
86d473002f
@ -81,6 +81,7 @@ Supported methods:
|
||||
- [x] [PointRend](configs/point_rend)
|
||||
- [x] [EMANet](configs/emanet)
|
||||
- [x] [DNLNet](configs/dnlnet)
|
||||
- [x] [CGNet](configs/cgnet)
|
||||
- [x] [Mixed Precision (FP16) Training](configs/fp16/README.md)
|
||||
|
||||
## Installation
|
||||
|
35
configs/_base_/models/cgnet.py
Normal file
35
configs/_base_/models/cgnet.py
Normal file
@ -0,0 +1,35 @@
|
||||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', eps=1e-03, requires_grad=True)
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
backbone=dict(
|
||||
type='CGNet',
|
||||
norm_cfg=norm_cfg,
|
||||
in_channels=3,
|
||||
num_channels=(32, 64, 128),
|
||||
num_blocks=(3, 21),
|
||||
dilations=(2, 4),
|
||||
reductions=(8, 16)),
|
||||
decode_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=256,
|
||||
in_index=2,
|
||||
channels=256,
|
||||
num_convs=0,
|
||||
concat_input=False,
|
||||
dropout_ratio=0,
|
||||
num_classes=19,
|
||||
norm_cfg=norm_cfg,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
loss_weight=1.0,
|
||||
class_weight=[
|
||||
2.5959933, 6.7415504, 3.5354059, 9.8663225, 9.690899, 9.369352,
|
||||
10.289121, 9.953208, 4.3097677, 9.490387, 7.674431, 9.396905,
|
||||
10.347791, 6.3927646, 10.226669, 10.241062, 10.280587,
|
||||
10.396974, 10.055647
|
||||
])))
|
||||
# model training and testing settings
|
||||
train_cfg = dict(sampler=None)
|
||||
test_cfg = dict(mode='whole')
|
21
configs/cgnet/README.md
Normal file
21
configs/cgnet/README.md
Normal file
@ -0,0 +1,21 @@
|
||||
# CGNet: A Light-weight Context Guided Network for Semantic Segmentation
|
||||
|
||||
## Introduction
|
||||
|
||||
```latext
|
||||
@article{wu2018cgnet,
|
||||
title={CGNet: A Light-weight Context Guided Network for Semantic Segmentation},
|
||||
author={Wu, Tianyi and Tang, Sheng and Zhang, Rui and Zhang, Yongdong},
|
||||
journal={arXiv preprint arXiv:1811.08201},
|
||||
year={2018}
|
||||
}
|
||||
```
|
||||
|
||||
## Results and models
|
||||
|
||||
### Cityscapes
|
||||
|
||||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|
||||
|-----------|----------|-----------|--------:|----------|----------------|------:|--------------:|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| CGNet | M3N21 | 680x680 | 60000 | 7.5 | 30.51 | 65.63 | 68.04 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/cgnet/cgnet_680x680_60k_cityscapes/cgnet_680x680_60k_cityscapes_20201101_110253-4c0b2f2d.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/cgnet/cgnet_680x680_60k_cityscapes/cgnet_680x680_60k_cityscapes-20201101_110253.log.json) |
|
||||
| CGNet | M3N21 | 512x1024 | 60000 | 8.3 | 31.14 | 68.27 | 70.33 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/cgnet/cgnet_512x1024_60k_cityscapes/cgnet_512x1024_60k_cityscapes_20201101_110254-124ea03b.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/cgnet/cgnet_512x1024_60k_cityscapes/cgnet_512x1024_60k_cityscapes-20201101_110254.log.json) |
|
66
configs/cgnet/cgnet_512x1024_60k_cityscapes.py
Normal file
66
configs/cgnet/cgnet_512x1024_60k_cityscapes.py
Normal file
@ -0,0 +1,66 @@
|
||||
_base_ = ['../_base_/models/cgnet.py', '../_base_/default_runtime.py']
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(type='Adam', lr=0.001, eps=1e-08, weight_decay=0.0005)
|
||||
optimizer_config = dict()
|
||||
# learning policy
|
||||
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
|
||||
# runtime settings
|
||||
total_iters = 60000
|
||||
checkpoint_config = dict(by_epoch=False, interval=4000)
|
||||
evaluation = dict(interval=4000, metric='mIoU')
|
||||
|
||||
# dataset settings
|
||||
dataset_type = 'CityscapesDataset'
|
||||
data_root = 'data/cityscapes/'
|
||||
img_norm_cfg = dict(
|
||||
mean=[72.39239876, 82.90891754, 73.15835921], std=[1, 1, 1], to_rgb=True)
|
||||
crop_size = (512, 1024)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', flip_ratio=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(2048, 1024),
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=8,
|
||||
workers_per_gpu=8,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='leftImg8bit/train',
|
||||
ann_dir='gtFine/train',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='leftImg8bit/val',
|
||||
ann_dir='gtFine/val',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='leftImg8bit/val',
|
||||
ann_dir='gtFine/val',
|
||||
pipeline=test_pipeline))
|
50
configs/cgnet/cgnet_680x680_60k_cityscapes.py
Normal file
50
configs/cgnet/cgnet_680x680_60k_cityscapes.py
Normal file
@ -0,0 +1,50 @@
|
||||
_base_ = [
|
||||
'../_base_/models/cgnet.py', '../_base_/datasets/cityscapes.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(type='Adam', lr=0.001, eps=1e-08, weight_decay=0.0005)
|
||||
optimizer_config = dict()
|
||||
# learning policy
|
||||
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
|
||||
# runtime settings
|
||||
total_iters = 60000
|
||||
checkpoint_config = dict(by_epoch=False, interval=4000)
|
||||
evaluation = dict(interval=4000, metric='mIoU')
|
||||
|
||||
img_norm_cfg = dict(
|
||||
mean=[72.39239876, 82.90891754, 73.15835921], std=[1, 1, 1], to_rgb=True)
|
||||
crop_size = (680, 680)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size),
|
||||
dict(type='RandomFlip', flip_ratio=0.5),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(2048, 1024),
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=8,
|
||||
workers_per_gpu=8,
|
||||
train=dict(pipeline=train_pipeline),
|
||||
val=dict(pipeline=test_pipeline),
|
||||
test=dict(pipeline=test_pipeline))
|
@ -111,6 +111,10 @@ Please refer to [EMANet](https://github.com/open-mmlab/mmsegmentation/blob/maste
|
||||
|
||||
Please refer to [DNLNet](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dnlnet) for details.
|
||||
|
||||
### CGNet
|
||||
|
||||
Please refer to [CGNet](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/cgnet) for details.
|
||||
|
||||
### Mixed Precision (FP16) Training
|
||||
|
||||
Please refer [Mixed Precision (FP16) Training](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/fp16/README.md) for details.
|
||||
|
@ -1,3 +1,4 @@
|
||||
from .cgnet import CGNet
|
||||
from .fast_scnn import FastSCNN
|
||||
from .hrnet import HRNet
|
||||
from .mobilenet_v2 import MobileNetV2
|
||||
@ -8,5 +9,5 @@ from .unet import UNet
|
||||
|
||||
__all__ = [
|
||||
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
|
||||
'ResNeSt', 'MobileNetV2', 'UNet'
|
||||
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet'
|
||||
]
|
||||
|
367
mmseg/models/backbones/cgnet.py
Normal file
367
mmseg/models/backbones/cgnet.py
Normal file
@ -0,0 +1,367 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
|
||||
constant_init, kaiming_init)
|
||||
from mmcv.runner import load_checkpoint
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.utils import get_root_logger
|
||||
from ..builder import BACKBONES
|
||||
|
||||
|
||||
class GlobalContextExtractor(nn.Module):
|
||||
"""Global Context Extractor for CGNet.
|
||||
|
||||
This class is employed to refine the joFint feature of both local feature
|
||||
and surrounding context.
|
||||
|
||||
Args:
|
||||
channel (int): Number of input feature channels.
|
||||
reduction (int): Reductions for global context extractor. Default: 16.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self, channel, reduction=16, with_cp=False):
|
||||
super(GlobalContextExtractor, self).__init__()
|
||||
self.channel = channel
|
||||
self.reduction = reduction
|
||||
assert reduction >= 1 and channel >= reduction
|
||||
self.with_cp = with_cp
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True),
|
||||
nn.Linear(channel // reduction, channel), nn.Sigmoid())
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
num_batch, num_channel = x.size()[:2]
|
||||
y = self.avg_pool(x).view(num_batch, num_channel)
|
||||
y = self.fc(y).view(num_batch, num_channel, 1, 1)
|
||||
return x * y
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ContextGuidedBlock(nn.Module):
|
||||
"""Context Guided Block for CGNet.
|
||||
|
||||
This class consists of four components: local feature extractor,
|
||||
surrounding feature extractor, joint feature extractor and global
|
||||
context extractor.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input feature channels.
|
||||
out_channels (int): Number of output feature channels.
|
||||
dilation (int): Dilation rate for surrounding context extractor.
|
||||
Default: 2.
|
||||
reduction (int): Reduction for global context extractor. Default: 16.
|
||||
skip_connect (bool): Add input to output or not. Default: True.
|
||||
downsample (bool): Downsample the input to 1/2 or not. Default: False.
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN', requires_grad=True).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='PReLU').
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
dilation=2,
|
||||
reduction=16,
|
||||
skip_connect=True,
|
||||
downsample=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='PReLU'),
|
||||
with_cp=False):
|
||||
super(ContextGuidedBlock, self).__init__()
|
||||
self.with_cp = with_cp
|
||||
self.downsample = downsample
|
||||
|
||||
channels = out_channels if downsample else out_channels // 2
|
||||
if 'type' in act_cfg and act_cfg['type'] == 'PReLU':
|
||||
act_cfg['num_parameters'] = channels
|
||||
kernel_size = 3 if downsample else 1
|
||||
stride = 2 if downsample else 1
|
||||
padding = (kernel_size - 1) // 2
|
||||
|
||||
self.conv1x1 = ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
self.f_loc = build_conv_layer(
|
||||
conv_cfg,
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
groups=channels,
|
||||
bias=False)
|
||||
self.f_sur = build_conv_layer(
|
||||
conv_cfg,
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
padding=dilation,
|
||||
groups=channels,
|
||||
dilation=dilation,
|
||||
bias=False)
|
||||
|
||||
self.bn = build_norm_layer(norm_cfg, 2 * channels)[1]
|
||||
self.activate = nn.PReLU(2 * channels)
|
||||
|
||||
if downsample:
|
||||
self.bottleneck = build_conv_layer(
|
||||
conv_cfg,
|
||||
2 * channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
bias=False)
|
||||
|
||||
self.skip_connect = skip_connect and not downsample
|
||||
self.f_glo = GlobalContextExtractor(out_channels, reduction, with_cp)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
out = self.conv1x1(x)
|
||||
loc = self.f_loc(out)
|
||||
sur = self.f_sur(out)
|
||||
|
||||
joi_feat = torch.cat([loc, sur], 1) # the joint feature
|
||||
joi_feat = self.bn(joi_feat)
|
||||
joi_feat = self.activate(joi_feat)
|
||||
if self.downsample:
|
||||
joi_feat = self.bottleneck(joi_feat) # channel = out_channels
|
||||
# f_glo is employed to refine the joint feature
|
||||
out = self.f_glo(joi_feat)
|
||||
|
||||
if self.skip_connect:
|
||||
return x + out
|
||||
else:
|
||||
return out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class InputInjection(nn.Module):
|
||||
"""Downsampling module for CGNet."""
|
||||
|
||||
def __init__(self, num_downsampling):
|
||||
super(InputInjection, self).__init__()
|
||||
self.pool = nn.ModuleList()
|
||||
for i in range(num_downsampling):
|
||||
self.pool.append(nn.AvgPool2d(3, stride=2, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
for pool in self.pool:
|
||||
x = pool(x)
|
||||
return x
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class CGNet(nn.Module):
|
||||
"""CGNet backbone.
|
||||
|
||||
A Light-weight Context Guided Network for Semantic Segmentation
|
||||
arXiv: https://arxiv.org/abs/1811.08201
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input image channels. Normally 3.
|
||||
num_channels (tuple[int]): Numbers of feature channels at each stages.
|
||||
Default: (32, 64, 128).
|
||||
num_blocks (tuple[int]): Numbers of CG blocks at stage 1 and stage 2.
|
||||
Default: (3, 21).
|
||||
dilations (tuple[int]): Dilation rate for surrounding context
|
||||
extractors at stage 1 and stage 2. Default: (2, 4).
|
||||
reductions (tuple[int]): Reductions for global context extractors at
|
||||
stage 1 and stage 2. Default: (8, 16).
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN', requires_grad=True).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='PReLU').
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
num_channels=(32, 64, 128),
|
||||
num_blocks=(3, 21),
|
||||
dilations=(2, 4),
|
||||
reductions=(8, 16),
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='PReLU'),
|
||||
norm_eval=False,
|
||||
with_cp=False):
|
||||
|
||||
super(CGNet, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.num_channels = num_channels
|
||||
assert isinstance(self.num_channels, tuple) and len(
|
||||
self.num_channels) == 3
|
||||
self.num_blocks = num_blocks
|
||||
assert isinstance(self.num_blocks, tuple) and len(self.num_blocks) == 2
|
||||
self.dilations = dilations
|
||||
assert isinstance(self.dilations, tuple) and len(self.dilations) == 2
|
||||
self.reductions = reductions
|
||||
assert isinstance(self.reductions, tuple) and len(self.reductions) == 2
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
if 'type' in self.act_cfg and self.act_cfg['type'] == 'PReLU':
|
||||
self.act_cfg['num_parameters'] = num_channels[0]
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
|
||||
cur_channels = in_channels
|
||||
self.stem = nn.ModuleList()
|
||||
for i in range(3):
|
||||
self.stem.append(
|
||||
ConvModule(
|
||||
cur_channels,
|
||||
num_channels[0],
|
||||
3,
|
||||
2 if i == 0 else 1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
cur_channels = num_channels[0]
|
||||
|
||||
self.inject_2x = InputInjection(1) # down-sample for Input, factor=2
|
||||
self.inject_4x = InputInjection(2) # down-sample for Input, factor=4
|
||||
|
||||
cur_channels += in_channels
|
||||
self.norm_prelu_0 = nn.Sequential(
|
||||
build_norm_layer(norm_cfg, cur_channels)[1],
|
||||
nn.PReLU(cur_channels))
|
||||
|
||||
# stage 1
|
||||
self.level1 = nn.ModuleList()
|
||||
for i in range(num_blocks[0]):
|
||||
self.level1.append(
|
||||
ContextGuidedBlock(
|
||||
cur_channels if i == 0 else num_channels[1],
|
||||
num_channels[1],
|
||||
dilations[0],
|
||||
reductions[0],
|
||||
downsample=(i == 0),
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
with_cp=with_cp)) # CG block
|
||||
|
||||
cur_channels = 2 * num_channels[1] + in_channels
|
||||
self.norm_prelu_1 = nn.Sequential(
|
||||
build_norm_layer(norm_cfg, cur_channels)[1],
|
||||
nn.PReLU(cur_channels))
|
||||
|
||||
# stage 2
|
||||
self.level2 = nn.ModuleList()
|
||||
for i in range(num_blocks[1]):
|
||||
self.level2.append(
|
||||
ContextGuidedBlock(
|
||||
cur_channels if i == 0 else num_channels[2],
|
||||
num_channels[2],
|
||||
dilations[1],
|
||||
reductions[1],
|
||||
downsample=(i == 0),
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
with_cp=with_cp)) # CG block
|
||||
|
||||
cur_channels = 2 * num_channels[2]
|
||||
self.norm_prelu_2 = nn.Sequential(
|
||||
build_norm_layer(norm_cfg, cur_channels)[1],
|
||||
nn.PReLU(cur_channels))
|
||||
|
||||
def forward(self, x):
|
||||
output = []
|
||||
|
||||
# stage 0
|
||||
inp_2x = self.inject_2x(x)
|
||||
inp_4x = self.inject_4x(x)
|
||||
for layer in self.stem:
|
||||
x = layer(x)
|
||||
x = self.norm_prelu_0(torch.cat([x, inp_2x], 1))
|
||||
output.append(x)
|
||||
|
||||
# stage 1
|
||||
for i, layer in enumerate(self.level1):
|
||||
x = layer(x)
|
||||
if i == 0:
|
||||
down1 = x
|
||||
x = self.norm_prelu_1(torch.cat([x, down1, inp_4x], 1))
|
||||
output.append(x)
|
||||
|
||||
# stage 2
|
||||
for i, layer in enumerate(self.level2):
|
||||
x = layer(x)
|
||||
if i == 0:
|
||||
down2 = x
|
||||
x = self.norm_prelu_2(torch.cat([down2, x], 1))
|
||||
output.append(x)
|
||||
|
||||
return output
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
"""Initialize the weights in backbone.
|
||||
|
||||
Args:
|
||||
pretrained (str, optional): Path to pre-trained weights.
|
||||
Defaults to None.
|
||||
"""
|
||||
if isinstance(pretrained, str):
|
||||
logger = get_root_logger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
kaiming_init(m)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, 1)
|
||||
elif isinstance(m, nn.PReLU):
|
||||
constant_init(m, 0)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode whill keeping the normalization
|
||||
layer freezed."""
|
||||
super(CGNet, self).train(mode)
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
@ -24,11 +24,14 @@ class FCNHead(BaseDecodeHead):
|
||||
kernel_size=3,
|
||||
concat_input=True,
|
||||
**kwargs):
|
||||
assert num_convs > 0
|
||||
assert num_convs >= 0
|
||||
self.num_convs = num_convs
|
||||
self.concat_input = concat_input
|
||||
self.kernel_size = kernel_size
|
||||
super(FCNHead, self).__init__(**kwargs)
|
||||
if num_convs == 0:
|
||||
assert self.in_channels == self.channels
|
||||
|
||||
convs = []
|
||||
convs.append(
|
||||
ConvModule(
|
||||
@ -49,7 +52,10 @@ class FCNHead(BaseDecodeHead):
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
self.convs = nn.Sequential(*convs)
|
||||
if num_convs == 0:
|
||||
self.convs = nn.Identity()
|
||||
else:
|
||||
self.convs = nn.Sequential(*convs)
|
||||
if self.concat_input:
|
||||
self.conv_cat = ConvModule(
|
||||
self.in_channels + self.channels,
|
||||
|
@ -4,8 +4,10 @@ from mmcv.ops import DeformConv2dPack
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
from torch.nn.modules import AvgPool2d, GroupNorm
|
||||
|
||||
from mmseg.models.backbones import (FastSCNN, ResNeSt, ResNet, ResNetV1d,
|
||||
ResNeXt)
|
||||
from mmseg.models.backbones import (CGNet, FastSCNN, ResNeSt, ResNet,
|
||||
ResNetV1d, ResNeXt)
|
||||
from mmseg.models.backbones.cgnet import (ContextGuidedBlock,
|
||||
GlobalContextExtractor)
|
||||
from mmseg.models.backbones.resnest import Bottleneck as BottleneckS
|
||||
from mmseg.models.backbones.resnet import BasicBlock, Bottleneck
|
||||
from mmseg.models.backbones.resnext import Bottleneck as BottleneckX
|
||||
@ -729,3 +731,147 @@ def test_resnest_backbone():
|
||||
assert feat[1].shape == torch.Size([2, 512, 28, 28])
|
||||
assert feat[2].shape == torch.Size([2, 1024, 14, 14])
|
||||
assert feat[3].shape == torch.Size([2, 2048, 7, 7])
|
||||
|
||||
|
||||
def test_cgnet_GlobalContextExtractor():
|
||||
block = GlobalContextExtractor(16, 16, with_cp=True)
|
||||
x = torch.randn(2, 16, 64, 64, requires_grad=True)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([2, 16, 64, 64])
|
||||
|
||||
|
||||
def test_cgnet_context_guided_block():
|
||||
with pytest.raises(AssertionError):
|
||||
# cgnet ContextGuidedBlock GlobalContextExtractor channel and reduction
|
||||
# constraints.
|
||||
ContextGuidedBlock(8, 8)
|
||||
|
||||
# test cgnet ContextGuidedBlock with checkpoint forward
|
||||
block = ContextGuidedBlock(
|
||||
16, 16, act_cfg=dict(type='PReLU'), with_cp=True)
|
||||
assert block.with_cp
|
||||
x = torch.randn(2, 16, 64, 64, requires_grad=True)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([2, 16, 64, 64])
|
||||
|
||||
# test cgnet ContextGuidedBlock without checkpoint forward
|
||||
block = ContextGuidedBlock(32, 32)
|
||||
assert not block.with_cp
|
||||
x = torch.randn(3, 32, 32, 32)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([3, 32, 32, 32])
|
||||
|
||||
# test cgnet ContextGuidedBlock with down sampling
|
||||
block = ContextGuidedBlock(32, 32, downsample=True)
|
||||
assert block.conv1x1.conv.in_channels == 32
|
||||
assert block.conv1x1.conv.out_channels == 32
|
||||
assert block.conv1x1.conv.kernel_size == (3, 3)
|
||||
assert block.conv1x1.conv.stride == (2, 2)
|
||||
assert block.conv1x1.conv.padding == (1, 1)
|
||||
|
||||
assert block.f_loc.in_channels == 32
|
||||
assert block.f_loc.out_channels == 32
|
||||
assert block.f_loc.kernel_size == (3, 3)
|
||||
assert block.f_loc.stride == (1, 1)
|
||||
assert block.f_loc.padding == (1, 1)
|
||||
assert block.f_loc.groups == 32
|
||||
assert block.f_loc.dilation == (1, 1)
|
||||
assert block.f_loc.bias is None
|
||||
|
||||
assert block.f_sur.in_channels == 32
|
||||
assert block.f_sur.out_channels == 32
|
||||
assert block.f_sur.kernel_size == (3, 3)
|
||||
assert block.f_sur.stride == (1, 1)
|
||||
assert block.f_sur.padding == (2, 2)
|
||||
assert block.f_sur.groups == 32
|
||||
assert block.f_sur.dilation == (2, 2)
|
||||
assert block.f_sur.bias is None
|
||||
|
||||
assert block.bottleneck.in_channels == 64
|
||||
assert block.bottleneck.out_channels == 32
|
||||
assert block.bottleneck.kernel_size == (1, 1)
|
||||
assert block.bottleneck.stride == (1, 1)
|
||||
assert block.bottleneck.bias is None
|
||||
|
||||
x = torch.randn(1, 32, 32, 32)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 32, 16, 16])
|
||||
|
||||
# test cgnet ContextGuidedBlock without down sampling
|
||||
block = ContextGuidedBlock(32, 32, downsample=False)
|
||||
assert block.conv1x1.conv.in_channels == 32
|
||||
assert block.conv1x1.conv.out_channels == 16
|
||||
assert block.conv1x1.conv.kernel_size == (1, 1)
|
||||
assert block.conv1x1.conv.stride == (1, 1)
|
||||
assert block.conv1x1.conv.padding == (0, 0)
|
||||
|
||||
assert block.f_loc.in_channels == 16
|
||||
assert block.f_loc.out_channels == 16
|
||||
assert block.f_loc.kernel_size == (3, 3)
|
||||
assert block.f_loc.stride == (1, 1)
|
||||
assert block.f_loc.padding == (1, 1)
|
||||
assert block.f_loc.groups == 16
|
||||
assert block.f_loc.dilation == (1, 1)
|
||||
assert block.f_loc.bias is None
|
||||
|
||||
assert block.f_sur.in_channels == 16
|
||||
assert block.f_sur.out_channels == 16
|
||||
assert block.f_sur.kernel_size == (3, 3)
|
||||
assert block.f_sur.stride == (1, 1)
|
||||
assert block.f_sur.padding == (2, 2)
|
||||
assert block.f_sur.groups == 16
|
||||
assert block.f_sur.dilation == (2, 2)
|
||||
assert block.f_sur.bias is None
|
||||
|
||||
x = torch.randn(1, 32, 32, 32)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 32, 32, 32])
|
||||
|
||||
|
||||
def test_cgnet_backbone():
|
||||
with pytest.raises(AssertionError):
|
||||
# check invalid num_channels
|
||||
CGNet(num_channels=(32, 64, 128, 256))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# check invalid num_blocks
|
||||
CGNet(num_blocks=(3, 21, 3))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# check invalid dilation
|
||||
CGNet(num_blocks=2)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# check invalid reduction
|
||||
CGNet(reductions=16)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# check invalid num_channels and reduction
|
||||
CGNet(num_channels=(32, 64, 128), reductions=(64, 129))
|
||||
|
||||
# Test CGNet with default settings
|
||||
model = CGNet()
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(2, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == torch.Size([2, 35, 112, 112])
|
||||
assert feat[1].shape == torch.Size([2, 131, 56, 56])
|
||||
assert feat[2].shape == torch.Size([2, 256, 28, 28])
|
||||
|
||||
# Test CGNet with norm_eval True and with_cp True
|
||||
model = CGNet(norm_eval=True, with_cp=True)
|
||||
with pytest.raises(TypeError):
|
||||
# check invalid pretrained
|
||||
model.init_weights(pretrained=8)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(2, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == torch.Size([2, 35, 112, 112])
|
||||
assert feat[1].shape == torch.Size([2, 131, 56, 56])
|
||||
assert feat[2].shape == torch.Size([2, 256, 28, 28])
|
||||
|
@ -105,8 +105,8 @@ def test_decode_head():
|
||||
def test_fcn_head():
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# num_convs must be larger than 0
|
||||
FCNHead(num_classes=19, num_convs=0)
|
||||
# num_convs must be not less than 0
|
||||
FCNHead(num_classes=19, num_convs=-1)
|
||||
|
||||
# test no norm_cfg
|
||||
head = FCNHead(in_channels=32, channels=16, num_classes=19)
|
||||
@ -178,6 +178,20 @@ def test_fcn_head():
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||
|
||||
# test num_conv = 0
|
||||
inputs = [torch.randn(1, 32, 45, 45)]
|
||||
head = FCNHead(
|
||||
in_channels=32,
|
||||
channels=32,
|
||||
num_classes=19,
|
||||
num_convs=0,
|
||||
concat_input=False)
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
assert isinstance(head.convs, torch.nn.Identity)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||
|
||||
|
||||
def test_psp_head():
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user