201 lines
6.9 KiB
Python
201 lines
6.9 KiB
Python
_base_ = ['../_base_/default_runtime.py', '../_base_/datasets/ade20k.py']
|
|
|
|
custom_imports = dict(imports='mmdet.models', allow_failed_imports=False)
|
|
|
|
crop_size = (512, 512)
|
|
data_preprocessor = dict(
|
|
type='SegDataPreProcessor',
|
|
mean=[123.675, 116.28, 103.53],
|
|
std=[58.395, 57.12, 57.375],
|
|
bgr_to_rgb=True,
|
|
pad_val=0,
|
|
seg_pad_val=255,
|
|
size=crop_size,
|
|
test_cfg=dict(size_divisor=32))
|
|
num_classes = 150
|
|
model = dict(
|
|
type='EncoderDecoder',
|
|
data_preprocessor=data_preprocessor,
|
|
backbone=dict(
|
|
type='ResNet',
|
|
depth=50,
|
|
deep_stem=False,
|
|
num_stages=4,
|
|
out_indices=(0, 1, 2, 3),
|
|
frozen_stages=-1,
|
|
norm_cfg=dict(type='SyncBN', requires_grad=False),
|
|
style='pytorch',
|
|
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
|
|
decode_head=dict(
|
|
type='Mask2FormerHead',
|
|
in_channels=[256, 512, 1024, 2048],
|
|
strides=[4, 8, 16, 32],
|
|
feat_channels=256,
|
|
out_channels=256,
|
|
num_classes=num_classes,
|
|
num_queries=100,
|
|
num_transformer_feat_level=3,
|
|
align_corners=False,
|
|
pixel_decoder=dict(
|
|
type='mmdet.MSDeformAttnPixelDecoder',
|
|
num_outs=3,
|
|
norm_cfg=dict(type='GN', num_groups=32),
|
|
act_cfg=dict(type='ReLU'),
|
|
encoder=dict( # DeformableDetrTransformerEncoder
|
|
num_layers=6,
|
|
layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
|
|
self_attn_cfg=dict( # MultiScaleDeformableAttention
|
|
embed_dims=256,
|
|
num_heads=8,
|
|
num_levels=3,
|
|
num_points=4,
|
|
im2col_step=64,
|
|
dropout=0.0,
|
|
batch_first=True,
|
|
norm_cfg=None,
|
|
init_cfg=None),
|
|
ffn_cfg=dict(
|
|
embed_dims=256,
|
|
feedforward_channels=1024,
|
|
num_fcs=2,
|
|
ffn_drop=0.0,
|
|
act_cfg=dict(type='ReLU', inplace=True))),
|
|
init_cfg=None),
|
|
positional_encoding=dict( # SinePositionalEncoding
|
|
num_feats=128, normalize=True),
|
|
init_cfg=None),
|
|
enforce_decoder_input_project=False,
|
|
positional_encoding=dict( # SinePositionalEncoding
|
|
num_feats=128, normalize=True),
|
|
transformer_decoder=dict( # Mask2FormerTransformerDecoder
|
|
return_intermediate=True,
|
|
num_layers=9,
|
|
layer_cfg=dict( # Mask2FormerTransformerDecoderLayer
|
|
self_attn_cfg=dict( # MultiheadAttention
|
|
embed_dims=256,
|
|
num_heads=8,
|
|
attn_drop=0.0,
|
|
proj_drop=0.0,
|
|
dropout_layer=None,
|
|
batch_first=True),
|
|
cross_attn_cfg=dict( # MultiheadAttention
|
|
embed_dims=256,
|
|
num_heads=8,
|
|
attn_drop=0.0,
|
|
proj_drop=0.0,
|
|
dropout_layer=None,
|
|
batch_first=True),
|
|
ffn_cfg=dict(
|
|
embed_dims=256,
|
|
feedforward_channels=2048,
|
|
num_fcs=2,
|
|
act_cfg=dict(type='ReLU', inplace=True),
|
|
ffn_drop=0.0,
|
|
dropout_layer=None,
|
|
add_identity=True)),
|
|
init_cfg=None),
|
|
loss_cls=dict(
|
|
type='mmdet.CrossEntropyLoss',
|
|
use_sigmoid=False,
|
|
loss_weight=2.0,
|
|
reduction='mean',
|
|
class_weight=[1.0] * num_classes + [0.1]),
|
|
loss_mask=dict(
|
|
type='mmdet.CrossEntropyLoss',
|
|
use_sigmoid=True,
|
|
reduction='mean',
|
|
loss_weight=5.0),
|
|
loss_dice=dict(
|
|
type='mmdet.DiceLoss',
|
|
use_sigmoid=True,
|
|
activate=True,
|
|
reduction='mean',
|
|
naive_dice=True,
|
|
eps=1.0,
|
|
loss_weight=5.0),
|
|
train_cfg=dict(
|
|
num_points=12544,
|
|
oversample_ratio=3.0,
|
|
importance_sample_ratio=0.75,
|
|
assigner=dict(
|
|
type='mmdet.HungarianAssigner',
|
|
match_costs=[
|
|
dict(type='mmdet.ClassificationCost', weight=2.0),
|
|
dict(
|
|
type='mmdet.CrossEntropyLossCost',
|
|
weight=5.0,
|
|
use_sigmoid=True),
|
|
dict(
|
|
type='mmdet.DiceCost',
|
|
weight=5.0,
|
|
pred_act=True,
|
|
eps=1.0)
|
|
]),
|
|
sampler=dict(type='mmdet.MaskPseudoSampler'))),
|
|
train_cfg=dict(),
|
|
test_cfg=dict(mode='whole'))
|
|
|
|
# dataset config
|
|
train_pipeline = [
|
|
dict(type='LoadImageFromFile'),
|
|
dict(type='LoadAnnotations', reduce_zero_label=True),
|
|
dict(
|
|
type='RandomChoiceResize',
|
|
scales=[int(512 * x * 0.1) for x in range(5, 21)],
|
|
resize_type='ResizeShortestEdge',
|
|
max_size=2048),
|
|
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
|
dict(type='RandomFlip', prob=0.5),
|
|
dict(type='PhotoMetricDistortion'),
|
|
dict(type='PackSegInputs')
|
|
]
|
|
train_dataloader = dict(batch_size=2, dataset=dict(pipeline=train_pipeline))
|
|
|
|
# optimizer
|
|
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
|
|
optimizer = dict(
|
|
type='AdamW', lr=0.0001, weight_decay=0.05, eps=1e-8, betas=(0.9, 0.999))
|
|
optim_wrapper = dict(
|
|
type='OptimWrapper',
|
|
optimizer=optimizer,
|
|
clip_grad=dict(max_norm=0.01, norm_type=2),
|
|
paramwise_cfg=dict(
|
|
custom_keys={
|
|
'backbone': dict(lr_mult=0.1, decay_mult=1.0),
|
|
'query_embed': embed_multi,
|
|
'query_feat': embed_multi,
|
|
'level_embed': embed_multi,
|
|
},
|
|
norm_decay_mult=0.0))
|
|
# learning policy
|
|
param_scheduler = [
|
|
dict(
|
|
type='PolyLR',
|
|
eta_min=0,
|
|
power=0.9,
|
|
begin=0,
|
|
end=160000,
|
|
by_epoch=False)
|
|
]
|
|
|
|
# training schedule for 160k
|
|
train_cfg = dict(
|
|
type='IterBasedTrainLoop', max_iters=160000, val_interval=5000)
|
|
val_cfg = dict(type='ValLoop')
|
|
test_cfg = dict(type='TestLoop')
|
|
default_hooks = dict(
|
|
timer=dict(type='IterTimerHook'),
|
|
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
|
|
param_scheduler=dict(type='ParamSchedulerHook'),
|
|
checkpoint=dict(
|
|
type='CheckpointHook', by_epoch=False, interval=5000,
|
|
save_best='mIoU'),
|
|
sampler_seed=dict(type='DistSamplerSeedHook'),
|
|
visualization=dict(type='SegVisualizationHook'))
|
|
|
|
# Default setting for scaling LR automatically
|
|
# - `enable` means enable scaling LR automatically
|
|
# or not by default.
|
|
# - `base_batch_size` = (8 GPUs) x (2 samples per GPU).
|
|
auto_scale_lr = dict(enable=False, base_batch_size=16)
|