mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
36 lines
998 B
Python
36 lines
998 B
Python
|
# model settings
|
||
|
backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True)
|
||
|
model = dict(
|
||
|
type='EncoderDecoder',
|
||
|
pretrained='pretrain/vit_base_p16_384.pth',
|
||
|
backbone=dict(
|
||
|
type='VisionTransformer',
|
||
|
img_size=(512, 512),
|
||
|
patch_size=16,
|
||
|
in_channels=3,
|
||
|
embed_dims=768,
|
||
|
num_layers=12,
|
||
|
num_heads=12,
|
||
|
drop_path_rate=0.1,
|
||
|
attn_drop_rate=0.0,
|
||
|
drop_rate=0.0,
|
||
|
final_norm=True,
|
||
|
norm_cfg=backbone_norm_cfg,
|
||
|
with_cls_token=True,
|
||
|
interpolate_mode='bicubic',
|
||
|
),
|
||
|
decode_head=dict(
|
||
|
type='SegmenterMaskTransformerHead',
|
||
|
in_channels=768,
|
||
|
channels=768,
|
||
|
num_classes=150,
|
||
|
num_layers=2,
|
||
|
num_heads=12,
|
||
|
embed_dims=768,
|
||
|
dropout_ratio=0.0,
|
||
|
loss_decode=dict(
|
||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
||
|
),
|
||
|
test_cfg=dict(mode='slide', crop_size=(512, 512), stride=(480, 480)),
|
||
|
)
|