85 lines
2.4 KiB
Python
85 lines
2.4 KiB
Python
_base_ = [
|
|
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py',
|
|
'../_base_/datasets/ade20k.py'
|
|
]
|
|
# model settings
|
|
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_t_20230227-119e8c9f.pth' # noqa
|
|
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
|
|
crop_size = (512, 512)
|
|
data_preprocessor = dict(
|
|
type='SegDataPreProcessor',
|
|
mean=[123.675, 116.28, 103.53],
|
|
std=[58.395, 57.12, 57.375],
|
|
bgr_to_rgb=True,
|
|
pad_val=0,
|
|
seg_pad_val=255,
|
|
size=(512, 512),
|
|
test_cfg=dict(size_divisor=32))
|
|
model = dict(
|
|
type='EncoderDecoder',
|
|
data_preprocessor=data_preprocessor,
|
|
pretrained=None,
|
|
backbone=dict(
|
|
type='MSCAN',
|
|
init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),
|
|
embed_dims=[32, 64, 160, 256],
|
|
mlp_ratios=[8, 8, 4, 4],
|
|
drop_rate=0.0,
|
|
drop_path_rate=0.1,
|
|
depths=[3, 3, 5, 2],
|
|
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
|
|
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
|
|
act_cfg=dict(type='GELU'),
|
|
norm_cfg=dict(type='BN', requires_grad=True)),
|
|
decode_head=dict(
|
|
type='LightHamHead',
|
|
in_channels=[64, 160, 256],
|
|
in_index=[1, 2, 3],
|
|
channels=256,
|
|
ham_channels=256,
|
|
dropout_ratio=0.1,
|
|
num_classes=150,
|
|
norm_cfg=ham_norm_cfg,
|
|
align_corners=False,
|
|
loss_decode=dict(
|
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
|
ham_kwargs=dict(
|
|
MD_S=1,
|
|
MD_R=16,
|
|
train_steps=6,
|
|
eval_steps=7,
|
|
inv_t=100,
|
|
rand_init=True)),
|
|
# model training and testing settings
|
|
train_cfg=dict(),
|
|
test_cfg=dict(mode='whole'))
|
|
|
|
# dataset settings
|
|
train_dataloader = dict(batch_size=16)
|
|
|
|
# optimizer
|
|
optim_wrapper = dict(
|
|
_delete_=True,
|
|
type='OptimWrapper',
|
|
optimizer=dict(
|
|
type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
|
|
paramwise_cfg=dict(
|
|
custom_keys={
|
|
'pos_block': dict(decay_mult=0.),
|
|
'norm': dict(decay_mult=0.),
|
|
'head': dict(lr_mult=10.)
|
|
}))
|
|
|
|
param_scheduler = [
|
|
dict(
|
|
type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500),
|
|
dict(
|
|
type='PolyLR',
|
|
power=1.0,
|
|
begin=1500,
|
|
end=160000,
|
|
eta_min=0.0,
|
|
by_epoch=False,
|
|
)
|
|
]
|