[features] Support stdc (#284)

* add stdc semantic segmentation algorithm
pull/286/head
yhq 2023-02-16 14:00:59 +08:00 committed by GitHub
parent 2fe73eee91
commit 26cd12ab42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 1350 additions and 8 deletions

View File

@ -0,0 +1,198 @@
_base_ = ['configs/base.py']
# warning batch_size need >= 2
# model
norm_cfg = dict(type='BN', requires_grad=True)
model = dict(
type='EncoderDecoder',
backbone=dict(
type='STDCContextPathNet',
backbone_cfg=dict(
type='STDCNet',
stdc_type='STDCNet1',
in_channels=3,
channels=(32, 64, 256, 512, 1024),
bottleneck_type='cat',
num_convs=4,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'),
with_final_conv=False),
last_in_channels=(1024, 512),
out_channels=128,
ffm_cfg=dict(in_channels=384, out_channels=256, scale_factor=4)),
decode_head=dict(
type='FCNHead',
in_channels=256,
channels=256,
num_convs=1,
num_classes=19,
in_index=3,
concat_input=False,
dropout_ratio=0.1,
norm_cfg=norm_cfg,
align_corners=True,
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=[
dict(
type='FCNHead',
in_channels=128,
channels=64,
num_convs=1,
num_classes=19,
in_index=2,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
dict(
type='FCNHead',
in_channels=128,
channels=64,
num_convs=1,
num_classes=19,
in_index=1,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
dict(
type='STDCHead',
in_channels=256,
channels=64,
num_convs=1,
num_classes=2,
boundary_threshold=0.1,
in_index=0,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=True,
loss_decode=[
dict(
type='CrossEntropyLoss',
loss_name='loss_ce',
use_sigmoid=True,
loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=1.0)
]),
],
train_cfg=dict(),
test_cfg=dict(mode='whole'),
pretrained=
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/stdc/pretrain/stdc1_easycv.pth'
)
# dataset
CLASSES = [
'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light',
'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
'truck', 'bus', 'train', 'motorcycle', 'bicycle'
]
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 1024)
train_pipeline = [
dict(type='MMResize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
dict(type='SegRandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='MMRandomFlip', flip_ratio=0.5),
dict(type='MMPhotoMetricDistortion'),
dict(type='MMNormalize', **img_norm_cfg),
dict(type='MMPad', size=crop_size),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
keys=['img', 'gt_semantic_seg'],
meta_keys=('filename', 'ori_filename', 'ori_shape', 'img_shape',
'pad_shape', 'scale_factor', 'flip', 'flip_direction',
'img_norm_cfg')),
]
test_pipeline = [
dict(
type='MMMultiScaleFlipAug',
img_scale=(2048, 1024),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='MMResize', keep_ratio=True),
dict(type='MMRandomFlip'),
dict(type='MMNormalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(
type='Collect',
keys=['img'],
meta_keys=('filename', 'ori_filename', 'ori_shape',
'img_shape', 'pad_shape', 'scale_factor', 'flip',
'flip_direction', 'img_norm_cfg')),
])
]
dataset_type = 'SegDataset'
data_root = '../Cityscapes/'
train_img_root = data_root + 'leftImg8bit/train/'
train_label_root = data_root + 'gtFine/train/'
val_img_root = data_root + 'leftImg8bit/val/'
val_label_root = data_root + 'gtFine/val/'
data = dict(
imgs_per_gpu=6,
workers_per_gpu=4,
persistent_workers=True,
train=dict(
type=dataset_type,
ignore_index=255,
data_source=dict(
type='SegSourceCityscapes',
img_root=train_img_root,
label_root=train_label_root,
classes=CLASSES),
pipeline=train_pipeline),
val=dict(
imgs_per_gpu=1,
ignore_index=255,
type=dataset_type,
data_source=dict(
type='SegSourceCityscapes',
img_root=val_img_root,
label_root=val_label_root,
classes=CLASSES),
pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(
policy='CosineAnnealing',
min_lr=1e-4,
warmup='linear',
warmup_iters=10,
warmup_ratio=0.0001,
warmup_by_epoch=True,
by_epoch=False)
# runtime settings
total_epochs = 1290
checkpoint_config = dict(interval=10)
eval_config = dict(interval=10, gpu_collect=False)
eval_pipelines = [
dict(
mode='test',
evaluators=[
dict(
type='SegmentationEvaluator',
classes=CLASSES,
metric_names=['mIoU'])
],
)
]
# export config
export = dict(export_neck=True)
checkpoint_sync_export = True

View File

@ -0,0 +1,7 @@
_base_ = ['configs/segmentation/stdc/stdc1_cityscape_8xb6_e1290.py']
model = dict(
backbone=dict(backbone_cfg=dict(stdc_type='STDCNet2')),
pretrained=
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/stdc/pretrain/stdc2_easycv.pth'
)

View File

@ -15,6 +15,12 @@ Pretrained on **Pascal VOC 2012 + Aug**.
| ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| upernet_r50 | [upernet_r50_512x512_8xb4_60e_voc12aug](https://github.com/alibaba/EasyCV/tree/master/configs/segmentation/upernet/upernet_r50_512x512_8xb4_60e_voc12aug.py) | 23M/66M | 5.5 | 282.9ms | 76.59 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/upernet_r50/epoch_60.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/upernet_r50/20220706_114712.log.json) |
## STDC
trained on **Cityscapes**.
| Algorithm | Config | Params<br/>(backbone/total) | Train memory<br/>(GB) | inference time(V100)<br/>(ms/img) | mIoU | Download |
| ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| STDC1 | [stdc1_cityscape_8xb6_e1290](https://github.com/alibaba/EasyCV/tree/master/configs/segmentation/stdc/stdc1_cityscape_8xb6_e1290.py) | 7.7M/8.5M | 4.5 | 11.9ms | 75.4 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/stdc/stdc1_cityscapes/epoch_1250.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/stdc/stdc1_cityscapes/20230214_173123.log.json) |
## Mask2former
### Instance Segmentation on COCO

View File

@ -1,4 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .cityscapes import SegSourceCityscapes
from .coco import SegSourceCoco, SegSourceCoco2017
from .coco_stuff import SegSourceCocoStuff10k, SegSourceCocoStuff164k
from .raw import SegSourceRaw
@ -7,5 +8,5 @@ from .voc import SegSourceVoc2007, SegSourceVoc2010, SegSourceVoc2012
__all__ = [
'SegSourceRaw', 'SegSourceVoc2010', 'SegSourceVoc2007', 'SegSourceVoc2012',
'SegSourceCoco', 'SegSourceCoco2017', 'SegSourceCocoStuff164k',
'SegSourceCocoStuff10k'
'SegSourceCocoStuff10k', 'SegSourceCityscapes'
]

View File

@ -0,0 +1,129 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import logging
import os
import subprocess
import numpy as np
from easycv.datasets.registry import DATASOURCES
from easycv.file import io
from easycv.file.image import load_image as _load_img
from .raw import SegSourceRaw
try:
import cityscapesscripts.helpers.labels as CSLabels
except ModuleNotFoundError as e:
res = subprocess.call('pip install cityscapesscripts', shell=True)
if res != 0:
info_string = (
'\n\nAuto install failed! Please install cityscapesscripts with the following commands :\n'
'\t`pip install cityscapesscripts`\n')
raise ModuleNotFoundError(info_string)
def load_seg_map_cityscape(seg_path, reduce_zero_label):
gt_semantic_seg = _load_img(seg_path, mode='P')
gt_semantic_seg_copy = gt_semantic_seg.copy()
for labels in CSLabels.labels:
gt_semantic_seg_copy[gt_semantic_seg == labels.id] = labels.trainId
return {'gt_semantic_seg': gt_semantic_seg_copy}
@DATASOURCES.register_module
class SegSourceCityscapes(SegSourceRaw):
"""Cityscapes datasource
"""
CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
'bicycle')
PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
[0, 80, 100], [0, 0, 230], [119, 11, 32]]
def __init__(self,
img_suffix='_leftImg8bit.png',
label_suffix='_gtFine_labelIds.png',
**kwargs):
super(SegSourceCityscapes, self).__init__(
img_suffix=img_suffix, label_suffix=label_suffix, **kwargs)
def __getitem__(self, idx):
result_dict = self.samples_list[idx]
load_success = True
try:
# avoid data cache from taking up too much memory
if not self.cache_at_init and not self.cache_on_the_fly:
result_dict = copy.deepcopy(result_dict)
if not self.cache_at_init:
if result_dict.get('img', None) is None:
img = _load_img(result_dict['filename'], mode='BGR')
result = {
'img': img.astype(np.float32),
'img_shape': img.shape, # h, w, c
'ori_shape': img.shape,
}
result_dict.update(result)
if result_dict.get('gt_semantic_seg', None) is None:
result_dict.update(
load_seg_map_cityscape(
result_dict['seg_filename'],
reduce_zero_label=self.reduce_zero_label))
if self.cache_on_the_fly:
self.samples_list[idx] = result_dict
result_dict = self.post_process_fn(copy.deepcopy(result_dict))
self._retry_count = 0
except Exception as e:
logging.warning(e)
load_success = False
if not load_success:
logging.warning(
'Something wrong with current sample %s,Try load next sample...'
% result_dict.get('filename', ''))
self._retry_count += 1
if self._retry_count >= self._max_retry_num:
raise ValueError('All samples failed to load!')
result_dict = self[(idx + 1) % self.num_samples]
return result_dict
def get_source_iterator(self):
self.img_files = [
os.path.join(self.img_root, i)
for i in io.listdir(self.img_root, recursive=True)
if i.endswith(self.img_suffix[0])
]
self.label_files = []
for img_path in self.img_files:
self.img_root = os.path.join(self.img_root, '')
img_name = img_path.replace(self.img_root,
'')[:-len(self.img_suffix[0])]
find_label_path = False
for label_format in self.label_suffix:
lable_path = os.path.join(self.label_root,
img_name + label_format)
if io.exists(lable_path):
find_label_path = True
self.label_files.append(lable_path)
break
if not find_label_path:
logging.warning(
'Not find label file %s for img: %s, skip the sample!' %
(lable_path, img_path))
self.img_files.remove(img_path)
assert len(self.img_files) == len(self.label_files)
assert len(
self.img_files) > 0, 'No samples found in %s' % self.img_root
return list(zip(self.img_files, self.label_files))

View File

@ -22,6 +22,7 @@ from .resnet import ResNet
from .resnet_jit import ResNetJIT
from .resnext import ResNeXt
from .shuffle_transformer import ShuffleTransformer
from .stdc import STDCContextPathNet, STDCNet
from .swin_transformer import SwinTransformer
from .swin_transformer3d import SwinTransformer3D
from .vision_transformer import VisionTransformer

View File

@ -0,0 +1,465 @@
# Modified from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/backbones/stdc.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from easycv.models import builder
from ..registry import BACKBONES
class AttentionRefinementModule(BaseModule):
"""Attention Refinement Module (ARM) to refine the features of each stage.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
Returns:
x_out (torch.Tensor): Feature map of Attention Refinement Module.
"""
def __init__(self,
in_channels,
out_channel,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super(AttentionRefinementModule, self).__init__(init_cfg=init_cfg)
self.conv_layer = ConvModule(
in_channels=in_channels,
out_channels=out_channel,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.atten_conv_layer = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
ConvModule(
in_channels=out_channel,
out_channels=out_channel,
kernel_size=1,
bias=False,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None), nn.Sigmoid())
def forward(self, x):
x = self.conv_layer(x)
x_atten = self.atten_conv_layer(x)
x_out = x * x_atten
return x_out
class STDCModule(BaseModule):
"""STDCModule.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels before scaling.
stride (int): The number of stride for the first conv layer.
norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (dict): The activation config for conv layers.
num_convs (int): Numbers of conv layers.
fusion_type (str): Type of fusion operation. Default: 'add'.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
stride,
norm_cfg=None,
act_cfg=None,
num_convs=4,
fusion_type='add',
init_cfg=None):
super(STDCModule, self).__init__(init_cfg=init_cfg)
assert num_convs > 1
assert fusion_type in ['add', 'cat']
self.stride = stride
self.with_downsample = True if self.stride == 2 else False
self.fusion_type = fusion_type
self.layers = ModuleList()
conv_0 = ConvModule(
in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg)
if self.with_downsample:
self.downsample = ConvModule(
out_channels // 2,
out_channels // 2,
kernel_size=3,
stride=2,
padding=1,
groups=out_channels // 2,
norm_cfg=norm_cfg,
act_cfg=None)
if self.fusion_type == 'add':
self.layers.append(nn.Sequential(conv_0, self.downsample))
self.skip = Sequential(
ConvModule(
in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=1,
groups=in_channels,
norm_cfg=norm_cfg,
act_cfg=None),
ConvModule(
in_channels,
out_channels,
1,
norm_cfg=norm_cfg,
act_cfg=None))
else:
self.layers.append(conv_0)
self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
else:
self.layers.append(conv_0)
for i in range(1, num_convs):
out_factor = 2**(i + 1) if i != num_convs - 1 else 2**i
self.layers.append(
ConvModule(
out_channels // 2**i,
out_channels // out_factor,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
def forward(self, inputs):
if self.fusion_type == 'add':
out = self.forward_add(inputs)
else:
out = self.forward_cat(inputs)
return out
def forward_add(self, inputs):
layer_outputs = []
x = inputs.clone()
for layer in self.layers:
x = layer(x)
layer_outputs.append(x)
if self.with_downsample:
inputs = self.skip(inputs)
return torch.cat(layer_outputs, dim=1) + inputs
def forward_cat(self, inputs):
x0 = self.layers[0](inputs)
layer_outputs = [x0]
for i, layer in enumerate(self.layers[1:]):
if i == 0:
if self.with_downsample:
x = layer(self.downsample(x0))
else:
x = layer(x0)
else:
x = layer(x)
layer_outputs.append(x)
if self.with_downsample:
layer_outputs[0] = self.skip(x0)
return torch.cat(layer_outputs, dim=1)
class FeatureFusionModule(BaseModule):
"""Feature Fusion Module. This module is different from FeatureFusionModule
in BiSeNetV1. It uses two ConvModules in `self.attention` whose inter
channel number is calculated by given `scale_factor`, while
FeatureFusionModule in BiSeNetV1 only uses one ConvModule in
`self.conv_atten`.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
scale_factor (int): The number of channel scale factor.
Default: 4.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): The activation config for conv layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
scale_factor=4,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super(FeatureFusionModule, self).__init__(init_cfg=init_cfg)
channels = out_channels // scale_factor
self.conv0 = ConvModule(
in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
self.attention = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
ConvModule(
out_channels,
channels,
1,
norm_cfg=None,
bias=False,
act_cfg=act_cfg),
ConvModule(
channels,
out_channels,
1,
norm_cfg=None,
bias=False,
act_cfg=None), nn.Sigmoid())
def forward(self, spatial_inputs, context_inputs):
inputs = torch.cat([spatial_inputs, context_inputs], dim=1)
x = self.conv0(inputs)
attn = self.attention(x)
x_attn = x * attn
return x_attn + x
@BACKBONES.register_module()
class STDCNet(BaseModule):
"""This backbone is the implementation of `Rethinking BiSeNet For Real-time
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.
Args:
stdc_type (int): The type of backbone structure,
`STDCNet1` and`STDCNet2` denotes two main backbones in paper,
whose FLOPs is 813M and 1446M, respectively.
in_channels (int): The num of input_channels.
channels (tuple[int]): The output channels for each stage.
bottleneck_type (str): The type of STDC Module type, the value must
be 'add' or 'cat'.
norm_cfg (dict): Config dict for normalization layer.
act_cfg (dict): The activation config for conv layers.
num_convs (int): Numbers of conv layer at each STDC Module.
Default: 4.
with_final_conv (bool): Whether add a conv layer at the Module output.
Default: True.
pretrained (str, optional): Model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Example:
>>> import torch
>>> stdc_type = 'STDCNet1'
>>> in_channels = 3
>>> channels = (32, 64, 256, 512, 1024)
>>> bottleneck_type = 'cat'
>>> inputs = torch.rand(1, 3, 1024, 2048)
>>> self = STDCNet(stdc_type, in_channels,
... channels, bottleneck_type).eval()
>>> outputs = self.forward(inputs)
>>> for i in range(len(outputs)):
... print(f'outputs[{i}].shape = {outputs[i].shape}')
outputs[0].shape = torch.Size([1, 256, 128, 256])
outputs[1].shape = torch.Size([1, 512, 64, 128])
outputs[2].shape = torch.Size([1, 1024, 32, 64])
"""
arch_settings = {
'STDCNet1': [(2, 1), (2, 1), (2, 1)],
'STDCNet2': [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)]
}
def __init__(self,
stdc_type,
in_channels,
channels,
bottleneck_type,
norm_cfg,
act_cfg,
num_convs=4,
with_final_conv=False,
pretrained=None,
init_cfg=None):
super(STDCNet, self).__init__(init_cfg=init_cfg)
assert stdc_type in self.arch_settings, \
f'invalid structure {stdc_type} for STDCNet.'
assert bottleneck_type in ['add', 'cat'],\
f'bottleneck_type must be `add` or `cat`, got {bottleneck_type}'
assert len(channels) == 5,\
f'invalid channels length {len(channels)} for STDCNet.'
self.in_channels = in_channels
self.channels = channels
self.stage_strides = self.arch_settings[stdc_type]
self.prtrained = pretrained
self.num_convs = num_convs
self.with_final_conv = with_final_conv
self.stages = ModuleList([
ConvModule(
self.in_channels,
self.channels[0],
kernel_size=3,
stride=2,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
self.channels[0],
self.channels[1],
kernel_size=3,
stride=2,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
])
# `self.num_shallow_features` is the number of shallow modules in
# `STDCNet`, which is noted as `Stage1` and `Stage2` in original paper.
# They are both not used for following modules like Attention
# Refinement Module and Feature Fusion Module.
# Thus they would be cut from `outs`. Please refer to Figure 4
# of original paper for more details.
self.num_shallow_features = len(self.stages)
for strides in self.stage_strides:
idx = len(self.stages) - 1
self.stages.append(
self._make_stage(self.channels[idx], self.channels[idx + 1],
strides, norm_cfg, act_cfg, bottleneck_type))
# After appending, `self.stages` is a ModuleList including several
# shallow modules and STDCModules.
# (len(self.stages) ==
# self.num_shallow_features + len(self.stage_strides))
if self.with_final_conv:
self.final_conv = ConvModule(
self.channels[-1],
max(1024, self.channels[-1]),
1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def _make_stage(self, in_channels, out_channels, strides, norm_cfg,
act_cfg, bottleneck_type):
layers = []
for i, stride in enumerate(strides):
layers.append(
STDCModule(
in_channels if i == 0 else out_channels,
out_channels,
stride,
norm_cfg,
act_cfg,
num_convs=self.num_convs,
fusion_type=bottleneck_type))
return Sequential(*layers)
def forward(self, x):
outs = []
for stage in self.stages:
x = stage(x)
outs.append(x)
if self.with_final_conv:
outs[-1] = self.final_conv(outs[-1])
outs = outs[self.num_shallow_features:]
return tuple(outs)
@BACKBONES.register_module()
class STDCContextPathNet(BaseModule):
"""STDCNet with Context Path. The `outs` below is a list of three feature
maps from deep to shallow, whose height and width is from small to big,
respectively. The biggest feature map of `outs` is outputted for
`STDCHead`, where Detail Loss would be calculated by Detail Ground-truth.
The other two feature maps are used for Attention Refinement Module,
respectively. Besides, the biggest feature map of `outs` and the last
output of Attention Refinement Module are concatenated for Feature Fusion
Module. Then, this fusion feature map `feat_fuse` would be outputted for
`decode_head`. More details please refer to Figure 4 of original paper.
Args:
backbone_cfg (dict): Config dict for stdc backbone.
last_in_channels (tuple(int)), The number of channels of last
two feature maps from stdc backbone. Default: (1024, 512).
out_channels (int): The channels of output feature maps.
Default: 128.
ffm_cfg (dict): Config dict for Feature Fusion Module. Default:
`dict(in_channels=512, out_channels=256, scale_factor=4)`.
upsample_mode (str): Algorithm used for upsampling:
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
``'trilinear'``. Default: ``'nearest'``.
align_corners (str): align_corners argument of F.interpolate. It
must be `None` if upsample_mode is ``'nearest'``. Default: None.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Return:
outputs (tuple): The tuple of list of output feature map for
auxiliary heads and decoder head.
"""
def __init__(self,
backbone_cfg,
last_in_channels=(1024, 512),
out_channels=128,
ffm_cfg=dict(
in_channels=512, out_channels=256, scale_factor=4),
upsample_mode='nearest',
align_corners=None,
norm_cfg=dict(type='BN'),
init_cfg=None):
super(STDCContextPathNet, self).__init__(init_cfg=init_cfg)
self.backbone = builder.build_backbone(backbone_cfg)
self.arms = ModuleList()
self.convs = ModuleList()
for channels in last_in_channels:
self.arms.append(AttentionRefinementModule(channels, out_channels))
self.convs.append(
ConvModule(
out_channels,
out_channels,
3,
padding=1,
norm_cfg=norm_cfg))
self.conv_avg = ConvModule(
last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg)
self.ffm = FeatureFusionModule(**ffm_cfg)
self.upsample_mode = upsample_mode
self.align_corners = align_corners
def forward(self, x):
outs = list(self.backbone(x))
avg = F.adaptive_avg_pool2d(outs[-1], 1)
avg_feat = self.conv_avg(avg)
feature_up = F.interpolate(
avg_feat,
size=outs[-1].shape[2:],
mode=self.upsample_mode,
align_corners=self.align_corners)
arms_out = []
for i in range(len(self.arms)):
x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up
feature_up = F.interpolate(
x_arm,
size=outs[len(outs) - 1 - i - 1].shape[2:],
mode=self.upsample_mode,
align_corners=self.align_corners)
feature_up = self.convs[i](feature_up)
arms_out.append(feature_up)
feat_fuse = self.ffm(outs[0], arms_out[1])
# The `outputs` has four feature maps.
# `outs[0]` is outputted for `STDCHead` auxiliary head.
# Two feature maps of `arms_out` are outputted for auxiliary head.
# `feat_fuse` is outputted for decoder head.
outputs = [outs[0]] + list(arms_out) + [feat_fuse]
return tuple(outputs)

View File

@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .cross_entropy_loss import CrossEntropyLoss
from .det_db_loss import DBLoss
from .dice_loss import DiceLoss
from .face_keypoint_loss import FacePoseLoss, WingLossWithPose
from .focal_loss import FocalLoss, VarifocalLoss
from .iou_loss import GIoULoss, IoULoss, YOLOX_IOULoss
@ -22,5 +23,6 @@ __all__ = [
'FocalLoss2d', 'DistributeMSELoss', 'CrossEntropyLossWithLabelSmooth',
'AMSoftmaxLoss', 'ModelParallelSoftmaxLoss', 'ModelParallelAMSoftmaxLoss',
'SoftTargetCrossEntropy', 'CDNCriterion', 'DNCriterion', 'DBLoss',
'HungarianMatcher', 'SetCriterion', 'L1Loss', 'MultiLoss', 'SmoothL1Loss'
'HungarianMatcher', 'SetCriterion', 'L1Loss', 'MultiLoss', 'SmoothL1Loss',
'DiceLoss'
]

View File

@ -0,0 +1,135 @@
# Borrowed from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/losses/dice_loss.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from easycv.models.builder import LOSSES
from .utils import get_class_weight, weighted_loss
@weighted_loss
def dice_loss(pred,
target,
valid_mask,
smooth=1,
exponent=2,
class_weight=None,
ignore_index=255):
assert pred.shape[0] == target.shape[0]
total_loss = 0
num_classes = pred.shape[1]
for i in range(num_classes):
if i != ignore_index:
dice_loss = binary_dice_loss(
pred[:, i],
target[..., i],
valid_mask=valid_mask,
smooth=smooth,
exponent=exponent)
if class_weight is not None:
dice_loss *= class_weight[i]
total_loss += dice_loss
return total_loss / num_classes
@weighted_loss
def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwargs):
assert pred.shape[0] == target.shape[0]
pred = pred.reshape(pred.shape[0], -1)
target = target.reshape(target.shape[0], -1)
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth
den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth
return 1 - num / den
@LOSSES.register_module()
class DiceLoss(nn.Module):
"""DiceLoss.
This loss is proposed in `V-Net: Fully Convolutional Neural Networks for
Volumetric Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
Args:
smooth (float): A float number to smooth loss, and avoid NaN error.
Default: 1
exponent (float): An float number to calculate denominator
value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2.
reduction (str, optional): The method used to reduce the loss. Options
are "none", "mean" and "sum". This parameter only works when
per_image is True. Default: 'mean'.
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Default to 1.0.
ignore_index (int | None): The label index to be ignored. Default: 255.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_dice'.
"""
def __init__(self,
smooth=1,
exponent=2,
reduction='mean',
class_weight=None,
loss_weight=1.0,
ignore_index=255,
loss_name='loss_dice',
**kwargs):
super(DiceLoss, self).__init__()
self.smooth = smooth
self.exponent = exponent
self.reduction = reduction
self.class_weight = get_class_weight(class_weight)
self.loss_weight = loss_weight
self.ignore_index = ignore_index
self._loss_name = loss_name
def forward(self,
pred,
target,
avg_factor=None,
reduction_override=None,
**kwargs):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.class_weight is not None:
class_weight = pred.new_tensor(self.class_weight)
else:
class_weight = None
pred = F.softmax(pred, dim=1)
num_classes = pred.shape[1]
one_hot_target = F.one_hot(
torch.clamp(target.long(), 0, num_classes - 1),
num_classes=num_classes)
valid_mask = (target != self.ignore_index).long()
loss = self.loss_weight * dice_loss(
pred,
one_hot_target,
valid_mask=valid_mask,
reduction=reduction,
avg_factor=avg_factor,
smooth=self.smooth,
exponent=self.exponent,
class_weight=class_weight,
ignore_index=self.ignore_index)
return loss
@property
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name

View File

@ -1,12 +1,32 @@
# Copyright (c) OpenMMLab. All rights reserved.
import functools
import mmcv
import numpy as np
import torch
import torch.nn.functional as F
from easycv.framework.errors import ValueError
def get_class_weight(class_weight):
"""Get class weight for loss function.
Args:
class_weight (list[float] | str | None): If class_weight is a str,
take it as a file name and read from it.
"""
if isinstance(class_weight, str):
# take it as a file path
if class_weight.endswith('.npy'):
class_weight = np.load(class_weight)
else:
# pkl, json or yaml
class_weight = mmcv.load(class_weight)
return class_weight
def reduce_loss(loss, reduction):
"""Reduce loss as specified.

View File

@ -2,3 +2,4 @@
from .encoder_decoder import EncoderDecoder
from .heads import *
from .mask2former import Mask2Former
from .sampler import BasePixelSampler, OHEMPixelSampler, build_pixel_sampler

View File

@ -157,7 +157,6 @@ class EncoderDecoder(BaseModel):
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
x = self.extract_feat(img)
losses = dict()
loss_decode = self._decode_head_forward_train(x, img_metas,

View File

@ -2,6 +2,9 @@
from .fcn_head import FCNHead
from .mask2former_head import Mask2FormerHead
from .segformer_head import SegformerHead
from .stdc_head import STDCHead
from .uper_head import UPerHead
__all__ = ['FCNHead', 'UPerHead', 'Mask2FormerHead', 'SegformerHead']
__all__ = [
'FCNHead', 'UPerHead', 'Mask2FormerHead', 'SegformerHead', 'STDCHead'
]

View File

@ -11,6 +11,7 @@ from easycv.framework.errors import TypeError
from easycv.models.builder import build_loss
from easycv.models.utils.ops import resize_tensor
from easycv.utils.logger import print_log
from ..sampler import build_pixel_sampler
# Modified from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/decode_head.py
@ -69,6 +70,7 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
sampler=None,
ignore_index=255,
align_corners=False,
init_cfg=dict(
@ -97,6 +99,11 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
but got {type(loss_decode)}')
if sampler is not None:
self.sampler = build_pixel_sampler(sampler, context=self)
else:
self.sampler = None
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
if dropout_ratio > 0:
self.dropout = nn.Dropout2d(dropout_ratio)
@ -232,7 +239,10 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
size=seg_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
if self.sampler is not None:
seg_weight = self.sampler.sample(seg_logit, seg_label)
else:
seg_weight = None
seg_label = seg_label.squeeze(1)
if not isinstance(self.loss_decode, nn.ModuleList):

View File

@ -0,0 +1,85 @@
# Modified from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/stdc_head.py
import torch
import torch.nn.functional as F
from easycv.models.builder import HEADS
from .fcn_head import FCNHead
@HEADS.register_module()
class STDCHead(FCNHead):
"""This head is the implementation of `Rethinking BiSeNet For Real-time
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.
Args:
boundary_threshold (float): The threshold of calculating boundary.
Default: 0.1.
"""
def __init__(self, boundary_threshold=0.1, **kwargs):
super(STDCHead, self).__init__(**kwargs)
self.boundary_threshold = boundary_threshold
# Using register buffer to make laplacian kernel on the same
# device of `seg_label`.
self.register_buffer(
'laplacian_kernel',
torch.tensor([-1, -1, -1, -1, 8, -1, -1, -1, -1],
dtype=torch.float32,
requires_grad=False).reshape((1, 1, 3, 3)))
self.fusion_kernel = torch.nn.Parameter(
torch.tensor([[6. / 10], [3. / 10], [1. / 10]],
dtype=torch.float32).reshape(1, 3, 1, 1),
requires_grad=False)
def losses(self, seg_logit, seg_label):
"""Compute Detail Aggregation Loss."""
# Note: The paper claims `fusion_kernel` is a trainable 1x1 conv
# parameters. However, it is a constant in original repo and other
# codebase because it would not be added into computation graph
# after threshold operation.
seg_label = seg_label.to(self.laplacian_kernel)
boundary_targets = F.conv2d(
seg_label, self.laplacian_kernel, padding=1)
boundary_targets = boundary_targets.clamp(min=0)
boundary_targets[boundary_targets > self.boundary_threshold] = 1
boundary_targets[boundary_targets <= self.boundary_threshold] = 0
boundary_targets_x2 = F.conv2d(
seg_label, self.laplacian_kernel, stride=2, padding=1)
boundary_targets_x2 = boundary_targets_x2.clamp(min=0)
boundary_targets_x4 = F.conv2d(
seg_label, self.laplacian_kernel, stride=4, padding=1)
boundary_targets_x4 = boundary_targets_x4.clamp(min=0)
boundary_targets_x4_up = F.interpolate(
boundary_targets_x4, boundary_targets.shape[2:], mode='nearest')
boundary_targets_x2_up = F.interpolate(
boundary_targets_x2, boundary_targets.shape[2:], mode='nearest')
boundary_targets_x2_up[
boundary_targets_x2_up > self.boundary_threshold] = 1
boundary_targets_x2_up[
boundary_targets_x2_up <= self.boundary_threshold] = 0
boundary_targets_x4_up[
boundary_targets_x4_up > self.boundary_threshold] = 1
boundary_targets_x4_up[
boundary_targets_x4_up <= self.boundary_threshold] = 0
boundary_targets_pyramids = torch.stack(
(boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up),
dim=1)
boundary_targets_pyramids = boundary_targets_pyramids.squeeze(2)
boundary_targets_pyramid = F.conv2d(boundary_targets_pyramids,
self.fusion_kernel)
boundary_targets_pyramid[
boundary_targets_pyramid > self.boundary_threshold] = 1
boundary_targets_pyramid[
boundary_targets_pyramid <= self.boundary_threshold] = 0
loss = super(STDCHead, self).losses(seg_logit,
boundary_targets_pyramid.long())
return loss

View File

@ -0,0 +1,5 @@
from .base_pixel_sampler import BasePixelSampler
from .builder import build_pixel_sampler
from .ohem_pixel_sampler import OHEMPixelSampler
__all__ = ['BasePixelSampler', 'OHEMPixelSampler', 'build_pixel_sampler']

View File

@ -0,0 +1,13 @@
# Borrowed from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg
from abc import ABCMeta, abstractmethod
class BasePixelSampler(metaclass=ABCMeta):
"""Base class of pixel sampler."""
def __init__(self, **kwargs):
pass
@abstractmethod
def sample(self, seg_logit, seg_label):
"""Placeholder for sample function."""

View File

@ -0,0 +1,9 @@
# Borrowed from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg
from mmcv.utils import Registry, build_from_cfg
PIXEL_SAMPLERS = Registry('pixel sampler')
def build_pixel_sampler(cfg, **default_args):
"""Build pixel sampler for segmentation map."""
return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args)

View File

@ -0,0 +1,85 @@
# Borrowed from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base_pixel_sampler import BasePixelSampler
from .builder import PIXEL_SAMPLERS
@PIXEL_SAMPLERS.register_module()
class OHEMPixelSampler(BasePixelSampler):
"""Online Hard Example Mining Sampler for segmentation.
Args:
context (nn.Module): The context of sampler, subclass of
:obj:`BaseDecodeHead`.
thresh (float, optional): The threshold for hard example selection.
Below which, are prediction with low confidence. If not
specified, the hard examples will be pixels of top ``min_kept``
loss. Default: None.
min_kept (int, optional): The minimum number of predictions to keep.
Default: 100000.
"""
def __init__(self, context, thresh=None, min_kept=100000):
super(OHEMPixelSampler, self).__init__()
self.context = context
assert min_kept > 1
self.thresh = thresh
self.min_kept = min_kept
def sample(self, seg_logit, seg_label):
"""Sample pixels that have high loss or with low prediction confidence.
Args:
seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)
seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W)
Returns:
torch.Tensor: segmentation weight, shape (N, H, W)
"""
with torch.no_grad():
assert seg_logit.shape[2:] == seg_label.shape[2:]
assert seg_label.shape[1] == 1
seg_label = seg_label.squeeze(1).long()
batch_kept = self.min_kept * seg_label.size(0)
valid_mask = seg_label != self.context.ignore_index
seg_weight = seg_logit.new_zeros(size=seg_label.size())
valid_seg_weight = seg_weight[valid_mask]
if self.thresh is not None:
seg_prob = F.softmax(seg_logit, dim=1)
tmp_seg_label = seg_label.clone().unsqueeze(1)
tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0
seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)
sort_prob, sort_indices = seg_prob[valid_mask].sort()
if sort_prob.numel() > 0:
min_threshold = sort_prob[min(batch_kept,
sort_prob.numel() - 1)]
else:
min_threshold = 0.0
threshold = max(min_threshold, self.thresh)
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
else:
if not isinstance(self.context.loss_decode, nn.ModuleList):
losses_decode = [self.context.loss_decode]
else:
losses_decode = self.context.loss_decode
losses = 0.0
for loss_module in losses_decode:
losses += loss_module(
seg_logit,
seg_label,
weight=None,
ignore_index=self.context.ignore_index,
reduction_override='none')
# faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
_, sort_indices = losses[valid_mask].sort(descending=True)
valid_seg_weight[sort_indices[:batch_kept]] = 1.
seg_weight[valid_mask] = valid_seg_weight
return seg_weight

View File

@ -29,7 +29,7 @@ class SegmentationPredictor(PredictorV2):
def __init__(self,
model_path,
config_file,
config_file=None,
batch_size=1,
device=None,
save_results=False,
@ -54,7 +54,7 @@ class SegmentationPredictor(PredictorV2):
**kwargs)
self.CLASSES = self.cfg.CLASSES
self.PALETTE = self.cfg.PALETTE
self.PALETTE = self.cfg.get('PALETTE', None)
def show_result(self,
img,
@ -90,7 +90,8 @@ class SegmentationPredictor(PredictorV2):
img = mmcv.imread(img)
img = img.copy()
seg = result[0]
# seg = result[0]
seg = result
if palette is None:
if self.PALETTE is None:
# Get random state before set seed,

View File

@ -1,4 +1,5 @@
albumentations
cityscapesscripts
dataclasses
decord
einops

View File

@ -0,0 +1,38 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import random
import unittest
from tests.ut_config import SEG_DATA_SAMLL_CITYSCAPES
from easycv.datasets.segmentation.data_sources.cityscapes import \
SegSourceCityscapes
CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
'bicycle')
class SegSourceCityscapesTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_cityscapes(self):
data_source = SegSourceCityscapes(
img_root=os.path.join(SEG_DATA_SAMLL_CITYSCAPES, 'leftImg8bit'),
label_root=os.path.join(SEG_DATA_SAMLL_CITYSCAPES, 'gtFine'),
classes=CLASSES,
)
index_list = random.choices(list(range(20)), k=3)
for idx in index_list:
data = data_source[idx]
self.assertIn('img_fields', data)
self.assertIn('seg_fields', data)
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,126 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
import torch
from easycv.models import build_model
class StdcTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_stdc(self):
norm_cfg = dict(type='BN', requires_grad=True)
model_cfg = dict(
type='EncoderDecoder',
backbone=dict(
type='STDCContextPathNet',
backbone_cfg=dict(
type='STDCNet',
stdc_type='STDCNet1',
in_channels=3,
channels=(32, 64, 256, 512, 1024),
bottleneck_type='cat',
num_convs=4,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'),
with_final_conv=False),
last_in_channels=(1024, 512),
out_channels=128,
ffm_cfg=dict(
in_channels=384, out_channels=256, scale_factor=4)),
decode_head=dict(
type='FCNHead',
in_channels=256,
channels=256,
num_convs=1,
num_classes=19,
in_index=3,
concat_input=False,
dropout_ratio=0.1,
norm_cfg=norm_cfg,
align_corners=True,
sampler=dict(
type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0)),
auxiliary_head=[
dict(
type='FCNHead',
in_channels=128,
channels=64,
num_convs=1,
num_classes=19,
in_index=2,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
sampler=dict(
type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0)),
dict(
type='FCNHead',
in_channels=128,
channels=64,
num_convs=1,
num_classes=19,
in_index=1,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
sampler=dict(
type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0)),
dict(
type='STDCHead',
in_channels=256,
channels=64,
num_convs=1,
num_classes=2,
boundary_threshold=0.1,
in_index=0,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=True,
loss_decode=[
dict(
type='CrossEntropyLoss',
loss_name='loss_ce',
use_sigmoid=True,
loss_weight=1.0),
dict(
type='DiceLoss',
loss_name='loss_dice',
loss_weight=1.0)
]),
],
train_cfg=dict(),
test_cfg=dict(mode='whole'),
)
model = build_model(model_cfg).to('cuda')
img = torch.rand(2, 3, 512, 1024).to('cuda')
gt_semantic_seg = torch.randint(
low=0, high=18, size=(2, 1, 512, 1024)).to('cuda')
train_output = model.forward_train(img, [], gt_semantic_seg)
self.assertIn('decode.loss_ce', train_output)
self.assertIn('aux_0.loss_ce', train_output)
self.assertIn('aux_1.loss_ce', train_output)
self.assertIn('aux_2.loss_ce', train_output)
if __name__ == '__main__':
unittest.main()

View File

@ -159,6 +159,8 @@ SEG_DATA_SMALL_COCO_STUFF_10K = os.path.join(
BASE_LOCAL_PATH, 'data/segmentation/small_coco_stuff/small_coco_stuff10k')
SEG_DATA_SAMLL_COCO_STUFF_164K = os.path.join(
BASE_LOCAL_PATH, 'data/segmentation/small_coco_stuff/small_coco_stuff164k')
SEG_DATA_SAMLL_CITYSCAPES = os.path.join(BASE_LOCAL_PATH,
'data/segmentation/small_cityscapes')
# OCR data
SMALL_OCR_CLS_DATA = os.path.join(BASE_LOCAL_PATH, 'data/ocr/small_ocr_cls')