[Feature] Support K-Net (#1289)
* knet first commit * fix import error in knet * remove kernel update head from decoder head * [Feature] Add kenerl updation for some decoder heads. * [Feature] Add kenerl updation for some decoder heads. * directly use forward_feature && modify other 3 decoder heads * remover kernel_update attr * delete unnecessary variables in forward function * delete kernel update function * delete kernel update function * delete kernel_generate_head * add unit test & comments in knet.py * add copyright to fix lint error * modify config names of knet * rename swin-l 640 * upload models&logs and refactor knet_head.py * modify docstrings and add some ut * add url, modify docstring and add loss ut * modify docstringspull/1801/head
parent
da6bb2c8c5
commit
054dc66145
|
@ -121,6 +121,7 @@ Supported methods:
|
|||
- [x] [DPT (ArXiv'2021)](configs/dpt)
|
||||
- [x] [Segmenter (ICCV'2021)](configs/segmenter)
|
||||
- [x] [SegFormer (NeurIPS'2021)](configs/segformer)
|
||||
- [x] [K-Net (NeurIPS'2021)](configs/knet)
|
||||
|
||||
Supported datasets:
|
||||
|
||||
|
|
|
@ -120,6 +120,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
|
|||
- [x] [DPT (ArXiv'2021)](configs/dpt)
|
||||
- [x] [Segmenter (ICCV'2021)](configs/segmenter)
|
||||
- [x] [SegFormer (NeurIPS'2021)](configs/segformer)
|
||||
- [x] [K-Net (NeurIPS'2021)](configs/knet)
|
||||
|
||||
已支持的数据集:
|
||||
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
# K-Net
|
||||
|
||||
[K-Net: Towards Unified Image Segmentation](https://arxiv.org/abs/2106.14855)
|
||||
|
||||
## Introduction
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
<a href="https://github.com/ZwwWayne/K-Net/">Official Repo</a>
|
||||
|
||||
<a href="https://github.com/open-mmlab/mmsegmentation/blob/v0.23.0/mmseg/models/decode_heads/knet_head.py#L392">Code Snippet</a>
|
||||
|
||||
## Abstract
|
||||
|
||||
<!-- [ABSTRACT] -->
|
||||
|
||||
Semantic, instance, and panoptic segmentations have been addressed using different and specialized frameworks despite their underlying connections. This paper presents a unified, simple, and effective framework for these essentially similar tasks. The framework, named K-Net, segments both instances and semantic categories consistently by a group of learnable kernels, where each kernel is responsible for generating a mask for either a potential instance or a stuff class. To remedy the difficulties of distinguishing various instances, we propose a kernel update strategy that enables each kernel dynamic and conditional on its meaningful group in the input image. K-Net can be trained in an end-to-end manner with bipartite matching, and its training and inference are naturally NMS-free and box-free. Without bells and whistles, K-Net surpasses all previous published state-of-the-art single-model results of panoptic segmentation on MS COCO test-dev split and semantic segmentation on ADE20K val split with 55.2% PQ and 54.3% mIoU, respectively. Its instance segmentation performance is also on par with Cascade Mask R-CNN on MS COCO with 60%-90% faster inference speeds. Code and models will be released at [this https URL](https://github.com/ZwwWayne/K-Net/).
|
||||
|
||||
<!-- [IMAGE] -->
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/24582831/157008300-9f40905c-b8e8-4a2a-9593-c1177fa35b2c.png" width="90%"/>
|
||||
</div>
|
||||
|
||||
```bibtex
|
||||
@inproceedings{zhang2021knet,
|
||||
title={{K-Net: Towards} Unified Image Segmentation},
|
||||
author={Wenwei Zhang and Jiangmiao Pang and Kai Chen and Chen Change Loy},
|
||||
year={2021},
|
||||
booktitle={NeurIPS},
|
||||
}
|
||||
```
|
||||
|
||||
## Results and models
|
||||
|
||||
### ADE20K
|
||||
|
||||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
|
||||
| --------------- | -------- | --------- | ------- | -------- | -------------- | ----- | ------------- | ---------------------------------------------------------------------------------------------------------------------------------------- | ----- |
|
||||
| KNet + FCN | R-50-D8 | 512x512 | 80000 | 7.01 | 19.24 | 43.60 | 45.12 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_043751-abcab920.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_043751.log.json) |
|
||||
| KNet + PSPNet | R-50-D8 | 512x512 | 80000 | 6.98 | 20.04 | 44.18 | 45.58 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_054634-d2c72240.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_054634.log.json) |
|
||||
| KNet + DeepLabV3| R-50-D8 | 512x512 | 80000 | 7.42 | 12.10 | 45.06 | 46.11 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_041642-00c8fbeb.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_041642.log.json) |
|
||||
| KNet + UperNet | R-50-D8 | 512x512 | 80000 | 7.34 | 17.11 | 43.45 | 44.07 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k_20220304_125657-215753b0.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k_20220304_125657.log.json) |
|
||||
| KNet + UperNet | Swin-T | 512x512 | 80000 | 7.57 | 15.56 | 45.84 | 46.27 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k_20220303_133059-7545e1dc.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k_20220303_133059.log.json) |
|
||||
| KNet + UperNet | Swin-L | 512x512 | 80000 | 13.5 | 8.29 | 52.05 | 53.24 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k_20220303_154559-d8da9a90.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k_20220303_154559.log.json) |
|
||||
| KNet + UperNet | Swin-L | 640x640 | 80000 | 13.54 | 8.29 | 52.21 | 53.34 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k_20220301_220747-8787fc71.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k_20220301_220747.log.json) |
|
||||
|
||||
Note:
|
||||
|
||||
- All experiments of K-Net are implemented with 8 V100 (32G) GPUs with 2 samplers per GPU.
|
|
@ -0,0 +1,169 @@
|
|||
Collections:
|
||||
- Name: KNet
|
||||
Metadata:
|
||||
Training Data:
|
||||
- ADE20K
|
||||
Paper:
|
||||
URL: https://arxiv.org/abs/2106.14855
|
||||
Title: 'K-Net: Towards Unified Image Segmentation'
|
||||
README: configs/knet/README.md
|
||||
Code:
|
||||
URL: https://github.com/open-mmlab/mmsegmentation/blob/v0.23.0/mmseg/models/decode_heads/knet_head.py#L392
|
||||
Version: v0.23.0
|
||||
Converted From:
|
||||
Code: https://github.com/ZwwWayne/K-Net/
|
||||
Models:
|
||||
- Name: knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k
|
||||
In Collection: KNet
|
||||
Metadata:
|
||||
backbone: R-50-D8
|
||||
crop size: (512,512)
|
||||
lr schd: 80000
|
||||
inference time (ms/im):
|
||||
- value: 51.98
|
||||
hardware: V100
|
||||
backend: PyTorch
|
||||
batch size: 1
|
||||
mode: FP32
|
||||
resolution: (512,512)
|
||||
Training Memory (GB): 7.01
|
||||
Results:
|
||||
- Task: Semantic Segmentation
|
||||
Dataset: ADE20K
|
||||
Metrics:
|
||||
mIoU: 43.6
|
||||
mIoU(ms+flip): 45.12
|
||||
Config: configs/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k.py
|
||||
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_043751-abcab920.pth
|
||||
- Name: knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k
|
||||
In Collection: KNet
|
||||
Metadata:
|
||||
backbone: R-50-D8
|
||||
crop size: (512,512)
|
||||
lr schd: 80000
|
||||
inference time (ms/im):
|
||||
- value: 49.9
|
||||
hardware: V100
|
||||
backend: PyTorch
|
||||
batch size: 1
|
||||
mode: FP32
|
||||
resolution: (512,512)
|
||||
Training Memory (GB): 6.98
|
||||
Results:
|
||||
- Task: Semantic Segmentation
|
||||
Dataset: ADE20K
|
||||
Metrics:
|
||||
mIoU: 44.18
|
||||
mIoU(ms+flip): 45.58
|
||||
Config: configs/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k.py
|
||||
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_054634-d2c72240.pth
|
||||
- Name: knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k
|
||||
In Collection: KNet
|
||||
Metadata:
|
||||
backbone: R-50-D8
|
||||
crop size: (512,512)
|
||||
lr schd: 80000
|
||||
inference time (ms/im):
|
||||
- value: 82.64
|
||||
hardware: V100
|
||||
backend: PyTorch
|
||||
batch size: 1
|
||||
mode: FP32
|
||||
resolution: (512,512)
|
||||
Training Memory (GB): 7.42
|
||||
Results:
|
||||
- Task: Semantic Segmentation
|
||||
Dataset: ADE20K
|
||||
Metrics:
|
||||
mIoU: 45.06
|
||||
mIoU(ms+flip): 46.11
|
||||
Config: configs/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k.py
|
||||
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_041642-00c8fbeb.pth
|
||||
- Name: knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k
|
||||
In Collection: KNet
|
||||
Metadata:
|
||||
backbone: R-50-D8
|
||||
crop size: (512,512)
|
||||
lr schd: 80000
|
||||
inference time (ms/im):
|
||||
- value: 58.45
|
||||
hardware: V100
|
||||
backend: PyTorch
|
||||
batch size: 1
|
||||
mode: FP32
|
||||
resolution: (512,512)
|
||||
Training Memory (GB): 7.34
|
||||
Results:
|
||||
- Task: Semantic Segmentation
|
||||
Dataset: ADE20K
|
||||
Metrics:
|
||||
mIoU: 43.45
|
||||
mIoU(ms+flip): 44.07
|
||||
Config: configs/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k.py
|
||||
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k_20220304_125657-215753b0.pth
|
||||
- Name: knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k
|
||||
In Collection: KNet
|
||||
Metadata:
|
||||
backbone: Swin-T
|
||||
crop size: (512,512)
|
||||
lr schd: 80000
|
||||
inference time (ms/im):
|
||||
- value: 64.27
|
||||
hardware: V100
|
||||
backend: PyTorch
|
||||
batch size: 1
|
||||
mode: FP32
|
||||
resolution: (512,512)
|
||||
Training Memory (GB): 7.57
|
||||
Results:
|
||||
- Task: Semantic Segmentation
|
||||
Dataset: ADE20K
|
||||
Metrics:
|
||||
mIoU: 45.84
|
||||
mIoU(ms+flip): 46.27
|
||||
Config: configs/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k.py
|
||||
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k_20220303_133059-7545e1dc.pth
|
||||
- Name: knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k
|
||||
In Collection: KNet
|
||||
Metadata:
|
||||
backbone: Swin-L
|
||||
crop size: (512,512)
|
||||
lr schd: 80000
|
||||
inference time (ms/im):
|
||||
- value: 120.63
|
||||
hardware: V100
|
||||
backend: PyTorch
|
||||
batch size: 1
|
||||
mode: FP32
|
||||
resolution: (512,512)
|
||||
Training Memory (GB): 13.5
|
||||
Results:
|
||||
- Task: Semantic Segmentation
|
||||
Dataset: ADE20K
|
||||
Metrics:
|
||||
mIoU: 52.05
|
||||
mIoU(ms+flip): 53.24
|
||||
Config: configs/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k.py
|
||||
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k_20220303_154559-d8da9a90.pth
|
||||
- Name: knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k
|
||||
In Collection: KNet
|
||||
Metadata:
|
||||
backbone: Swin-L
|
||||
crop size: (640,640)
|
||||
lr schd: 80000
|
||||
inference time (ms/im):
|
||||
- value: 120.63
|
||||
hardware: V100
|
||||
backend: PyTorch
|
||||
batch size: 1
|
||||
mode: FP32
|
||||
resolution: (640,640)
|
||||
Training Memory (GB): 13.54
|
||||
Results:
|
||||
- Task: Semantic Segmentation
|
||||
Dataset: ADE20K
|
||||
Metrics:
|
||||
mIoU: 52.21
|
||||
mIoU(ms+flip): 53.34
|
||||
Config: configs/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k.py
|
||||
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k_20220301_220747-8787fc71.pth
|
|
@ -0,0 +1,93 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
|
||||
'../_base_/schedules/schedule_80k.py'
|
||||
]
|
||||
|
||||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
num_stages = 3
|
||||
conv_kernel_size = 1
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
pretrained='open-mmlab://resnet50_v1c',
|
||||
backbone=dict(
|
||||
type='ResNetV1c',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
dilations=(1, 1, 2, 4),
|
||||
strides=(1, 2, 1, 1),
|
||||
norm_cfg=norm_cfg,
|
||||
norm_eval=False,
|
||||
style='pytorch',
|
||||
contract_dilation=True),
|
||||
decode_head=dict(
|
||||
type='IterativeDecodeHead',
|
||||
num_stages=num_stages,
|
||||
kernel_update_head=[
|
||||
dict(
|
||||
type='KernelUpdateHead',
|
||||
num_classes=150,
|
||||
num_ffn_fcs=2,
|
||||
num_heads=8,
|
||||
num_mask_fcs=1,
|
||||
feedforward_channels=2048,
|
||||
in_channels=512,
|
||||
out_channels=512,
|
||||
dropout=0.0,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
ffn_act_cfg=dict(type='ReLU', inplace=True),
|
||||
with_ffn=True,
|
||||
feat_transform_cfg=dict(
|
||||
conv_cfg=dict(type='Conv2d'), act_cfg=None),
|
||||
kernel_updator_cfg=dict(
|
||||
type='KernelUpdator',
|
||||
in_channels=256,
|
||||
feat_channels=256,
|
||||
out_channels=256,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
norm_cfg=dict(type='LN'))) for _ in range(num_stages)
|
||||
],
|
||||
kernel_generate_head=dict(
|
||||
type='ASPPHead',
|
||||
in_channels=2048,
|
||||
in_index=3,
|
||||
channels=512,
|
||||
dilations=(1, 12, 24, 36),
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))),
|
||||
auxiliary_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=1024,
|
||||
in_index=2,
|
||||
channels=256,
|
||||
num_convs=1,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
||||
# model training and testing settings
|
||||
train_cfg=dict(),
|
||||
test_cfg=dict(mode='whole'))
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005)
|
||||
optimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))
|
||||
# learning policy
|
||||
lr_config = dict(
|
||||
_delete_=True,
|
||||
policy='step',
|
||||
warmup='linear',
|
||||
warmup_iters=1000,
|
||||
warmup_ratio=0.001,
|
||||
step=[60000, 72000],
|
||||
by_epoch=False)
|
||||
# In K-Net implementation we use batch size 2 per GPU as default
|
||||
data = dict(samples_per_gpu=2, workers_per_gpu=2)
|
|
@ -0,0 +1,93 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
|
||||
'../_base_/schedules/schedule_80k.py'
|
||||
]
|
||||
|
||||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
num_stages = 3
|
||||
conv_kernel_size = 1
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
pretrained='open-mmlab://resnet50_v1c',
|
||||
backbone=dict(
|
||||
type='ResNetV1c',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
dilations=(1, 1, 2, 4),
|
||||
strides=(1, 2, 1, 1),
|
||||
norm_cfg=norm_cfg,
|
||||
norm_eval=False,
|
||||
style='pytorch',
|
||||
contract_dilation=True),
|
||||
decode_head=dict(
|
||||
type='IterativeDecodeHead',
|
||||
num_stages=num_stages,
|
||||
kernel_update_head=[
|
||||
dict(
|
||||
type='KernelUpdateHead',
|
||||
num_classes=150,
|
||||
num_ffn_fcs=2,
|
||||
num_heads=8,
|
||||
num_mask_fcs=1,
|
||||
feedforward_channels=2048,
|
||||
in_channels=512,
|
||||
out_channels=512,
|
||||
dropout=0.0,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
ffn_act_cfg=dict(type='ReLU', inplace=True),
|
||||
with_ffn=True,
|
||||
feat_transform_cfg=dict(
|
||||
conv_cfg=dict(type='Conv2d'), act_cfg=None),
|
||||
kernel_updator_cfg=dict(
|
||||
type='KernelUpdator',
|
||||
in_channels=256,
|
||||
feat_channels=256,
|
||||
out_channels=256,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
norm_cfg=dict(type='LN'))) for _ in range(num_stages)
|
||||
],
|
||||
kernel_generate_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=2048,
|
||||
in_index=3,
|
||||
channels=512,
|
||||
num_convs=2,
|
||||
concat_input=True,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))),
|
||||
auxiliary_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=1024,
|
||||
in_index=2,
|
||||
channels=256,
|
||||
num_convs=1,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
||||
# model training and testing settings
|
||||
train_cfg=dict(),
|
||||
test_cfg=dict(mode='whole'))
|
||||
# optimizer
|
||||
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005)
|
||||
optimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))
|
||||
# learning policy
|
||||
lr_config = dict(
|
||||
_delete_=True,
|
||||
policy='step',
|
||||
warmup='linear',
|
||||
warmup_iters=1000,
|
||||
warmup_ratio=0.001,
|
||||
step=[60000, 72000],
|
||||
by_epoch=False)
|
||||
# In K-Net implementation we use batch size 2 per GPU as default
|
||||
data = dict(samples_per_gpu=2, workers_per_gpu=2)
|
|
@ -0,0 +1,92 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
|
||||
'../_base_/schedules/schedule_80k.py'
|
||||
]
|
||||
|
||||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
num_stages = 3
|
||||
conv_kernel_size = 1
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
pretrained='open-mmlab://resnet50_v1c',
|
||||
backbone=dict(
|
||||
type='ResNetV1c',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
dilations=(1, 1, 2, 4),
|
||||
strides=(1, 2, 1, 1),
|
||||
norm_cfg=norm_cfg,
|
||||
norm_eval=False,
|
||||
style='pytorch',
|
||||
contract_dilation=True),
|
||||
decode_head=dict(
|
||||
type='IterativeDecodeHead',
|
||||
num_stages=num_stages,
|
||||
kernel_update_head=[
|
||||
dict(
|
||||
type='KernelUpdateHead',
|
||||
num_classes=150,
|
||||
num_ffn_fcs=2,
|
||||
num_heads=8,
|
||||
num_mask_fcs=1,
|
||||
feedforward_channels=2048,
|
||||
in_channels=512,
|
||||
out_channels=512,
|
||||
dropout=0.0,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
ffn_act_cfg=dict(type='ReLU', inplace=True),
|
||||
with_ffn=True,
|
||||
feat_transform_cfg=dict(
|
||||
conv_cfg=dict(type='Conv2d'), act_cfg=None),
|
||||
kernel_updator_cfg=dict(
|
||||
type='KernelUpdator',
|
||||
in_channels=256,
|
||||
feat_channels=256,
|
||||
out_channels=256,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
norm_cfg=dict(type='LN'))) for _ in range(num_stages)
|
||||
],
|
||||
kernel_generate_head=dict(
|
||||
type='PSPHead',
|
||||
in_channels=2048,
|
||||
in_index=3,
|
||||
channels=512,
|
||||
pool_scales=(1, 2, 3, 6),
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))),
|
||||
auxiliary_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=1024,
|
||||
in_index=2,
|
||||
channels=256,
|
||||
num_convs=1,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
||||
# model training and testing settings
|
||||
train_cfg=dict(),
|
||||
test_cfg=dict(mode='whole'))
|
||||
# optimizer
|
||||
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005)
|
||||
optimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))
|
||||
# learning policy
|
||||
lr_config = dict(
|
||||
_delete_=True,
|
||||
policy='step',
|
||||
warmup='linear',
|
||||
warmup_iters=1000,
|
||||
warmup_ratio=0.001,
|
||||
step=[60000, 72000],
|
||||
by_epoch=False)
|
||||
# In K-Net implementation we use batch size 2 per GPU as default
|
||||
data = dict(samples_per_gpu=2, workers_per_gpu=2)
|
|
@ -0,0 +1,93 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
|
||||
'../_base_/schedules/schedule_80k.py'
|
||||
]
|
||||
|
||||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
num_stages = 3
|
||||
conv_kernel_size = 1
|
||||
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
pretrained='open-mmlab://resnet50_v1c',
|
||||
backbone=dict(
|
||||
type='ResNetV1c',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
dilations=(1, 1, 1, 1),
|
||||
strides=(1, 2, 2, 2),
|
||||
norm_cfg=norm_cfg,
|
||||
norm_eval=False,
|
||||
style='pytorch',
|
||||
contract_dilation=True),
|
||||
decode_head=dict(
|
||||
type='IterativeDecodeHead',
|
||||
num_stages=num_stages,
|
||||
kernel_update_head=[
|
||||
dict(
|
||||
type='KernelUpdateHead',
|
||||
num_classes=150,
|
||||
num_ffn_fcs=2,
|
||||
num_heads=8,
|
||||
num_mask_fcs=1,
|
||||
feedforward_channels=2048,
|
||||
in_channels=512,
|
||||
out_channels=512,
|
||||
dropout=0.0,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
ffn_act_cfg=dict(type='ReLU', inplace=True),
|
||||
with_ffn=True,
|
||||
feat_transform_cfg=dict(
|
||||
conv_cfg=dict(type='Conv2d'), act_cfg=None),
|
||||
kernel_updator_cfg=dict(
|
||||
type='KernelUpdator',
|
||||
in_channels=256,
|
||||
feat_channels=256,
|
||||
out_channels=256,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
norm_cfg=dict(type='LN'))) for _ in range(num_stages)
|
||||
],
|
||||
kernel_generate_head=dict(
|
||||
type='UPerHead',
|
||||
in_channels=[256, 512, 1024, 2048],
|
||||
in_index=[0, 1, 2, 3],
|
||||
pool_scales=(1, 2, 3, 6),
|
||||
channels=512,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))),
|
||||
auxiliary_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=1024,
|
||||
in_index=2,
|
||||
channels=256,
|
||||
num_convs=1,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
||||
# model training and testing settings
|
||||
train_cfg=dict(),
|
||||
test_cfg=dict(mode='whole'))
|
||||
# optimizer
|
||||
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005)
|
||||
optimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))
|
||||
# learning policy
|
||||
lr_config = dict(
|
||||
_delete_=True,
|
||||
policy='step',
|
||||
warmup='linear',
|
||||
warmup_iters=1000,
|
||||
warmup_ratio=0.001,
|
||||
step=[60000, 72000],
|
||||
by_epoch=False)
|
||||
# In K-Net implementation we use batch size 2 per GPU as default
|
||||
data = dict(samples_per_gpu=2, workers_per_gpu=2)
|
|
@ -0,0 +1,19 @@
|
|||
_base_ = 'knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k.py'
|
||||
|
||||
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window7_224_22k_20220308-d5bdebaf.pth' # noqa
|
||||
# model settings
|
||||
model = dict(
|
||||
pretrained=checkpoint_file,
|
||||
backbone=dict(
|
||||
embed_dims=192,
|
||||
depths=[2, 2, 18, 2],
|
||||
num_heads=[6, 12, 24, 48],
|
||||
window_size=7,
|
||||
use_abs_pos_embed=False,
|
||||
drop_path_rate=0.3,
|
||||
patch_norm=True),
|
||||
decode_head=dict(
|
||||
kernel_generate_head=dict(in_channels=[192, 384, 768, 1536])),
|
||||
auxiliary_head=dict(in_channels=768))
|
||||
# In K-Net implementation we use batch size 2 per GPU as default
|
||||
data = dict(samples_per_gpu=2, workers_per_gpu=2)
|
|
@ -0,0 +1,54 @@
|
|||
_base_ = 'knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k.py'
|
||||
|
||||
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_large_patch4_window7_224_22k_20220308-d5bdebaf.pth' # noqa
|
||||
# model settings
|
||||
model = dict(
|
||||
pretrained=checkpoint_file,
|
||||
backbone=dict(
|
||||
embed_dims=192,
|
||||
depths=[2, 2, 18, 2],
|
||||
num_heads=[6, 12, 24, 48],
|
||||
window_size=7,
|
||||
use_abs_pos_embed=False,
|
||||
drop_path_rate=0.4,
|
||||
patch_norm=True),
|
||||
decode_head=dict(
|
||||
kernel_generate_head=dict(in_channels=[192, 384, 768, 1536])),
|
||||
auxiliary_head=dict(in_channels=768))
|
||||
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
crop_size = (640, 640)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(type='Resize', img_scale=(2048, 640), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(2048, 640),
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
train=dict(pipeline=train_pipeline),
|
||||
val=dict(pipeline=test_pipeline),
|
||||
test=dict(pipeline=test_pipeline))
|
||||
# In K-Net implementation we use batch size 2 per GPU as default
|
||||
data = dict(samples_per_gpu=2, workers_per_gpu=2)
|
|
@ -0,0 +1,57 @@
|
|||
_base_ = 'knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k.py'
|
||||
|
||||
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_tiny_patch4_window7_224_20220308-f41b89d3.pth' # noqa
|
||||
|
||||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
num_stages = 3
|
||||
conv_kernel_size = 1
|
||||
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
pretrained=checkpoint_file,
|
||||
backbone=dict(
|
||||
_delete_=True,
|
||||
type='SwinTransformer',
|
||||
embed_dims=96,
|
||||
depths=[2, 2, 6, 2],
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_size=7,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.3,
|
||||
use_abs_pos_embed=False,
|
||||
patch_norm=True,
|
||||
out_indices=(0, 1, 2, 3)),
|
||||
decode_head=dict(
|
||||
kernel_generate_head=dict(in_channels=[96, 192, 384, 768])),
|
||||
auxiliary_head=dict(in_channels=384))
|
||||
|
||||
# modify learning rate following the official implementation of Swin Transformer # noqa
|
||||
optimizer = dict(
|
||||
_delete_=True,
|
||||
type='AdamW',
|
||||
lr=0.00006,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=0.0005,
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'absolute_pos_embed': dict(decay_mult=0.),
|
||||
'relative_position_bias_table': dict(decay_mult=0.),
|
||||
'norm': dict(decay_mult=0.)
|
||||
}))
|
||||
optimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))
|
||||
# learning policy
|
||||
lr_config = dict(
|
||||
_delete_=True,
|
||||
policy='step',
|
||||
warmup='linear',
|
||||
warmup_iters=1000,
|
||||
warmup_ratio=0.001,
|
||||
step=[60000, 72000],
|
||||
by_epoch=False)
|
||||
# In K-Net implementation we use batch size 2 per GPU as default
|
||||
data = dict(samples_per_gpu=2, workers_per_gpu=2)
|
|
@ -13,6 +13,7 @@ from .fcn_head import FCNHead
|
|||
from .fpn_head import FPNHead
|
||||
from .gc_head import GCHead
|
||||
from .isa_head import ISAHead
|
||||
from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator
|
||||
from .lraspp_head import LRASPPHead
|
||||
from .nl_head import NLHead
|
||||
from .ocr_head import OCRHead
|
||||
|
@ -34,5 +35,6 @@ __all__ = [
|
|||
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
|
||||
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
|
||||
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead',
|
||||
'SegformerHead', 'ISAHead', 'STDCHead'
|
||||
'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead',
|
||||
'KernelUpdateHead', 'KernelUpdator'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,453 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import (FFN, TRANSFORMER_LAYER,
|
||||
MultiheadAttention,
|
||||
build_transformer_layer)
|
||||
|
||||
from mmseg.models.builder import HEADS, build_head
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.utils import get_root_logger
|
||||
|
||||
|
||||
@TRANSFORMER_LAYER.register_module()
|
||||
class KernelUpdator(nn.Module):
|
||||
"""Dynamic Kernel Updator in Kernel Update Head.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels of input feature map.
|
||||
Default: 256.
|
||||
feat_channels (int): The number of middle-stage channels in
|
||||
the kernel updator. Default: 64.
|
||||
out_channels (int): The number of output channels.
|
||||
gate_sigmoid (bool): Whether use sigmoid function in gate
|
||||
mechanism. Default: True.
|
||||
gate_norm_act (bool): Whether add normalization and activation
|
||||
layer in gate mechanism. Default: False.
|
||||
activate_out: Whether add activation after gate mechanism.
|
||||
Default: False.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='LN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=256,
|
||||
feat_channels=64,
|
||||
out_channels=None,
|
||||
gate_sigmoid=True,
|
||||
gate_norm_act=False,
|
||||
activate_out=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
):
|
||||
super(KernelUpdator, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.feat_channels = feat_channels
|
||||
self.out_channels_raw = out_channels
|
||||
self.gate_sigmoid = gate_sigmoid
|
||||
self.gate_norm_act = gate_norm_act
|
||||
self.activate_out = activate_out
|
||||
self.act_cfg = act_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.out_channels = out_channels if out_channels else in_channels
|
||||
|
||||
self.num_params_in = self.feat_channels
|
||||
self.num_params_out = self.feat_channels
|
||||
self.dynamic_layer = nn.Linear(
|
||||
self.in_channels, self.num_params_in + self.num_params_out)
|
||||
self.input_layer = nn.Linear(self.in_channels,
|
||||
self.num_params_in + self.num_params_out,
|
||||
1)
|
||||
self.input_gate = nn.Linear(self.in_channels, self.feat_channels, 1)
|
||||
self.update_gate = nn.Linear(self.in_channels, self.feat_channels, 1)
|
||||
if self.gate_norm_act:
|
||||
self.gate_norm = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
|
||||
self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
self.norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
self.input_norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
self.input_norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
|
||||
self.activation = build_activation_layer(act_cfg)
|
||||
|
||||
self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1)
|
||||
self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
|
||||
|
||||
def forward(self, update_feature, input_feature):
|
||||
"""Forward function of KernelUpdator.
|
||||
|
||||
Args:
|
||||
update_feature (torch.Tensor): Feature map assembled from
|
||||
each group. It would be reshaped with last dimension
|
||||
shape: `self.in_channels`.
|
||||
input_feature (torch.Tensor): Intermediate feature
|
||||
with shape: (N, num_classes, conv_kernel_size**2, channels).
|
||||
Returns:
|
||||
Tensor: The output tensor of shape (N*C1/C2, K*K, C2), where N is
|
||||
the number of classes, C1 and C2 are the feature map channels of
|
||||
KernelUpdateHead and KernelUpdator, respectively.
|
||||
"""
|
||||
|
||||
update_feature = update_feature.reshape(-1, self.in_channels)
|
||||
num_proposals = update_feature.size(0)
|
||||
# dynamic_layer works for
|
||||
# phi_1 and psi_3 in Eq.(4) and (5) of K-Net paper
|
||||
parameters = self.dynamic_layer(update_feature)
|
||||
param_in = parameters[:, :self.num_params_in].view(
|
||||
-1, self.feat_channels)
|
||||
param_out = parameters[:, -self.num_params_out:].view(
|
||||
-1, self.feat_channels)
|
||||
|
||||
# input_layer works for
|
||||
# phi_2 and psi_4 in Eq.(4) and (5) of K-Net paper
|
||||
input_feats = self.input_layer(
|
||||
input_feature.reshape(num_proposals, -1, self.feat_channels))
|
||||
input_in = input_feats[..., :self.num_params_in]
|
||||
input_out = input_feats[..., -self.num_params_out:]
|
||||
|
||||
# `gate_feats` is F^G in K-Net paper
|
||||
gate_feats = input_in * param_in.unsqueeze(-2)
|
||||
if self.gate_norm_act:
|
||||
gate_feats = self.activation(self.gate_norm(gate_feats))
|
||||
|
||||
input_gate = self.input_norm_in(self.input_gate(gate_feats))
|
||||
update_gate = self.norm_in(self.update_gate(gate_feats))
|
||||
if self.gate_sigmoid:
|
||||
input_gate = input_gate.sigmoid()
|
||||
update_gate = update_gate.sigmoid()
|
||||
param_out = self.norm_out(param_out)
|
||||
input_out = self.input_norm_out(input_out)
|
||||
|
||||
if self.activate_out:
|
||||
param_out = self.activation(param_out)
|
||||
input_out = self.activation(input_out)
|
||||
|
||||
# Gate mechanism. Eq.(5) in original paper.
|
||||
# param_out has shape (batch_size, feat_channels, out_channels)
|
||||
features = update_gate * param_out.unsqueeze(
|
||||
-2) + input_gate * input_out
|
||||
|
||||
features = self.fc_layer(features)
|
||||
features = self.fc_norm(features)
|
||||
features = self.activation(features)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class KernelUpdateHead(nn.Module):
|
||||
"""Kernel Update Head in K-Net.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes. Default: 150.
|
||||
num_ffn_fcs (int): The number of fully-connected layers in
|
||||
FFNs. Default: 2.
|
||||
num_heads (int): The number of parallel attention heads.
|
||||
Default: 8.
|
||||
num_mask_fcs (int): The number of fully connected layers for
|
||||
mask prediction. Default: 3.
|
||||
feedforward_channels (int): The hidden dimension of FFNs.
|
||||
Defaults: 2048.
|
||||
in_channels (int): The number of channels of input feature map.
|
||||
Default: 256.
|
||||
out_channels (int): The number of output channels.
|
||||
Default: 256.
|
||||
dropout (float): The Probability of an element to be
|
||||
zeroed in MultiheadAttention and FFN. Default 0.0.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
ffn_act_cfg (dict): Config of activation layers in FFN.
|
||||
Default: dict(type='ReLU').
|
||||
conv_kernel_size (int): The kernel size of convolution in
|
||||
Kernel Update Head for dynamic kernel updation.
|
||||
Default: 1.
|
||||
feat_transform_cfg (dict | None): Config of feature transform.
|
||||
Default: None.
|
||||
kernel_init (bool): Whether initiate mask kernel in mask head.
|
||||
Default: False.
|
||||
with_ffn (bool): Whether add FFN in kernel update head.
|
||||
Default: True.
|
||||
feat_gather_stride (int): Stride of convolution in feature transform.
|
||||
Default: 1.
|
||||
mask_transform_stride (int): Stride of mask transform.
|
||||
Default: 1.
|
||||
kernel_updator_cfg (dict): Config of kernel updator.
|
||||
Default: dict(
|
||||
type='DynamicConv',
|
||||
in_channels=256,
|
||||
feat_channels=64,
|
||||
out_channels=256,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
norm_cfg=dict(type='LN')).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes=150,
|
||||
num_ffn_fcs=2,
|
||||
num_heads=8,
|
||||
num_mask_fcs=3,
|
||||
feedforward_channels=2048,
|
||||
in_channels=256,
|
||||
out_channels=256,
|
||||
dropout=0.0,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
ffn_act_cfg=dict(type='ReLU', inplace=True),
|
||||
conv_kernel_size=1,
|
||||
feat_transform_cfg=None,
|
||||
kernel_init=False,
|
||||
with_ffn=True,
|
||||
feat_gather_stride=1,
|
||||
mask_transform_stride=1,
|
||||
kernel_updator_cfg=dict(
|
||||
type='DynamicConv',
|
||||
in_channels=256,
|
||||
feat_channels=64,
|
||||
out_channels=256,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
norm_cfg=dict(type='LN'))):
|
||||
super(KernelUpdateHead, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.fp16_enabled = False
|
||||
self.dropout = dropout
|
||||
self.num_heads = num_heads
|
||||
self.kernel_init = kernel_init
|
||||
self.with_ffn = with_ffn
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.feat_gather_stride = feat_gather_stride
|
||||
self.mask_transform_stride = mask_transform_stride
|
||||
|
||||
self.attention = MultiheadAttention(in_channels * conv_kernel_size**2,
|
||||
num_heads, dropout)
|
||||
self.attention_norm = build_norm_layer(
|
||||
dict(type='LN'), in_channels * conv_kernel_size**2)[1]
|
||||
self.kernel_update_conv = build_transformer_layer(kernel_updator_cfg)
|
||||
|
||||
if feat_transform_cfg is not None:
|
||||
kernel_size = feat_transform_cfg.pop('kernel_size', 1)
|
||||
transform_channels = in_channels
|
||||
self.feat_transform = ConvModule(
|
||||
transform_channels,
|
||||
in_channels,
|
||||
kernel_size,
|
||||
stride=feat_gather_stride,
|
||||
padding=int(feat_gather_stride // 2),
|
||||
**feat_transform_cfg)
|
||||
else:
|
||||
self.feat_transform = None
|
||||
|
||||
if self.with_ffn:
|
||||
self.ffn = FFN(
|
||||
in_channels,
|
||||
feedforward_channels,
|
||||
num_ffn_fcs,
|
||||
act_cfg=ffn_act_cfg,
|
||||
dropout=dropout)
|
||||
self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1]
|
||||
|
||||
self.mask_fcs = nn.ModuleList()
|
||||
for _ in range(num_mask_fcs):
|
||||
self.mask_fcs.append(
|
||||
nn.Linear(in_channels, in_channels, bias=False))
|
||||
self.mask_fcs.append(
|
||||
build_norm_layer(dict(type='LN'), in_channels)[1])
|
||||
self.mask_fcs.append(build_activation_layer(act_cfg))
|
||||
|
||||
self.fc_mask = nn.Linear(in_channels, out_channels)
|
||||
|
||||
def init_weights(self):
|
||||
"""Use xavier initialization for all weight parameter and set
|
||||
classification head bias as a specific value when use focal loss."""
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
else:
|
||||
# adopt the default initialization for
|
||||
# the weight and bias of the layer norm
|
||||
pass
|
||||
if self.kernel_init:
|
||||
logger = get_root_logger()
|
||||
logger.info(
|
||||
'mask kernel in mask head is normal initialized by std 0.01')
|
||||
nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01)
|
||||
|
||||
def forward(self, x, proposal_feat, mask_preds, mask_shape=None):
|
||||
"""Forward function of Dynamic Instance Interactive Head.
|
||||
|
||||
Args:
|
||||
x (Tensor): Feature map from FPN with shape
|
||||
(batch_size, feature_dimensions, H , W).
|
||||
proposal_feat (Tensor): Intermediate feature get from
|
||||
diihead in last stage, has shape
|
||||
(batch_size, num_proposals, feature_dimensions)
|
||||
mask_preds (Tensor): mask prediction from the former stage in shape
|
||||
(batch_size, num_proposals, H, W).
|
||||
|
||||
Returns:
|
||||
Tuple: The first tensor is predicted mask with shape
|
||||
(N, num_classes, H, W), the second tensor is dynamic kernel
|
||||
with shape (N, num_classes, channels, K, K).
|
||||
"""
|
||||
N, num_proposals = proposal_feat.shape[:2]
|
||||
if self.feat_transform is not None:
|
||||
x = self.feat_transform(x)
|
||||
|
||||
C, H, W = x.shape[-3:]
|
||||
|
||||
mask_h, mask_w = mask_preds.shape[-2:]
|
||||
if mask_h != H or mask_w != W:
|
||||
gather_mask = F.interpolate(
|
||||
mask_preds, (H, W), align_corners=False, mode='bilinear')
|
||||
else:
|
||||
gather_mask = mask_preds
|
||||
|
||||
sigmoid_masks = gather_mask.softmax(dim=1)
|
||||
|
||||
# Group Feature Assembling. Eq.(3) in original paper.
|
||||
# einsum is faster than bmm by 30%
|
||||
x_feat = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x)
|
||||
|
||||
# obj_feat in shape [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C]
|
||||
proposal_feat = proposal_feat.reshape(N, num_proposals,
|
||||
self.in_channels,
|
||||
-1).permute(0, 1, 3, 2)
|
||||
obj_feat = self.kernel_update_conv(x_feat, proposal_feat)
|
||||
|
||||
# [B, N, K*K, C] -> [B, N, K*K*C] -> [N, B, K*K*C]
|
||||
obj_feat = obj_feat.reshape(N, num_proposals, -1).permute(1, 0, 2)
|
||||
obj_feat = self.attention_norm(self.attention(obj_feat))
|
||||
# [N, B, K*K*C] -> [B, N, K*K*C]
|
||||
obj_feat = obj_feat.permute(1, 0, 2)
|
||||
|
||||
# obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C]
|
||||
obj_feat = obj_feat.reshape(N, num_proposals, -1, self.in_channels)
|
||||
|
||||
# FFN
|
||||
if self.with_ffn:
|
||||
obj_feat = self.ffn_norm(self.ffn(obj_feat))
|
||||
|
||||
mask_feat = obj_feat
|
||||
|
||||
for reg_layer in self.mask_fcs:
|
||||
mask_feat = reg_layer(mask_feat)
|
||||
|
||||
# [B, N, K*K, C] -> [B, N, C, K*K]
|
||||
mask_feat = self.fc_mask(mask_feat).permute(0, 1, 3, 2)
|
||||
|
||||
if (self.mask_transform_stride == 2 and self.feat_gather_stride == 1):
|
||||
mask_x = F.interpolate(
|
||||
x, scale_factor=0.5, mode='bilinear', align_corners=False)
|
||||
H, W = mask_x.shape[-2:]
|
||||
else:
|
||||
mask_x = x
|
||||
# group conv is 5x faster than unfold and uses about 1/5 memory
|
||||
# Group conv vs. unfold vs. concat batch, 2.9ms :13.5ms :3.8ms
|
||||
# Group conv vs. unfold vs. concat batch, 278 : 1420 : 369
|
||||
# but in real training group conv is slower than concat batch
|
||||
# so we keep using concat batch.
|
||||
# fold_x = F.unfold(
|
||||
# mask_x,
|
||||
# self.conv_kernel_size,
|
||||
# padding=int(self.conv_kernel_size // 2))
|
||||
# mask_feat = mask_feat.reshape(N, num_proposals, -1)
|
||||
# new_mask_preds = torch.einsum('bnc,bcl->bnl', mask_feat, fold_x)
|
||||
# [B, N, C, K*K] -> [B*N, C, K, K]
|
||||
mask_feat = mask_feat.reshape(N, num_proposals, C,
|
||||
self.conv_kernel_size,
|
||||
self.conv_kernel_size)
|
||||
# [B, C, H, W] -> [1, B*C, H, W]
|
||||
new_mask_preds = []
|
||||
for i in range(N):
|
||||
new_mask_preds.append(
|
||||
F.conv2d(
|
||||
mask_x[i:i + 1],
|
||||
mask_feat[i],
|
||||
padding=int(self.conv_kernel_size // 2)))
|
||||
|
||||
new_mask_preds = torch.cat(new_mask_preds, dim=0)
|
||||
new_mask_preds = new_mask_preds.reshape(N, num_proposals, H, W)
|
||||
if self.mask_transform_stride == 2:
|
||||
new_mask_preds = F.interpolate(
|
||||
new_mask_preds,
|
||||
scale_factor=2,
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
if mask_shape is not None and mask_shape[0] != H:
|
||||
new_mask_preds = F.interpolate(
|
||||
new_mask_preds,
|
||||
mask_shape,
|
||||
align_corners=False,
|
||||
mode='bilinear')
|
||||
|
||||
return new_mask_preds, obj_feat.permute(0, 1, 3, 2).reshape(
|
||||
N, num_proposals, self.in_channels, self.conv_kernel_size,
|
||||
self.conv_kernel_size)
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class IterativeDecodeHead(BaseDecodeHead):
|
||||
"""K-Net: Towards Unified Image Segmentation.
|
||||
|
||||
This head is the implementation of
|
||||
`K-Net: <https://arxiv.org/abs/2106.14855>`_.
|
||||
|
||||
Args:
|
||||
num_stages (int): The number of stages (kernel update heads)
|
||||
in IterativeDecodeHead. Default: 3.
|
||||
kernel_generate_head:(dict): Config of kernel generate head which
|
||||
generate mask predictions, dynamic kernels and class predictions
|
||||
for next kernel update heads.
|
||||
kernel_update_head (dict): Config of kernel update head which refine
|
||||
dynamic kernels and class predictions iteratively.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, num_stages, kernel_generate_head, kernel_update_head,
|
||||
**kwargs):
|
||||
super(BaseDecodeHead, self).__init__(**kwargs)
|
||||
assert num_stages == len(kernel_update_head)
|
||||
self.num_stages = num_stages
|
||||
self.kernel_generate_head = build_head(kernel_generate_head)
|
||||
self.kernel_update_head = nn.ModuleList()
|
||||
self.align_corners = self.kernel_generate_head.align_corners
|
||||
self.num_classes = self.kernel_generate_head.num_classes
|
||||
self.input_transform = self.kernel_generate_head.input_transform
|
||||
self.ignore_index = self.kernel_generate_head.ignore_index
|
||||
|
||||
for head_cfg in kernel_update_head:
|
||||
self.kernel_update_head.append(build_head(head_cfg))
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
feats = self.kernel_generate_head._forward_feature(inputs)
|
||||
sem_seg = self.kernel_generate_head.cls_seg(feats)
|
||||
seg_kernels = self.kernel_generate_head.conv_seg.weight.clone()
|
||||
seg_kernels = seg_kernels[None].expand(
|
||||
feats.size(0), *seg_kernels.size())
|
||||
|
||||
stage_segs = [sem_seg]
|
||||
for i in range(self.num_stages):
|
||||
sem_seg, seg_kernels = self.kernel_update_head[i](feats,
|
||||
seg_kernels,
|
||||
sem_seg)
|
||||
stage_segs.append(sem_seg)
|
||||
if self.training:
|
||||
return stage_segs
|
||||
# only return the prediction of the last stage during testing
|
||||
return stage_segs[-1]
|
||||
|
||||
def losses(self, seg_logit, seg_label):
|
||||
losses = dict()
|
||||
for i, logit in enumerate(seg_logit):
|
||||
loss = self.kernel_generate_head.losses(logit, seg_label)
|
||||
for k, v in loss.items():
|
||||
losses[f'{k}.s{i}'] = v
|
||||
|
||||
return losses
|
|
@ -22,6 +22,7 @@ Import:
|
|||
- configs/hrnet/hrnet.yml
|
||||
- configs/icnet/icnet.yml
|
||||
- configs/isanet/isanet.yml
|
||||
- configs/knet/knet.yml
|
||||
- configs/mobilenet_v2/mobilenet_v2.yml
|
||||
- configs/mobilenet_v3/mobilenet_v3.yml
|
||||
- configs/nonlocal_net/nonlocal_net.yml
|
||||
|
|
|
@ -0,0 +1,195 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads.knet_head import (IterativeDecodeHead,
|
||||
KernelUpdateHead)
|
||||
from .utils import to_cuda
|
||||
|
||||
num_stages = 3
|
||||
conv_kernel_size = 1
|
||||
|
||||
kernel_updator_cfg = dict(
|
||||
type='KernelUpdator',
|
||||
in_channels=16,
|
||||
feat_channels=16,
|
||||
out_channels=16,
|
||||
gate_norm_act=True,
|
||||
activate_out=True,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
norm_cfg=dict(type='LN'))
|
||||
|
||||
|
||||
def test_knet_head():
|
||||
# test init function of kernel update head
|
||||
kernel_update_head = KernelUpdateHead(
|
||||
num_classes=150,
|
||||
num_ffn_fcs=2,
|
||||
num_heads=8,
|
||||
num_mask_fcs=1,
|
||||
feedforward_channels=128,
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
dropout=0.0,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
ffn_act_cfg=dict(type='ReLU', inplace=True),
|
||||
with_ffn=True,
|
||||
feat_transform_cfg=dict(conv_cfg=dict(type='Conv2d'), act_cfg=None),
|
||||
kernel_init=True,
|
||||
kernel_updator_cfg=kernel_updator_cfg)
|
||||
kernel_update_head.init_weights()
|
||||
|
||||
head = IterativeDecodeHead(
|
||||
num_stages=num_stages,
|
||||
kernel_update_head=[
|
||||
dict(
|
||||
type='KernelUpdateHead',
|
||||
num_classes=150,
|
||||
num_ffn_fcs=2,
|
||||
num_heads=8,
|
||||
num_mask_fcs=1,
|
||||
feedforward_channels=128,
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
dropout=0.0,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
ffn_act_cfg=dict(type='ReLU', inplace=True),
|
||||
with_ffn=True,
|
||||
feat_transform_cfg=dict(
|
||||
conv_cfg=dict(type='Conv2d'), act_cfg=None),
|
||||
kernel_init=False,
|
||||
kernel_updator_cfg=kernel_updator_cfg)
|
||||
for _ in range(num_stages)
|
||||
],
|
||||
kernel_generate_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=128,
|
||||
in_index=3,
|
||||
channels=32,
|
||||
num_convs=2,
|
||||
concat_input=True,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
align_corners=False))
|
||||
head.init_weights()
|
||||
inputs = [
|
||||
torch.randn(1, 16, 27, 32),
|
||||
torch.randn(1, 32, 27, 16),
|
||||
torch.randn(1, 64, 27, 16),
|
||||
torch.randn(1, 128, 27, 16)
|
||||
]
|
||||
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs[-1].shape == (1, head.num_classes, 27, 16)
|
||||
|
||||
# test whether only return the prediction of
|
||||
# the last stage during testing
|
||||
with torch.no_grad():
|
||||
head.eval()
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 27, 16)
|
||||
|
||||
# test K-Net without `feat_transform_cfg`
|
||||
head = IterativeDecodeHead(
|
||||
num_stages=num_stages,
|
||||
kernel_update_head=[
|
||||
dict(
|
||||
type='KernelUpdateHead',
|
||||
num_classes=150,
|
||||
num_ffn_fcs=2,
|
||||
num_heads=8,
|
||||
num_mask_fcs=1,
|
||||
feedforward_channels=128,
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
dropout=0.0,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
ffn_act_cfg=dict(type='ReLU', inplace=True),
|
||||
with_ffn=True,
|
||||
feat_transform_cfg=None,
|
||||
kernel_updator_cfg=kernel_updator_cfg)
|
||||
for _ in range(num_stages)
|
||||
],
|
||||
kernel_generate_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=128,
|
||||
in_index=3,
|
||||
channels=32,
|
||||
num_convs=2,
|
||||
concat_input=True,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
align_corners=False))
|
||||
head.init_weights()
|
||||
|
||||
inputs = [
|
||||
torch.randn(1, 16, 27, 32),
|
||||
torch.randn(1, 32, 27, 16),
|
||||
torch.randn(1, 64, 27, 16),
|
||||
torch.randn(1, 128, 27, 16)
|
||||
]
|
||||
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs[-1].shape == (1, head.num_classes, 27, 16)
|
||||
|
||||
# test K-Net with
|
||||
# self.mask_transform_stride == 2 and self.feat_gather_stride == 1
|
||||
head = IterativeDecodeHead(
|
||||
num_stages=num_stages,
|
||||
kernel_update_head=[
|
||||
dict(
|
||||
type='KernelUpdateHead',
|
||||
num_classes=150,
|
||||
num_ffn_fcs=2,
|
||||
num_heads=8,
|
||||
num_mask_fcs=1,
|
||||
feedforward_channels=128,
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
dropout=0.0,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
ffn_act_cfg=dict(type='ReLU', inplace=True),
|
||||
with_ffn=True,
|
||||
feat_transform_cfg=dict(
|
||||
conv_cfg=dict(type='Conv2d'), act_cfg=None),
|
||||
kernel_init=False,
|
||||
mask_transform_stride=2,
|
||||
feat_gather_stride=1,
|
||||
kernel_updator_cfg=kernel_updator_cfg)
|
||||
for _ in range(num_stages)
|
||||
],
|
||||
kernel_generate_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=128,
|
||||
in_index=3,
|
||||
channels=32,
|
||||
num_convs=2,
|
||||
concat_input=True,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
align_corners=False))
|
||||
head.init_weights()
|
||||
|
||||
inputs = [
|
||||
torch.randn(1, 16, 27, 32),
|
||||
torch.randn(1, 32, 27, 16),
|
||||
torch.randn(1, 64, 27, 16),
|
||||
torch.randn(1, 128, 27, 16)
|
||||
]
|
||||
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs[-1].shape == (1, head.num_classes, 26, 16)
|
||||
|
||||
# test loss function in K-Net
|
||||
fake_label = torch.ones_like(
|
||||
outputs[-1][:, 0:1, :, :], dtype=torch.int16).long()
|
||||
loss = head.losses(seg_logit=outputs, seg_label=fake_label)
|
||||
assert loss['loss_ce.s0'] != torch.zeros_like(loss['loss_ce.s0'])
|
||||
assert loss['loss_ce.s1'] != torch.zeros_like(loss['loss_ce.s1'])
|
||||
assert loss['loss_ce.s2'] != torch.zeros_like(loss['loss_ce.s2'])
|
||||
assert loss['loss_ce.s3'] != torch.zeros_like(loss['loss_ce.s3'])
|
Loading…
Reference in New Issue