45 lines
1.3 KiB
Python
45 lines
1.3 KiB
Python
checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segmenter/vit_base_p16_384_20220308-96dfe169.pth' # noqa
|
|
# model settings
|
|
backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True)
|
|
data_preprocessor = dict(
|
|
type='SegDataPreProcessor',
|
|
mean=[127.5, 127.5, 127.5],
|
|
std=[127.5, 127.5, 127.5],
|
|
bgr_to_rgb=True,
|
|
pad_val=0,
|
|
seg_pad_val=255)
|
|
model = dict(
|
|
type='EncoderDecoder',
|
|
data_preprocessor=data_preprocessor,
|
|
pretrained=checkpoint,
|
|
backbone=dict(
|
|
type='VisionTransformer',
|
|
img_size=(512, 512),
|
|
patch_size=16,
|
|
in_channels=3,
|
|
embed_dims=768,
|
|
num_layers=12,
|
|
num_heads=12,
|
|
drop_path_rate=0.1,
|
|
attn_drop_rate=0.0,
|
|
drop_rate=0.0,
|
|
final_norm=True,
|
|
norm_cfg=backbone_norm_cfg,
|
|
with_cls_token=True,
|
|
interpolate_mode='bicubic',
|
|
),
|
|
decode_head=dict(
|
|
type='SegmenterMaskTransformerHead',
|
|
in_channels=768,
|
|
channels=768,
|
|
num_classes=150,
|
|
num_layers=2,
|
|
num_heads=12,
|
|
embed_dims=768,
|
|
dropout_ratio=0.0,
|
|
loss_decode=dict(
|
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
|
),
|
|
test_cfg=dict(mode='slide', crop_size=(512, 512), stride=(480, 480)),
|
|
)
|