Support APCNet (#299)
* Support APCNet * code optimization * add apcnet configs * add benchmark * add readme and model zoo * fix docpull/313/head
parent
5c6e65759f
commit
e3f6f655d6
|
@ -73,6 +73,7 @@ Supported methods:
|
|||
- [x] [EncNet](configs/encnet)
|
||||
- [x] [CCNet](configs/ccnet)
|
||||
- [x] [DANet](configs/danet)
|
||||
- [x] [APCNet](configs/apcnet)
|
||||
- [x] [GCNet](configs/gcnet)
|
||||
- [x] [ANN](configs/ann)
|
||||
- [x] [OCRNet](configs/ocrnet)
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
pretrained='open-mmlab://resnet50_v1c',
|
||||
backbone=dict(
|
||||
type='ResNetV1c',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
dilations=(1, 1, 2, 4),
|
||||
strides=(1, 2, 1, 1),
|
||||
norm_cfg=norm_cfg,
|
||||
norm_eval=False,
|
||||
style='pytorch',
|
||||
contract_dilation=True),
|
||||
decode_head=dict(
|
||||
type='APCHead',
|
||||
in_channels=2048,
|
||||
in_index=3,
|
||||
channels=512,
|
||||
pool_scales=(1, 2, 3, 6),
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
auxiliary_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=1024,
|
||||
in_index=2,
|
||||
channels=256,
|
||||
num_convs=1,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)))
|
||||
# model training and testing settings
|
||||
train_cfg = dict()
|
||||
test_cfg = dict(mode='whole')
|
|
@ -0,0 +1,37 @@
|
|||
# Adaptive Pyramid Context Network for Semantic Segmentation
|
||||
|
||||
## Introduction
|
||||
|
||||
```latex
|
||||
@InProceedings{He_2019_CVPR,
|
||||
author = {He, Junjun and Deng, Zhongying and Zhou, Lei and Wang, Yali and Qiao, Yu},
|
||||
title = {Adaptive Pyramid Context Network for Semantic Segmentation},
|
||||
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||
month = {June},
|
||||
year = {2019}
|
||||
}
|
||||
```
|
||||
|
||||
## Results and models
|
||||
|
||||
### Cityscapes
|
||||
|
||||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|
||||
|--------|----------|-----------|--------:|----------|----------------|------:|--------------:|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| APCNet | R-50-D8 | 512x1024 | 40000 | 7.7 | 3.57 | 78.02 | 79.26 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r50-d8_512x1024_40k_cityscapes/apcnet_r50-d8_512x1024_40k_cityscapes_20201214_115717-5e88fa33.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r50-d8_512x1024_40k_cityscapes/apcnet_r50-d8_512x1024_40k_cityscapes-20201214_115717.log.json) |
|
||||
| APCNet | R-101-D8 | 512x1024 | 40000 | 11.2 | 2.15 | 79.08 | 80.34 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r101-d8_512x1024_40k_cityscapes/apcnet_r101-d8_512x1024_40k_cityscapes_20201214_115716-abc9d111.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r101-d8_512x1024_40k_cityscapes/apcnet_r101-d8_512x1024_40k_cityscapes-20201214_115716.log.json) |
|
||||
| APCNet | R-50-D8 | 769x769 | 40000 | 8.7 | 1.52 | 77.89 | 79.75 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r50-d8_769x769_40k_cityscapes/apcnet_r50-d8_769x769_40k_cityscapes_20201214_115717-2a2628d7.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r50-d8_769x769_40k_cityscapes/apcnet_r50-d8_769x769_40k_cityscapes-20201214_115717.log.json) |
|
||||
| APCNet | R-101-D8 | 769x769 | 40000 | 12.7 | 1.03 | 77.96 | 79.24 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r101-d8_769x769_40k_cityscapes/apcnet_r101-d8_769x769_40k_cityscapes_20201214_115718-b650de90.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r101-d8_769x769_40k_cityscapes/apcnet_r101-d8_769x769_40k_cityscapes-20201214_115718.log.json) |
|
||||
| APCNet | R-50-D8 | 512x1024 | 80000 | - | - | 78.96 | 79.94 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r50-d8_512x1024_80k_cityscapes/apcnet_r50-d8_512x1024_80k_cityscapes_20201214_115716-987f51e3.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r50-d8_512x1024_80k_cityscapes/apcnet_r50-d8_512x1024_80k_cityscapes-20201214_115716.log.json) |
|
||||
| APCNet | R-101-D8 | 512x1024 | 80000 | - | - | 79.64 | 80.61 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r101-d8_512x1024_80k_cityscapes/apcnet_r101-d8_512x1024_80k_cityscapes_20201214_115705-b1ff208a.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r101-d8_512x1024_80k_cityscapes/apcnet_r101-d8_512x1024_80k_cityscapes-20201214_115705.log.json) |
|
||||
| APCNet | R-50-D8 | 769x769 | 80000 | - | - | 78.79 | 80.35 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r50-d8_769x769_80k_cityscapes/apcnet_r50-d8_769x769_80k_cityscapes_20201214_115718-7ea9fa12.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r50-d8_769x769_80k_cityscapes/apcnet_r50-d8_769x769_80k_cityscapes-20201214_115718.log.json) |
|
||||
| APCNet | R-101-D8 | 769x769 | 80000 | - | - | 78.45 | 79.91 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r101-d8_769x769_80k_cityscapes/apcnet_r101-d8_769x769_80k_cityscapes_20201214_115716-a7fbc2ab.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r101-d8_769x769_80k_cityscapes/apcnet_r101-d8_769x769_80k_cityscapes-20201214_115716.log.json) |
|
||||
|
||||
### ADE20K
|
||||
|
||||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|
||||
|--------|----------|-----------|--------:|----------|----------------|------:|--------------:|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| APCNet | R-50-D8 | 512x512 | 80000 | 10.1 | 19.61 | 42.20 | 43.30 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r50-d8_512x512_80k_ade20k/apcnet_r50-d8_512x512_80k_ade20k_20201214_115705-a8626293.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r50-d8_512x512_80k_ade20k/apcnet_r50-d8_512x512_80k_ade20k-20201214_115705.log.json) |
|
||||
| APCNet | R-101-D8 | 512x512 | 80000 | 13.6 | 13.10 | 45.54 | 46.65 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r101-d8_512x512_80k_ade20k/apcnet_r101-d8_512x512_80k_ade20k_20201214_115704-c656c3fb.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r101-d8_512x512_80k_ade20k/apcnet_r101-d8_512x512_80k_ade20k-20201214_115704.log.json) |
|
||||
| APCNet | R-50-D8 | 512x512 | 160000 | - | - | 43.40 | 43.94 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r50-d8_512x512_160k_ade20k/apcnet_r50-d8_512x512_160k_ade20k_20201214_115706-25fb92c2.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r50-d8_512x512_160k_ade20k/apcnet_r50-d8_512x512_160k_ade20k-20201214_115706.log.json) |
|
||||
| APCNet | R-101-D8 | 512x512 | 160000 | - | - | 45.41 | 46.63 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r101-d8_512x512_160k_ade20k/apcnet_r101-d8_512x512_160k_ade20k_20201214_115705-73f9a8d7.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/apcnet/apcnet_r101-d8_512x512_160k_ade20k/apcnet_r101-d8_512x512_160k_ade20k-20201214_115705.log.json) |
|
|
@ -0,0 +1,2 @@
|
|||
_base_ = './apcnet_r50-d8_512x1024_40k_cityscapes.py'
|
||||
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
|
@ -0,0 +1,2 @@
|
|||
_base_ = './apcnet_r50-d8_512x1024_80k_cityscapes.py'
|
||||
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
|
@ -0,0 +1,2 @@
|
|||
_base_ = './apcnet_r50-d8_512x512_160k_ade20k.py'
|
||||
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
|
@ -0,0 +1,2 @@
|
|||
_base_ = './apcnet_r50-d8_512x512_80k_ade20k.py'
|
||||
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
|
@ -0,0 +1,2 @@
|
|||
_base_ = './apcnet_r50-d8_769x769_40k_cityscapes.py'
|
||||
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
|
@ -0,0 +1,2 @@
|
|||
_base_ = './apcnet_r50-d8_769x769_80k_cityscapes.py'
|
||||
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
|
@ -0,0 +1,4 @@
|
|||
_base_ = [
|
||||
'../_base_/models/apcnet_r50-d8.py', '../_base_/datasets/cityscapes.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py'
|
||||
]
|
|
@ -0,0 +1,4 @@
|
|||
_base_ = [
|
||||
'../_base_/models/apcnet_r50-d8.py', '../_base_/datasets/cityscapes.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
|
||||
]
|
|
@ -0,0 +1,7 @@
|
|||
_base_ = [
|
||||
'../_base_/models/apcnet_r50-d8.py', '../_base_/datasets/ade20k.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
|
||||
]
|
||||
model = dict(
|
||||
decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150))
|
||||
test_cfg = dict(mode='whole')
|
|
@ -0,0 +1,7 @@
|
|||
_base_ = [
|
||||
'../_base_/models/apcnet_r50-d8.py', '../_base_/datasets/ade20k.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
|
||||
]
|
||||
model = dict(
|
||||
decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150))
|
||||
test_cfg = dict(mode='whole')
|
|
@ -0,0 +1,9 @@
|
|||
_base_ = [
|
||||
'../_base_/models/apcnet_r50-d8.py',
|
||||
'../_base_/datasets/cityscapes_769x769.py', '../_base_/default_runtime.py',
|
||||
'../_base_/schedules/schedule_40k.py'
|
||||
]
|
||||
model = dict(
|
||||
decode_head=dict(align_corners=True),
|
||||
auxiliary_head=dict(align_corners=True))
|
||||
test_cfg = dict(mode='slide', crop_size=(769, 769), stride=(513, 513))
|
|
@ -0,0 +1,9 @@
|
|||
_base_ = [
|
||||
'../_base_/models/apcnet_r50-d8.py',
|
||||
'../_base_/datasets/cityscapes_769x769.py', '../_base_/default_runtime.py',
|
||||
'../_base_/schedules/schedule_80k.py'
|
||||
]
|
||||
model = dict(
|
||||
decode_head=dict(align_corners=True),
|
||||
auxiliary_head=dict(align_corners=True))
|
||||
test_cfg = dict(mode='slide', crop_size=(769, 769), stride=(513, 513))
|
|
@ -67,6 +67,10 @@ Please refer to [CCNet](https://github.com/open-mmlab/mmsegmentation/blob/master
|
|||
|
||||
Please refer to [DANet](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/danet) for details.
|
||||
|
||||
### APCNet
|
||||
|
||||
Please refer to [APCNet](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/apcnet) for details.
|
||||
|
||||
### HRNet
|
||||
|
||||
Please refer to [HRNet](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet) for details.
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from .ann_head import ANNHead
|
||||
from .apc_head import APCHead
|
||||
from .aspp_head import ASPPHead
|
||||
from .cc_head import CCHead
|
||||
from .da_head import DAHead
|
||||
|
@ -21,5 +22,5 @@ __all__ = [
|
|||
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
|
||||
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
|
||||
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
|
||||
'PointHead'
|
||||
'PointHead', 'APCHead'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import HEADS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class ACM(nn.Module):
|
||||
"""Adaptive Context Module used in APCNet.
|
||||
|
||||
Args:
|
||||
pool_scale (int): Pooling scale used in Adaptive Context
|
||||
Module to extract region fetures.
|
||||
fusion (bool): Add one conv to fuse residual feature.
|
||||
in_channels (int): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg,
|
||||
norm_cfg, act_cfg):
|
||||
super(ACM, self).__init__()
|
||||
self.pool_scale = pool_scale
|
||||
self.fusion = fusion
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.pooled_redu_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.input_redu_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.global_info = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0)
|
||||
|
||||
self.residual_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
if self.fusion:
|
||||
self.fusion_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale)
|
||||
# [batch_size, channels, h, w]
|
||||
x = self.input_redu_conv(x)
|
||||
# [batch_size, channels, pool_scale, pool_scale]
|
||||
pooled_x = self.pooled_redu_conv(pooled_x)
|
||||
batch_size = x.size(0)
|
||||
# [batch_size, pool_scale * pool_scale, channels]
|
||||
pooled_x = pooled_x.view(batch_size, self.channels,
|
||||
-1).permute(0, 2, 1).contiguous()
|
||||
# [batch_size, h * w, pool_scale * pool_scale]
|
||||
affinity_matrix = self.gla(x + resize(
|
||||
self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:])
|
||||
).permute(0, 2, 3, 1).reshape(
|
||||
batch_size, -1, self.pool_scale**2)
|
||||
affinity_matrix = F.sigmoid(affinity_matrix)
|
||||
# [batch_size, h * w, channels]
|
||||
z_out = torch.matmul(affinity_matrix, pooled_x)
|
||||
# [batch_size, channels, h * w]
|
||||
z_out = z_out.permute(0, 2, 1).contiguous()
|
||||
# [batch_size, channels, h, w]
|
||||
z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3))
|
||||
z_out = self.residual_conv(z_out)
|
||||
z_out = F.relu(z_out + x)
|
||||
if self.fusion:
|
||||
z_out = self.fusion_conv(z_out)
|
||||
|
||||
return z_out
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class APCHead(BaseDecodeHead):
|
||||
"""Adaptive Pyramid Context Network for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of
|
||||
`APCNet <https://openaccess.thecvf.com/content_CVPR_2019/papers/\
|
||||
He_Adaptive_Pyramid_Context_Network_for_Semantic_Segmentation_\
|
||||
CVPR_2019_paper.pdf>`_.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Adaptive Context
|
||||
Module. Default: (1, 2, 3, 6).
|
||||
fusion (bool): Add one conv to fuse residual feature.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs):
|
||||
super(APCHead, self).__init__(**kwargs)
|
||||
assert isinstance(pool_scales, (list, tuple))
|
||||
self.pool_scales = pool_scales
|
||||
self.fusion = fusion
|
||||
acm_modules = []
|
||||
for pool_scale in self.pool_scales:
|
||||
acm_modules.append(
|
||||
ACM(pool_scale,
|
||||
self.fusion,
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
self.acm_modules = nn.ModuleList(acm_modules)
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels + len(pool_scales) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
acm_outs = [x]
|
||||
for acm_module in self.acm_modules:
|
||||
acm_outs.append(acm_module(x))
|
||||
acm_outs = torch.cat(acm_outs, dim=1)
|
||||
output = self.bottleneck(acm_outs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
|
@ -6,8 +6,8 @@ from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
|||
from mmcv.utils import ConfigDict
|
||||
from mmcv.utils.parrots_wrapper import SyncBatchNorm
|
||||
|
||||
from mmseg.models.decode_heads import (ANNHead, ASPPHead, CCHead, DAHead,
|
||||
DepthwiseSeparableASPPHead,
|
||||
from mmseg.models.decode_heads import (ANNHead, APCHead, ASPPHead, CCHead,
|
||||
DAHead, DepthwiseSeparableASPPHead,
|
||||
DepthwiseSeparableFCNHead, DNLHead,
|
||||
EMAHead, EncHead, FCNHead, GCHead,
|
||||
NLHead, OCRHead, PointHead, PSAHead,
|
||||
|
@ -223,6 +223,59 @@ def test_psp_head():
|
|||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||
|
||||
|
||||
def test_apc_head():
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# pool_scales must be list|tuple
|
||||
APCHead(in_channels=32, channels=16, num_classes=19, pool_scales=1)
|
||||
|
||||
# test no norm_cfg
|
||||
head = APCHead(in_channels=32, channels=16, num_classes=19)
|
||||
assert not _conv_has_norm(head, sync_bn=False)
|
||||
|
||||
# test with norm_cfg
|
||||
head = APCHead(
|
||||
in_channels=32,
|
||||
channels=16,
|
||||
num_classes=19,
|
||||
norm_cfg=dict(type='SyncBN'))
|
||||
assert _conv_has_norm(head, sync_bn=True)
|
||||
|
||||
# fusion=True
|
||||
inputs = [torch.randn(1, 32, 45, 45)]
|
||||
head = APCHead(
|
||||
in_channels=32,
|
||||
channels=16,
|
||||
num_classes=19,
|
||||
pool_scales=(1, 2, 3),
|
||||
fusion=True)
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
assert head.fusion is True
|
||||
assert head.acm_modules[0].pool_scale == 1
|
||||
assert head.acm_modules[1].pool_scale == 2
|
||||
assert head.acm_modules[2].pool_scale == 3
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||
|
||||
# fusion=False
|
||||
inputs = [torch.randn(1, 32, 45, 45)]
|
||||
head = APCHead(
|
||||
in_channels=32,
|
||||
channels=16,
|
||||
num_classes=19,
|
||||
pool_scales=(1, 2, 3),
|
||||
fusion=False)
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
assert head.fusion is False
|
||||
assert head.acm_modules[0].pool_scale == 1
|
||||
assert head.acm_modules[1].pool_scale == 2
|
||||
assert head.acm_modules[2].pool_scale == 3
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||
|
||||
|
||||
def test_aspp_head():
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
|
|
Loading…
Reference in New Issue