mirror of https://github.com/alibaba/EasyCV.git
parent
2fe73eee91
commit
26cd12ab42
|
@ -0,0 +1,198 @@
|
|||
_base_ = ['configs/base.py']
|
||||
|
||||
# warning batch_size need >= 2
|
||||
# model
|
||||
norm_cfg = dict(type='BN', requires_grad=True)
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
backbone=dict(
|
||||
type='STDCContextPathNet',
|
||||
backbone_cfg=dict(
|
||||
type='STDCNet',
|
||||
stdc_type='STDCNet1',
|
||||
in_channels=3,
|
||||
channels=(32, 64, 256, 512, 1024),
|
||||
bottleneck_type='cat',
|
||||
num_convs=4,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
with_final_conv=False),
|
||||
last_in_channels=(1024, 512),
|
||||
out_channels=128,
|
||||
ffm_cfg=dict(in_channels=384, out_channels=256, scale_factor=4)),
|
||||
decode_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=256,
|
||||
channels=256,
|
||||
num_convs=1,
|
||||
num_classes=19,
|
||||
in_index=3,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=True,
|
||||
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
auxiliary_head=[
|
||||
dict(
|
||||
type='FCNHead',
|
||||
in_channels=128,
|
||||
channels=64,
|
||||
num_convs=1,
|
||||
num_classes=19,
|
||||
in_index=2,
|
||||
norm_cfg=norm_cfg,
|
||||
concat_input=False,
|
||||
align_corners=False,
|
||||
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
dict(
|
||||
type='FCNHead',
|
||||
in_channels=128,
|
||||
channels=64,
|
||||
num_convs=1,
|
||||
num_classes=19,
|
||||
in_index=1,
|
||||
norm_cfg=norm_cfg,
|
||||
concat_input=False,
|
||||
align_corners=False,
|
||||
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
dict(
|
||||
type='STDCHead',
|
||||
in_channels=256,
|
||||
channels=64,
|
||||
num_convs=1,
|
||||
num_classes=2,
|
||||
boundary_threshold=0.1,
|
||||
in_index=0,
|
||||
norm_cfg=norm_cfg,
|
||||
concat_input=False,
|
||||
align_corners=True,
|
||||
loss_decode=[
|
||||
dict(
|
||||
type='CrossEntropyLoss',
|
||||
loss_name='loss_ce',
|
||||
use_sigmoid=True,
|
||||
loss_weight=1.0),
|
||||
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=1.0)
|
||||
]),
|
||||
],
|
||||
train_cfg=dict(),
|
||||
test_cfg=dict(mode='whole'),
|
||||
pretrained=
|
||||
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/stdc/pretrain/stdc1_easycv.pth'
|
||||
)
|
||||
|
||||
# dataset
|
||||
CLASSES = [
|
||||
'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light',
|
||||
'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
|
||||
'truck', 'bus', 'train', 'motorcycle', 'bicycle'
|
||||
]
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
crop_size = (512, 1024)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='MMResize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
|
||||
dict(type='SegRandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='MMRandomFlip', flip_ratio=0.5),
|
||||
dict(type='MMPhotoMetricDistortion'),
|
||||
dict(type='MMNormalize', **img_norm_cfg),
|
||||
dict(type='MMPad', size=crop_size),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img', 'gt_semantic_seg'],
|
||||
meta_keys=('filename', 'ori_filename', 'ori_shape', 'img_shape',
|
||||
'pad_shape', 'scale_factor', 'flip', 'flip_direction',
|
||||
'img_norm_cfg')),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(
|
||||
type='MMMultiScaleFlipAug',
|
||||
img_scale=(2048, 1024),
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='MMResize', keep_ratio=True),
|
||||
dict(type='MMRandomFlip'),
|
||||
dict(type='MMNormalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=('filename', 'ori_filename', 'ori_shape',
|
||||
'img_shape', 'pad_shape', 'scale_factor', 'flip',
|
||||
'flip_direction', 'img_norm_cfg')),
|
||||
])
|
||||
]
|
||||
dataset_type = 'SegDataset'
|
||||
data_root = '../Cityscapes/'
|
||||
|
||||
train_img_root = data_root + 'leftImg8bit/train/'
|
||||
train_label_root = data_root + 'gtFine/train/'
|
||||
|
||||
val_img_root = data_root + 'leftImg8bit/val/'
|
||||
val_label_root = data_root + 'gtFine/val/'
|
||||
data = dict(
|
||||
imgs_per_gpu=6,
|
||||
workers_per_gpu=4,
|
||||
persistent_workers=True,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ignore_index=255,
|
||||
data_source=dict(
|
||||
type='SegSourceCityscapes',
|
||||
img_root=train_img_root,
|
||||
label_root=train_label_root,
|
||||
classes=CLASSES),
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
imgs_per_gpu=1,
|
||||
ignore_index=255,
|
||||
type=dataset_type,
|
||||
data_source=dict(
|
||||
type='SegSourceCityscapes',
|
||||
img_root=val_img_root,
|
||||
label_root=val_label_root,
|
||||
classes=CLASSES),
|
||||
pipeline=test_pipeline))
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||
optimizer_config = dict()
|
||||
# learning policy
|
||||
lr_config = dict(
|
||||
policy='CosineAnnealing',
|
||||
min_lr=1e-4,
|
||||
warmup='linear',
|
||||
warmup_iters=10,
|
||||
warmup_ratio=0.0001,
|
||||
warmup_by_epoch=True,
|
||||
by_epoch=False)
|
||||
|
||||
# runtime settings
|
||||
total_epochs = 1290
|
||||
checkpoint_config = dict(interval=10)
|
||||
eval_config = dict(interval=10, gpu_collect=False)
|
||||
eval_pipelines = [
|
||||
dict(
|
||||
mode='test',
|
||||
evaluators=[
|
||||
dict(
|
||||
type='SegmentationEvaluator',
|
||||
classes=CLASSES,
|
||||
metric_names=['mIoU'])
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
# export config
|
||||
export = dict(export_neck=True)
|
||||
checkpoint_sync_export = True
|
|
@ -0,0 +1,7 @@
|
|||
_base_ = ['configs/segmentation/stdc/stdc1_cityscape_8xb6_e1290.py']
|
||||
|
||||
model = dict(
|
||||
backbone=dict(backbone_cfg=dict(stdc_type='STDCNet2')),
|
||||
pretrained=
|
||||
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/stdc/pretrain/stdc2_easycv.pth'
|
||||
)
|
|
@ -15,6 +15,12 @@ Pretrained on **Pascal VOC 2012 + Aug**.
|
|||
| ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||
| upernet_r50 | [upernet_r50_512x512_8xb4_60e_voc12aug](https://github.com/alibaba/EasyCV/tree/master/configs/segmentation/upernet/upernet_r50_512x512_8xb4_60e_voc12aug.py) | 23M/66M | 5.5 | 282.9ms | 76.59 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/upernet_r50/epoch_60.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/upernet_r50/20220706_114712.log.json) |
|
||||
|
||||
## STDC
|
||||
trained on **Cityscapes**.
|
||||
| Algorithm | Config | Params<br/>(backbone/total) | Train memory<br/>(GB) | inference time(V100)<br/>(ms/img) | mIoU | Download |
|
||||
| ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||
| STDC1 | [stdc1_cityscape_8xb6_e1290](https://github.com/alibaba/EasyCV/tree/master/configs/segmentation/stdc/stdc1_cityscape_8xb6_e1290.py) | 7.7M/8.5M | 4.5 | 11.9ms | 75.4 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/stdc/stdc1_cityscapes/epoch_1250.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/stdc/stdc1_cityscapes/20230214_173123.log.json) |
|
||||
|
||||
## Mask2former
|
||||
|
||||
### Instance Segmentation on COCO
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .cityscapes import SegSourceCityscapes
|
||||
from .coco import SegSourceCoco, SegSourceCoco2017
|
||||
from .coco_stuff import SegSourceCocoStuff10k, SegSourceCocoStuff164k
|
||||
from .raw import SegSourceRaw
|
||||
|
@ -7,5 +8,5 @@ from .voc import SegSourceVoc2007, SegSourceVoc2010, SegSourceVoc2012
|
|||
__all__ = [
|
||||
'SegSourceRaw', 'SegSourceVoc2010', 'SegSourceVoc2007', 'SegSourceVoc2012',
|
||||
'SegSourceCoco', 'SegSourceCoco2017', 'SegSourceCocoStuff164k',
|
||||
'SegSourceCocoStuff10k'
|
||||
'SegSourceCocoStuff10k', 'SegSourceCityscapes'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import numpy as np
|
||||
|
||||
from easycv.datasets.registry import DATASOURCES
|
||||
from easycv.file import io
|
||||
from easycv.file.image import load_image as _load_img
|
||||
from .raw import SegSourceRaw
|
||||
|
||||
try:
|
||||
import cityscapesscripts.helpers.labels as CSLabels
|
||||
except ModuleNotFoundError as e:
|
||||
res = subprocess.call('pip install cityscapesscripts', shell=True)
|
||||
if res != 0:
|
||||
info_string = (
|
||||
'\n\nAuto install failed! Please install cityscapesscripts with the following commands :\n'
|
||||
'\t`pip install cityscapesscripts`\n')
|
||||
raise ModuleNotFoundError(info_string)
|
||||
|
||||
|
||||
def load_seg_map_cityscape(seg_path, reduce_zero_label):
|
||||
gt_semantic_seg = _load_img(seg_path, mode='P')
|
||||
gt_semantic_seg_copy = gt_semantic_seg.copy()
|
||||
for labels in CSLabels.labels:
|
||||
gt_semantic_seg_copy[gt_semantic_seg == labels.id] = labels.trainId
|
||||
|
||||
return {'gt_semantic_seg': gt_semantic_seg_copy}
|
||||
|
||||
|
||||
@DATASOURCES.register_module
|
||||
class SegSourceCityscapes(SegSourceRaw):
|
||||
"""Cityscapes datasource
|
||||
"""
|
||||
CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
|
||||
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
|
||||
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
|
||||
'bicycle')
|
||||
|
||||
PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
|
||||
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
|
||||
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
|
||||
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
|
||||
[0, 80, 100], [0, 0, 230], [119, 11, 32]]
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='_leftImg8bit.png',
|
||||
label_suffix='_gtFine_labelIds.png',
|
||||
**kwargs):
|
||||
super(SegSourceCityscapes, self).__init__(
|
||||
img_suffix=img_suffix, label_suffix=label_suffix, **kwargs)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
result_dict = self.samples_list[idx]
|
||||
load_success = True
|
||||
try:
|
||||
# avoid data cache from taking up too much memory
|
||||
if not self.cache_at_init and not self.cache_on_the_fly:
|
||||
result_dict = copy.deepcopy(result_dict)
|
||||
|
||||
if not self.cache_at_init:
|
||||
if result_dict.get('img', None) is None:
|
||||
img = _load_img(result_dict['filename'], mode='BGR')
|
||||
result = {
|
||||
'img': img.astype(np.float32),
|
||||
'img_shape': img.shape, # h, w, c
|
||||
'ori_shape': img.shape,
|
||||
}
|
||||
result_dict.update(result)
|
||||
if result_dict.get('gt_semantic_seg', None) is None:
|
||||
result_dict.update(
|
||||
load_seg_map_cityscape(
|
||||
result_dict['seg_filename'],
|
||||
reduce_zero_label=self.reduce_zero_label))
|
||||
if self.cache_on_the_fly:
|
||||
self.samples_list[idx] = result_dict
|
||||
result_dict = self.post_process_fn(copy.deepcopy(result_dict))
|
||||
self._retry_count = 0
|
||||
except Exception as e:
|
||||
logging.warning(e)
|
||||
load_success = False
|
||||
|
||||
if not load_success:
|
||||
logging.warning(
|
||||
'Something wrong with current sample %s,Try load next sample...'
|
||||
% result_dict.get('filename', ''))
|
||||
self._retry_count += 1
|
||||
if self._retry_count >= self._max_retry_num:
|
||||
raise ValueError('All samples failed to load!')
|
||||
|
||||
result_dict = self[(idx + 1) % self.num_samples]
|
||||
|
||||
return result_dict
|
||||
|
||||
def get_source_iterator(self):
|
||||
|
||||
self.img_files = [
|
||||
os.path.join(self.img_root, i)
|
||||
for i in io.listdir(self.img_root, recursive=True)
|
||||
if i.endswith(self.img_suffix[0])
|
||||
]
|
||||
|
||||
self.label_files = []
|
||||
for img_path in self.img_files:
|
||||
self.img_root = os.path.join(self.img_root, '')
|
||||
img_name = img_path.replace(self.img_root,
|
||||
'')[:-len(self.img_suffix[0])]
|
||||
find_label_path = False
|
||||
for label_format in self.label_suffix:
|
||||
lable_path = os.path.join(self.label_root,
|
||||
img_name + label_format)
|
||||
if io.exists(lable_path):
|
||||
find_label_path = True
|
||||
self.label_files.append(lable_path)
|
||||
break
|
||||
if not find_label_path:
|
||||
logging.warning(
|
||||
'Not find label file %s for img: %s, skip the sample!' %
|
||||
(lable_path, img_path))
|
||||
self.img_files.remove(img_path)
|
||||
|
||||
assert len(self.img_files) == len(self.label_files)
|
||||
assert len(
|
||||
self.img_files) > 0, 'No samples found in %s' % self.img_root
|
||||
|
||||
return list(zip(self.img_files, self.label_files))
|
|
@ -22,6 +22,7 @@ from .resnet import ResNet
|
|||
from .resnet_jit import ResNetJIT
|
||||
from .resnext import ResNeXt
|
||||
from .shuffle_transformer import ShuffleTransformer
|
||||
from .stdc import STDCContextPathNet, STDCNet
|
||||
from .swin_transformer import SwinTransformer
|
||||
from .swin_transformer3d import SwinTransformer3D
|
||||
from .vision_transformer import VisionTransformer
|
||||
|
|
|
@ -0,0 +1,465 @@
|
|||
# Modified from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/backbones/stdc.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
|
||||
|
||||
from easycv.models import builder
|
||||
from ..registry import BACKBONES
|
||||
|
||||
|
||||
class AttentionRefinementModule(BaseModule):
|
||||
"""Attention Refinement Module (ARM) to refine the features of each stage.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
Returns:
|
||||
x_out (torch.Tensor): Feature map of Attention Refinement Module.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channel,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super(AttentionRefinementModule, self).__init__(init_cfg=init_cfg)
|
||||
self.conv_layer = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.atten_conv_layer = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
ConvModule(
|
||||
in_channels=out_channel,
|
||||
out_channels=out_channel,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None), nn.Sigmoid())
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_layer(x)
|
||||
x_atten = self.atten_conv_layer(x)
|
||||
x_out = x * x_atten
|
||||
return x_out
|
||||
|
||||
|
||||
class STDCModule(BaseModule):
|
||||
"""STDCModule.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels before scaling.
|
||||
stride (int): The number of stride for the first conv layer.
|
||||
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||
act_cfg (dict): The activation config for conv layers.
|
||||
num_convs (int): Numbers of conv layers.
|
||||
fusion_type (str): Type of fusion operation. Default: 'add'.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
norm_cfg=None,
|
||||
act_cfg=None,
|
||||
num_convs=4,
|
||||
fusion_type='add',
|
||||
init_cfg=None):
|
||||
super(STDCModule, self).__init__(init_cfg=init_cfg)
|
||||
assert num_convs > 1
|
||||
assert fusion_type in ['add', 'cat']
|
||||
self.stride = stride
|
||||
self.with_downsample = True if self.stride == 2 else False
|
||||
self.fusion_type = fusion_type
|
||||
|
||||
self.layers = ModuleList()
|
||||
conv_0 = ConvModule(
|
||||
in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg)
|
||||
|
||||
if self.with_downsample:
|
||||
self.downsample = ConvModule(
|
||||
out_channels // 2,
|
||||
out_channels // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
groups=out_channels // 2,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
if self.fusion_type == 'add':
|
||||
self.layers.append(nn.Sequential(conv_0, self.downsample))
|
||||
self.skip = Sequential(
|
||||
ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None))
|
||||
else:
|
||||
self.layers.append(conv_0)
|
||||
self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
|
||||
else:
|
||||
self.layers.append(conv_0)
|
||||
|
||||
for i in range(1, num_convs):
|
||||
out_factor = 2**(i + 1) if i != num_convs - 1 else 2**i
|
||||
self.layers.append(
|
||||
ConvModule(
|
||||
out_channels // 2**i,
|
||||
out_channels // out_factor,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.fusion_type == 'add':
|
||||
out = self.forward_add(inputs)
|
||||
else:
|
||||
out = self.forward_cat(inputs)
|
||||
return out
|
||||
|
||||
def forward_add(self, inputs):
|
||||
layer_outputs = []
|
||||
x = inputs.clone()
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
layer_outputs.append(x)
|
||||
if self.with_downsample:
|
||||
inputs = self.skip(inputs)
|
||||
|
||||
return torch.cat(layer_outputs, dim=1) + inputs
|
||||
|
||||
def forward_cat(self, inputs):
|
||||
x0 = self.layers[0](inputs)
|
||||
layer_outputs = [x0]
|
||||
for i, layer in enumerate(self.layers[1:]):
|
||||
if i == 0:
|
||||
if self.with_downsample:
|
||||
x = layer(self.downsample(x0))
|
||||
else:
|
||||
x = layer(x0)
|
||||
else:
|
||||
x = layer(x)
|
||||
layer_outputs.append(x)
|
||||
if self.with_downsample:
|
||||
layer_outputs[0] = self.skip(x0)
|
||||
return torch.cat(layer_outputs, dim=1)
|
||||
|
||||
|
||||
class FeatureFusionModule(BaseModule):
|
||||
"""Feature Fusion Module. This module is different from FeatureFusionModule
|
||||
in BiSeNetV1. It uses two ConvModules in `self.attention` whose inter
|
||||
channel number is calculated by given `scale_factor`, while
|
||||
FeatureFusionModule in BiSeNetV1 only uses one ConvModule in
|
||||
`self.conv_atten`.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
scale_factor (int): The number of channel scale factor.
|
||||
Default: 4.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): The activation config for conv layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
scale_factor=4,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super(FeatureFusionModule, self).__init__(init_cfg=init_cfg)
|
||||
channels = out_channels // scale_factor
|
||||
self.conv0 = ConvModule(
|
||||
in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
self.attention = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
ConvModule(
|
||||
out_channels,
|
||||
channels,
|
||||
1,
|
||||
norm_cfg=None,
|
||||
bias=False,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
channels,
|
||||
out_channels,
|
||||
1,
|
||||
norm_cfg=None,
|
||||
bias=False,
|
||||
act_cfg=None), nn.Sigmoid())
|
||||
|
||||
def forward(self, spatial_inputs, context_inputs):
|
||||
inputs = torch.cat([spatial_inputs, context_inputs], dim=1)
|
||||
x = self.conv0(inputs)
|
||||
attn = self.attention(x)
|
||||
x_attn = x * attn
|
||||
return x_attn + x
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class STDCNet(BaseModule):
|
||||
"""This backbone is the implementation of `Rethinking BiSeNet For Real-time
|
||||
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.
|
||||
|
||||
Args:
|
||||
stdc_type (int): The type of backbone structure,
|
||||
`STDCNet1` and`STDCNet2` denotes two main backbones in paper,
|
||||
whose FLOPs is 813M and 1446M, respectively.
|
||||
in_channels (int): The num of input_channels.
|
||||
channels (tuple[int]): The output channels for each stage.
|
||||
bottleneck_type (str): The type of STDC Module type, the value must
|
||||
be 'add' or 'cat'.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
act_cfg (dict): The activation config for conv layers.
|
||||
num_convs (int): Numbers of conv layer at each STDC Module.
|
||||
Default: 4.
|
||||
with_final_conv (bool): Whether add a conv layer at the Module output.
|
||||
Default: True.
|
||||
pretrained (str, optional): Model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
|
||||
Example:
|
||||
>>> import torch
|
||||
>>> stdc_type = 'STDCNet1'
|
||||
>>> in_channels = 3
|
||||
>>> channels = (32, 64, 256, 512, 1024)
|
||||
>>> bottleneck_type = 'cat'
|
||||
>>> inputs = torch.rand(1, 3, 1024, 2048)
|
||||
>>> self = STDCNet(stdc_type, in_channels,
|
||||
... channels, bottleneck_type).eval()
|
||||
>>> outputs = self.forward(inputs)
|
||||
>>> for i in range(len(outputs)):
|
||||
... print(f'outputs[{i}].shape = {outputs[i].shape}')
|
||||
outputs[0].shape = torch.Size([1, 256, 128, 256])
|
||||
outputs[1].shape = torch.Size([1, 512, 64, 128])
|
||||
outputs[2].shape = torch.Size([1, 1024, 32, 64])
|
||||
"""
|
||||
|
||||
arch_settings = {
|
||||
'STDCNet1': [(2, 1), (2, 1), (2, 1)],
|
||||
'STDCNet2': [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)]
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
stdc_type,
|
||||
in_channels,
|
||||
channels,
|
||||
bottleneck_type,
|
||||
norm_cfg,
|
||||
act_cfg,
|
||||
num_convs=4,
|
||||
with_final_conv=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super(STDCNet, self).__init__(init_cfg=init_cfg)
|
||||
assert stdc_type in self.arch_settings, \
|
||||
f'invalid structure {stdc_type} for STDCNet.'
|
||||
assert bottleneck_type in ['add', 'cat'],\
|
||||
f'bottleneck_type must be `add` or `cat`, got {bottleneck_type}'
|
||||
|
||||
assert len(channels) == 5,\
|
||||
f'invalid channels length {len(channels)} for STDCNet.'
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.stage_strides = self.arch_settings[stdc_type]
|
||||
self.prtrained = pretrained
|
||||
self.num_convs = num_convs
|
||||
self.with_final_conv = with_final_conv
|
||||
|
||||
self.stages = ModuleList([
|
||||
ConvModule(
|
||||
self.in_channels,
|
||||
self.channels[0],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
self.channels[0],
|
||||
self.channels[1],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
])
|
||||
# `self.num_shallow_features` is the number of shallow modules in
|
||||
# `STDCNet`, which is noted as `Stage1` and `Stage2` in original paper.
|
||||
# They are both not used for following modules like Attention
|
||||
# Refinement Module and Feature Fusion Module.
|
||||
# Thus they would be cut from `outs`. Please refer to Figure 4
|
||||
# of original paper for more details.
|
||||
self.num_shallow_features = len(self.stages)
|
||||
|
||||
for strides in self.stage_strides:
|
||||
idx = len(self.stages) - 1
|
||||
self.stages.append(
|
||||
self._make_stage(self.channels[idx], self.channels[idx + 1],
|
||||
strides, norm_cfg, act_cfg, bottleneck_type))
|
||||
# After appending, `self.stages` is a ModuleList including several
|
||||
# shallow modules and STDCModules.
|
||||
# (len(self.stages) ==
|
||||
# self.num_shallow_features + len(self.stage_strides))
|
||||
if self.with_final_conv:
|
||||
self.final_conv = ConvModule(
|
||||
self.channels[-1],
|
||||
max(1024, self.channels[-1]),
|
||||
1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def _make_stage(self, in_channels, out_channels, strides, norm_cfg,
|
||||
act_cfg, bottleneck_type):
|
||||
layers = []
|
||||
for i, stride in enumerate(strides):
|
||||
layers.append(
|
||||
STDCModule(
|
||||
in_channels if i == 0 else out_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
norm_cfg,
|
||||
act_cfg,
|
||||
num_convs=self.num_convs,
|
||||
fusion_type=bottleneck_type))
|
||||
return Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
for stage in self.stages:
|
||||
x = stage(x)
|
||||
outs.append(x)
|
||||
if self.with_final_conv:
|
||||
outs[-1] = self.final_conv(outs[-1])
|
||||
outs = outs[self.num_shallow_features:]
|
||||
return tuple(outs)
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class STDCContextPathNet(BaseModule):
|
||||
"""STDCNet with Context Path. The `outs` below is a list of three feature
|
||||
maps from deep to shallow, whose height and width is from small to big,
|
||||
respectively. The biggest feature map of `outs` is outputted for
|
||||
`STDCHead`, where Detail Loss would be calculated by Detail Ground-truth.
|
||||
The other two feature maps are used for Attention Refinement Module,
|
||||
respectively. Besides, the biggest feature map of `outs` and the last
|
||||
output of Attention Refinement Module are concatenated for Feature Fusion
|
||||
Module. Then, this fusion feature map `feat_fuse` would be outputted for
|
||||
`decode_head`. More details please refer to Figure 4 of original paper.
|
||||
|
||||
Args:
|
||||
backbone_cfg (dict): Config dict for stdc backbone.
|
||||
last_in_channels (tuple(int)), The number of channels of last
|
||||
two feature maps from stdc backbone. Default: (1024, 512).
|
||||
out_channels (int): The channels of output feature maps.
|
||||
Default: 128.
|
||||
ffm_cfg (dict): Config dict for Feature Fusion Module. Default:
|
||||
`dict(in_channels=512, out_channels=256, scale_factor=4)`.
|
||||
upsample_mode (str): Algorithm used for upsampling:
|
||||
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
|
||||
``'trilinear'``. Default: ``'nearest'``.
|
||||
align_corners (str): align_corners argument of F.interpolate. It
|
||||
must be `None` if upsample_mode is ``'nearest'``. Default: None.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
|
||||
Return:
|
||||
outputs (tuple): The tuple of list of output feature map for
|
||||
auxiliary heads and decoder head.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone_cfg,
|
||||
last_in_channels=(1024, 512),
|
||||
out_channels=128,
|
||||
ffm_cfg=dict(
|
||||
in_channels=512, out_channels=256, scale_factor=4),
|
||||
upsample_mode='nearest',
|
||||
align_corners=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
init_cfg=None):
|
||||
super(STDCContextPathNet, self).__init__(init_cfg=init_cfg)
|
||||
self.backbone = builder.build_backbone(backbone_cfg)
|
||||
self.arms = ModuleList()
|
||||
self.convs = ModuleList()
|
||||
for channels in last_in_channels:
|
||||
self.arms.append(AttentionRefinementModule(channels, out_channels))
|
||||
self.convs.append(
|
||||
ConvModule(
|
||||
out_channels,
|
||||
out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg))
|
||||
self.conv_avg = ConvModule(
|
||||
last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg)
|
||||
|
||||
self.ffm = FeatureFusionModule(**ffm_cfg)
|
||||
|
||||
self.upsample_mode = upsample_mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
outs = list(self.backbone(x))
|
||||
avg = F.adaptive_avg_pool2d(outs[-1], 1)
|
||||
avg_feat = self.conv_avg(avg)
|
||||
|
||||
feature_up = F.interpolate(
|
||||
avg_feat,
|
||||
size=outs[-1].shape[2:],
|
||||
mode=self.upsample_mode,
|
||||
align_corners=self.align_corners)
|
||||
arms_out = []
|
||||
for i in range(len(self.arms)):
|
||||
x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up
|
||||
feature_up = F.interpolate(
|
||||
x_arm,
|
||||
size=outs[len(outs) - 1 - i - 1].shape[2:],
|
||||
mode=self.upsample_mode,
|
||||
align_corners=self.align_corners)
|
||||
feature_up = self.convs[i](feature_up)
|
||||
arms_out.append(feature_up)
|
||||
|
||||
feat_fuse = self.ffm(outs[0], arms_out[1])
|
||||
|
||||
# The `outputs` has four feature maps.
|
||||
# `outs[0]` is outputted for `STDCHead` auxiliary head.
|
||||
# Two feature maps of `arms_out` are outputted for auxiliary head.
|
||||
# `feat_fuse` is outputted for decoder head.
|
||||
outputs = [outs[0]] + list(arms_out) + [feat_fuse]
|
||||
return tuple(outputs)
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .cross_entropy_loss import CrossEntropyLoss
|
||||
from .det_db_loss import DBLoss
|
||||
from .dice_loss import DiceLoss
|
||||
from .face_keypoint_loss import FacePoseLoss, WingLossWithPose
|
||||
from .focal_loss import FocalLoss, VarifocalLoss
|
||||
from .iou_loss import GIoULoss, IoULoss, YOLOX_IOULoss
|
||||
|
@ -22,5 +23,6 @@ __all__ = [
|
|||
'FocalLoss2d', 'DistributeMSELoss', 'CrossEntropyLossWithLabelSmooth',
|
||||
'AMSoftmaxLoss', 'ModelParallelSoftmaxLoss', 'ModelParallelAMSoftmaxLoss',
|
||||
'SoftTargetCrossEntropy', 'CDNCriterion', 'DNCriterion', 'DBLoss',
|
||||
'HungarianMatcher', 'SetCriterion', 'L1Loss', 'MultiLoss', 'SmoothL1Loss'
|
||||
'HungarianMatcher', 'SetCriterion', 'L1Loss', 'MultiLoss', 'SmoothL1Loss',
|
||||
'DiceLoss'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
# Borrowed from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/losses/dice_loss.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from easycv.models.builder import LOSSES
|
||||
from .utils import get_class_weight, weighted_loss
|
||||
|
||||
|
||||
@weighted_loss
|
||||
def dice_loss(pred,
|
||||
target,
|
||||
valid_mask,
|
||||
smooth=1,
|
||||
exponent=2,
|
||||
class_weight=None,
|
||||
ignore_index=255):
|
||||
assert pred.shape[0] == target.shape[0]
|
||||
total_loss = 0
|
||||
num_classes = pred.shape[1]
|
||||
for i in range(num_classes):
|
||||
if i != ignore_index:
|
||||
dice_loss = binary_dice_loss(
|
||||
pred[:, i],
|
||||
target[..., i],
|
||||
valid_mask=valid_mask,
|
||||
smooth=smooth,
|
||||
exponent=exponent)
|
||||
if class_weight is not None:
|
||||
dice_loss *= class_weight[i]
|
||||
total_loss += dice_loss
|
||||
return total_loss / num_classes
|
||||
|
||||
|
||||
@weighted_loss
|
||||
def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwargs):
|
||||
assert pred.shape[0] == target.shape[0]
|
||||
pred = pred.reshape(pred.shape[0], -1)
|
||||
target = target.reshape(target.shape[0], -1)
|
||||
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
|
||||
|
||||
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth
|
||||
den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth
|
||||
|
||||
return 1 - num / den
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class DiceLoss(nn.Module):
|
||||
"""DiceLoss.
|
||||
|
||||
This loss is proposed in `V-Net: Fully Convolutional Neural Networks for
|
||||
Volumetric Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
|
||||
|
||||
Args:
|
||||
smooth (float): A float number to smooth loss, and avoid NaN error.
|
||||
Default: 1
|
||||
exponent (float): An float number to calculate denominator
|
||||
value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2.
|
||||
reduction (str, optional): The method used to reduce the loss. Options
|
||||
are "none", "mean" and "sum". This parameter only works when
|
||||
per_image is True. Default: 'mean'.
|
||||
class_weight (list[float] | str, optional): Weight of each class. If in
|
||||
str format, read them from a file. Defaults to None.
|
||||
loss_weight (float, optional): Weight of the loss. Default to 1.0.
|
||||
ignore_index (int | None): The label index to be ignored. Default: 255.
|
||||
loss_name (str, optional): Name of the loss item. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_dice'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
smooth=1,
|
||||
exponent=2,
|
||||
reduction='mean',
|
||||
class_weight=None,
|
||||
loss_weight=1.0,
|
||||
ignore_index=255,
|
||||
loss_name='loss_dice',
|
||||
**kwargs):
|
||||
super(DiceLoss, self).__init__()
|
||||
self.smooth = smooth
|
||||
self.exponent = exponent
|
||||
self.reduction = reduction
|
||||
self.class_weight = get_class_weight(class_weight)
|
||||
self.loss_weight = loss_weight
|
||||
self.ignore_index = ignore_index
|
||||
self._loss_name = loss_name
|
||||
|
||||
def forward(self,
|
||||
pred,
|
||||
target,
|
||||
avg_factor=None,
|
||||
reduction_override=None,
|
||||
**kwargs):
|
||||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
if self.class_weight is not None:
|
||||
class_weight = pred.new_tensor(self.class_weight)
|
||||
else:
|
||||
class_weight = None
|
||||
|
||||
pred = F.softmax(pred, dim=1)
|
||||
num_classes = pred.shape[1]
|
||||
one_hot_target = F.one_hot(
|
||||
torch.clamp(target.long(), 0, num_classes - 1),
|
||||
num_classes=num_classes)
|
||||
valid_mask = (target != self.ignore_index).long()
|
||||
|
||||
loss = self.loss_weight * dice_loss(
|
||||
pred,
|
||||
one_hot_target,
|
||||
valid_mask=valid_mask,
|
||||
reduction=reduction,
|
||||
avg_factor=avg_factor,
|
||||
smooth=self.smooth,
|
||||
exponent=self.exponent,
|
||||
class_weight=class_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
return loss
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
"""Loss Name.
|
||||
|
||||
This function must be implemented and will return the name of this
|
||||
loss function. This name will be used to combine different loss items
|
||||
by simple sum operation. In addition, if you want this loss item to be
|
||||
included into the backward graph, `loss_` must be the prefix of the
|
||||
name.
|
||||
Returns:
|
||||
str: The name of this loss item.
|
||||
"""
|
||||
return self._loss_name
|
|
@ -1,12 +1,32 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import functools
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from easycv.framework.errors import ValueError
|
||||
|
||||
|
||||
def get_class_weight(class_weight):
|
||||
"""Get class weight for loss function.
|
||||
|
||||
Args:
|
||||
class_weight (list[float] | str | None): If class_weight is a str,
|
||||
take it as a file name and read from it.
|
||||
"""
|
||||
if isinstance(class_weight, str):
|
||||
# take it as a file path
|
||||
if class_weight.endswith('.npy'):
|
||||
class_weight = np.load(class_weight)
|
||||
else:
|
||||
# pkl, json or yaml
|
||||
class_weight = mmcv.load(class_weight)
|
||||
|
||||
return class_weight
|
||||
|
||||
|
||||
def reduce_loss(loss, reduction):
|
||||
"""Reduce loss as specified.
|
||||
|
||||
|
|
|
@ -2,3 +2,4 @@
|
|||
from .encoder_decoder import EncoderDecoder
|
||||
from .heads import *
|
||||
from .mask2former import Mask2Former
|
||||
from .sampler import BasePixelSampler, OHEMPixelSampler, build_pixel_sampler
|
||||
|
|
|
@ -157,7 +157,6 @@ class EncoderDecoder(BaseModel):
|
|||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
|
||||
x = self.extract_feat(img)
|
||||
losses = dict()
|
||||
loss_decode = self._decode_head_forward_train(x, img_metas,
|
||||
|
|
|
@ -2,6 +2,9 @@
|
|||
from .fcn_head import FCNHead
|
||||
from .mask2former_head import Mask2FormerHead
|
||||
from .segformer_head import SegformerHead
|
||||
from .stdc_head import STDCHead
|
||||
from .uper_head import UPerHead
|
||||
|
||||
__all__ = ['FCNHead', 'UPerHead', 'Mask2FormerHead', 'SegformerHead']
|
||||
__all__ = [
|
||||
'FCNHead', 'UPerHead', 'Mask2FormerHead', 'SegformerHead', 'STDCHead'
|
||||
]
|
||||
|
|
|
@ -11,6 +11,7 @@ from easycv.framework.errors import TypeError
|
|||
from easycv.models.builder import build_loss
|
||||
from easycv.models.utils.ops import resize_tensor
|
||||
from easycv.utils.logger import print_log
|
||||
from ..sampler import build_pixel_sampler
|
||||
|
||||
|
||||
# Modified from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/decode_head.py
|
||||
|
@ -69,6 +70,7 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
|
|||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
loss_weight=1.0),
|
||||
sampler=None,
|
||||
ignore_index=255,
|
||||
align_corners=False,
|
||||
init_cfg=dict(
|
||||
|
@ -97,6 +99,11 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
|
|||
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
|
||||
but got {type(loss_decode)}')
|
||||
|
||||
if sampler is not None:
|
||||
self.sampler = build_pixel_sampler(sampler, context=self)
|
||||
else:
|
||||
self.sampler = None
|
||||
|
||||
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
||||
if dropout_ratio > 0:
|
||||
self.dropout = nn.Dropout2d(dropout_ratio)
|
||||
|
@ -232,7 +239,10 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
|
|||
size=seg_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
if self.sampler is not None:
|
||||
seg_weight = self.sampler.sample(seg_logit, seg_label)
|
||||
else:
|
||||
seg_weight = None
|
||||
seg_label = seg_label.squeeze(1)
|
||||
|
||||
if not isinstance(self.loss_decode, nn.ModuleList):
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
# Modified from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/stdc_head.py
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from easycv.models.builder import HEADS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class STDCHead(FCNHead):
|
||||
"""This head is the implementation of `Rethinking BiSeNet For Real-time
|
||||
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.
|
||||
|
||||
Args:
|
||||
boundary_threshold (float): The threshold of calculating boundary.
|
||||
Default: 0.1.
|
||||
"""
|
||||
|
||||
def __init__(self, boundary_threshold=0.1, **kwargs):
|
||||
super(STDCHead, self).__init__(**kwargs)
|
||||
self.boundary_threshold = boundary_threshold
|
||||
# Using register buffer to make laplacian kernel on the same
|
||||
# device of `seg_label`.
|
||||
self.register_buffer(
|
||||
'laplacian_kernel',
|
||||
torch.tensor([-1, -1, -1, -1, 8, -1, -1, -1, -1],
|
||||
dtype=torch.float32,
|
||||
requires_grad=False).reshape((1, 1, 3, 3)))
|
||||
self.fusion_kernel = torch.nn.Parameter(
|
||||
torch.tensor([[6. / 10], [3. / 10], [1. / 10]],
|
||||
dtype=torch.float32).reshape(1, 3, 1, 1),
|
||||
requires_grad=False)
|
||||
|
||||
def losses(self, seg_logit, seg_label):
|
||||
"""Compute Detail Aggregation Loss."""
|
||||
# Note: The paper claims `fusion_kernel` is a trainable 1x1 conv
|
||||
# parameters. However, it is a constant in original repo and other
|
||||
# codebase because it would not be added into computation graph
|
||||
# after threshold operation.
|
||||
seg_label = seg_label.to(self.laplacian_kernel)
|
||||
boundary_targets = F.conv2d(
|
||||
seg_label, self.laplacian_kernel, padding=1)
|
||||
boundary_targets = boundary_targets.clamp(min=0)
|
||||
boundary_targets[boundary_targets > self.boundary_threshold] = 1
|
||||
boundary_targets[boundary_targets <= self.boundary_threshold] = 0
|
||||
|
||||
boundary_targets_x2 = F.conv2d(
|
||||
seg_label, self.laplacian_kernel, stride=2, padding=1)
|
||||
boundary_targets_x2 = boundary_targets_x2.clamp(min=0)
|
||||
|
||||
boundary_targets_x4 = F.conv2d(
|
||||
seg_label, self.laplacian_kernel, stride=4, padding=1)
|
||||
boundary_targets_x4 = boundary_targets_x4.clamp(min=0)
|
||||
|
||||
boundary_targets_x4_up = F.interpolate(
|
||||
boundary_targets_x4, boundary_targets.shape[2:], mode='nearest')
|
||||
boundary_targets_x2_up = F.interpolate(
|
||||
boundary_targets_x2, boundary_targets.shape[2:], mode='nearest')
|
||||
|
||||
boundary_targets_x2_up[
|
||||
boundary_targets_x2_up > self.boundary_threshold] = 1
|
||||
boundary_targets_x2_up[
|
||||
boundary_targets_x2_up <= self.boundary_threshold] = 0
|
||||
|
||||
boundary_targets_x4_up[
|
||||
boundary_targets_x4_up > self.boundary_threshold] = 1
|
||||
boundary_targets_x4_up[
|
||||
boundary_targets_x4_up <= self.boundary_threshold] = 0
|
||||
|
||||
boundary_targets_pyramids = torch.stack(
|
||||
(boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up),
|
||||
dim=1)
|
||||
|
||||
boundary_targets_pyramids = boundary_targets_pyramids.squeeze(2)
|
||||
boundary_targets_pyramid = F.conv2d(boundary_targets_pyramids,
|
||||
self.fusion_kernel)
|
||||
|
||||
boundary_targets_pyramid[
|
||||
boundary_targets_pyramid > self.boundary_threshold] = 1
|
||||
boundary_targets_pyramid[
|
||||
boundary_targets_pyramid <= self.boundary_threshold] = 0
|
||||
|
||||
loss = super(STDCHead, self).losses(seg_logit,
|
||||
boundary_targets_pyramid.long())
|
||||
return loss
|
|
@ -0,0 +1,5 @@
|
|||
from .base_pixel_sampler import BasePixelSampler
|
||||
from .builder import build_pixel_sampler
|
||||
from .ohem_pixel_sampler import OHEMPixelSampler
|
||||
|
||||
__all__ = ['BasePixelSampler', 'OHEMPixelSampler', 'build_pixel_sampler']
|
|
@ -0,0 +1,13 @@
|
|||
# Borrowed from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
class BasePixelSampler(metaclass=ABCMeta):
|
||||
"""Base class of pixel sampler."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sample(self, seg_logit, seg_label):
|
||||
"""Placeholder for sample function."""
|
|
@ -0,0 +1,9 @@
|
|||
# Borrowed from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg
|
||||
from mmcv.utils import Registry, build_from_cfg
|
||||
|
||||
PIXEL_SAMPLERS = Registry('pixel sampler')
|
||||
|
||||
|
||||
def build_pixel_sampler(cfg, **default_args):
|
||||
"""Build pixel sampler for segmentation map."""
|
||||
return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args)
|
|
@ -0,0 +1,85 @@
|
|||
# Borrowed from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .base_pixel_sampler import BasePixelSampler
|
||||
from .builder import PIXEL_SAMPLERS
|
||||
|
||||
|
||||
@PIXEL_SAMPLERS.register_module()
|
||||
class OHEMPixelSampler(BasePixelSampler):
|
||||
"""Online Hard Example Mining Sampler for segmentation.
|
||||
|
||||
Args:
|
||||
context (nn.Module): The context of sampler, subclass of
|
||||
:obj:`BaseDecodeHead`.
|
||||
thresh (float, optional): The threshold for hard example selection.
|
||||
Below which, are prediction with low confidence. If not
|
||||
specified, the hard examples will be pixels of top ``min_kept``
|
||||
loss. Default: None.
|
||||
min_kept (int, optional): The minimum number of predictions to keep.
|
||||
Default: 100000.
|
||||
"""
|
||||
|
||||
def __init__(self, context, thresh=None, min_kept=100000):
|
||||
super(OHEMPixelSampler, self).__init__()
|
||||
self.context = context
|
||||
assert min_kept > 1
|
||||
self.thresh = thresh
|
||||
self.min_kept = min_kept
|
||||
|
||||
def sample(self, seg_logit, seg_label):
|
||||
"""Sample pixels that have high loss or with low prediction confidence.
|
||||
|
||||
Args:
|
||||
seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)
|
||||
seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: segmentation weight, shape (N, H, W)
|
||||
"""
|
||||
with torch.no_grad():
|
||||
assert seg_logit.shape[2:] == seg_label.shape[2:]
|
||||
assert seg_label.shape[1] == 1
|
||||
seg_label = seg_label.squeeze(1).long()
|
||||
batch_kept = self.min_kept * seg_label.size(0)
|
||||
valid_mask = seg_label != self.context.ignore_index
|
||||
seg_weight = seg_logit.new_zeros(size=seg_label.size())
|
||||
valid_seg_weight = seg_weight[valid_mask]
|
||||
if self.thresh is not None:
|
||||
seg_prob = F.softmax(seg_logit, dim=1)
|
||||
|
||||
tmp_seg_label = seg_label.clone().unsqueeze(1)
|
||||
tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0
|
||||
seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)
|
||||
sort_prob, sort_indices = seg_prob[valid_mask].sort()
|
||||
|
||||
if sort_prob.numel() > 0:
|
||||
min_threshold = sort_prob[min(batch_kept,
|
||||
sort_prob.numel() - 1)]
|
||||
else:
|
||||
min_threshold = 0.0
|
||||
threshold = max(min_threshold, self.thresh)
|
||||
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
|
||||
else:
|
||||
if not isinstance(self.context.loss_decode, nn.ModuleList):
|
||||
losses_decode = [self.context.loss_decode]
|
||||
else:
|
||||
losses_decode = self.context.loss_decode
|
||||
losses = 0.0
|
||||
for loss_module in losses_decode:
|
||||
losses += loss_module(
|
||||
seg_logit,
|
||||
seg_label,
|
||||
weight=None,
|
||||
ignore_index=self.context.ignore_index,
|
||||
reduction_override='none')
|
||||
|
||||
# faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
|
||||
_, sort_indices = losses[valid_mask].sort(descending=True)
|
||||
valid_seg_weight[sort_indices[:batch_kept]] = 1.
|
||||
|
||||
seg_weight[valid_mask] = valid_seg_weight
|
||||
|
||||
return seg_weight
|
|
@ -29,7 +29,7 @@ class SegmentationPredictor(PredictorV2):
|
|||
|
||||
def __init__(self,
|
||||
model_path,
|
||||
config_file,
|
||||
config_file=None,
|
||||
batch_size=1,
|
||||
device=None,
|
||||
save_results=False,
|
||||
|
@ -54,7 +54,7 @@ class SegmentationPredictor(PredictorV2):
|
|||
**kwargs)
|
||||
|
||||
self.CLASSES = self.cfg.CLASSES
|
||||
self.PALETTE = self.cfg.PALETTE
|
||||
self.PALETTE = self.cfg.get('PALETTE', None)
|
||||
|
||||
def show_result(self,
|
||||
img,
|
||||
|
@ -90,7 +90,8 @@ class SegmentationPredictor(PredictorV2):
|
|||
|
||||
img = mmcv.imread(img)
|
||||
img = img.copy()
|
||||
seg = result[0]
|
||||
# seg = result[0]
|
||||
seg = result
|
||||
if palette is None:
|
||||
if self.PALETTE is None:
|
||||
# Get random state before set seed,
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
albumentations
|
||||
cityscapesscripts
|
||||
dataclasses
|
||||
decord
|
||||
einops
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import random
|
||||
import unittest
|
||||
|
||||
from tests.ut_config import SEG_DATA_SAMLL_CITYSCAPES
|
||||
|
||||
from easycv.datasets.segmentation.data_sources.cityscapes import \
|
||||
SegSourceCityscapes
|
||||
|
||||
CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
|
||||
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
|
||||
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
|
||||
'bicycle')
|
||||
|
||||
|
||||
class SegSourceCityscapesTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
def test_cityscapes(self):
|
||||
|
||||
data_source = SegSourceCityscapes(
|
||||
img_root=os.path.join(SEG_DATA_SAMLL_CITYSCAPES, 'leftImg8bit'),
|
||||
label_root=os.path.join(SEG_DATA_SAMLL_CITYSCAPES, 'gtFine'),
|
||||
classes=CLASSES,
|
||||
)
|
||||
|
||||
index_list = random.choices(list(range(20)), k=3)
|
||||
for idx in index_list:
|
||||
data = data_source[idx]
|
||||
self.assertIn('img_fields', data)
|
||||
self.assertIn('seg_fields', data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from easycv.models import build_model
|
||||
|
||||
|
||||
class StdcTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
def test_stdc(self):
|
||||
|
||||
norm_cfg = dict(type='BN', requires_grad=True)
|
||||
model_cfg = dict(
|
||||
type='EncoderDecoder',
|
||||
backbone=dict(
|
||||
type='STDCContextPathNet',
|
||||
backbone_cfg=dict(
|
||||
type='STDCNet',
|
||||
stdc_type='STDCNet1',
|
||||
in_channels=3,
|
||||
channels=(32, 64, 256, 512, 1024),
|
||||
bottleneck_type='cat',
|
||||
num_convs=4,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
with_final_conv=False),
|
||||
last_in_channels=(1024, 512),
|
||||
out_channels=128,
|
||||
ffm_cfg=dict(
|
||||
in_channels=384, out_channels=256, scale_factor=4)),
|
||||
decode_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=256,
|
||||
channels=256,
|
||||
num_convs=1,
|
||||
num_classes=19,
|
||||
in_index=3,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=True,
|
||||
sampler=dict(
|
||||
type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
loss_weight=1.0)),
|
||||
auxiliary_head=[
|
||||
dict(
|
||||
type='FCNHead',
|
||||
in_channels=128,
|
||||
channels=64,
|
||||
num_convs=1,
|
||||
num_classes=19,
|
||||
in_index=2,
|
||||
norm_cfg=norm_cfg,
|
||||
concat_input=False,
|
||||
align_corners=False,
|
||||
sampler=dict(
|
||||
type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
loss_weight=1.0)),
|
||||
dict(
|
||||
type='FCNHead',
|
||||
in_channels=128,
|
||||
channels=64,
|
||||
num_convs=1,
|
||||
num_classes=19,
|
||||
in_index=1,
|
||||
norm_cfg=norm_cfg,
|
||||
concat_input=False,
|
||||
align_corners=False,
|
||||
sampler=dict(
|
||||
type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
loss_weight=1.0)),
|
||||
dict(
|
||||
type='STDCHead',
|
||||
in_channels=256,
|
||||
channels=64,
|
||||
num_convs=1,
|
||||
num_classes=2,
|
||||
boundary_threshold=0.1,
|
||||
in_index=0,
|
||||
norm_cfg=norm_cfg,
|
||||
concat_input=False,
|
||||
align_corners=True,
|
||||
loss_decode=[
|
||||
dict(
|
||||
type='CrossEntropyLoss',
|
||||
loss_name='loss_ce',
|
||||
use_sigmoid=True,
|
||||
loss_weight=1.0),
|
||||
dict(
|
||||
type='DiceLoss',
|
||||
loss_name='loss_dice',
|
||||
loss_weight=1.0)
|
||||
]),
|
||||
],
|
||||
train_cfg=dict(),
|
||||
test_cfg=dict(mode='whole'),
|
||||
)
|
||||
|
||||
model = build_model(model_cfg).to('cuda')
|
||||
|
||||
img = torch.rand(2, 3, 512, 1024).to('cuda')
|
||||
gt_semantic_seg = torch.randint(
|
||||
low=0, high=18, size=(2, 1, 512, 1024)).to('cuda')
|
||||
|
||||
train_output = model.forward_train(img, [], gt_semantic_seg)
|
||||
self.assertIn('decode.loss_ce', train_output)
|
||||
self.assertIn('aux_0.loss_ce', train_output)
|
||||
self.assertIn('aux_1.loss_ce', train_output)
|
||||
self.assertIn('aux_2.loss_ce', train_output)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -159,6 +159,8 @@ SEG_DATA_SMALL_COCO_STUFF_10K = os.path.join(
|
|||
BASE_LOCAL_PATH, 'data/segmentation/small_coco_stuff/small_coco_stuff10k')
|
||||
SEG_DATA_SAMLL_COCO_STUFF_164K = os.path.join(
|
||||
BASE_LOCAL_PATH, 'data/segmentation/small_coco_stuff/small_coco_stuff164k')
|
||||
SEG_DATA_SAMLL_CITYSCAPES = os.path.join(BASE_LOCAL_PATH,
|
||||
'data/segmentation/small_cityscapes')
|
||||
|
||||
# OCR data
|
||||
SMALL_OCR_CLS_DATA = os.path.join(BASE_LOCAL_PATH, 'data/ocr/small_ocr_cls')
|
||||
|
|
Loading…
Reference in New Issue