[Fix] Fix Coco-stuff164k on BiSeNetV1 config error (#1893)

pull/1913/head^2
MengzhangLI 2022-08-09 22:34:11 +08:00 committed by GitHub
parent 4eaa8e6919
commit 5b2f19aae4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 75 additions and 6 deletions

View File

@ -3,6 +3,7 @@ _base_ = [
'../_base_/datasets/coco-stuff164k.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_160k.py'
]
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
backbone=dict(
context_channels=(512, 1024, 2048),
@ -11,8 +12,30 @@ model = dict(
backbone_cfg=dict(type='ResNet', depth=101)),
decode_head=dict(in_channels=1024, channels=1024, num_classes=171),
auxiliary_head=[
dict(in_channels=512, channels=256, num_classes=171),
dict(in_channels=512, channels=256, num_classes=171),
dict(
type='FCNHead',
in_channels=512,
channels=256,
num_convs=1,
num_classes=171,
in_index=1,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
dict(
type='FCNHead',
in_channels=512,
channels=256,
num_convs=1,
num_classes=171,
in_index=2,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
])
lr_config = dict(warmup='linear', warmup_iters=1000)
optimizer = dict(lr=0.005)

View File

@ -3,11 +3,34 @@ _base_ = [
'../_base_/datasets/coco-stuff164k.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_160k.py'
]
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
decode_head=dict(num_classes=171),
auxiliary_head=[
dict(num_classes=171),
dict(num_classes=171),
dict(
type='FCNHead',
in_channels=128,
channels=64,
num_convs=1,
num_classes=171,
in_index=1,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
dict(
type='FCNHead',
in_channels=128,
channels=64,
num_convs=1,
num_classes=171,
in_index=2,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
])
lr_config = dict(warmup='linear', warmup_iters=1000)
optimizer = dict(lr=0.005)

View File

@ -3,6 +3,7 @@ _base_ = [
'../_base_/datasets/coco-stuff164k.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_160k.py'
]
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
backbone=dict(
context_channels=(512, 1024, 2048),
@ -11,8 +12,30 @@ model = dict(
backbone_cfg=dict(type='ResNet', depth=50)),
decode_head=dict(in_channels=1024, channels=1024, num_classes=171),
auxiliary_head=[
dict(in_channels=512, channels=256, num_classes=171),
dict(in_channels=512, channels=256, num_classes=171),
dict(
type='FCNHead',
in_channels=512,
channels=256,
num_convs=1,
num_classes=171,
in_index=1,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
dict(
type='FCNHead',
in_channels=512,
channels=256,
num_convs=1,
num_classes=171,
in_index=2,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
])
lr_config = dict(warmup='linear', warmup_iters=1000)
optimizer = dict(lr=0.005)