[Feature] Support PointRend (#109)

* [Feature] Support PointRend

* add previous test

* update modelzoo
This commit is contained in:
Jerry Jiarui XU 2020-09-07 19:59:44 +08:00 committed by GitHub
parent e807773a64
commit ff98229a3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 554 additions and 4 deletions

View File

@ -73,6 +73,7 @@ Supported methods:
- [x] [OCRNet](configs/ocrnet) - [x] [OCRNet](configs/ocrnet)
- [x] [Fast-SCNN](configs/fastscnn) - [x] [Fast-SCNN](configs/fastscnn)
- [x] [Semantic FPN](configs/sem_fpn) - [x] [Semantic FPN](configs/sem_fpn)
- [x] [PointRend](configs/point_rend)
- [x] [EMANet](configs/emanet) - [x] [EMANet](configs/emanet)
- [x] [DNLNet](configs/dnlnet) - [x] [DNLNet](configs/dnlnet)
- [x] [Mixed Precision (FP16) Training](configs/fp16/README.md) - [x] [Mixed Precision (FP16) Training](configs/fp16/README.md)

View File

@ -0,0 +1,56 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='CascadeEncoderDecoder',
num_stages=2,
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 1, 1),
strides=(1, 2, 2, 2),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=4),
decode_head=[
dict(
type='FPNHead',
in_channels=[256, 256, 256, 256],
in_index=[0, 1, 2, 3],
feature_strides=[4, 8, 16, 32],
channels=128,
dropout_ratio=-1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
dict(
type='PointHead',
in_channels=[256],
in_index=[0],
channels=256,
num_fcs=3,
coarse_pred_each_layer=True,
dropout_ratio=-1,
num_classes=19,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
])
# model training and testing settings
train_cfg = dict(
num_points=2048, oversample_ratio=3, importance_sample_ratio=0.75)
test_cfg = dict(
mode='whole',
subdivision_steps=2,
subdivision_num_points=8196,
scale_factor=2)

View File

@ -0,0 +1,27 @@
# PointRend: Image Segmentation as Rendering
## Introduction
```
@misc{alex2019pointrend,
title={PointRend: Image Segmentation as Rendering},
author={Alexander Kirillov and Yuxin Wu and Kaiming He and Ross Girshick},
year={2019},
eprint={1912.08193},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
## Results and models
### Cityscapes
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|-----------|----------|-----------|--------:|---------:|----------------|------:|---------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| PointRend | R-50 | 512x1024 | 80000 | 3.1 | 8.48 | 76.47 | 78.13 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/point_rend/pointrend_r50_512x1024_80k_cityscapes/pointrend_r50_512x1024_80k_cityscapes_20200711_015821-bb1ff523.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/point_rend/pointrend_r50_512x1024_80k_cityscapes/pointrend_r50_512x1024_80k_cityscapes-20200715_214714.log.json) |
| PointRend | R-101 | 512x1024 | 80000 | 4.2 | 7.00 | 78.30 | 79.97 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/point_rend/pointrend_r101_512x1024_80k_cityscapes/pointrend_r101_512x1024_80k_cityscapes_20200711_170850-d0ca84be.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/point_rend/pointrend_r101_512x1024_80k_cityscapes/pointrend_r101_512x1024_80k_cityscapes-20200715_214824.log.json) |
### ADE20K
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|-----------|----------|-----------|--------:|---------:|----------------|------:|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| PointRend | R-50 | 512x512 | 160000 | 5.1 | 17.31 | 37.64 | 39.17 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/point_rend/pointrend_r50_512x512_160k_ade20k/pointrend_r50_512x512_160k_ade20k_20200807_232644-ac3febf2.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/point_rend/pointrend_r50_512x512_160k_ade20k/pointrend_r50_512x512_160k_ade20k-20200807_232644.log.json) |
| PointRend | R-101 | 512x512 | 160000 | 6.1 | 15.50 | 40.02 | 41.60 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/point_rend/pointrend_r101_512x512_160k_ade20k/pointrend_r101_512x512_160k_ade20k_20200808_030852-8834902a.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/point_rend/pointrend_r101_512x512_160k_ade20k/pointrend_r101_512x512_160k_ade20k-20200808_030852.log.json) |

View File

@ -0,0 +1,2 @@
_base_ = './pointrend_r50_512x1024_80k_cityscapes.py'
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))

View File

@ -0,0 +1,2 @@
_base_ = './pointrend_r50_512x512_160k_ade20k.py'
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))

View File

@ -0,0 +1,5 @@
_base_ = [
'../_base_/models/pointrend_r50.py', '../_base_/datasets/cityscapes.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
]
lr_config = dict(warmup='linear', warmup_iters=200)

View File

@ -0,0 +1,32 @@
_base_ = [
'../_base_/models/pointrend_r50.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(decode_head=[
dict(
type='FPNHead',
in_channels=[256, 256, 256, 256],
in_index=[0, 1, 2, 3],
feature_strides=[4, 8, 16, 32],
channels=128,
dropout_ratio=-1,
num_classes=150,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
dict(
type='PointHead',
in_channels=[256],
in_index=[0],
channels=256,
num_fcs=3,
coarse_pred_each_layer=True,
dropout_ratio=-1,
num_classes=150,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
])
lr_config = dict(warmup='linear', warmup_iters=200)

View File

@ -89,6 +89,22 @@ Please refer to [Fast-SCNN](https://github.com/open-mmlab/mmsegmentation/blob/ma
Please refer to [ResNeSt](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/resnest) for details. Please refer to [ResNeSt](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/resnest) for details.
### Semantic FPN
Please refer to [Semantic FPN](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/semfpn) for details.
### PointRend
Please refer to [PointRend](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/point_rend) for details.
### EMANet
Please refer to [EMANet](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/emanet) for details.
### DNLNet
Please refer to [DNLNet](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/dnlnet) for details.
### Mixed Precision (FP16) Training ### Mixed Precision (FP16) Training
Please refer [Mixed Precision (FP16) Training](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/fp16/README.md) for details. Please refer [Mixed Precision (FP16) Training](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/fp16/README.md) for details.

View File

@ -10,6 +10,7 @@ from .fpn_head import FPNHead
from .gc_head import GCHead from .gc_head import GCHead
from .nl_head import NLHead from .nl_head import NLHead
from .ocr_head import OCRHead from .ocr_head import OCRHead
from .point_head import PointHead
from .psa_head import PSAHead from .psa_head import PSAHead
from .psp_head import PSPHead from .psp_head import PSPHead
from .sep_aspp_head import DepthwiseSeparableASPPHead from .sep_aspp_head import DepthwiseSeparableASPPHead
@ -19,5 +20,6 @@ from .uper_head import UPerHead
__all__ = [ __all__ = [
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead', 'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead', 'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead' 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
'PointHead'
] ]

View File

@ -0,0 +1,349 @@
# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, normal_init
from mmcv.ops import point_sample
from mmseg.models.builder import HEADS
from mmseg.ops import resize
from ..losses import accuracy
from .cascade_decode_head import BaseCascadeDecodeHead
def calculate_uncertainty(seg_logits):
"""Estimate uncertainty based on seg logits.
For each location of the prediction ``seg_logits`` we estimate
uncertainty as the difference between top first and top second
predicted logits.
Args:
seg_logits (Tensor): Semantic segmentation logits,
shape (batch_size, num_classes, height, width).
Returns:
scores (Tensor): T uncertainty scores with the most uncertain
locations having the highest uncertainty score, shape (
batch_size, 1, height, width)
"""
top2_scores = torch.topk(seg_logits, k=2, dim=1)[0]
return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
@HEADS.register_module()
class PointHead(BaseCascadeDecodeHead):
"""A mask point head use in PointRend.
``PointHead`` use shared multi-layer perceptron (equivalent to
nn.Conv1d) to predict the logit of input points. The fine-grained feature
and coarse feature will be concatenate together for predication.
Args:
num_fcs (int): Number of fc layers in the head. Default: 3.
in_channels (int): Number of input channels. Default: 256.
fc_channels (int): Number of fc channels. Default: 256.
num_classes (int): Number of classes for logits. Default: 80.
class_agnostic (bool): Whether use class agnostic classification.
If so, the output channels of logits will be 1. Default: False.
coarse_pred_each_layer (bool): Whether concatenate coarse feature with
the output of each fc layer. Default: True.
conv_cfg (dict|None): Dictionary to construct and config conv layer.
Default: dict(type='Conv1d'))
norm_cfg (dict|None): Dictionary to construct and config norm layer.
Default: None.
loss_point (dict): Dictionary to construct and config loss layer of
point head. Default: dict(type='CrossEntropyLoss', use_mask=True,
loss_weight=1.0).
"""
def __init__(self,
num_fcs=3,
coarse_pred_each_layer=True,
conv_cfg=dict(type='Conv1d'),
norm_cfg=None,
act_cfg=dict(type='ReLU', inplace=False),
**kwargs):
super(PointHead, self).__init__(
input_transform='multiple_select',
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**kwargs)
self.num_fcs = num_fcs
self.coarse_pred_each_layer = coarse_pred_each_layer
fc_in_channels = sum(self.in_channels) + self.num_classes
fc_channels = self.channels
self.fcs = nn.ModuleList()
for k in range(num_fcs):
fc = ConvModule(
fc_in_channels,
fc_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.fcs.append(fc)
fc_in_channels = fc_channels
fc_in_channels += self.num_classes if self.coarse_pred_each_layer \
else 0
self.fc_seg = nn.Conv1d(
fc_in_channels,
self.num_classes,
kernel_size=1,
stride=1,
padding=0)
if self.dropout_ratio > 0:
self.dropout = nn.Dropout(self.dropout_ratio)
delattr(self, 'conv_seg')
def init_weights(self):
"""Initialize weights of classification layer."""
normal_init(self.fc_seg, std=0.001)
def cls_seg(self, feat):
"""Classify each pixel with fc."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.fc_seg(feat)
return output
def forward(self, fine_grained_point_feats, coarse_point_feats):
x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1)
for fc in self.fcs:
x = fc(x)
if self.coarse_pred_each_layer:
x = torch.cat((x, coarse_point_feats), dim=1)
return self.cls_seg(x)
def _get_fine_grained_point_feats(self, x, points):
"""Sample from fine grained features.
Args:
x (list[Tensor]): Feature pyramid from by neck or backbone.
points (Tensor): Point coordinates, shape (batch_size,
num_points, 2).
Returns:
fine_grained_feats (Tensor): Sampled fine grained feature,
shape (batch_size, sum(channels of x), num_points).
"""
fine_grained_feats_list = [
point_sample(_, points, align_corners=self.align_corners)
for _ in x
]
if len(fine_grained_feats_list) > 1:
fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1)
else:
fine_grained_feats = fine_grained_feats_list[0]
return fine_grained_feats
def _get_coarse_point_feats(self, prev_output, points):
"""Sample from fine grained features.
Args:
prev_output (list[Tensor]): Prediction of previous decode head.
points (Tensor): Point coordinates, shape (batch_size,
num_points, 2).
Returns:
coarse_feats (Tensor): Sampled coarse feature, shape (batch_size,
num_classes, num_points).
"""
coarse_feats = point_sample(
prev_output, points, align_corners=self.align_corners)
return coarse_feats
def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg,
train_cfg):
"""Forward function for training.
Args:
inputs (list[Tensor]): List of multi-level img features.
prev_output (Tensor): The output of previous decode head.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
gt_semantic_seg (Tensor): Semantic segmentation masks
used if the architecture supports semantic segmentation task.
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
x = self._transform_inputs(inputs)
with torch.no_grad():
points = self.get_points_train(
prev_output, calculate_uncertainty, cfg=train_cfg)
fine_grained_point_feats = self._get_fine_grained_point_feats(
x, points)
coarse_point_feats = self._get_coarse_point_feats(prev_output, points)
point_logits = self.forward(fine_grained_point_feats,
coarse_point_feats)
point_label = point_sample(
gt_semantic_seg.float(),
points,
mode='nearest',
align_corners=self.align_corners)
point_label = point_label.squeeze(1).long()
losses = self.losses(point_logits, point_label)
return losses
def forward_test(self, inputs, prev_output, img_metas, test_cfg):
"""Forward function for testing.
Args:
inputs (list[Tensor]): List of multi-level img features.
prev_output (Tensor): The output of previous decode head.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
test_cfg (dict): The testing config.
Returns:
Tensor: Output segmentation map.
"""
x = self._transform_inputs(inputs)
refined_seg_logits = prev_output.clone()
for _ in range(test_cfg.subdivision_steps):
refined_seg_logits = resize(
refined_seg_logits,
scale_factor=test_cfg.scale_factor,
mode='bilinear',
align_corners=self.align_corners)
batch_size, channels, height, width = refined_seg_logits.shape
point_indices, points = self.get_points_test(
refined_seg_logits, calculate_uncertainty, cfg=test_cfg)
fine_grained_point_feats = self._get_fine_grained_point_feats(
x, points)
coarse_point_feats = self._get_coarse_point_feats(
prev_output, points)
point_logits = self.forward(fine_grained_point_feats,
coarse_point_feats)
point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1)
refined_seg_logits = refined_seg_logits.reshape(
batch_size, channels, height * width)
refined_seg_logits = refined_seg_logits.scatter_(
2, point_indices, point_logits)
refined_seg_logits = refined_seg_logits.view(
batch_size, channels, height, width)
return refined_seg_logits
def losses(self, point_logits, point_label):
"""Compute segmentation loss."""
loss = dict()
loss['loss_point'] = self.loss_decode(
point_logits, point_label, ignore_index=self.ignore_index)
loss['acc_point'] = accuracy(point_logits, point_label)
return loss
def get_points_train(self, seg_logits, uncertainty_func, cfg):
"""Sample points for training.
Sample points in [0, 1] x [0, 1] coordinate space based on their
uncertainty. The uncertainties are calculated for each point using
'uncertainty_func' function that takes point's logit prediction as
input.
Args:
seg_logits (Tensor): Semantic segmentation logits, shape (
batch_size, num_classes, height, width).
uncertainty_func (func): uncertainty calculation function.
cfg (dict): Training config of point head.
Returns:
point_coords (Tensor): A tensor of shape (batch_size, num_points,
2) that contains the coordinates of ``num_points`` sampled
points.
"""
num_points = cfg.num_points
oversample_ratio = cfg.oversample_ratio
importance_sample_ratio = cfg.importance_sample_ratio
assert oversample_ratio >= 1
assert 0 <= importance_sample_ratio <= 1
batch_size = seg_logits.shape[0]
num_sampled = int(num_points * oversample_ratio)
point_coords = torch.rand(
batch_size, num_sampled, 2, device=seg_logits.device)
point_logits = point_sample(seg_logits, point_coords)
# It is crucial to calculate uncertainty based on the sampled
# prediction value for the points. Calculating uncertainties of the
# coarse predictions first and sampling them for points leads to
# incorrect results. To illustrate this: assume uncertainty func(
# logits)=-abs(logits), a sampled point between two coarse
# predictions with -1 and 1 logits has 0 logits, and therefore 0
# uncertainty value. However, if we calculate uncertainties for the
# coarse predictions first, both will have -1 uncertainty,
# and sampled point will get -1 uncertainty.
point_uncertainties = uncertainty_func(point_logits)
num_uncertain_points = int(importance_sample_ratio * num_points)
num_random_points = num_points - num_uncertain_points
idx = torch.topk(
point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
shift = num_sampled * torch.arange(
batch_size, dtype=torch.long, device=seg_logits.device)
idx += shift[:, None]
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
batch_size, num_uncertain_points, 2)
if num_random_points > 0:
rand_point_coords = torch.rand(
batch_size, num_random_points, 2, device=seg_logits.device)
point_coords = torch.cat((point_coords, rand_point_coords), dim=1)
return point_coords
def get_points_test(self, seg_logits, uncertainty_func, cfg):
"""Sample points for testing.
Find ``num_points`` most uncertain points from ``uncertainty_map``.
Args:
seg_logits (Tensor): A tensor of shape (batch_size, num_classes,
height, width) for class-specific or class-agnostic prediction.
uncertainty_func (func): uncertainty calculation function.
cfg (dict): Testing config of point head.
Returns:
point_indices (Tensor): A tensor of shape (batch_size, num_points)
that contains indices from [0, height x width) of the most
uncertain points.
point_coords (Tensor): A tensor of shape (batch_size, num_points,
2) that contains [0, 1] x [0, 1] normalized coordinates of the
most uncertain points from the ``height x width`` grid .
"""
num_points = cfg.subdivision_num_points
uncertainty_map = uncertainty_func(seg_logits)
batch_size, _, height, width = uncertainty_map.shape
h_step = 1.0 / height
w_step = 1.0 / width
uncertainty_map = uncertainty_map.view(batch_size, height * width)
num_points = min(height * width, num_points)
point_indices = uncertainty_map.topk(num_points, dim=1)[1]
point_coords = torch.zeros(
batch_size,
num_points,
2,
dtype=torch.float,
device=seg_logits.device)
point_coords[:, :, 0] = w_step / 2.0 + (point_indices %
width).float() * w_step
point_coords[:, :, 1] = h_step / 2.0 + (point_indices //
width).float() * h_step
return point_indices, point_coords

View File

@ -157,6 +157,11 @@ def test_sem_fpn_forward():
_test_encoder_decoder_forward('sem_fpn/fpn_r50_512x1024_80k_cityscapes.py') _test_encoder_decoder_forward('sem_fpn/fpn_r50_512x1024_80k_cityscapes.py')
def test_point_rend_forward():
_test_encoder_decoder_forward(
'point_rend/pointrend_r50_512x1024_80k_cityscapes.py')
def test_mobilenet_v2_forward(): def test_mobilenet_v2_forward():
_test_encoder_decoder_forward( _test_encoder_decoder_forward(
'mobilenet_v2/pspnet_m-v2-d8_512x1024_80k_cityscapes.py') 'mobilenet_v2/pspnet_m-v2-d8_512x1024_80k_cityscapes.py')

View File

@ -3,13 +3,15 @@ from unittest.mock import patch
import pytest import pytest
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.utils import ConfigDict
from mmcv.utils.parrots_wrapper import SyncBatchNorm from mmcv.utils.parrots_wrapper import SyncBatchNorm
from mmseg.models.decode_heads import (ANNHead, ASPPHead, CCHead, DAHead, from mmseg.models.decode_heads import (ANNHead, ASPPHead, CCHead, DAHead,
DepthwiseSeparableASPPHead, DNLHead, DepthwiseSeparableASPPHead,
DepthwiseSeparableFCNHead, DNLHead,
EMAHead, EncHead, FCNHead, GCHead, EMAHead, EncHead, FCNHead, GCHead,
NLHead, OCRHead, PSAHead, PSPHead, NLHead, OCRHead, PointHead, PSAHead,
UPerHead) PSPHead, UPerHead)
from mmseg.models.decode_heads.decode_head import BaseDecodeHead from mmseg.models.decode_heads.decode_head import BaseDecodeHead
@ -542,6 +544,40 @@ def test_dw_aspp_head():
assert outputs.shape == (1, head.num_classes, 45, 45) assert outputs.shape == (1, head.num_classes, 45, 45)
def test_sep_fcn_head():
# test sep_fcn_head with concat_input=False
head = DepthwiseSeparableFCNHead(
in_channels=128,
channels=128,
concat_input=False,
num_classes=19,
in_index=-1,
norm_cfg=dict(type='BN', requires_grad=True, momentum=0.01))
x = [torch.rand(2, 128, 32, 32)]
output = head(x)
assert output.shape == (2, head.num_classes, 32, 32)
assert not head.concat_input
from mmseg.ops.separable_conv_module import DepthwiseSeparableConvModule
assert isinstance(head.convs[0], DepthwiseSeparableConvModule)
assert isinstance(head.convs[1], DepthwiseSeparableConvModule)
assert head.conv_seg.kernel_size == (1, 1)
head = DepthwiseSeparableFCNHead(
in_channels=64,
channels=64,
concat_input=True,
num_classes=19,
in_index=-1,
norm_cfg=dict(type='BN', requires_grad=True, momentum=0.01))
x = [torch.rand(3, 64, 32, 32)]
output = head(x)
assert output.shape == (3, head.num_classes, 32, 32)
assert head.concat_input
from mmseg.ops.separable_conv_module import DepthwiseSeparableConvModule
assert isinstance(head.convs[0], DepthwiseSeparableConvModule)
assert isinstance(head.convs[1], DepthwiseSeparableConvModule)
def test_dnl_head(): def test_dnl_head():
# DNL with 'embedded_gaussian' mode # DNL with 'embedded_gaussian' mode
head = DNLHead(in_channels=32, channels=16, num_classes=19) head = DNLHead(in_channels=32, channels=16, num_classes=19)
@ -598,3 +634,20 @@ def test_emanet_head():
head, inputs = to_cuda(head, inputs) head, inputs = to_cuda(head, inputs)
outputs = head(inputs) outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 45, 45) assert outputs.shape == (1, head.num_classes, 45, 45)
def test_point_head():
inputs = [torch.randn(1, 32, 45, 45)]
point_head = PointHead(
in_channels=[32], in_index=[0], channels=16, num_classes=19)
assert len(point_head.fcs) == 3
fcn_head = FCNHead(in_channels=32, channels=16, num_classes=19)
if torch.cuda.is_available():
head, inputs = to_cuda(point_head, inputs)
head, inputs = to_cuda(fcn_head, inputs)
prev_output = fcn_head(inputs)
test_cfg = ConfigDict(
subdivision_steps=2, subdivision_num_points=8196, scale_factor=2)
output = point_head.forward_test(inputs, prev_output, None, test_cfg)
assert output.shape == (1, point_head.num_classes, 180, 180)