add upernet algo (#118)

* add upernet algo
* fix import onnx bug
pull/127/head
yhq 2022-07-12 11:14:09 +08:00 committed by GitHub
parent e4722c754f
commit 6b8b04db72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 478 additions and 9 deletions

View File

@ -0,0 +1,163 @@
_base_ = ['configs/base.py']
CLASSES = [
'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
]
# model settings
num_classes = 21
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(1, 2, 3, 4),
dilations=(1, 1, 1, 1),
strides=(1, 2, 2, 2),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True,
),
decode_head=dict(
type='UPerHead',
in_channels=[256, 512, 1024, 2048],
in_index=[0, 1, 2, 3],
pool_scales=(1, 2, 3, 6),
channels=512,
dropout_ratio=0.1,
num_classes=num_classes,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=21,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
# dataset settings
dataset_type = 'SegDataset'
data_root = 'data/VOCdevkit/VOC2012/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='MMResize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='SegRandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='MMRandomFlip', flip_ratio=0.5),
dict(type='MMPhotoMetricDistortion'),
dict(type='MMNormalize', **img_norm_cfg),
dict(type='MMPad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
keys=['img', 'gt_semantic_seg'],
meta_keys=('filename', 'ori_filename', 'ori_shape', 'img_shape',
'pad_shape', 'scale_factor', 'flip', 'flip_direction',
'img_norm_cfg')),
]
test_pipeline = [
dict(
type='MMMultiScaleFlipAug',
img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='MMResize', keep_ratio=True),
dict(type='MMRandomFlip'),
dict(type='MMNormalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(
type='Collect',
keys=['img'],
meta_keys=('filename', 'ori_filename', 'ori_shape',
'img_shape', 'pad_shape', 'scale_factor', 'flip',
'flip_direction', 'img_norm_cfg')),
])
]
data = dict(
imgs_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ignore_index=255,
data_source=dict(
type='SourceConcat',
data_source_list=[
dict(
type='SegSourceRaw',
img_root=data_root + 'JPEGImages',
label_root=data_root + 'SegmentationClass',
split=data_root + 'ImageSets/Segmentation/train.txt',
classes=CLASSES),
dict(
type='SegSourceRaw',
img_root=data_root + 'JPEGImages',
label_root=data_root + 'SegmentationClassAug',
split=data_root + 'ImageSets/Segmentation/aug.txt',
classes=CLASSES),
]),
pipeline=train_pipeline),
val=dict(
imgs_per_gpu=1,
ignore_index=255,
type=dataset_type,
data_source=dict(
type='SegSourceRaw',
img_root=data_root + 'JPEGImages',
label_root=data_root + 'SegmentationClass',
split=data_root + 'ImageSets/Segmentation/val.txt',
classes=CLASSES,
),
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_source=dict(
type='SegSourceRaw',
img_root=data_root + 'JPEGImages',
label_root=data_root + 'SegmentationClass',
split=data_root + 'ImageSets/Segmentation/test.txt',
classes=CLASSES,
),
pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
total_epochs = 60
checkpoint_config = dict(interval=5)
eval_config = dict(interval=1, gpu_collect=False)
eval_pipelines = [
dict(
mode='test',
evaluators=[
dict(
type='SegmentationEvaluator',
classes=CLASSES,
metric_names=['mIoU'])
],
)
]

View File

@ -7,3 +7,10 @@ Pretrained on **Pascal VOC 2012 + Aug**.
| Algorithm | Config | mIoU | Download |
| ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ |
| fcn_r50_d8 | [fcn_r50-d8_512x512_8xb4_60e_voc12aug](https://github.com/alibaba/EasyCV/tree/master/configs/segmentation/fcn/fcn_r50-d8_512x512_8xb4_60e_voc12aug.py) | 69.01 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/fcn_r50/epoch_60.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/fcn_r50/20220525_203606.log.json) |
## UperNet
Pretrained on **Pascal VOC 2012 + Aug**.
| Algorithm | Config | mIoU | Download |
| ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ |
| upernet_r50 | [upernet_r50_512x512_8xb4_60e_voc12aug](https://github.com/alibaba/EasyCV/tree/master/configs/segmentation/upernet/upernet_r50_512x512_8xb4_60e_voc12aug.py) | 76.59 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/upernet_r50/epoch_60.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/upernet_r50/20220706_114712.log.json) |

View File

@ -7,7 +7,8 @@ from easycv.models import builder
from easycv.models.base import BaseModel
from easycv.models.builder import MODELS
from easycv.models.utils.ops import resize_tensor
from easycv.utils.logger import print_log
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger, print_log
from easycv.utils.misc import add_prefix
@ -36,7 +37,7 @@ class EncoderDecoder(BaseModel):
self.neck = neck
self.auxiliary_head = auxiliary_head
self.pretrained = pretrained
if self.neck is not None:
self.neck = builder.build_neck(self.neck)
@ -55,12 +56,32 @@ class EncoderDecoder(BaseModel):
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
self.init_weights()
def init_weights(self, pretrained=None):
if pretrained is not None:
print_log('load model from: {}'.format(pretrained), logger='root')
self.backbone.init_weights(pretrained=pretrained)
def init_weights(self):
logger = get_root_logger()
if isinstance(self.pretrained, str):
load_checkpoint(
self.backbone, self.pretrained, strict=False, logger=logger)
elif self.pretrained:
if self.backbone.__class__.__name__ == 'PytorchImageModelWrapper':
self.backbone.init_weights(pretrained=self.pretrained)
elif hasattr(self.backbone, 'default_pretrained_model_path'
) and self.backbone.default_pretrained_model_path:
print_log(
'load model from default path: {}'.format(
self.backbone.default_pretrained_model_path), logger)
load_checkpoint(
self.backbone,
self.backbone.default_pretrained_model_path,
strict=False,
logger=logger)
else:
print_log('load model from init weights')
self.backbone.init_weights()
else:
print_log('load model from init weights')
self.backbone.init_weights()
if hasattr(self.decode_head, 'init_weights'):
self.decode_head.init_weights()

View File

@ -1,4 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .fcn_head import FCNHead
from .uper_head import UPerHead
__all__ = ['FCNHead']
__all__ = ['FCNHead', 'UPerHead']

View File

@ -0,0 +1,194 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from easycv.models.builder import HEADS
from easycv.models.utils.ops import resize_tensor
from .base import BaseDecodeHead
# Modified from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/uper_head.py
@HEADS.register_module()
class UPerHead(BaseDecodeHead):
"""Unified Perceptual Parsing for Scene Understanding.
This head is the implementation of `UPerNet
<https://arxiv.org/abs/1807.10221>`_.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module applied on the last feature. Default: (1, 2, 3, 6).
"""
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
super(UPerHead, self).__init__(
input_transform='multiple_select', **kwargs)
# PSP Module
self.psp_modules = PPM(
pool_scales,
self.in_channels[-1],
self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
self.bottleneck = ConvModule(
self.in_channels[-1] + len(pool_scales) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
# FPN Module
self.lateral_convs = nn.ModuleList()
self.fpn_convs = nn.ModuleList()
for in_channels in self.in_channels[:-1]: # skip the top layer
l_conv = ConvModule(
in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
inplace=False)
fpn_conv = ConvModule(
self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
inplace=False)
self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
self.fpn_bottleneck = ConvModule(
len(self.in_channels) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def psp_forward(self, inputs):
"""Forward function of PSP module."""
x = inputs[-1]
psp_outs = [x]
psp_outs.extend(self.psp_modules(x))
psp_outs = torch.cat(psp_outs, dim=1)
output = self.bottleneck(psp_outs)
return output
def _forward_feature(self, inputs):
"""Forward function for feature maps before classifying each pixel with
``self.cls_seg`` fc.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
feats (Tensor): A tensor of shape (batch_size, self.channels,
H, W) which is feature map for last layer of decoder head.
"""
inputs = self._transform_inputs(inputs)
# build laterals
laterals = [
lateral_conv(inputs[i])
for i, lateral_conv in enumerate(self.lateral_convs)
]
laterals.append(self.psp_forward(inputs))
# build top-down path
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
prev_shape = laterals[i - 1].shape[2:]
laterals[i - 1] = laterals[i - 1] + resize_tensor(
laterals[i],
size=prev_shape,
mode='bilinear',
align_corners=self.align_corners)
# build outputs
fpn_outs = [
self.fpn_convs[i](laterals[i])
for i in range(used_backbone_levels - 1)
]
# append psp feature
fpn_outs.append(laterals[-1])
for i in range(used_backbone_levels - 1, 0, -1):
fpn_outs[i] = resize_tensor(
fpn_outs[i],
size=fpn_outs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners)
fpn_outs = torch.cat(fpn_outs, dim=1)
feats = self.fpn_bottleneck(fpn_outs)
return feats
def forward(self, inputs):
"""Forward function."""
output = self._forward_feature(inputs)
output = self.cls_seg(output)
return output
class PPM(nn.ModuleList):
"""Pooling Pyramid Module used in PSPNet.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module.
in_channels (int): Input channels.
channels (int): Channels after modules, before conv_seg.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict): Config of activation layers.
align_corners (bool): align_corners argument of F.interpolate.
"""
def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
act_cfg, align_corners, **kwargs):
super(PPM, self).__init__()
self.pool_scales = pool_scales
self.align_corners = align_corners
self.in_channels = in_channels
self.channels = channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
for pool_scale in pool_scales:
self.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(pool_scale),
ConvModule(
self.in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
**kwargs)))
def forward(self, x):
"""Forward function."""
ppm_outs = []
for ppm in self:
ppm_out = ppm(x)
upsampled_ppm_out = resize_tensor(
ppm_out,
size=x.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
ppm_outs.append(upsampled_ppm_out)
return ppm_outs

View File

@ -0,0 +1,83 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
import torch
from easycv.models.builder import build_head
class UperHeadTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_forward_train(self):
norm_cfg = dict(type='BN', requires_grad=True)
uper_head_config = dict(
type='UPerHead',
in_channels=[256, 512, 1024, 2048],
in_index=[0, 1, 2, 3],
pool_scales=(1, 2, 3, 6),
channels=512,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
head = build_head(uper_head_config)
head = head.to('cuda')
batch_size = 2
dummy_inputs = [
torch.rand(batch_size, 256, 128, 128).to('cuda'),
torch.rand(batch_size, 512, 64, 64).to('cuda'),
torch.rand(batch_size, 1024, 32, 32).to('cuda'),
torch.rand(batch_size, 2048, 16, 16).to('cuda'),
]
gt_semantic_seg = torch.randint(
low=0, high=19, size=(batch_size, 1, 512, 512)).to('cuda')
train_output = head.forward_train(
dummy_inputs,
img_metas=None,
gt_semantic_seg=gt_semantic_seg,
train_cfg=None)
self.assertIn('loss_ce', train_output)
self.assertIn('acc_seg', train_output)
self.assertEqual(train_output['acc_seg'].shape, torch.Size([1]))
def test_forward_test(self):
norm_cfg = dict(type='BN', requires_grad=True)
uper_head_config = dict(
type='UPerHead',
in_channels=[256, 512, 1024, 2048],
in_index=[0, 1, 2, 3],
pool_scales=(1, 2, 3, 6),
channels=512,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
head = build_head(uper_head_config)
head = head.to('cuda')
batch_size = 2
dummy_inputs = [
torch.rand(batch_size, 256, 128, 128).to('cuda'),
torch.rand(batch_size, 512, 64, 64).to('cuda'),
torch.rand(batch_size, 1024, 32, 32).to('cuda'),
torch.rand(batch_size, 2048, 16, 16).to('cuda'),
]
with torch.no_grad():
test_output = head.forward_test(
dummy_inputs, img_metas=None, test_cfg=None)
self.assertEqual(test_output.shape, torch.Size([2, 19, 128, 128]))
if __name__ == '__main__':
unittest.main()

View File

@ -26,7 +26,6 @@ from easycv.datasets.utils import is_dali_dataset_type
from easycv.file import io
from easycv.models import build_model
from easycv.utils.collect_env import collect_env
from easycv.utils.flops_counter import get_model_info
from easycv.utils.logger import get_root_logger
from easycv.utils.mmlab_utils import dynamic_adapt_for_mmlab
from easycv.utils.config_tools import traverse_replace
@ -214,6 +213,7 @@ def main():
print(model)
if 'stage' in cfg.model and cfg.model['stage'] == 'EDGE':
from easycv.utils.flops_counter import get_model_info
get_model_info(model, cfg.img_scale, cfg.model, logger)
assert len(cfg.workflow) == 1, 'Validation is called by hook.'