[Refactory] KNet head
parent
8c68540271
commit
ffa0616a68
|
@ -3,7 +3,14 @@ _base_ = [
|
|||
'../_base_/schedules/schedule_80k.py'
|
||||
]
|
||||
crop_size = (512, 512)
|
||||
data_preprocessor = dict(size=crop_size)
|
||||
data_preprocessor = dict(
|
||||
type='SegDataPreProcessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True,
|
||||
pad_val=0,
|
||||
size=crop_size,
|
||||
seg_pad_val=255)
|
||||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
num_stages = 3
|
||||
|
@ -80,11 +87,10 @@ model = dict(
|
|||
test_cfg=dict(mode='whole'))
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005)
|
||||
optim_wrapper = dict(
|
||||
_delete_=True,
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer,
|
||||
optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.0005),
|
||||
clip_grad=dict(max_norm=1, norm_type=2))
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
|
|
|
@ -3,7 +3,14 @@ _base_ = [
|
|||
'../_base_/schedules/schedule_80k.py'
|
||||
]
|
||||
crop_size = (512, 512)
|
||||
data_preprocessor = dict(size=crop_size)
|
||||
data_preprocessor = dict(
|
||||
type='SegDataPreProcessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True,
|
||||
pad_val=0,
|
||||
size=crop_size,
|
||||
seg_pad_val=255)
|
||||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
num_stages = 3
|
||||
|
@ -80,11 +87,10 @@ model = dict(
|
|||
train_cfg=dict(),
|
||||
test_cfg=dict(mode='whole'))
|
||||
# optimizer
|
||||
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005)
|
||||
optim_wrapper = dict(
|
||||
_delete_=True,
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer,
|
||||
optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.0005),
|
||||
clip_grad=dict(max_norm=1, norm_type=2))
|
||||
|
||||
# learning policy
|
||||
|
|
|
@ -3,7 +3,14 @@ _base_ = [
|
|||
'../_base_/schedules/schedule_80k.py'
|
||||
]
|
||||
crop_size = (512, 512)
|
||||
data_preprocessor = dict(size=crop_size)
|
||||
data_preprocessor = dict(
|
||||
type='SegDataPreProcessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True,
|
||||
pad_val=0,
|
||||
size=crop_size,
|
||||
seg_pad_val=255)
|
||||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
num_stages = 3
|
||||
|
@ -79,11 +86,10 @@ model = dict(
|
|||
train_cfg=dict(),
|
||||
test_cfg=dict(mode='whole'))
|
||||
# optimizer
|
||||
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005)
|
||||
optim_wrapper = dict(
|
||||
_delete_=True,
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer,
|
||||
optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.0005),
|
||||
clip_grad=dict(max_norm=1, norm_type=2))
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
|
|
|
@ -3,7 +3,14 @@ _base_ = [
|
|||
'../_base_/schedules/schedule_80k.py'
|
||||
]
|
||||
crop_size = (512, 512)
|
||||
data_preprocessor = dict(size=crop_size)
|
||||
data_preprocessor = dict(
|
||||
type='SegDataPreProcessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True,
|
||||
pad_val=0,
|
||||
size=crop_size,
|
||||
seg_pad_val=255)
|
||||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
num_stages = 3
|
||||
|
@ -80,11 +87,10 @@ model = dict(
|
|||
train_cfg=dict(),
|
||||
test_cfg=dict(mode='whole'))
|
||||
# optimizer
|
||||
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005)
|
||||
optim_wrapper = dict(
|
||||
_delete_=True,
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer,
|
||||
optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.0005),
|
||||
clip_grad=dict(max_norm=1, norm_type=2))
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
|
|
|
@ -3,7 +3,14 @@ _base_ = 'knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k.py'
|
|||
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window7_224_22k_20220308-d5bdebaf.pth' # noqa
|
||||
# model settings
|
||||
crop_size = (640, 640)
|
||||
data_preprocessor = dict(size=crop_size)
|
||||
data_preprocessor = dict(
|
||||
type='SegDataPreProcessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True,
|
||||
pad_val=0,
|
||||
size=crop_size,
|
||||
seg_pad_val=255)
|
||||
model = dict(
|
||||
data_preprocessor=data_preprocessor,
|
||||
pretrained=checkpoint_file,
|
||||
|
|
|
@ -30,18 +30,12 @@ model = dict(
|
|||
kernel_generate_head=dict(in_channels=[96, 192, 384, 768])),
|
||||
auxiliary_head=dict(in_channels=384))
|
||||
|
||||
# modify learning rate following the official implementation of Swin Transformer # noqa
|
||||
optimizer = dict(
|
||||
_delete_=True,
|
||||
type='AdamW',
|
||||
lr=0.00006,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=0.0005)
|
||||
|
||||
optim_wrapper = dict(
|
||||
_delete_=True,
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer,
|
||||
# modify learning rate following the official implementation of Swin Transformer # noqa
|
||||
optimizer=dict(
|
||||
type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.0005),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'absolute_pos_embed': dict(decay_mult=0.),
|
||||
|
|
|
@ -260,7 +260,7 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
|||
]
|
||||
return torch.stack(gt_semantic_segs, dim=0)
|
||||
|
||||
def loss_by_feat(self, seg_logit: Tensor, batch_data_samples: SampleList,
|
||||
def loss_by_feat(self, seg_logits: Tensor, batch_data_samples: SampleList,
|
||||
**kwargs) -> dict:
|
||||
"""Compute segmentation loss.
|
||||
|
||||
|
@ -276,13 +276,13 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
|||
|
||||
seg_label = self._stack_batch_gt(batch_data_samples)
|
||||
loss = dict()
|
||||
seg_logit = resize(
|
||||
input=seg_logit,
|
||||
seg_logits = resize(
|
||||
input=seg_logits,
|
||||
size=seg_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
if self.sampler is not None:
|
||||
seg_weight = self.sampler.sample(seg_logit, seg_label)
|
||||
seg_weight = self.sampler.sample(seg_logits, seg_label)
|
||||
else:
|
||||
seg_weight = None
|
||||
seg_label = seg_label.squeeze(1)
|
||||
|
@ -294,24 +294,24 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
|||
for loss_decode in losses_decode:
|
||||
if loss_decode.loss_name not in loss:
|
||||
loss[loss_decode.loss_name] = loss_decode(
|
||||
seg_logit,
|
||||
seg_logits,
|
||||
seg_label,
|
||||
weight=seg_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
else:
|
||||
loss[loss_decode.loss_name] += loss_decode(
|
||||
seg_logit,
|
||||
seg_logits,
|
||||
seg_label,
|
||||
weight=seg_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
|
||||
loss['acc_seg'] = accuracy(
|
||||
seg_logit, seg_label, ignore_index=self.ignore_index)
|
||||
seg_logits, seg_label, ignore_index=self.ignore_index)
|
||||
return loss
|
||||
|
||||
def predict_by_feat(self, seg_logits: Tensor, batch_img_metas: List[dict],
|
||||
**kwargs) -> List[Tensor]:
|
||||
"""Trnasform a batch of output seg_logits to the input shape.
|
||||
"""Transform a batch of output seg_logits to the input shape.
|
||||
|
||||
Args:
|
||||
seg_logits (Tensor): The output from decode head forward function.
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
@ -6,7 +8,9 @@ from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
|||
from mmcv.cnn.bricks.transformer import (FFN, TRANSFORMER_LAYER,
|
||||
MultiheadAttention,
|
||||
build_transformer_layer)
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.core.utils import SampleList
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import get_root_logger
|
||||
|
@ -443,10 +447,12 @@ class IterativeDecodeHead(BaseDecodeHead):
|
|||
# only return the prediction of the last stage during testing
|
||||
return stage_segs[-1]
|
||||
|
||||
def losses(self, seg_logit, seg_label):
|
||||
def loss_by_feat(self, seg_logits: List[Tensor],
|
||||
batch_data_samples: SampleList, **kwargs) -> dict:
|
||||
losses = dict()
|
||||
for i, logit in enumerate(seg_logit):
|
||||
loss = self.kernel_generate_head.losses(logit, seg_label)
|
||||
for i, logit in enumerate(seg_logits):
|
||||
loss = self.kernel_generate_head.loss_by_feat(
|
||||
logit, batch_data_samples)
|
||||
for k, v in loss.items():
|
||||
losses[f'{k}.s{i}'] = v
|
||||
|
||||
|
|
Loading…
Reference in New Issue