mirror of https://github.com/alibaba/EasyCV.git
parent
2fe73eee91
commit
26cd12ab42
|
@ -0,0 +1,198 @@
|
||||||
|
_base_ = ['configs/base.py']
|
||||||
|
|
||||||
|
# warning batch_size need >= 2
|
||||||
|
# model
|
||||||
|
norm_cfg = dict(type='BN', requires_grad=True)
|
||||||
|
model = dict(
|
||||||
|
type='EncoderDecoder',
|
||||||
|
backbone=dict(
|
||||||
|
type='STDCContextPathNet',
|
||||||
|
backbone_cfg=dict(
|
||||||
|
type='STDCNet',
|
||||||
|
stdc_type='STDCNet1',
|
||||||
|
in_channels=3,
|
||||||
|
channels=(32, 64, 256, 512, 1024),
|
||||||
|
bottleneck_type='cat',
|
||||||
|
num_convs=4,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=dict(type='ReLU'),
|
||||||
|
with_final_conv=False),
|
||||||
|
last_in_channels=(1024, 512),
|
||||||
|
out_channels=128,
|
||||||
|
ffm_cfg=dict(in_channels=384, out_channels=256, scale_factor=4)),
|
||||||
|
decode_head=dict(
|
||||||
|
type='FCNHead',
|
||||||
|
in_channels=256,
|
||||||
|
channels=256,
|
||||||
|
num_convs=1,
|
||||||
|
num_classes=19,
|
||||||
|
in_index=3,
|
||||||
|
concat_input=False,
|
||||||
|
dropout_ratio=0.1,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
align_corners=True,
|
||||||
|
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||||
|
auxiliary_head=[
|
||||||
|
dict(
|
||||||
|
type='FCNHead',
|
||||||
|
in_channels=128,
|
||||||
|
channels=64,
|
||||||
|
num_convs=1,
|
||||||
|
num_classes=19,
|
||||||
|
in_index=2,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
concat_input=False,
|
||||||
|
align_corners=False,
|
||||||
|
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||||
|
dict(
|
||||||
|
type='FCNHead',
|
||||||
|
in_channels=128,
|
||||||
|
channels=64,
|
||||||
|
num_convs=1,
|
||||||
|
num_classes=19,
|
||||||
|
in_index=1,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
concat_input=False,
|
||||||
|
align_corners=False,
|
||||||
|
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||||
|
dict(
|
||||||
|
type='STDCHead',
|
||||||
|
in_channels=256,
|
||||||
|
channels=64,
|
||||||
|
num_convs=1,
|
||||||
|
num_classes=2,
|
||||||
|
boundary_threshold=0.1,
|
||||||
|
in_index=0,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
concat_input=False,
|
||||||
|
align_corners=True,
|
||||||
|
loss_decode=[
|
||||||
|
dict(
|
||||||
|
type='CrossEntropyLoss',
|
||||||
|
loss_name='loss_ce',
|
||||||
|
use_sigmoid=True,
|
||||||
|
loss_weight=1.0),
|
||||||
|
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=1.0)
|
||||||
|
]),
|
||||||
|
],
|
||||||
|
train_cfg=dict(),
|
||||||
|
test_cfg=dict(mode='whole'),
|
||||||
|
pretrained=
|
||||||
|
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/stdc/pretrain/stdc1_easycv.pth'
|
||||||
|
)
|
||||||
|
|
||||||
|
# dataset
|
||||||
|
CLASSES = [
|
||||||
|
'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light',
|
||||||
|
'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
|
||||||
|
'truck', 'bus', 'train', 'motorcycle', 'bicycle'
|
||||||
|
]
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||||
|
crop_size = (512, 1024)
|
||||||
|
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='MMResize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
|
||||||
|
dict(type='SegRandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||||
|
dict(type='MMRandomFlip', flip_ratio=0.5),
|
||||||
|
dict(type='MMPhotoMetricDistortion'),
|
||||||
|
dict(type='MMNormalize', **img_norm_cfg),
|
||||||
|
dict(type='MMPad', size=crop_size),
|
||||||
|
dict(type='DefaultFormatBundle'),
|
||||||
|
dict(
|
||||||
|
type='Collect',
|
||||||
|
keys=['img', 'gt_semantic_seg'],
|
||||||
|
meta_keys=('filename', 'ori_filename', 'ori_shape', 'img_shape',
|
||||||
|
'pad_shape', 'scale_factor', 'flip', 'flip_direction',
|
||||||
|
'img_norm_cfg')),
|
||||||
|
]
|
||||||
|
|
||||||
|
test_pipeline = [
|
||||||
|
dict(
|
||||||
|
type='MMMultiScaleFlipAug',
|
||||||
|
img_scale=(2048, 1024),
|
||||||
|
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||||
|
flip=False,
|
||||||
|
transforms=[
|
||||||
|
dict(type='MMResize', keep_ratio=True),
|
||||||
|
dict(type='MMRandomFlip'),
|
||||||
|
dict(type='MMNormalize', **img_norm_cfg),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(
|
||||||
|
type='Collect',
|
||||||
|
keys=['img'],
|
||||||
|
meta_keys=('filename', 'ori_filename', 'ori_shape',
|
||||||
|
'img_shape', 'pad_shape', 'scale_factor', 'flip',
|
||||||
|
'flip_direction', 'img_norm_cfg')),
|
||||||
|
])
|
||||||
|
]
|
||||||
|
dataset_type = 'SegDataset'
|
||||||
|
data_root = '../Cityscapes/'
|
||||||
|
|
||||||
|
train_img_root = data_root + 'leftImg8bit/train/'
|
||||||
|
train_label_root = data_root + 'gtFine/train/'
|
||||||
|
|
||||||
|
val_img_root = data_root + 'leftImg8bit/val/'
|
||||||
|
val_label_root = data_root + 'gtFine/val/'
|
||||||
|
data = dict(
|
||||||
|
imgs_per_gpu=6,
|
||||||
|
workers_per_gpu=4,
|
||||||
|
persistent_workers=True,
|
||||||
|
train=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
ignore_index=255,
|
||||||
|
data_source=dict(
|
||||||
|
type='SegSourceCityscapes',
|
||||||
|
img_root=train_img_root,
|
||||||
|
label_root=train_label_root,
|
||||||
|
classes=CLASSES),
|
||||||
|
pipeline=train_pipeline),
|
||||||
|
val=dict(
|
||||||
|
imgs_per_gpu=1,
|
||||||
|
ignore_index=255,
|
||||||
|
type=dataset_type,
|
||||||
|
data_source=dict(
|
||||||
|
type='SegSourceCityscapes',
|
||||||
|
img_root=val_img_root,
|
||||||
|
label_root=val_label_root,
|
||||||
|
classes=CLASSES),
|
||||||
|
pipeline=test_pipeline))
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||||
|
optimizer_config = dict()
|
||||||
|
# learning policy
|
||||||
|
lr_config = dict(
|
||||||
|
policy='CosineAnnealing',
|
||||||
|
min_lr=1e-4,
|
||||||
|
warmup='linear',
|
||||||
|
warmup_iters=10,
|
||||||
|
warmup_ratio=0.0001,
|
||||||
|
warmup_by_epoch=True,
|
||||||
|
by_epoch=False)
|
||||||
|
|
||||||
|
# runtime settings
|
||||||
|
total_epochs = 1290
|
||||||
|
checkpoint_config = dict(interval=10)
|
||||||
|
eval_config = dict(interval=10, gpu_collect=False)
|
||||||
|
eval_pipelines = [
|
||||||
|
dict(
|
||||||
|
mode='test',
|
||||||
|
evaluators=[
|
||||||
|
dict(
|
||||||
|
type='SegmentationEvaluator',
|
||||||
|
classes=CLASSES,
|
||||||
|
metric_names=['mIoU'])
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# export config
|
||||||
|
export = dict(export_neck=True)
|
||||||
|
checkpoint_sync_export = True
|
|
@ -0,0 +1,7 @@
|
||||||
|
_base_ = ['configs/segmentation/stdc/stdc1_cityscape_8xb6_e1290.py']
|
||||||
|
|
||||||
|
model = dict(
|
||||||
|
backbone=dict(backbone_cfg=dict(stdc_type='STDCNet2')),
|
||||||
|
pretrained=
|
||||||
|
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/stdc/pretrain/stdc2_easycv.pth'
|
||||||
|
)
|
|
@ -15,6 +15,12 @@ Pretrained on **Pascal VOC 2012 + Aug**.
|
||||||
| ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
| ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||||
| upernet_r50 | [upernet_r50_512x512_8xb4_60e_voc12aug](https://github.com/alibaba/EasyCV/tree/master/configs/segmentation/upernet/upernet_r50_512x512_8xb4_60e_voc12aug.py) | 23M/66M | 5.5 | 282.9ms | 76.59 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/upernet_r50/epoch_60.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/upernet_r50/20220706_114712.log.json) |
|
| upernet_r50 | [upernet_r50_512x512_8xb4_60e_voc12aug](https://github.com/alibaba/EasyCV/tree/master/configs/segmentation/upernet/upernet_r50_512x512_8xb4_60e_voc12aug.py) | 23M/66M | 5.5 | 282.9ms | 76.59 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/upernet_r50/epoch_60.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/upernet_r50/20220706_114712.log.json) |
|
||||||
|
|
||||||
|
## STDC
|
||||||
|
trained on **Cityscapes**.
|
||||||
|
| Algorithm | Config | Params<br/>(backbone/total) | Train memory<br/>(GB) | inference time(V100)<br/>(ms/img) | mIoU | Download |
|
||||||
|
| ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||||
|
| STDC1 | [stdc1_cityscape_8xb6_e1290](https://github.com/alibaba/EasyCV/tree/master/configs/segmentation/stdc/stdc1_cityscape_8xb6_e1290.py) | 7.7M/8.5M | 4.5 | 11.9ms | 75.4 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/stdc/stdc1_cityscapes/epoch_1250.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/stdc/stdc1_cityscapes/20230214_173123.log.json) |
|
||||||
|
|
||||||
## Mask2former
|
## Mask2former
|
||||||
|
|
||||||
### Instance Segmentation on COCO
|
### Instance Segmentation on COCO
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
from .cityscapes import SegSourceCityscapes
|
||||||
from .coco import SegSourceCoco, SegSourceCoco2017
|
from .coco import SegSourceCoco, SegSourceCoco2017
|
||||||
from .coco_stuff import SegSourceCocoStuff10k, SegSourceCocoStuff164k
|
from .coco_stuff import SegSourceCocoStuff10k, SegSourceCocoStuff164k
|
||||||
from .raw import SegSourceRaw
|
from .raw import SegSourceRaw
|
||||||
|
@ -7,5 +8,5 @@ from .voc import SegSourceVoc2007, SegSourceVoc2010, SegSourceVoc2012
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'SegSourceRaw', 'SegSourceVoc2010', 'SegSourceVoc2007', 'SegSourceVoc2012',
|
'SegSourceRaw', 'SegSourceVoc2010', 'SegSourceVoc2007', 'SegSourceVoc2012',
|
||||||
'SegSourceCoco', 'SegSourceCoco2017', 'SegSourceCocoStuff164k',
|
'SegSourceCoco', 'SegSourceCoco2017', 'SegSourceCocoStuff164k',
|
||||||
'SegSourceCocoStuff10k'
|
'SegSourceCocoStuff10k', 'SegSourceCityscapes'
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,129 @@
|
||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from easycv.datasets.registry import DATASOURCES
|
||||||
|
from easycv.file import io
|
||||||
|
from easycv.file.image import load_image as _load_img
|
||||||
|
from .raw import SegSourceRaw
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cityscapesscripts.helpers.labels as CSLabels
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
res = subprocess.call('pip install cityscapesscripts', shell=True)
|
||||||
|
if res != 0:
|
||||||
|
info_string = (
|
||||||
|
'\n\nAuto install failed! Please install cityscapesscripts with the following commands :\n'
|
||||||
|
'\t`pip install cityscapesscripts`\n')
|
||||||
|
raise ModuleNotFoundError(info_string)
|
||||||
|
|
||||||
|
|
||||||
|
def load_seg_map_cityscape(seg_path, reduce_zero_label):
|
||||||
|
gt_semantic_seg = _load_img(seg_path, mode='P')
|
||||||
|
gt_semantic_seg_copy = gt_semantic_seg.copy()
|
||||||
|
for labels in CSLabels.labels:
|
||||||
|
gt_semantic_seg_copy[gt_semantic_seg == labels.id] = labels.trainId
|
||||||
|
|
||||||
|
return {'gt_semantic_seg': gt_semantic_seg_copy}
|
||||||
|
|
||||||
|
|
||||||
|
@DATASOURCES.register_module
|
||||||
|
class SegSourceCityscapes(SegSourceRaw):
|
||||||
|
"""Cityscapes datasource
|
||||||
|
"""
|
||||||
|
CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
|
||||||
|
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
|
||||||
|
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
|
||||||
|
'bicycle')
|
||||||
|
|
||||||
|
PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
|
||||||
|
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
|
||||||
|
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
|
||||||
|
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
|
||||||
|
[0, 80, 100], [0, 0, 230], [119, 11, 32]]
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
img_suffix='_leftImg8bit.png',
|
||||||
|
label_suffix='_gtFine_labelIds.png',
|
||||||
|
**kwargs):
|
||||||
|
super(SegSourceCityscapes, self).__init__(
|
||||||
|
img_suffix=img_suffix, label_suffix=label_suffix, **kwargs)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
result_dict = self.samples_list[idx]
|
||||||
|
load_success = True
|
||||||
|
try:
|
||||||
|
# avoid data cache from taking up too much memory
|
||||||
|
if not self.cache_at_init and not self.cache_on_the_fly:
|
||||||
|
result_dict = copy.deepcopy(result_dict)
|
||||||
|
|
||||||
|
if not self.cache_at_init:
|
||||||
|
if result_dict.get('img', None) is None:
|
||||||
|
img = _load_img(result_dict['filename'], mode='BGR')
|
||||||
|
result = {
|
||||||
|
'img': img.astype(np.float32),
|
||||||
|
'img_shape': img.shape, # h, w, c
|
||||||
|
'ori_shape': img.shape,
|
||||||
|
}
|
||||||
|
result_dict.update(result)
|
||||||
|
if result_dict.get('gt_semantic_seg', None) is None:
|
||||||
|
result_dict.update(
|
||||||
|
load_seg_map_cityscape(
|
||||||
|
result_dict['seg_filename'],
|
||||||
|
reduce_zero_label=self.reduce_zero_label))
|
||||||
|
if self.cache_on_the_fly:
|
||||||
|
self.samples_list[idx] = result_dict
|
||||||
|
result_dict = self.post_process_fn(copy.deepcopy(result_dict))
|
||||||
|
self._retry_count = 0
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(e)
|
||||||
|
load_success = False
|
||||||
|
|
||||||
|
if not load_success:
|
||||||
|
logging.warning(
|
||||||
|
'Something wrong with current sample %s,Try load next sample...'
|
||||||
|
% result_dict.get('filename', ''))
|
||||||
|
self._retry_count += 1
|
||||||
|
if self._retry_count >= self._max_retry_num:
|
||||||
|
raise ValueError('All samples failed to load!')
|
||||||
|
|
||||||
|
result_dict = self[(idx + 1) % self.num_samples]
|
||||||
|
|
||||||
|
return result_dict
|
||||||
|
|
||||||
|
def get_source_iterator(self):
|
||||||
|
|
||||||
|
self.img_files = [
|
||||||
|
os.path.join(self.img_root, i)
|
||||||
|
for i in io.listdir(self.img_root, recursive=True)
|
||||||
|
if i.endswith(self.img_suffix[0])
|
||||||
|
]
|
||||||
|
|
||||||
|
self.label_files = []
|
||||||
|
for img_path in self.img_files:
|
||||||
|
self.img_root = os.path.join(self.img_root, '')
|
||||||
|
img_name = img_path.replace(self.img_root,
|
||||||
|
'')[:-len(self.img_suffix[0])]
|
||||||
|
find_label_path = False
|
||||||
|
for label_format in self.label_suffix:
|
||||||
|
lable_path = os.path.join(self.label_root,
|
||||||
|
img_name + label_format)
|
||||||
|
if io.exists(lable_path):
|
||||||
|
find_label_path = True
|
||||||
|
self.label_files.append(lable_path)
|
||||||
|
break
|
||||||
|
if not find_label_path:
|
||||||
|
logging.warning(
|
||||||
|
'Not find label file %s for img: %s, skip the sample!' %
|
||||||
|
(lable_path, img_path))
|
||||||
|
self.img_files.remove(img_path)
|
||||||
|
|
||||||
|
assert len(self.img_files) == len(self.label_files)
|
||||||
|
assert len(
|
||||||
|
self.img_files) > 0, 'No samples found in %s' % self.img_root
|
||||||
|
|
||||||
|
return list(zip(self.img_files, self.label_files))
|
|
@ -22,6 +22,7 @@ from .resnet import ResNet
|
||||||
from .resnet_jit import ResNetJIT
|
from .resnet_jit import ResNetJIT
|
||||||
from .resnext import ResNeXt
|
from .resnext import ResNeXt
|
||||||
from .shuffle_transformer import ShuffleTransformer
|
from .shuffle_transformer import ShuffleTransformer
|
||||||
|
from .stdc import STDCContextPathNet, STDCNet
|
||||||
from .swin_transformer import SwinTransformer
|
from .swin_transformer import SwinTransformer
|
||||||
from .swin_transformer3d import SwinTransformer3D
|
from .swin_transformer3d import SwinTransformer3D
|
||||||
from .vision_transformer import VisionTransformer
|
from .vision_transformer import VisionTransformer
|
||||||
|
|
|
@ -0,0 +1,465 @@
|
||||||
|
# Modified from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/backbones/stdc.py
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from mmcv.cnn import ConvModule
|
||||||
|
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
|
||||||
|
|
||||||
|
from easycv.models import builder
|
||||||
|
from ..registry import BACKBONES
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionRefinementModule(BaseModule):
|
||||||
|
"""Attention Refinement Module (ARM) to refine the features of each stage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): The number of input channels.
|
||||||
|
out_channels (int): The number of output channels.
|
||||||
|
Returns:
|
||||||
|
x_out (torch.Tensor): Feature map of Attention Refinement Module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channel,
|
||||||
|
conv_cfg=None,
|
||||||
|
norm_cfg=dict(type='BN'),
|
||||||
|
act_cfg=dict(type='ReLU'),
|
||||||
|
init_cfg=None):
|
||||||
|
super(AttentionRefinementModule, self).__init__(init_cfg=init_cfg)
|
||||||
|
self.conv_layer = ConvModule(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channel,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
conv_cfg=conv_cfg,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=act_cfg)
|
||||||
|
self.atten_conv_layer = nn.Sequential(
|
||||||
|
nn.AdaptiveAvgPool2d((1, 1)),
|
||||||
|
ConvModule(
|
||||||
|
in_channels=out_channel,
|
||||||
|
out_channels=out_channel,
|
||||||
|
kernel_size=1,
|
||||||
|
bias=False,
|
||||||
|
conv_cfg=conv_cfg,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=None), nn.Sigmoid())
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv_layer(x)
|
||||||
|
x_atten = self.atten_conv_layer(x)
|
||||||
|
x_out = x * x_atten
|
||||||
|
return x_out
|
||||||
|
|
||||||
|
|
||||||
|
class STDCModule(BaseModule):
|
||||||
|
"""STDCModule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): The number of input channels.
|
||||||
|
out_channels (int): The number of output channels before scaling.
|
||||||
|
stride (int): The number of stride for the first conv layer.
|
||||||
|
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||||
|
act_cfg (dict): The activation config for conv layers.
|
||||||
|
num_convs (int): Numbers of conv layers.
|
||||||
|
fusion_type (str): Type of fusion operation. Default: 'add'.
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
|
Default: None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
stride,
|
||||||
|
norm_cfg=None,
|
||||||
|
act_cfg=None,
|
||||||
|
num_convs=4,
|
||||||
|
fusion_type='add',
|
||||||
|
init_cfg=None):
|
||||||
|
super(STDCModule, self).__init__(init_cfg=init_cfg)
|
||||||
|
assert num_convs > 1
|
||||||
|
assert fusion_type in ['add', 'cat']
|
||||||
|
self.stride = stride
|
||||||
|
self.with_downsample = True if self.stride == 2 else False
|
||||||
|
self.fusion_type = fusion_type
|
||||||
|
|
||||||
|
self.layers = ModuleList()
|
||||||
|
conv_0 = ConvModule(
|
||||||
|
in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg)
|
||||||
|
|
||||||
|
if self.with_downsample:
|
||||||
|
self.downsample = ConvModule(
|
||||||
|
out_channels // 2,
|
||||||
|
out_channels // 2,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
groups=out_channels // 2,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=None)
|
||||||
|
|
||||||
|
if self.fusion_type == 'add':
|
||||||
|
self.layers.append(nn.Sequential(conv_0, self.downsample))
|
||||||
|
self.skip = Sequential(
|
||||||
|
ConvModule(
|
||||||
|
in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
groups=in_channels,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=None),
|
||||||
|
ConvModule(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
1,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=None))
|
||||||
|
else:
|
||||||
|
self.layers.append(conv_0)
|
||||||
|
self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
|
else:
|
||||||
|
self.layers.append(conv_0)
|
||||||
|
|
||||||
|
for i in range(1, num_convs):
|
||||||
|
out_factor = 2**(i + 1) if i != num_convs - 1 else 2**i
|
||||||
|
self.layers.append(
|
||||||
|
ConvModule(
|
||||||
|
out_channels // 2**i,
|
||||||
|
out_channels // out_factor,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=act_cfg))
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
if self.fusion_type == 'add':
|
||||||
|
out = self.forward_add(inputs)
|
||||||
|
else:
|
||||||
|
out = self.forward_cat(inputs)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward_add(self, inputs):
|
||||||
|
layer_outputs = []
|
||||||
|
x = inputs.clone()
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x)
|
||||||
|
layer_outputs.append(x)
|
||||||
|
if self.with_downsample:
|
||||||
|
inputs = self.skip(inputs)
|
||||||
|
|
||||||
|
return torch.cat(layer_outputs, dim=1) + inputs
|
||||||
|
|
||||||
|
def forward_cat(self, inputs):
|
||||||
|
x0 = self.layers[0](inputs)
|
||||||
|
layer_outputs = [x0]
|
||||||
|
for i, layer in enumerate(self.layers[1:]):
|
||||||
|
if i == 0:
|
||||||
|
if self.with_downsample:
|
||||||
|
x = layer(self.downsample(x0))
|
||||||
|
else:
|
||||||
|
x = layer(x0)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
layer_outputs.append(x)
|
||||||
|
if self.with_downsample:
|
||||||
|
layer_outputs[0] = self.skip(x0)
|
||||||
|
return torch.cat(layer_outputs, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureFusionModule(BaseModule):
|
||||||
|
"""Feature Fusion Module. This module is different from FeatureFusionModule
|
||||||
|
in BiSeNetV1. It uses two ConvModules in `self.attention` whose inter
|
||||||
|
channel number is calculated by given `scale_factor`, while
|
||||||
|
FeatureFusionModule in BiSeNetV1 only uses one ConvModule in
|
||||||
|
`self.conv_atten`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): The number of input channels.
|
||||||
|
out_channels (int): The number of output channels.
|
||||||
|
scale_factor (int): The number of channel scale factor.
|
||||||
|
Default: 4.
|
||||||
|
norm_cfg (dict): Config dict for normalization layer.
|
||||||
|
Default: dict(type='BN').
|
||||||
|
act_cfg (dict): The activation config for conv layers.
|
||||||
|
Default: dict(type='ReLU').
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
|
Default: None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
scale_factor=4,
|
||||||
|
norm_cfg=dict(type='BN'),
|
||||||
|
act_cfg=dict(type='ReLU'),
|
||||||
|
init_cfg=None):
|
||||||
|
super(FeatureFusionModule, self).__init__(init_cfg=init_cfg)
|
||||||
|
channels = out_channels // scale_factor
|
||||||
|
self.conv0 = ConvModule(
|
||||||
|
in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||||
|
self.attention = nn.Sequential(
|
||||||
|
nn.AdaptiveAvgPool2d((1, 1)),
|
||||||
|
ConvModule(
|
||||||
|
out_channels,
|
||||||
|
channels,
|
||||||
|
1,
|
||||||
|
norm_cfg=None,
|
||||||
|
bias=False,
|
||||||
|
act_cfg=act_cfg),
|
||||||
|
ConvModule(
|
||||||
|
channels,
|
||||||
|
out_channels,
|
||||||
|
1,
|
||||||
|
norm_cfg=None,
|
||||||
|
bias=False,
|
||||||
|
act_cfg=None), nn.Sigmoid())
|
||||||
|
|
||||||
|
def forward(self, spatial_inputs, context_inputs):
|
||||||
|
inputs = torch.cat([spatial_inputs, context_inputs], dim=1)
|
||||||
|
x = self.conv0(inputs)
|
||||||
|
attn = self.attention(x)
|
||||||
|
x_attn = x * attn
|
||||||
|
return x_attn + x
|
||||||
|
|
||||||
|
|
||||||
|
@BACKBONES.register_module()
|
||||||
|
class STDCNet(BaseModule):
|
||||||
|
"""This backbone is the implementation of `Rethinking BiSeNet For Real-time
|
||||||
|
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stdc_type (int): The type of backbone structure,
|
||||||
|
`STDCNet1` and`STDCNet2` denotes two main backbones in paper,
|
||||||
|
whose FLOPs is 813M and 1446M, respectively.
|
||||||
|
in_channels (int): The num of input_channels.
|
||||||
|
channels (tuple[int]): The output channels for each stage.
|
||||||
|
bottleneck_type (str): The type of STDC Module type, the value must
|
||||||
|
be 'add' or 'cat'.
|
||||||
|
norm_cfg (dict): Config dict for normalization layer.
|
||||||
|
act_cfg (dict): The activation config for conv layers.
|
||||||
|
num_convs (int): Numbers of conv layer at each STDC Module.
|
||||||
|
Default: 4.
|
||||||
|
with_final_conv (bool): Whether add a conv layer at the Module output.
|
||||||
|
Default: True.
|
||||||
|
pretrained (str, optional): Model pretrained path. Default: None.
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
|
Default: None.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> import torch
|
||||||
|
>>> stdc_type = 'STDCNet1'
|
||||||
|
>>> in_channels = 3
|
||||||
|
>>> channels = (32, 64, 256, 512, 1024)
|
||||||
|
>>> bottleneck_type = 'cat'
|
||||||
|
>>> inputs = torch.rand(1, 3, 1024, 2048)
|
||||||
|
>>> self = STDCNet(stdc_type, in_channels,
|
||||||
|
... channels, bottleneck_type).eval()
|
||||||
|
>>> outputs = self.forward(inputs)
|
||||||
|
>>> for i in range(len(outputs)):
|
||||||
|
... print(f'outputs[{i}].shape = {outputs[i].shape}')
|
||||||
|
outputs[0].shape = torch.Size([1, 256, 128, 256])
|
||||||
|
outputs[1].shape = torch.Size([1, 512, 64, 128])
|
||||||
|
outputs[2].shape = torch.Size([1, 1024, 32, 64])
|
||||||
|
"""
|
||||||
|
|
||||||
|
arch_settings = {
|
||||||
|
'STDCNet1': [(2, 1), (2, 1), (2, 1)],
|
||||||
|
'STDCNet2': [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)]
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
stdc_type,
|
||||||
|
in_channels,
|
||||||
|
channels,
|
||||||
|
bottleneck_type,
|
||||||
|
norm_cfg,
|
||||||
|
act_cfg,
|
||||||
|
num_convs=4,
|
||||||
|
with_final_conv=False,
|
||||||
|
pretrained=None,
|
||||||
|
init_cfg=None):
|
||||||
|
super(STDCNet, self).__init__(init_cfg=init_cfg)
|
||||||
|
assert stdc_type in self.arch_settings, \
|
||||||
|
f'invalid structure {stdc_type} for STDCNet.'
|
||||||
|
assert bottleneck_type in ['add', 'cat'],\
|
||||||
|
f'bottleneck_type must be `add` or `cat`, got {bottleneck_type}'
|
||||||
|
|
||||||
|
assert len(channels) == 5,\
|
||||||
|
f'invalid channels length {len(channels)} for STDCNet.'
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.channels = channels
|
||||||
|
self.stage_strides = self.arch_settings[stdc_type]
|
||||||
|
self.prtrained = pretrained
|
||||||
|
self.num_convs = num_convs
|
||||||
|
self.with_final_conv = with_final_conv
|
||||||
|
|
||||||
|
self.stages = ModuleList([
|
||||||
|
ConvModule(
|
||||||
|
self.in_channels,
|
||||||
|
self.channels[0],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=act_cfg),
|
||||||
|
ConvModule(
|
||||||
|
self.channels[0],
|
||||||
|
self.channels[1],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=act_cfg)
|
||||||
|
])
|
||||||
|
# `self.num_shallow_features` is the number of shallow modules in
|
||||||
|
# `STDCNet`, which is noted as `Stage1` and `Stage2` in original paper.
|
||||||
|
# They are both not used for following modules like Attention
|
||||||
|
# Refinement Module and Feature Fusion Module.
|
||||||
|
# Thus they would be cut from `outs`. Please refer to Figure 4
|
||||||
|
# of original paper for more details.
|
||||||
|
self.num_shallow_features = len(self.stages)
|
||||||
|
|
||||||
|
for strides in self.stage_strides:
|
||||||
|
idx = len(self.stages) - 1
|
||||||
|
self.stages.append(
|
||||||
|
self._make_stage(self.channels[idx], self.channels[idx + 1],
|
||||||
|
strides, norm_cfg, act_cfg, bottleneck_type))
|
||||||
|
# After appending, `self.stages` is a ModuleList including several
|
||||||
|
# shallow modules and STDCModules.
|
||||||
|
# (len(self.stages) ==
|
||||||
|
# self.num_shallow_features + len(self.stage_strides))
|
||||||
|
if self.with_final_conv:
|
||||||
|
self.final_conv = ConvModule(
|
||||||
|
self.channels[-1],
|
||||||
|
max(1024, self.channels[-1]),
|
||||||
|
1,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=act_cfg)
|
||||||
|
|
||||||
|
def _make_stage(self, in_channels, out_channels, strides, norm_cfg,
|
||||||
|
act_cfg, bottleneck_type):
|
||||||
|
layers = []
|
||||||
|
for i, stride in enumerate(strides):
|
||||||
|
layers.append(
|
||||||
|
STDCModule(
|
||||||
|
in_channels if i == 0 else out_channels,
|
||||||
|
out_channels,
|
||||||
|
stride,
|
||||||
|
norm_cfg,
|
||||||
|
act_cfg,
|
||||||
|
num_convs=self.num_convs,
|
||||||
|
fusion_type=bottleneck_type))
|
||||||
|
return Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
outs = []
|
||||||
|
for stage in self.stages:
|
||||||
|
x = stage(x)
|
||||||
|
outs.append(x)
|
||||||
|
if self.with_final_conv:
|
||||||
|
outs[-1] = self.final_conv(outs[-1])
|
||||||
|
outs = outs[self.num_shallow_features:]
|
||||||
|
return tuple(outs)
|
||||||
|
|
||||||
|
|
||||||
|
@BACKBONES.register_module()
|
||||||
|
class STDCContextPathNet(BaseModule):
|
||||||
|
"""STDCNet with Context Path. The `outs` below is a list of three feature
|
||||||
|
maps from deep to shallow, whose height and width is from small to big,
|
||||||
|
respectively. The biggest feature map of `outs` is outputted for
|
||||||
|
`STDCHead`, where Detail Loss would be calculated by Detail Ground-truth.
|
||||||
|
The other two feature maps are used for Attention Refinement Module,
|
||||||
|
respectively. Besides, the biggest feature map of `outs` and the last
|
||||||
|
output of Attention Refinement Module are concatenated for Feature Fusion
|
||||||
|
Module. Then, this fusion feature map `feat_fuse` would be outputted for
|
||||||
|
`decode_head`. More details please refer to Figure 4 of original paper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backbone_cfg (dict): Config dict for stdc backbone.
|
||||||
|
last_in_channels (tuple(int)), The number of channels of last
|
||||||
|
two feature maps from stdc backbone. Default: (1024, 512).
|
||||||
|
out_channels (int): The channels of output feature maps.
|
||||||
|
Default: 128.
|
||||||
|
ffm_cfg (dict): Config dict for Feature Fusion Module. Default:
|
||||||
|
`dict(in_channels=512, out_channels=256, scale_factor=4)`.
|
||||||
|
upsample_mode (str): Algorithm used for upsampling:
|
||||||
|
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
|
||||||
|
``'trilinear'``. Default: ``'nearest'``.
|
||||||
|
align_corners (str): align_corners argument of F.interpolate. It
|
||||||
|
must be `None` if upsample_mode is ``'nearest'``. Default: None.
|
||||||
|
norm_cfg (dict): Config dict for normalization layer.
|
||||||
|
Default: dict(type='BN').
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
|
Default: None.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
outputs (tuple): The tuple of list of output feature map for
|
||||||
|
auxiliary heads and decoder head.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
backbone_cfg,
|
||||||
|
last_in_channels=(1024, 512),
|
||||||
|
out_channels=128,
|
||||||
|
ffm_cfg=dict(
|
||||||
|
in_channels=512, out_channels=256, scale_factor=4),
|
||||||
|
upsample_mode='nearest',
|
||||||
|
align_corners=None,
|
||||||
|
norm_cfg=dict(type='BN'),
|
||||||
|
init_cfg=None):
|
||||||
|
super(STDCContextPathNet, self).__init__(init_cfg=init_cfg)
|
||||||
|
self.backbone = builder.build_backbone(backbone_cfg)
|
||||||
|
self.arms = ModuleList()
|
||||||
|
self.convs = ModuleList()
|
||||||
|
for channels in last_in_channels:
|
||||||
|
self.arms.append(AttentionRefinementModule(channels, out_channels))
|
||||||
|
self.convs.append(
|
||||||
|
ConvModule(
|
||||||
|
out_channels,
|
||||||
|
out_channels,
|
||||||
|
3,
|
||||||
|
padding=1,
|
||||||
|
norm_cfg=norm_cfg))
|
||||||
|
self.conv_avg = ConvModule(
|
||||||
|
last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg)
|
||||||
|
|
||||||
|
self.ffm = FeatureFusionModule(**ffm_cfg)
|
||||||
|
|
||||||
|
self.upsample_mode = upsample_mode
|
||||||
|
self.align_corners = align_corners
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
outs = list(self.backbone(x))
|
||||||
|
avg = F.adaptive_avg_pool2d(outs[-1], 1)
|
||||||
|
avg_feat = self.conv_avg(avg)
|
||||||
|
|
||||||
|
feature_up = F.interpolate(
|
||||||
|
avg_feat,
|
||||||
|
size=outs[-1].shape[2:],
|
||||||
|
mode=self.upsample_mode,
|
||||||
|
align_corners=self.align_corners)
|
||||||
|
arms_out = []
|
||||||
|
for i in range(len(self.arms)):
|
||||||
|
x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up
|
||||||
|
feature_up = F.interpolate(
|
||||||
|
x_arm,
|
||||||
|
size=outs[len(outs) - 1 - i - 1].shape[2:],
|
||||||
|
mode=self.upsample_mode,
|
||||||
|
align_corners=self.align_corners)
|
||||||
|
feature_up = self.convs[i](feature_up)
|
||||||
|
arms_out.append(feature_up)
|
||||||
|
|
||||||
|
feat_fuse = self.ffm(outs[0], arms_out[1])
|
||||||
|
|
||||||
|
# The `outputs` has four feature maps.
|
||||||
|
# `outs[0]` is outputted for `STDCHead` auxiliary head.
|
||||||
|
# Two feature maps of `arms_out` are outputted for auxiliary head.
|
||||||
|
# `feat_fuse` is outputted for decoder head.
|
||||||
|
outputs = [outs[0]] + list(arms_out) + [feat_fuse]
|
||||||
|
return tuple(outputs)
|
|
@ -1,6 +1,7 @@
|
||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
from .cross_entropy_loss import CrossEntropyLoss
|
from .cross_entropy_loss import CrossEntropyLoss
|
||||||
from .det_db_loss import DBLoss
|
from .det_db_loss import DBLoss
|
||||||
|
from .dice_loss import DiceLoss
|
||||||
from .face_keypoint_loss import FacePoseLoss, WingLossWithPose
|
from .face_keypoint_loss import FacePoseLoss, WingLossWithPose
|
||||||
from .focal_loss import FocalLoss, VarifocalLoss
|
from .focal_loss import FocalLoss, VarifocalLoss
|
||||||
from .iou_loss import GIoULoss, IoULoss, YOLOX_IOULoss
|
from .iou_loss import GIoULoss, IoULoss, YOLOX_IOULoss
|
||||||
|
@ -22,5 +23,6 @@ __all__ = [
|
||||||
'FocalLoss2d', 'DistributeMSELoss', 'CrossEntropyLossWithLabelSmooth',
|
'FocalLoss2d', 'DistributeMSELoss', 'CrossEntropyLossWithLabelSmooth',
|
||||||
'AMSoftmaxLoss', 'ModelParallelSoftmaxLoss', 'ModelParallelAMSoftmaxLoss',
|
'AMSoftmaxLoss', 'ModelParallelSoftmaxLoss', 'ModelParallelAMSoftmaxLoss',
|
||||||
'SoftTargetCrossEntropy', 'CDNCriterion', 'DNCriterion', 'DBLoss',
|
'SoftTargetCrossEntropy', 'CDNCriterion', 'DNCriterion', 'DBLoss',
|
||||||
'HungarianMatcher', 'SetCriterion', 'L1Loss', 'MultiLoss', 'SmoothL1Loss'
|
'HungarianMatcher', 'SetCriterion', 'L1Loss', 'MultiLoss', 'SmoothL1Loss',
|
||||||
|
'DiceLoss'
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,135 @@
|
||||||
|
# Borrowed from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/losses/dice_loss.py
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from easycv.models.builder import LOSSES
|
||||||
|
from .utils import get_class_weight, weighted_loss
|
||||||
|
|
||||||
|
|
||||||
|
@weighted_loss
|
||||||
|
def dice_loss(pred,
|
||||||
|
target,
|
||||||
|
valid_mask,
|
||||||
|
smooth=1,
|
||||||
|
exponent=2,
|
||||||
|
class_weight=None,
|
||||||
|
ignore_index=255):
|
||||||
|
assert pred.shape[0] == target.shape[0]
|
||||||
|
total_loss = 0
|
||||||
|
num_classes = pred.shape[1]
|
||||||
|
for i in range(num_classes):
|
||||||
|
if i != ignore_index:
|
||||||
|
dice_loss = binary_dice_loss(
|
||||||
|
pred[:, i],
|
||||||
|
target[..., i],
|
||||||
|
valid_mask=valid_mask,
|
||||||
|
smooth=smooth,
|
||||||
|
exponent=exponent)
|
||||||
|
if class_weight is not None:
|
||||||
|
dice_loss *= class_weight[i]
|
||||||
|
total_loss += dice_loss
|
||||||
|
return total_loss / num_classes
|
||||||
|
|
||||||
|
|
||||||
|
@weighted_loss
|
||||||
|
def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwargs):
|
||||||
|
assert pred.shape[0] == target.shape[0]
|
||||||
|
pred = pred.reshape(pred.shape[0], -1)
|
||||||
|
target = target.reshape(target.shape[0], -1)
|
||||||
|
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
|
||||||
|
|
||||||
|
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth
|
||||||
|
den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth
|
||||||
|
|
||||||
|
return 1 - num / den
|
||||||
|
|
||||||
|
|
||||||
|
@LOSSES.register_module()
|
||||||
|
class DiceLoss(nn.Module):
|
||||||
|
"""DiceLoss.
|
||||||
|
|
||||||
|
This loss is proposed in `V-Net: Fully Convolutional Neural Networks for
|
||||||
|
Volumetric Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
smooth (float): A float number to smooth loss, and avoid NaN error.
|
||||||
|
Default: 1
|
||||||
|
exponent (float): An float number to calculate denominator
|
||||||
|
value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2.
|
||||||
|
reduction (str, optional): The method used to reduce the loss. Options
|
||||||
|
are "none", "mean" and "sum". This parameter only works when
|
||||||
|
per_image is True. Default: 'mean'.
|
||||||
|
class_weight (list[float] | str, optional): Weight of each class. If in
|
||||||
|
str format, read them from a file. Defaults to None.
|
||||||
|
loss_weight (float, optional): Weight of the loss. Default to 1.0.
|
||||||
|
ignore_index (int | None): The label index to be ignored. Default: 255.
|
||||||
|
loss_name (str, optional): Name of the loss item. If you want this loss
|
||||||
|
item to be included into the backward graph, `loss_` must be the
|
||||||
|
prefix of the name. Defaults to 'loss_dice'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
smooth=1,
|
||||||
|
exponent=2,
|
||||||
|
reduction='mean',
|
||||||
|
class_weight=None,
|
||||||
|
loss_weight=1.0,
|
||||||
|
ignore_index=255,
|
||||||
|
loss_name='loss_dice',
|
||||||
|
**kwargs):
|
||||||
|
super(DiceLoss, self).__init__()
|
||||||
|
self.smooth = smooth
|
||||||
|
self.exponent = exponent
|
||||||
|
self.reduction = reduction
|
||||||
|
self.class_weight = get_class_weight(class_weight)
|
||||||
|
self.loss_weight = loss_weight
|
||||||
|
self.ignore_index = ignore_index
|
||||||
|
self._loss_name = loss_name
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
pred,
|
||||||
|
target,
|
||||||
|
avg_factor=None,
|
||||||
|
reduction_override=None,
|
||||||
|
**kwargs):
|
||||||
|
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||||
|
reduction = (
|
||||||
|
reduction_override if reduction_override else self.reduction)
|
||||||
|
if self.class_weight is not None:
|
||||||
|
class_weight = pred.new_tensor(self.class_weight)
|
||||||
|
else:
|
||||||
|
class_weight = None
|
||||||
|
|
||||||
|
pred = F.softmax(pred, dim=1)
|
||||||
|
num_classes = pred.shape[1]
|
||||||
|
one_hot_target = F.one_hot(
|
||||||
|
torch.clamp(target.long(), 0, num_classes - 1),
|
||||||
|
num_classes=num_classes)
|
||||||
|
valid_mask = (target != self.ignore_index).long()
|
||||||
|
|
||||||
|
loss = self.loss_weight * dice_loss(
|
||||||
|
pred,
|
||||||
|
one_hot_target,
|
||||||
|
valid_mask=valid_mask,
|
||||||
|
reduction=reduction,
|
||||||
|
avg_factor=avg_factor,
|
||||||
|
smooth=self.smooth,
|
||||||
|
exponent=self.exponent,
|
||||||
|
class_weight=class_weight,
|
||||||
|
ignore_index=self.ignore_index)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loss_name(self):
|
||||||
|
"""Loss Name.
|
||||||
|
|
||||||
|
This function must be implemented and will return the name of this
|
||||||
|
loss function. This name will be used to combine different loss items
|
||||||
|
by simple sum operation. In addition, if you want this loss item to be
|
||||||
|
included into the backward graph, `loss_` must be the prefix of the
|
||||||
|
name.
|
||||||
|
Returns:
|
||||||
|
str: The name of this loss item.
|
||||||
|
"""
|
||||||
|
return self._loss_name
|
|
@ -1,12 +1,32 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from easycv.framework.errors import ValueError
|
from easycv.framework.errors import ValueError
|
||||||
|
|
||||||
|
|
||||||
|
def get_class_weight(class_weight):
|
||||||
|
"""Get class weight for loss function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_weight (list[float] | str | None): If class_weight is a str,
|
||||||
|
take it as a file name and read from it.
|
||||||
|
"""
|
||||||
|
if isinstance(class_weight, str):
|
||||||
|
# take it as a file path
|
||||||
|
if class_weight.endswith('.npy'):
|
||||||
|
class_weight = np.load(class_weight)
|
||||||
|
else:
|
||||||
|
# pkl, json or yaml
|
||||||
|
class_weight = mmcv.load(class_weight)
|
||||||
|
|
||||||
|
return class_weight
|
||||||
|
|
||||||
|
|
||||||
def reduce_loss(loss, reduction):
|
def reduce_loss(loss, reduction):
|
||||||
"""Reduce loss as specified.
|
"""Reduce loss as specified.
|
||||||
|
|
||||||
|
|
|
@ -2,3 +2,4 @@
|
||||||
from .encoder_decoder import EncoderDecoder
|
from .encoder_decoder import EncoderDecoder
|
||||||
from .heads import *
|
from .heads import *
|
||||||
from .mask2former import Mask2Former
|
from .mask2former import Mask2Former
|
||||||
|
from .sampler import BasePixelSampler, OHEMPixelSampler, build_pixel_sampler
|
||||||
|
|
|
@ -157,7 +157,6 @@ class EncoderDecoder(BaseModel):
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, Tensor]: a dictionary of loss components
|
dict[str, Tensor]: a dictionary of loss components
|
||||||
"""
|
"""
|
||||||
|
|
||||||
x = self.extract_feat(img)
|
x = self.extract_feat(img)
|
||||||
losses = dict()
|
losses = dict()
|
||||||
loss_decode = self._decode_head_forward_train(x, img_metas,
|
loss_decode = self._decode_head_forward_train(x, img_metas,
|
||||||
|
|
|
@ -2,6 +2,9 @@
|
||||||
from .fcn_head import FCNHead
|
from .fcn_head import FCNHead
|
||||||
from .mask2former_head import Mask2FormerHead
|
from .mask2former_head import Mask2FormerHead
|
||||||
from .segformer_head import SegformerHead
|
from .segformer_head import SegformerHead
|
||||||
|
from .stdc_head import STDCHead
|
||||||
from .uper_head import UPerHead
|
from .uper_head import UPerHead
|
||||||
|
|
||||||
__all__ = ['FCNHead', 'UPerHead', 'Mask2FormerHead', 'SegformerHead']
|
__all__ = [
|
||||||
|
'FCNHead', 'UPerHead', 'Mask2FormerHead', 'SegformerHead', 'STDCHead'
|
||||||
|
]
|
||||||
|
|
|
@ -11,6 +11,7 @@ from easycv.framework.errors import TypeError
|
||||||
from easycv.models.builder import build_loss
|
from easycv.models.builder import build_loss
|
||||||
from easycv.models.utils.ops import resize_tensor
|
from easycv.models.utils.ops import resize_tensor
|
||||||
from easycv.utils.logger import print_log
|
from easycv.utils.logger import print_log
|
||||||
|
from ..sampler import build_pixel_sampler
|
||||||
|
|
||||||
|
|
||||||
# Modified from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/decode_head.py
|
# Modified from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/decode_head.py
|
||||||
|
@ -69,6 +70,7 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
|
||||||
type='CrossEntropyLoss',
|
type='CrossEntropyLoss',
|
||||||
use_sigmoid=False,
|
use_sigmoid=False,
|
||||||
loss_weight=1.0),
|
loss_weight=1.0),
|
||||||
|
sampler=None,
|
||||||
ignore_index=255,
|
ignore_index=255,
|
||||||
align_corners=False,
|
align_corners=False,
|
||||||
init_cfg=dict(
|
init_cfg=dict(
|
||||||
|
@ -97,6 +99,11 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
|
||||||
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
|
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
|
||||||
but got {type(loss_decode)}')
|
but got {type(loss_decode)}')
|
||||||
|
|
||||||
|
if sampler is not None:
|
||||||
|
self.sampler = build_pixel_sampler(sampler, context=self)
|
||||||
|
else:
|
||||||
|
self.sampler = None
|
||||||
|
|
||||||
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
||||||
if dropout_ratio > 0:
|
if dropout_ratio > 0:
|
||||||
self.dropout = nn.Dropout2d(dropout_ratio)
|
self.dropout = nn.Dropout2d(dropout_ratio)
|
||||||
|
@ -232,7 +239,10 @@ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
|
||||||
size=seg_label.shape[2:],
|
size=seg_label.shape[2:],
|
||||||
mode='bilinear',
|
mode='bilinear',
|
||||||
align_corners=self.align_corners)
|
align_corners=self.align_corners)
|
||||||
|
if self.sampler is not None:
|
||||||
|
seg_weight = self.sampler.sample(seg_logit, seg_label)
|
||||||
|
else:
|
||||||
|
seg_weight = None
|
||||||
seg_label = seg_label.squeeze(1)
|
seg_label = seg_label.squeeze(1)
|
||||||
|
|
||||||
if not isinstance(self.loss_decode, nn.ModuleList):
|
if not isinstance(self.loss_decode, nn.ModuleList):
|
||||||
|
|
|
@ -0,0 +1,85 @@
|
||||||
|
# Modified from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/stdc_head.py
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from easycv.models.builder import HEADS
|
||||||
|
from .fcn_head import FCNHead
|
||||||
|
|
||||||
|
|
||||||
|
@HEADS.register_module()
|
||||||
|
class STDCHead(FCNHead):
|
||||||
|
"""This head is the implementation of `Rethinking BiSeNet For Real-time
|
||||||
|
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
boundary_threshold (float): The threshold of calculating boundary.
|
||||||
|
Default: 0.1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, boundary_threshold=0.1, **kwargs):
|
||||||
|
super(STDCHead, self).__init__(**kwargs)
|
||||||
|
self.boundary_threshold = boundary_threshold
|
||||||
|
# Using register buffer to make laplacian kernel on the same
|
||||||
|
# device of `seg_label`.
|
||||||
|
self.register_buffer(
|
||||||
|
'laplacian_kernel',
|
||||||
|
torch.tensor([-1, -1, -1, -1, 8, -1, -1, -1, -1],
|
||||||
|
dtype=torch.float32,
|
||||||
|
requires_grad=False).reshape((1, 1, 3, 3)))
|
||||||
|
self.fusion_kernel = torch.nn.Parameter(
|
||||||
|
torch.tensor([[6. / 10], [3. / 10], [1. / 10]],
|
||||||
|
dtype=torch.float32).reshape(1, 3, 1, 1),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
def losses(self, seg_logit, seg_label):
|
||||||
|
"""Compute Detail Aggregation Loss."""
|
||||||
|
# Note: The paper claims `fusion_kernel` is a trainable 1x1 conv
|
||||||
|
# parameters. However, it is a constant in original repo and other
|
||||||
|
# codebase because it would not be added into computation graph
|
||||||
|
# after threshold operation.
|
||||||
|
seg_label = seg_label.to(self.laplacian_kernel)
|
||||||
|
boundary_targets = F.conv2d(
|
||||||
|
seg_label, self.laplacian_kernel, padding=1)
|
||||||
|
boundary_targets = boundary_targets.clamp(min=0)
|
||||||
|
boundary_targets[boundary_targets > self.boundary_threshold] = 1
|
||||||
|
boundary_targets[boundary_targets <= self.boundary_threshold] = 0
|
||||||
|
|
||||||
|
boundary_targets_x2 = F.conv2d(
|
||||||
|
seg_label, self.laplacian_kernel, stride=2, padding=1)
|
||||||
|
boundary_targets_x2 = boundary_targets_x2.clamp(min=0)
|
||||||
|
|
||||||
|
boundary_targets_x4 = F.conv2d(
|
||||||
|
seg_label, self.laplacian_kernel, stride=4, padding=1)
|
||||||
|
boundary_targets_x4 = boundary_targets_x4.clamp(min=0)
|
||||||
|
|
||||||
|
boundary_targets_x4_up = F.interpolate(
|
||||||
|
boundary_targets_x4, boundary_targets.shape[2:], mode='nearest')
|
||||||
|
boundary_targets_x2_up = F.interpolate(
|
||||||
|
boundary_targets_x2, boundary_targets.shape[2:], mode='nearest')
|
||||||
|
|
||||||
|
boundary_targets_x2_up[
|
||||||
|
boundary_targets_x2_up > self.boundary_threshold] = 1
|
||||||
|
boundary_targets_x2_up[
|
||||||
|
boundary_targets_x2_up <= self.boundary_threshold] = 0
|
||||||
|
|
||||||
|
boundary_targets_x4_up[
|
||||||
|
boundary_targets_x4_up > self.boundary_threshold] = 1
|
||||||
|
boundary_targets_x4_up[
|
||||||
|
boundary_targets_x4_up <= self.boundary_threshold] = 0
|
||||||
|
|
||||||
|
boundary_targets_pyramids = torch.stack(
|
||||||
|
(boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up),
|
||||||
|
dim=1)
|
||||||
|
|
||||||
|
boundary_targets_pyramids = boundary_targets_pyramids.squeeze(2)
|
||||||
|
boundary_targets_pyramid = F.conv2d(boundary_targets_pyramids,
|
||||||
|
self.fusion_kernel)
|
||||||
|
|
||||||
|
boundary_targets_pyramid[
|
||||||
|
boundary_targets_pyramid > self.boundary_threshold] = 1
|
||||||
|
boundary_targets_pyramid[
|
||||||
|
boundary_targets_pyramid <= self.boundary_threshold] = 0
|
||||||
|
|
||||||
|
loss = super(STDCHead, self).losses(seg_logit,
|
||||||
|
boundary_targets_pyramid.long())
|
||||||
|
return loss
|
|
@ -0,0 +1,5 @@
|
||||||
|
from .base_pixel_sampler import BasePixelSampler
|
||||||
|
from .builder import build_pixel_sampler
|
||||||
|
from .ohem_pixel_sampler import OHEMPixelSampler
|
||||||
|
|
||||||
|
__all__ = ['BasePixelSampler', 'OHEMPixelSampler', 'build_pixel_sampler']
|
|
@ -0,0 +1,13 @@
|
||||||
|
# Borrowed from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class BasePixelSampler(metaclass=ABCMeta):
|
||||||
|
"""Base class of pixel sampler."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def sample(self, seg_logit, seg_label):
|
||||||
|
"""Placeholder for sample function."""
|
|
@ -0,0 +1,9 @@
|
||||||
|
# Borrowed from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg
|
||||||
|
from mmcv.utils import Registry, build_from_cfg
|
||||||
|
|
||||||
|
PIXEL_SAMPLERS = Registry('pixel sampler')
|
||||||
|
|
||||||
|
|
||||||
|
def build_pixel_sampler(cfg, **default_args):
|
||||||
|
"""Build pixel sampler for segmentation map."""
|
||||||
|
return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args)
|
|
@ -0,0 +1,85 @@
|
||||||
|
# Borrowed from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .base_pixel_sampler import BasePixelSampler
|
||||||
|
from .builder import PIXEL_SAMPLERS
|
||||||
|
|
||||||
|
|
||||||
|
@PIXEL_SAMPLERS.register_module()
|
||||||
|
class OHEMPixelSampler(BasePixelSampler):
|
||||||
|
"""Online Hard Example Mining Sampler for segmentation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context (nn.Module): The context of sampler, subclass of
|
||||||
|
:obj:`BaseDecodeHead`.
|
||||||
|
thresh (float, optional): The threshold for hard example selection.
|
||||||
|
Below which, are prediction with low confidence. If not
|
||||||
|
specified, the hard examples will be pixels of top ``min_kept``
|
||||||
|
loss. Default: None.
|
||||||
|
min_kept (int, optional): The minimum number of predictions to keep.
|
||||||
|
Default: 100000.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, context, thresh=None, min_kept=100000):
|
||||||
|
super(OHEMPixelSampler, self).__init__()
|
||||||
|
self.context = context
|
||||||
|
assert min_kept > 1
|
||||||
|
self.thresh = thresh
|
||||||
|
self.min_kept = min_kept
|
||||||
|
|
||||||
|
def sample(self, seg_logit, seg_label):
|
||||||
|
"""Sample pixels that have high loss or with low prediction confidence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)
|
||||||
|
seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: segmentation weight, shape (N, H, W)
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
assert seg_logit.shape[2:] == seg_label.shape[2:]
|
||||||
|
assert seg_label.shape[1] == 1
|
||||||
|
seg_label = seg_label.squeeze(1).long()
|
||||||
|
batch_kept = self.min_kept * seg_label.size(0)
|
||||||
|
valid_mask = seg_label != self.context.ignore_index
|
||||||
|
seg_weight = seg_logit.new_zeros(size=seg_label.size())
|
||||||
|
valid_seg_weight = seg_weight[valid_mask]
|
||||||
|
if self.thresh is not None:
|
||||||
|
seg_prob = F.softmax(seg_logit, dim=1)
|
||||||
|
|
||||||
|
tmp_seg_label = seg_label.clone().unsqueeze(1)
|
||||||
|
tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0
|
||||||
|
seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)
|
||||||
|
sort_prob, sort_indices = seg_prob[valid_mask].sort()
|
||||||
|
|
||||||
|
if sort_prob.numel() > 0:
|
||||||
|
min_threshold = sort_prob[min(batch_kept,
|
||||||
|
sort_prob.numel() - 1)]
|
||||||
|
else:
|
||||||
|
min_threshold = 0.0
|
||||||
|
threshold = max(min_threshold, self.thresh)
|
||||||
|
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
|
||||||
|
else:
|
||||||
|
if not isinstance(self.context.loss_decode, nn.ModuleList):
|
||||||
|
losses_decode = [self.context.loss_decode]
|
||||||
|
else:
|
||||||
|
losses_decode = self.context.loss_decode
|
||||||
|
losses = 0.0
|
||||||
|
for loss_module in losses_decode:
|
||||||
|
losses += loss_module(
|
||||||
|
seg_logit,
|
||||||
|
seg_label,
|
||||||
|
weight=None,
|
||||||
|
ignore_index=self.context.ignore_index,
|
||||||
|
reduction_override='none')
|
||||||
|
|
||||||
|
# faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
|
||||||
|
_, sort_indices = losses[valid_mask].sort(descending=True)
|
||||||
|
valid_seg_weight[sort_indices[:batch_kept]] = 1.
|
||||||
|
|
||||||
|
seg_weight[valid_mask] = valid_seg_weight
|
||||||
|
|
||||||
|
return seg_weight
|
|
@ -29,7 +29,7 @@ class SegmentationPredictor(PredictorV2):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_path,
|
model_path,
|
||||||
config_file,
|
config_file=None,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
device=None,
|
device=None,
|
||||||
save_results=False,
|
save_results=False,
|
||||||
|
@ -54,7 +54,7 @@ class SegmentationPredictor(PredictorV2):
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
self.CLASSES = self.cfg.CLASSES
|
self.CLASSES = self.cfg.CLASSES
|
||||||
self.PALETTE = self.cfg.PALETTE
|
self.PALETTE = self.cfg.get('PALETTE', None)
|
||||||
|
|
||||||
def show_result(self,
|
def show_result(self,
|
||||||
img,
|
img,
|
||||||
|
@ -90,7 +90,8 @@ class SegmentationPredictor(PredictorV2):
|
||||||
|
|
||||||
img = mmcv.imread(img)
|
img = mmcv.imread(img)
|
||||||
img = img.copy()
|
img = img.copy()
|
||||||
seg = result[0]
|
# seg = result[0]
|
||||||
|
seg = result
|
||||||
if palette is None:
|
if palette is None:
|
||||||
if self.PALETTE is None:
|
if self.PALETTE is None:
|
||||||
# Get random state before set seed,
|
# Get random state before set seed,
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
albumentations
|
albumentations
|
||||||
|
cityscapesscripts
|
||||||
dataclasses
|
dataclasses
|
||||||
decord
|
decord
|
||||||
einops
|
einops
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from tests.ut_config import SEG_DATA_SAMLL_CITYSCAPES
|
||||||
|
|
||||||
|
from easycv.datasets.segmentation.data_sources.cityscapes import \
|
||||||
|
SegSourceCityscapes
|
||||||
|
|
||||||
|
CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
|
||||||
|
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
|
||||||
|
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
|
||||||
|
'bicycle')
|
||||||
|
|
||||||
|
|
||||||
|
class SegSourceCityscapesTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||||
|
|
||||||
|
def test_cityscapes(self):
|
||||||
|
|
||||||
|
data_source = SegSourceCityscapes(
|
||||||
|
img_root=os.path.join(SEG_DATA_SAMLL_CITYSCAPES, 'leftImg8bit'),
|
||||||
|
label_root=os.path.join(SEG_DATA_SAMLL_CITYSCAPES, 'gtFine'),
|
||||||
|
classes=CLASSES,
|
||||||
|
)
|
||||||
|
|
||||||
|
index_list = random.choices(list(range(20)), k=3)
|
||||||
|
for idx in index_list:
|
||||||
|
data = data_source[idx]
|
||||||
|
self.assertIn('img_fields', data)
|
||||||
|
self.assertIn('seg_fields', data)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
|
@ -0,0 +1,126 @@
|
||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from easycv.models import build_model
|
||||||
|
|
||||||
|
|
||||||
|
class StdcTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||||
|
|
||||||
|
def test_stdc(self):
|
||||||
|
|
||||||
|
norm_cfg = dict(type='BN', requires_grad=True)
|
||||||
|
model_cfg = dict(
|
||||||
|
type='EncoderDecoder',
|
||||||
|
backbone=dict(
|
||||||
|
type='STDCContextPathNet',
|
||||||
|
backbone_cfg=dict(
|
||||||
|
type='STDCNet',
|
||||||
|
stdc_type='STDCNet1',
|
||||||
|
in_channels=3,
|
||||||
|
channels=(32, 64, 256, 512, 1024),
|
||||||
|
bottleneck_type='cat',
|
||||||
|
num_convs=4,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=dict(type='ReLU'),
|
||||||
|
with_final_conv=False),
|
||||||
|
last_in_channels=(1024, 512),
|
||||||
|
out_channels=128,
|
||||||
|
ffm_cfg=dict(
|
||||||
|
in_channels=384, out_channels=256, scale_factor=4)),
|
||||||
|
decode_head=dict(
|
||||||
|
type='FCNHead',
|
||||||
|
in_channels=256,
|
||||||
|
channels=256,
|
||||||
|
num_convs=1,
|
||||||
|
num_classes=19,
|
||||||
|
in_index=3,
|
||||||
|
concat_input=False,
|
||||||
|
dropout_ratio=0.1,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
align_corners=True,
|
||||||
|
sampler=dict(
|
||||||
|
type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss',
|
||||||
|
use_sigmoid=False,
|
||||||
|
loss_weight=1.0)),
|
||||||
|
auxiliary_head=[
|
||||||
|
dict(
|
||||||
|
type='FCNHead',
|
||||||
|
in_channels=128,
|
||||||
|
channels=64,
|
||||||
|
num_convs=1,
|
||||||
|
num_classes=19,
|
||||||
|
in_index=2,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
concat_input=False,
|
||||||
|
align_corners=False,
|
||||||
|
sampler=dict(
|
||||||
|
type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss',
|
||||||
|
use_sigmoid=False,
|
||||||
|
loss_weight=1.0)),
|
||||||
|
dict(
|
||||||
|
type='FCNHead',
|
||||||
|
in_channels=128,
|
||||||
|
channels=64,
|
||||||
|
num_convs=1,
|
||||||
|
num_classes=19,
|
||||||
|
in_index=1,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
concat_input=False,
|
||||||
|
align_corners=False,
|
||||||
|
sampler=dict(
|
||||||
|
type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss',
|
||||||
|
use_sigmoid=False,
|
||||||
|
loss_weight=1.0)),
|
||||||
|
dict(
|
||||||
|
type='STDCHead',
|
||||||
|
in_channels=256,
|
||||||
|
channels=64,
|
||||||
|
num_convs=1,
|
||||||
|
num_classes=2,
|
||||||
|
boundary_threshold=0.1,
|
||||||
|
in_index=0,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
concat_input=False,
|
||||||
|
align_corners=True,
|
||||||
|
loss_decode=[
|
||||||
|
dict(
|
||||||
|
type='CrossEntropyLoss',
|
||||||
|
loss_name='loss_ce',
|
||||||
|
use_sigmoid=True,
|
||||||
|
loss_weight=1.0),
|
||||||
|
dict(
|
||||||
|
type='DiceLoss',
|
||||||
|
loss_name='loss_dice',
|
||||||
|
loss_weight=1.0)
|
||||||
|
]),
|
||||||
|
],
|
||||||
|
train_cfg=dict(),
|
||||||
|
test_cfg=dict(mode='whole'),
|
||||||
|
)
|
||||||
|
|
||||||
|
model = build_model(model_cfg).to('cuda')
|
||||||
|
|
||||||
|
img = torch.rand(2, 3, 512, 1024).to('cuda')
|
||||||
|
gt_semantic_seg = torch.randint(
|
||||||
|
low=0, high=18, size=(2, 1, 512, 1024)).to('cuda')
|
||||||
|
|
||||||
|
train_output = model.forward_train(img, [], gt_semantic_seg)
|
||||||
|
self.assertIn('decode.loss_ce', train_output)
|
||||||
|
self.assertIn('aux_0.loss_ce', train_output)
|
||||||
|
self.assertIn('aux_1.loss_ce', train_output)
|
||||||
|
self.assertIn('aux_2.loss_ce', train_output)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
|
@ -159,6 +159,8 @@ SEG_DATA_SMALL_COCO_STUFF_10K = os.path.join(
|
||||||
BASE_LOCAL_PATH, 'data/segmentation/small_coco_stuff/small_coco_stuff10k')
|
BASE_LOCAL_PATH, 'data/segmentation/small_coco_stuff/small_coco_stuff10k')
|
||||||
SEG_DATA_SAMLL_COCO_STUFF_164K = os.path.join(
|
SEG_DATA_SAMLL_COCO_STUFF_164K = os.path.join(
|
||||||
BASE_LOCAL_PATH, 'data/segmentation/small_coco_stuff/small_coco_stuff164k')
|
BASE_LOCAL_PATH, 'data/segmentation/small_coco_stuff/small_coco_stuff164k')
|
||||||
|
SEG_DATA_SAMLL_CITYSCAPES = os.path.join(BASE_LOCAL_PATH,
|
||||||
|
'data/segmentation/small_cityscapes')
|
||||||
|
|
||||||
# OCR data
|
# OCR data
|
||||||
SMALL_OCR_CLS_DATA = os.path.join(BASE_LOCAL_PATH, 'data/ocr/small_ocr_cls')
|
SMALL_OCR_CLS_DATA = os.path.join(BASE_LOCAL_PATH, 'data/ocr/small_ocr_cls')
|
||||||
|
|
Loading…
Reference in New Issue