mirror of https://github.com/alibaba/EasyCV.git
parent
e4722c754f
commit
6b8b04db72
|
@ -0,0 +1,163 @@
|
||||||
|
_base_ = ['configs/base.py']
|
||||||
|
|
||||||
|
CLASSES = [
|
||||||
|
'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
|
||||||
|
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
|
||||||
|
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
|
||||||
|
]
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
num_classes = 21
|
||||||
|
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||||
|
model = dict(
|
||||||
|
type='EncoderDecoder',
|
||||||
|
pretrained='open-mmlab://resnet50_v1c',
|
||||||
|
backbone=dict(
|
||||||
|
type='ResNetV1c',
|
||||||
|
depth=50,
|
||||||
|
num_stages=4,
|
||||||
|
out_indices=(1, 2, 3, 4),
|
||||||
|
dilations=(1, 1, 1, 1),
|
||||||
|
strides=(1, 2, 2, 2),
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
norm_eval=False,
|
||||||
|
style='pytorch',
|
||||||
|
contract_dilation=True,
|
||||||
|
),
|
||||||
|
decode_head=dict(
|
||||||
|
type='UPerHead',
|
||||||
|
in_channels=[256, 512, 1024, 2048],
|
||||||
|
in_index=[0, 1, 2, 3],
|
||||||
|
pool_scales=(1, 2, 3, 6),
|
||||||
|
channels=512,
|
||||||
|
dropout_ratio=0.1,
|
||||||
|
num_classes=num_classes,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
align_corners=False,
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||||
|
auxiliary_head=dict(
|
||||||
|
type='FCNHead',
|
||||||
|
in_channels=1024,
|
||||||
|
in_index=2,
|
||||||
|
channels=256,
|
||||||
|
num_convs=1,
|
||||||
|
concat_input=False,
|
||||||
|
dropout_ratio=0.1,
|
||||||
|
num_classes=21,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
align_corners=False,
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
||||||
|
# model training and testing settings
|
||||||
|
train_cfg=dict(),
|
||||||
|
test_cfg=dict(mode='whole'))
|
||||||
|
|
||||||
|
# dataset settings
|
||||||
|
dataset_type = 'SegDataset'
|
||||||
|
data_root = 'data/VOCdevkit/VOC2012/'
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||||
|
crop_size = (512, 512)
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='MMResize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
|
||||||
|
dict(type='SegRandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||||
|
dict(type='MMRandomFlip', flip_ratio=0.5),
|
||||||
|
dict(type='MMPhotoMetricDistortion'),
|
||||||
|
dict(type='MMNormalize', **img_norm_cfg),
|
||||||
|
dict(type='MMPad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||||
|
dict(type='DefaultFormatBundle'),
|
||||||
|
dict(
|
||||||
|
type='Collect',
|
||||||
|
keys=['img', 'gt_semantic_seg'],
|
||||||
|
meta_keys=('filename', 'ori_filename', 'ori_shape', 'img_shape',
|
||||||
|
'pad_shape', 'scale_factor', 'flip', 'flip_direction',
|
||||||
|
'img_norm_cfg')),
|
||||||
|
]
|
||||||
|
test_pipeline = [
|
||||||
|
dict(
|
||||||
|
type='MMMultiScaleFlipAug',
|
||||||
|
img_scale=(2048, 512),
|
||||||
|
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||||
|
flip=False,
|
||||||
|
transforms=[
|
||||||
|
dict(type='MMResize', keep_ratio=True),
|
||||||
|
dict(type='MMRandomFlip'),
|
||||||
|
dict(type='MMNormalize', **img_norm_cfg),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(
|
||||||
|
type='Collect',
|
||||||
|
keys=['img'],
|
||||||
|
meta_keys=('filename', 'ori_filename', 'ori_shape',
|
||||||
|
'img_shape', 'pad_shape', 'scale_factor', 'flip',
|
||||||
|
'flip_direction', 'img_norm_cfg')),
|
||||||
|
])
|
||||||
|
]
|
||||||
|
data = dict(
|
||||||
|
imgs_per_gpu=4,
|
||||||
|
workers_per_gpu=4,
|
||||||
|
train=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
ignore_index=255,
|
||||||
|
data_source=dict(
|
||||||
|
type='SourceConcat',
|
||||||
|
data_source_list=[
|
||||||
|
dict(
|
||||||
|
type='SegSourceRaw',
|
||||||
|
img_root=data_root + 'JPEGImages',
|
||||||
|
label_root=data_root + 'SegmentationClass',
|
||||||
|
split=data_root + 'ImageSets/Segmentation/train.txt',
|
||||||
|
classes=CLASSES),
|
||||||
|
dict(
|
||||||
|
type='SegSourceRaw',
|
||||||
|
img_root=data_root + 'JPEGImages',
|
||||||
|
label_root=data_root + 'SegmentationClassAug',
|
||||||
|
split=data_root + 'ImageSets/Segmentation/aug.txt',
|
||||||
|
classes=CLASSES),
|
||||||
|
]),
|
||||||
|
pipeline=train_pipeline),
|
||||||
|
val=dict(
|
||||||
|
imgs_per_gpu=1,
|
||||||
|
ignore_index=255,
|
||||||
|
type=dataset_type,
|
||||||
|
data_source=dict(
|
||||||
|
type='SegSourceRaw',
|
||||||
|
img_root=data_root + 'JPEGImages',
|
||||||
|
label_root=data_root + 'SegmentationClass',
|
||||||
|
split=data_root + 'ImageSets/Segmentation/val.txt',
|
||||||
|
classes=CLASSES,
|
||||||
|
),
|
||||||
|
pipeline=test_pipeline),
|
||||||
|
test=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
data_source=dict(
|
||||||
|
type='SegSourceRaw',
|
||||||
|
img_root=data_root + 'JPEGImages',
|
||||||
|
label_root=data_root + 'SegmentationClass',
|
||||||
|
split=data_root + 'ImageSets/Segmentation/test.txt',
|
||||||
|
classes=CLASSES,
|
||||||
|
),
|
||||||
|
pipeline=test_pipeline))
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||||
|
optimizer_config = dict()
|
||||||
|
|
||||||
|
# learning policy
|
||||||
|
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
|
||||||
|
|
||||||
|
# runtime settings
|
||||||
|
total_epochs = 60
|
||||||
|
checkpoint_config = dict(interval=5)
|
||||||
|
eval_config = dict(interval=1, gpu_collect=False)
|
||||||
|
eval_pipelines = [
|
||||||
|
dict(
|
||||||
|
mode='test',
|
||||||
|
evaluators=[
|
||||||
|
dict(
|
||||||
|
type='SegmentationEvaluator',
|
||||||
|
classes=CLASSES,
|
||||||
|
metric_names=['mIoU'])
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
|
@ -7,3 +7,10 @@ Pretrained on **Pascal VOC 2012 + Aug**.
|
||||||
| Algorithm | Config | mIoU | Download |
|
| Algorithm | Config | mIoU | Download |
|
||||||
| ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ |
|
| ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ |
|
||||||
| fcn_r50_d8 | [fcn_r50-d8_512x512_8xb4_60e_voc12aug](https://github.com/alibaba/EasyCV/tree/master/configs/segmentation/fcn/fcn_r50-d8_512x512_8xb4_60e_voc12aug.py) | 69.01 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/fcn_r50/epoch_60.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/fcn_r50/20220525_203606.log.json) |
|
| fcn_r50_d8 | [fcn_r50-d8_512x512_8xb4_60e_voc12aug](https://github.com/alibaba/EasyCV/tree/master/configs/segmentation/fcn/fcn_r50-d8_512x512_8xb4_60e_voc12aug.py) | 69.01 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/fcn_r50/epoch_60.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/fcn_r50/20220525_203606.log.json) |
|
||||||
|
|
||||||
|
## UperNet
|
||||||
|
|
||||||
|
Pretrained on **Pascal VOC 2012 + Aug**.
|
||||||
|
| Algorithm | Config | mIoU | Download |
|
||||||
|
| ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ |
|
||||||
|
| upernet_r50 | [upernet_r50_512x512_8xb4_60e_voc12aug](https://github.com/alibaba/EasyCV/tree/master/configs/segmentation/upernet/upernet_r50_512x512_8xb4_60e_voc12aug.py) | 76.59 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/upernet_r50/epoch_60.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/upernet_r50/20220706_114712.log.json) |
|
||||||
|
|
|
@ -7,7 +7,8 @@ from easycv.models import builder
|
||||||
from easycv.models.base import BaseModel
|
from easycv.models.base import BaseModel
|
||||||
from easycv.models.builder import MODELS
|
from easycv.models.builder import MODELS
|
||||||
from easycv.models.utils.ops import resize_tensor
|
from easycv.models.utils.ops import resize_tensor
|
||||||
from easycv.utils.logger import print_log
|
from easycv.utils.checkpoint import load_checkpoint
|
||||||
|
from easycv.utils.logger import get_root_logger, print_log
|
||||||
from easycv.utils.misc import add_prefix
|
from easycv.utils.misc import add_prefix
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,7 +37,7 @@ class EncoderDecoder(BaseModel):
|
||||||
|
|
||||||
self.neck = neck
|
self.neck = neck
|
||||||
self.auxiliary_head = auxiliary_head
|
self.auxiliary_head = auxiliary_head
|
||||||
|
self.pretrained = pretrained
|
||||||
if self.neck is not None:
|
if self.neck is not None:
|
||||||
self.neck = builder.build_neck(self.neck)
|
self.neck = builder.build_neck(self.neck)
|
||||||
|
|
||||||
|
@ -55,12 +56,32 @@ class EncoderDecoder(BaseModel):
|
||||||
self.train_cfg = train_cfg
|
self.train_cfg = train_cfg
|
||||||
self.test_cfg = test_cfg
|
self.test_cfg = test_cfg
|
||||||
|
|
||||||
self.init_weights(pretrained=pretrained)
|
self.init_weights()
|
||||||
|
|
||||||
def init_weights(self, pretrained=None):
|
def init_weights(self):
|
||||||
if pretrained is not None:
|
logger = get_root_logger()
|
||||||
print_log('load model from: {}'.format(pretrained), logger='root')
|
if isinstance(self.pretrained, str):
|
||||||
self.backbone.init_weights(pretrained=pretrained)
|
load_checkpoint(
|
||||||
|
self.backbone, self.pretrained, strict=False, logger=logger)
|
||||||
|
elif self.pretrained:
|
||||||
|
if self.backbone.__class__.__name__ == 'PytorchImageModelWrapper':
|
||||||
|
self.backbone.init_weights(pretrained=self.pretrained)
|
||||||
|
elif hasattr(self.backbone, 'default_pretrained_model_path'
|
||||||
|
) and self.backbone.default_pretrained_model_path:
|
||||||
|
print_log(
|
||||||
|
'load model from default path: {}'.format(
|
||||||
|
self.backbone.default_pretrained_model_path), logger)
|
||||||
|
load_checkpoint(
|
||||||
|
self.backbone,
|
||||||
|
self.backbone.default_pretrained_model_path,
|
||||||
|
strict=False,
|
||||||
|
logger=logger)
|
||||||
|
else:
|
||||||
|
print_log('load model from init weights')
|
||||||
|
self.backbone.init_weights()
|
||||||
|
else:
|
||||||
|
print_log('load model from init weights')
|
||||||
|
self.backbone.init_weights()
|
||||||
|
|
||||||
if hasattr(self.decode_head, 'init_weights'):
|
if hasattr(self.decode_head, 'init_weights'):
|
||||||
self.decode_head.init_weights()
|
self.decode_head.init_weights()
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
from .fcn_head import FCNHead
|
from .fcn_head import FCNHead
|
||||||
|
from .uper_head import UPerHead
|
||||||
|
|
||||||
__all__ = ['FCNHead']
|
__all__ = ['FCNHead', 'UPerHead']
|
||||||
|
|
|
@ -0,0 +1,194 @@
|
||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from mmcv.cnn import ConvModule
|
||||||
|
|
||||||
|
from easycv.models.builder import HEADS
|
||||||
|
from easycv.models.utils.ops import resize_tensor
|
||||||
|
from .base import BaseDecodeHead
|
||||||
|
|
||||||
|
|
||||||
|
# Modified from https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/uper_head.py
|
||||||
|
@HEADS.register_module()
|
||||||
|
class UPerHead(BaseDecodeHead):
|
||||||
|
"""Unified Perceptual Parsing for Scene Understanding.
|
||||||
|
|
||||||
|
This head is the implementation of `UPerNet
|
||||||
|
<https://arxiv.org/abs/1807.10221>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||||
|
Module applied on the last feature. Default: (1, 2, 3, 6).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
|
||||||
|
super(UPerHead, self).__init__(
|
||||||
|
input_transform='multiple_select', **kwargs)
|
||||||
|
# PSP Module
|
||||||
|
self.psp_modules = PPM(
|
||||||
|
pool_scales,
|
||||||
|
self.in_channels[-1],
|
||||||
|
self.channels,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=self.act_cfg,
|
||||||
|
align_corners=self.align_corners)
|
||||||
|
self.bottleneck = ConvModule(
|
||||||
|
self.in_channels[-1] + len(pool_scales) * self.channels,
|
||||||
|
self.channels,
|
||||||
|
3,
|
||||||
|
padding=1,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=self.act_cfg)
|
||||||
|
# FPN Module
|
||||||
|
self.lateral_convs = nn.ModuleList()
|
||||||
|
self.fpn_convs = nn.ModuleList()
|
||||||
|
for in_channels in self.in_channels[:-1]: # skip the top layer
|
||||||
|
l_conv = ConvModule(
|
||||||
|
in_channels,
|
||||||
|
self.channels,
|
||||||
|
1,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=self.act_cfg,
|
||||||
|
inplace=False)
|
||||||
|
fpn_conv = ConvModule(
|
||||||
|
self.channels,
|
||||||
|
self.channels,
|
||||||
|
3,
|
||||||
|
padding=1,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=self.act_cfg,
|
||||||
|
inplace=False)
|
||||||
|
self.lateral_convs.append(l_conv)
|
||||||
|
self.fpn_convs.append(fpn_conv)
|
||||||
|
|
||||||
|
self.fpn_bottleneck = ConvModule(
|
||||||
|
len(self.in_channels) * self.channels,
|
||||||
|
self.channels,
|
||||||
|
3,
|
||||||
|
padding=1,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=self.act_cfg)
|
||||||
|
|
||||||
|
def psp_forward(self, inputs):
|
||||||
|
"""Forward function of PSP module."""
|
||||||
|
x = inputs[-1]
|
||||||
|
psp_outs = [x]
|
||||||
|
psp_outs.extend(self.psp_modules(x))
|
||||||
|
psp_outs = torch.cat(psp_outs, dim=1)
|
||||||
|
output = self.bottleneck(psp_outs)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _forward_feature(self, inputs):
|
||||||
|
"""Forward function for feature maps before classifying each pixel with
|
||||||
|
``self.cls_seg`` fc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (list[Tensor]): List of multi-level img features.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
||||||
|
H, W) which is feature map for last layer of decoder head.
|
||||||
|
"""
|
||||||
|
inputs = self._transform_inputs(inputs)
|
||||||
|
|
||||||
|
# build laterals
|
||||||
|
laterals = [
|
||||||
|
lateral_conv(inputs[i])
|
||||||
|
for i, lateral_conv in enumerate(self.lateral_convs)
|
||||||
|
]
|
||||||
|
|
||||||
|
laterals.append(self.psp_forward(inputs))
|
||||||
|
|
||||||
|
# build top-down path
|
||||||
|
used_backbone_levels = len(laterals)
|
||||||
|
for i in range(used_backbone_levels - 1, 0, -1):
|
||||||
|
prev_shape = laterals[i - 1].shape[2:]
|
||||||
|
laterals[i - 1] = laterals[i - 1] + resize_tensor(
|
||||||
|
laterals[i],
|
||||||
|
size=prev_shape,
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=self.align_corners)
|
||||||
|
|
||||||
|
# build outputs
|
||||||
|
fpn_outs = [
|
||||||
|
self.fpn_convs[i](laterals[i])
|
||||||
|
for i in range(used_backbone_levels - 1)
|
||||||
|
]
|
||||||
|
# append psp feature
|
||||||
|
fpn_outs.append(laterals[-1])
|
||||||
|
|
||||||
|
for i in range(used_backbone_levels - 1, 0, -1):
|
||||||
|
fpn_outs[i] = resize_tensor(
|
||||||
|
fpn_outs[i],
|
||||||
|
size=fpn_outs[0].shape[2:],
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=self.align_corners)
|
||||||
|
fpn_outs = torch.cat(fpn_outs, dim=1)
|
||||||
|
feats = self.fpn_bottleneck(fpn_outs)
|
||||||
|
return feats
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
"""Forward function."""
|
||||||
|
output = self._forward_feature(inputs)
|
||||||
|
output = self.cls_seg(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class PPM(nn.ModuleList):
|
||||||
|
"""Pooling Pyramid Module used in PSPNet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||||
|
Module.
|
||||||
|
in_channels (int): Input channels.
|
||||||
|
channels (int): Channels after modules, before conv_seg.
|
||||||
|
conv_cfg (dict|None): Config of conv layers.
|
||||||
|
norm_cfg (dict|None): Config of norm layers.
|
||||||
|
act_cfg (dict): Config of activation layers.
|
||||||
|
align_corners (bool): align_corners argument of F.interpolate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
|
||||||
|
act_cfg, align_corners, **kwargs):
|
||||||
|
super(PPM, self).__init__()
|
||||||
|
self.pool_scales = pool_scales
|
||||||
|
self.align_corners = align_corners
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.channels = channels
|
||||||
|
self.conv_cfg = conv_cfg
|
||||||
|
self.norm_cfg = norm_cfg
|
||||||
|
self.act_cfg = act_cfg
|
||||||
|
for pool_scale in pool_scales:
|
||||||
|
self.append(
|
||||||
|
nn.Sequential(
|
||||||
|
nn.AdaptiveAvgPool2d(pool_scale),
|
||||||
|
ConvModule(
|
||||||
|
self.in_channels,
|
||||||
|
self.channels,
|
||||||
|
1,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=self.act_cfg,
|
||||||
|
**kwargs)))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward function."""
|
||||||
|
ppm_outs = []
|
||||||
|
for ppm in self:
|
||||||
|
ppm_out = ppm(x)
|
||||||
|
upsampled_ppm_out = resize_tensor(
|
||||||
|
ppm_out,
|
||||||
|
size=x.size()[2:],
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=self.align_corners)
|
||||||
|
ppm_outs.append(upsampled_ppm_out)
|
||||||
|
return ppm_outs
|
|
@ -0,0 +1,83 @@
|
||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from easycv.models.builder import build_head
|
||||||
|
|
||||||
|
|
||||||
|
class UperHeadTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||||
|
|
||||||
|
def test_forward_train(self):
|
||||||
|
norm_cfg = dict(type='BN', requires_grad=True)
|
||||||
|
uper_head_config = dict(
|
||||||
|
type='UPerHead',
|
||||||
|
in_channels=[256, 512, 1024, 2048],
|
||||||
|
in_index=[0, 1, 2, 3],
|
||||||
|
pool_scales=(1, 2, 3, 6),
|
||||||
|
channels=512,
|
||||||
|
dropout_ratio=0.1,
|
||||||
|
num_classes=19,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
align_corners=False,
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
|
||||||
|
head = build_head(uper_head_config)
|
||||||
|
head = head.to('cuda')
|
||||||
|
|
||||||
|
batch_size = 2
|
||||||
|
dummy_inputs = [
|
||||||
|
torch.rand(batch_size, 256, 128, 128).to('cuda'),
|
||||||
|
torch.rand(batch_size, 512, 64, 64).to('cuda'),
|
||||||
|
torch.rand(batch_size, 1024, 32, 32).to('cuda'),
|
||||||
|
torch.rand(batch_size, 2048, 16, 16).to('cuda'),
|
||||||
|
]
|
||||||
|
|
||||||
|
gt_semantic_seg = torch.randint(
|
||||||
|
low=0, high=19, size=(batch_size, 1, 512, 512)).to('cuda')
|
||||||
|
train_output = head.forward_train(
|
||||||
|
dummy_inputs,
|
||||||
|
img_metas=None,
|
||||||
|
gt_semantic_seg=gt_semantic_seg,
|
||||||
|
train_cfg=None)
|
||||||
|
self.assertIn('loss_ce', train_output)
|
||||||
|
self.assertIn('acc_seg', train_output)
|
||||||
|
self.assertEqual(train_output['acc_seg'].shape, torch.Size([1]))
|
||||||
|
|
||||||
|
def test_forward_test(self):
|
||||||
|
norm_cfg = dict(type='BN', requires_grad=True)
|
||||||
|
uper_head_config = dict(
|
||||||
|
type='UPerHead',
|
||||||
|
in_channels=[256, 512, 1024, 2048],
|
||||||
|
in_index=[0, 1, 2, 3],
|
||||||
|
pool_scales=(1, 2, 3, 6),
|
||||||
|
channels=512,
|
||||||
|
dropout_ratio=0.1,
|
||||||
|
num_classes=19,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
align_corners=False,
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
|
||||||
|
|
||||||
|
head = build_head(uper_head_config)
|
||||||
|
head = head.to('cuda')
|
||||||
|
|
||||||
|
batch_size = 2
|
||||||
|
dummy_inputs = [
|
||||||
|
torch.rand(batch_size, 256, 128, 128).to('cuda'),
|
||||||
|
torch.rand(batch_size, 512, 64, 64).to('cuda'),
|
||||||
|
torch.rand(batch_size, 1024, 32, 32).to('cuda'),
|
||||||
|
torch.rand(batch_size, 2048, 16, 16).to('cuda'),
|
||||||
|
]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
test_output = head.forward_test(
|
||||||
|
dummy_inputs, img_metas=None, test_cfg=None)
|
||||||
|
self.assertEqual(test_output.shape, torch.Size([2, 19, 128, 128]))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
|
@ -26,7 +26,6 @@ from easycv.datasets.utils import is_dali_dataset_type
|
||||||
from easycv.file import io
|
from easycv.file import io
|
||||||
from easycv.models import build_model
|
from easycv.models import build_model
|
||||||
from easycv.utils.collect_env import collect_env
|
from easycv.utils.collect_env import collect_env
|
||||||
from easycv.utils.flops_counter import get_model_info
|
|
||||||
from easycv.utils.logger import get_root_logger
|
from easycv.utils.logger import get_root_logger
|
||||||
from easycv.utils.mmlab_utils import dynamic_adapt_for_mmlab
|
from easycv.utils.mmlab_utils import dynamic_adapt_for_mmlab
|
||||||
from easycv.utils.config_tools import traverse_replace
|
from easycv.utils.config_tools import traverse_replace
|
||||||
|
@ -214,6 +213,7 @@ def main():
|
||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
if 'stage' in cfg.model and cfg.model['stage'] == 'EDGE':
|
if 'stage' in cfg.model and cfg.model['stage'] == 'EDGE':
|
||||||
|
from easycv.utils.flops_counter import get_model_info
|
||||||
get_model_info(model, cfg.img_scale, cfg.model, logger)
|
get_model_info(model, cfg.img_scale, cfg.model, logger)
|
||||||
|
|
||||||
assert len(cfg.workflow) == 1, 'Validation is called by hook.'
|
assert len(cfg.workflow) == 1, 'Validation is called by hook.'
|
||||||
|
|
Loading…
Reference in New Issue