[Refactory] KNet head

pull/1801/head
zhengmiao 2022-06-20 04:04:48 +00:00
parent 8c68540271
commit ffa0616a68
8 changed files with 64 additions and 33 deletions

View File

@ -3,7 +3,14 @@ _base_ = [
'../_base_/schedules/schedule_80k.py'
]
crop_size = (512, 512)
data_preprocessor = dict(size=crop_size)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
size=crop_size,
seg_pad_val=255)
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
num_stages = 3
@ -80,11 +87,10 @@ model = dict(
test_cfg=dict(mode='whole'))
# optimizer
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005)
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=optimizer,
optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.0005),
clip_grad=dict(max_norm=1, norm_type=2))
# learning policy
param_scheduler = [

View File

@ -3,7 +3,14 @@ _base_ = [
'../_base_/schedules/schedule_80k.py'
]
crop_size = (512, 512)
data_preprocessor = dict(size=crop_size)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
size=crop_size,
seg_pad_val=255)
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
num_stages = 3
@ -80,11 +87,10 @@ model = dict(
train_cfg=dict(),
test_cfg=dict(mode='whole'))
# optimizer
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005)
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=optimizer,
optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.0005),
clip_grad=dict(max_norm=1, norm_type=2))
# learning policy

View File

@ -3,7 +3,14 @@ _base_ = [
'../_base_/schedules/schedule_80k.py'
]
crop_size = (512, 512)
data_preprocessor = dict(size=crop_size)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
size=crop_size,
seg_pad_val=255)
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
num_stages = 3
@ -79,11 +86,10 @@ model = dict(
train_cfg=dict(),
test_cfg=dict(mode='whole'))
# optimizer
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005)
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=optimizer,
optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.0005),
clip_grad=dict(max_norm=1, norm_type=2))
# learning policy
param_scheduler = [

View File

@ -3,7 +3,14 @@ _base_ = [
'../_base_/schedules/schedule_80k.py'
]
crop_size = (512, 512)
data_preprocessor = dict(size=crop_size)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
size=crop_size,
seg_pad_val=255)
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
num_stages = 3
@ -80,11 +87,10 @@ model = dict(
train_cfg=dict(),
test_cfg=dict(mode='whole'))
# optimizer
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005)
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=optimizer,
optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.0005),
clip_grad=dict(max_norm=1, norm_type=2))
# learning policy
param_scheduler = [

View File

@ -3,7 +3,14 @@ _base_ = 'knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k.py'
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window7_224_22k_20220308-d5bdebaf.pth' # noqa
# model settings
crop_size = (640, 640)
data_preprocessor = dict(size=crop_size)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
size=crop_size,
seg_pad_val=255)
model = dict(
data_preprocessor=data_preprocessor,
pretrained=checkpoint_file,

View File

@ -30,18 +30,12 @@ model = dict(
kernel_generate_head=dict(in_channels=[96, 192, 384, 768])),
auxiliary_head=dict(in_channels=384))
# modify learning rate following the official implementation of Swin Transformer # noqa
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.0005)
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=optimizer,
# modify learning rate following the official implementation of Swin Transformer # noqa
optimizer=dict(
type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.0005),
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),

View File

@ -260,7 +260,7 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
]
return torch.stack(gt_semantic_segs, dim=0)
def loss_by_feat(self, seg_logit: Tensor, batch_data_samples: SampleList,
def loss_by_feat(self, seg_logits: Tensor, batch_data_samples: SampleList,
**kwargs) -> dict:
"""Compute segmentation loss.
@ -276,13 +276,13 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
seg_label = self._stack_batch_gt(batch_data_samples)
loss = dict()
seg_logit = resize(
input=seg_logit,
seg_logits = resize(
input=seg_logits,
size=seg_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
if self.sampler is not None:
seg_weight = self.sampler.sample(seg_logit, seg_label)
seg_weight = self.sampler.sample(seg_logits, seg_label)
else:
seg_weight = None
seg_label = seg_label.squeeze(1)
@ -294,24 +294,24 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
for loss_decode in losses_decode:
if loss_decode.loss_name not in loss:
loss[loss_decode.loss_name] = loss_decode(
seg_logit,
seg_logits,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
else:
loss[loss_decode.loss_name] += loss_decode(
seg_logit,
seg_logits,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
loss['acc_seg'] = accuracy(
seg_logit, seg_label, ignore_index=self.ignore_index)
seg_logits, seg_label, ignore_index=self.ignore_index)
return loss
def predict_by_feat(self, seg_logits: Tensor, batch_img_metas: List[dict],
**kwargs) -> List[Tensor]:
"""Trnasform a batch of output seg_logits to the input shape.
"""Transform a batch of output seg_logits to the input shape.
Args:
seg_logits (Tensor): The output from decode head forward function.

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -6,7 +8,9 @@ from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
from mmcv.cnn.bricks.transformer import (FFN, TRANSFORMER_LAYER,
MultiheadAttention,
build_transformer_layer)
from torch import Tensor
from mmseg.core.utils import SampleList
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.registry import MODELS
from mmseg.utils import get_root_logger
@ -443,10 +447,12 @@ class IterativeDecodeHead(BaseDecodeHead):
# only return the prediction of the last stage during testing
return stage_segs[-1]
def losses(self, seg_logit, seg_label):
def loss_by_feat(self, seg_logits: List[Tensor],
batch_data_samples: SampleList, **kwargs) -> dict:
losses = dict()
for i, logit in enumerate(seg_logit):
loss = self.kernel_generate_head.losses(logit, seg_label)
for i, logit in enumerate(seg_logits):
loss = self.kernel_generate_head.loss_by_feat(
logit, batch_data_samples)
for k, v in loss.items():
losses[f'{k}.s{i}'] = v