[Feature] Support HRNet and add pre-trained models (#660)
* Support HRNet * Add HRNet configs * Fix a bug in backward * Add configs and update docs. * Not use bias in conv before batch norm * Defaults to use `norm_eval=False` * Add unit tests and support out_channels in HRFuseScales * Update checkpoint path * Update docstring. * Remove incorrect files * Improve according to commentspull/679/head
parent
dc456a0c2c
commit
5de480ea9e
docs/en
mmcls/models
backbones
tests/test_models
test_backbones
|
@ -80,7 +80,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea
|
||||||
- [x] [Twins](https://github.com/open-mmlab/mmclassification/tree/master/configs/twins)
|
- [x] [Twins](https://github.com/open-mmlab/mmclassification/tree/master/configs/twins)
|
||||||
- [x] [EfficientNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientnet)
|
- [x] [EfficientNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientnet)
|
||||||
- [x] [ConvNeXt](https://github.com/open-mmlab/mmclassification/tree/master/configs/convnext)
|
- [x] [ConvNeXt](https://github.com/open-mmlab/mmclassification/tree/master/configs/convnext)
|
||||||
- [ ] HRNet
|
- [x] [HRNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hrnet)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
|
@ -79,7 +79,7 @@ MMClassification 是一款基于 PyTorch 的开源图像分类工具箱,是 [O
|
||||||
- [x] [Twins](https://github.com/open-mmlab/mmclassification/tree/master/configs/twins)
|
- [x] [Twins](https://github.com/open-mmlab/mmclassification/tree/master/configs/twins)
|
||||||
- [x] [EfficientNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientnet)
|
- [x] [EfficientNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientnet)
|
||||||
- [x] [ConvNeXt](https://github.com/open-mmlab/mmclassification/tree/master/configs/convnext)
|
- [x] [ConvNeXt](https://github.com/open-mmlab/mmclassification/tree/master/configs/convnext)
|
||||||
- [ ] HRNet
|
- [x] [HRNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hrnet)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(type='HRNet', arch='w18'),
|
||||||
|
neck=[
|
||||||
|
dict(type='HRFuseScales', in_channels=(18, 36, 72, 144)),
|
||||||
|
dict(type='GlobalAveragePooling'),
|
||||||
|
],
|
||||||
|
head=dict(
|
||||||
|
type='LinearClsHead',
|
||||||
|
in_channels=2048,
|
||||||
|
num_classes=1000,
|
||||||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||||
|
topk=(1, 5),
|
||||||
|
))
|
|
@ -0,0 +1,15 @@
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(type='HRNet', arch='w30'),
|
||||||
|
neck=[
|
||||||
|
dict(type='HRFuseScales', in_channels=(30, 60, 120, 240)),
|
||||||
|
dict(type='GlobalAveragePooling'),
|
||||||
|
],
|
||||||
|
head=dict(
|
||||||
|
type='LinearClsHead',
|
||||||
|
in_channels=2048,
|
||||||
|
num_classes=1000,
|
||||||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||||
|
topk=(1, 5),
|
||||||
|
))
|
|
@ -0,0 +1,15 @@
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(type='HRNet', arch='w32'),
|
||||||
|
neck=[
|
||||||
|
dict(type='HRFuseScales', in_channels=(32, 64, 128, 256)),
|
||||||
|
dict(type='GlobalAveragePooling'),
|
||||||
|
],
|
||||||
|
head=dict(
|
||||||
|
type='LinearClsHead',
|
||||||
|
in_channels=2048,
|
||||||
|
num_classes=1000,
|
||||||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||||
|
topk=(1, 5),
|
||||||
|
))
|
|
@ -0,0 +1,15 @@
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(type='HRNet', arch='w40'),
|
||||||
|
neck=[
|
||||||
|
dict(type='HRFuseScales', in_channels=(40, 80, 160, 320)),
|
||||||
|
dict(type='GlobalAveragePooling'),
|
||||||
|
],
|
||||||
|
head=dict(
|
||||||
|
type='LinearClsHead',
|
||||||
|
in_channels=2048,
|
||||||
|
num_classes=1000,
|
||||||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||||
|
topk=(1, 5),
|
||||||
|
))
|
|
@ -0,0 +1,15 @@
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(type='HRNet', arch='w44'),
|
||||||
|
neck=[
|
||||||
|
dict(type='HRFuseScales', in_channels=(44, 88, 176, 352)),
|
||||||
|
dict(type='GlobalAveragePooling'),
|
||||||
|
],
|
||||||
|
head=dict(
|
||||||
|
type='LinearClsHead',
|
||||||
|
in_channels=2048,
|
||||||
|
num_classes=1000,
|
||||||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||||
|
topk=(1, 5),
|
||||||
|
))
|
|
@ -0,0 +1,15 @@
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(type='HRNet', arch='w48'),
|
||||||
|
neck=[
|
||||||
|
dict(type='HRFuseScales', in_channels=(48, 96, 192, 384)),
|
||||||
|
dict(type='GlobalAveragePooling'),
|
||||||
|
],
|
||||||
|
head=dict(
|
||||||
|
type='LinearClsHead',
|
||||||
|
in_channels=2048,
|
||||||
|
num_classes=1000,
|
||||||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||||
|
topk=(1, 5),
|
||||||
|
))
|
|
@ -0,0 +1,15 @@
|
||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(type='HRNet', arch='w64'),
|
||||||
|
neck=[
|
||||||
|
dict(type='HRFuseScales', in_channels=(64, 128, 256, 512)),
|
||||||
|
dict(type='GlobalAveragePooling'),
|
||||||
|
],
|
||||||
|
head=dict(
|
||||||
|
type='LinearClsHead',
|
||||||
|
in_channels=2048,
|
||||||
|
num_classes=1000,
|
||||||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||||
|
topk=(1, 5),
|
||||||
|
))
|
|
@ -0,0 +1,43 @@
|
||||||
|
# HRNet
|
||||||
|
|
||||||
|
> [Deep High-Resolution Representation Learning for Visual Recognition](https://arxiv.org/abs/1908.07919v2)
|
||||||
|
<!-- [ALGORITHM] -->
|
||||||
|
|
||||||
|
## Abstract
|
||||||
|
|
||||||
|
High-resolution representations are essential for position-sensitive vision problems, such as human pose estimation, semantic segmentation, and object detection. Existing state-of-the-art frameworks first encode the input image as a low-resolution representation through a subnetwork that is formed by connecting high-to-low resolution convolutions *in series* (e.g., ResNet, VGGNet), and then recover the high-resolution representation from the encoded low-resolution representation. Instead, our proposed network, named as High-Resolution Network (HRNet), maintains high-resolution representations through the whole process. There are two key characteristics: (i) Connect the high-to-low resolution convolution streams *in parallel*; (ii) Repeatedly exchange the information across resolutions. The benefit is that the resulting representation is semantically richer and spatially more precise. We show the superiority of the proposed HRNet in a wide range of applications, including human pose estimation, semantic segmentation, and object detection, suggesting that the HRNet is a stronger backbone for computer vision problems.
|
||||||
|
|
||||||
|
<div align=center>
|
||||||
|
<img src="https://user-images.githubusercontent.com/26739999/149920446-cbe05670-989d-4fe6-accc-df20ae2984eb.png" width="100%"/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
## Results and models
|
||||||
|
|
||||||
|
## ImageNet-1k
|
||||||
|
|
||||||
|
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||||
|
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
|
||||||
|
| HRNet-W18\* | 21.30 | 4.33 | 76.75 | 93.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w18_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w18_3rdparty_8xb32_in1k_20220120-0c10b180.pth) |
|
||||||
|
| HRNet-W30\* | 37.71 | 8.17 | 78.19 | 94.22 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w30_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w30_3rdparty_8xb32_in1k_20220120-8aa3832f.pth) |
|
||||||
|
| HRNet-W32\* | 41.23 | 8.99 | 78.44 | 94.19 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w32_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w32_3rdparty_8xb32_in1k_20220120-c394f1ab.pth) |
|
||||||
|
| HRNet-W40\* | 57.55 | 12.77 | 78.94 | 94.47 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w40_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w40_3rdparty_8xb32_in1k_20220120-9a2dbfc5.pth) |
|
||||||
|
| HRNet-W44\* | 67.06 | 14.96 | 78.88 | 94.37 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w44_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w44_3rdparty_8xb32_in1k_20220120-35d07f73.pth) |
|
||||||
|
| HRNet-W48\* | 77.47 | 17.36 | 79.32 | 94.52 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w48_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w48_3rdparty_8xb32_in1k_20220120-e555ef50.pth) |
|
||||||
|
| HRNet-W64\* | 128.06 | 29.00 | 79.46 | 94.65 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w64_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w64_3rdparty_8xb32_in1k_20220120-19126642.pth) |
|
||||||
|
| HRNet-W18 (ssld)\* | 21.30 | 4.33 | 81.06 | 95.70 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w18_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w18_3rdparty_8xb32-ssld_in1k_20220120-455f69ea.pth) |
|
||||||
|
| HRNet-W48 (ssld)\* | 77.47 | 17.36 | 83.63 | 96.79 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w48_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w48_3rdparty_8xb32-ssld_in1k_20220120-d0459c38.pth) |
|
||||||
|
|
||||||
|
*Models with \* are converted from the [official repo](https://github.com/HRNet/HRNet-Image-Classification). The config files of these models are only for inference. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
```
|
||||||
|
@article{WangSCJDZLMTWLX19,
|
||||||
|
title={Deep High-Resolution Representation Learning for Visual Recognition},
|
||||||
|
author={Jingdong Wang and Ke Sun and Tianheng Cheng and
|
||||||
|
Borui Jiang and Chaorui Deng and Yang Zhao and Dong Liu and Yadong Mu and
|
||||||
|
Mingkui Tan and Xinggang Wang and Wenyu Liu and Bin Xiao},
|
||||||
|
journal = {TPAMI}
|
||||||
|
year={2019}
|
||||||
|
}
|
||||||
|
```
|
|
@ -0,0 +1,6 @@
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/models/hrnet/hrnet-w18.py',
|
||||||
|
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||||
|
'../_base_/schedules/imagenet_bs256_coslr.py',
|
||||||
|
'../_base_/default_runtime.py'
|
||||||
|
]
|
|
@ -0,0 +1,6 @@
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/models/hrnet/hrnet-w30.py',
|
||||||
|
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||||
|
'../_base_/schedules/imagenet_bs256_coslr.py',
|
||||||
|
'../_base_/default_runtime.py'
|
||||||
|
]
|
|
@ -0,0 +1,6 @@
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/models/hrnet/hrnet-w32.py',
|
||||||
|
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||||
|
'../_base_/schedules/imagenet_bs256_coslr.py',
|
||||||
|
'../_base_/default_runtime.py'
|
||||||
|
]
|
|
@ -0,0 +1,6 @@
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/models/hrnet/hrnet-w40.py',
|
||||||
|
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||||
|
'../_base_/schedules/imagenet_bs256_coslr.py',
|
||||||
|
'../_base_/default_runtime.py'
|
||||||
|
]
|
|
@ -0,0 +1,6 @@
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/models/hrnet/hrnet-w44.py',
|
||||||
|
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||||
|
'../_base_/schedules/imagenet_bs256_coslr.py',
|
||||||
|
'../_base_/default_runtime.py'
|
||||||
|
]
|
|
@ -0,0 +1,6 @@
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/models/hrnet/hrnet-w48.py',
|
||||||
|
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||||
|
'../_base_/schedules/imagenet_bs256_coslr.py',
|
||||||
|
'../_base_/default_runtime.py'
|
||||||
|
]
|
|
@ -0,0 +1,6 @@
|
||||||
|
_base_ = [
|
||||||
|
'../_base_/models/hrnet/hrnet-w64.py',
|
||||||
|
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||||
|
'../_base_/schedules/imagenet_bs256_coslr.py',
|
||||||
|
'../_base_/default_runtime.py'
|
||||||
|
]
|
|
@ -0,0 +1,162 @@
|
||||||
|
Collections:
|
||||||
|
- Name: HRNet
|
||||||
|
Metadata:
|
||||||
|
Training Data: ImageNet-1k
|
||||||
|
Architecture:
|
||||||
|
- Batch Normalization
|
||||||
|
- Convolution
|
||||||
|
- ReLU
|
||||||
|
- Residual Connection
|
||||||
|
Paper:
|
||||||
|
URL: https://arxiv.org/abs/1908.07919v2
|
||||||
|
Title: "Deep High-Resolution Representation Learning for Visual Recognition"
|
||||||
|
README: configs/hrnet/README.md
|
||||||
|
Code:
|
||||||
|
URL: https://github.com/open-mmlab/mmclassification/blob/v0.20.0/mmcls/models/backbones/hrnet.py
|
||||||
|
Version: v0.20.0
|
||||||
|
|
||||||
|
Models:
|
||||||
|
- Name: hrnet-w18_3rdparty_8xb32_in1k
|
||||||
|
Metadata:
|
||||||
|
FLOPs: 4330397932
|
||||||
|
Parameters: 21295164
|
||||||
|
In Collection: HRNet
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 76.75
|
||||||
|
Top 5 Accuracy: 93.44
|
||||||
|
Task: Image Classification
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w18_3rdparty_8xb32_in1k_20220120-0c10b180.pth
|
||||||
|
Config: configs/hrnet/hrnet-w18_4xb32_in1k.py
|
||||||
|
Converted From:
|
||||||
|
Weights: https://1drv.ms/u/s!Aus8VCZ_C_33cMkPimlmClRvmpw
|
||||||
|
Code: https://github.com/HRNet/HRNet-Image-Classification
|
||||||
|
- Name: hrnet-w30_3rdparty_8xb32_in1k
|
||||||
|
Metadata:
|
||||||
|
FLOPs: 8168305684
|
||||||
|
Parameters: 37708380
|
||||||
|
In Collection: HRNet
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 78.19
|
||||||
|
Top 5 Accuracy: 94.22
|
||||||
|
Task: Image Classification
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w30_3rdparty_8xb32_in1k_20220120-8aa3832f.pth
|
||||||
|
Config: configs/hrnet/hrnet-w30_4xb32_in1k.py
|
||||||
|
Converted From:
|
||||||
|
Weights: https://1drv.ms/u/s!Aus8VCZ_C_33cQoACCEfrzcSaVI
|
||||||
|
Code: https://github.com/HRNet/HRNet-Image-Classification
|
||||||
|
- Name: hrnet-w32_3rdparty_8xb32_in1k
|
||||||
|
Metadata:
|
||||||
|
FLOPs: 8986267584
|
||||||
|
Parameters: 41228840
|
||||||
|
In Collection: HRNet
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 78.44
|
||||||
|
Top 5 Accuracy: 94.19
|
||||||
|
Task: Image Classification
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w32_3rdparty_8xb32_in1k_20220120-c394f1ab.pth
|
||||||
|
Config: configs/hrnet/hrnet-w32_4xb32_in1k.py
|
||||||
|
Converted From:
|
||||||
|
Weights: https://1drv.ms/u/s!Aus8VCZ_C_33dYBMemi9xOUFR0w
|
||||||
|
Code: https://github.com/HRNet/HRNet-Image-Classification
|
||||||
|
- Name: hrnet-w40_3rdparty_8xb32_in1k
|
||||||
|
Metadata:
|
||||||
|
FLOPs: 12767574064
|
||||||
|
Parameters: 57553320
|
||||||
|
In Collection: HRNet
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 78.94
|
||||||
|
Top 5 Accuracy: 94.47
|
||||||
|
Task: Image Classification
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w40_3rdparty_8xb32_in1k_20220120-9a2dbfc5.pth
|
||||||
|
Config: configs/hrnet/hrnet-w40_4xb32_in1k.py
|
||||||
|
Converted From:
|
||||||
|
Weights: https://1drv.ms/u/s!Aus8VCZ_C_33ck0gvo5jfoWBOPo
|
||||||
|
Code: https://github.com/HRNet/HRNet-Image-Classification
|
||||||
|
- Name: hrnet-w44_3rdparty_8xb32_in1k
|
||||||
|
Metadata:
|
||||||
|
FLOPs: 14963902632
|
||||||
|
Parameters: 67061144
|
||||||
|
In Collection: HRNet
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 78.88
|
||||||
|
Top 5 Accuracy: 94.37
|
||||||
|
Task: Image Classification
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w44_3rdparty_8xb32_in1k_20220120-35d07f73.pth
|
||||||
|
Config: configs/hrnet/hrnet-w44_4xb32_in1k.py
|
||||||
|
Converted From:
|
||||||
|
Weights: https://1drv.ms/u/s!Aus8VCZ_C_33czZQ0woUb980gRs
|
||||||
|
Code: https://github.com/HRNet/HRNet-Image-Classification
|
||||||
|
- Name: hrnet-w48_3rdparty_8xb32_in1k
|
||||||
|
Metadata:
|
||||||
|
FLOPs: 17364014752
|
||||||
|
Parameters: 77466024
|
||||||
|
In Collection: HRNet
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 79.32
|
||||||
|
Top 5 Accuracy: 94.52
|
||||||
|
Task: Image Classification
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w48_3rdparty_8xb32_in1k_20220120-e555ef50.pth
|
||||||
|
Config: configs/hrnet/hrnet-w48_4xb32_in1k.py
|
||||||
|
Converted From:
|
||||||
|
Weights: https://1drv.ms/u/s!Aus8VCZ_C_33dKvqI6pBZlifgJk
|
||||||
|
Code: https://github.com/HRNet/HRNet-Image-Classification
|
||||||
|
- Name: hrnet-w64_3rdparty_8xb32_in1k
|
||||||
|
Metadata:
|
||||||
|
FLOPs: 29002298752
|
||||||
|
Parameters: 128056104
|
||||||
|
In Collection: HRNet
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 79.46
|
||||||
|
Top 5 Accuracy: 94.65
|
||||||
|
Task: Image Classification
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w64_3rdparty_8xb32_in1k_20220120-19126642.pth
|
||||||
|
Config: configs/hrnet/hrnet-w64_4xb32_in1k.py
|
||||||
|
Converted From:
|
||||||
|
Weights: https://1drv.ms/u/s!Aus8VCZ_C_33gQbJsUPTIj3rQu99
|
||||||
|
Code: https://github.com/HRNet/HRNet-Image-Classification
|
||||||
|
- Name: hrnet-w18_3rdparty_8xb32-ssld_in1k
|
||||||
|
Metadata:
|
||||||
|
FLOPs: 4330397932
|
||||||
|
Parameters: 21295164
|
||||||
|
In Collection: HRNet
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 81.06
|
||||||
|
Top 5 Accuracy: 95.7
|
||||||
|
Task: Image Classification
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w18_3rdparty_8xb32-ssld_in1k_20220120-455f69ea.pth
|
||||||
|
Config: configs/hrnet/hrnet-w18_4xb32_in1k.py
|
||||||
|
Converted From:
|
||||||
|
Weights: https://github.com/HRNet/HRNet-Image-Classification/releases/download/PretrainedWeights/HRNet_W18_C_ssld_pretrained.pth
|
||||||
|
Code: https://github.com/HRNet/HRNet-Image-Classification
|
||||||
|
- Name: hrnet-w48_3rdparty_8xb32-ssld_in1k
|
||||||
|
Metadata:
|
||||||
|
FLOPs: 17364014752
|
||||||
|
Parameters: 77466024
|
||||||
|
In Collection: HRNet
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 83.63
|
||||||
|
Top 5 Accuracy: 96.79
|
||||||
|
Task: Image Classification
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w48_3rdparty_8xb32-ssld_in1k_20220120-d0459c38.pth
|
||||||
|
Config: configs/hrnet/hrnet-w48_4xb32_in1k.py
|
||||||
|
Converted From:
|
||||||
|
Weights: https://github.com/HRNet/HRNet-Image-Classification/releases/download/PretrainedWeights/HRNet_W48_C_ssld_pretrained.pth
|
||||||
|
Code: https://github.com/HRNet/HRNet-Image-Classification
|
|
@ -119,6 +119,15 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
|
||||||
| ConvNeXt-L\* | 197.77 | 34.37 | 84.30 | 96.89 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convnext/convnext-large_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_3rdparty_64xb64_in1k_20220124-f8a0ded0.pth) |
|
| ConvNeXt-L\* | 197.77 | 34.37 | 84.30 | 96.89 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convnext/convnext-large_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_3rdparty_64xb64_in1k_20220124-f8a0ded0.pth) |
|
||||||
| ConvNeXt-L\* | 197.77 | 34.37 | 86.61 | 98.04 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convnext/convnext-large_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_in21k-pre-3rdparty_64xb64_in1k_20220124-2412403d.pth) |
|
| ConvNeXt-L\* | 197.77 | 34.37 | 86.61 | 98.04 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convnext/convnext-large_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_in21k-pre-3rdparty_64xb64_in1k_20220124-2412403d.pth) |
|
||||||
| ConvNeXt-XL\* | 350.20 | 60.93 | 86.97 | 98.20 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convnext/convnext-xlarge_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-xlarge_in21k-pre-3rdparty_64xb64_in1k_20220124-76b6863d.pth) |
|
| ConvNeXt-XL\* | 350.20 | 60.93 | 86.97 | 98.20 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convnext/convnext-xlarge_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-xlarge_in21k-pre-3rdparty_64xb64_in1k_20220124-76b6863d.pth) |
|
||||||
|
| HRNet-W18\* | 21.30 | 4.33 | 76.75 | 93.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w18_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w18_3rdparty_8xb32_in1k_20220120-0c10b180.pth) |
|
||||||
|
| HRNet-W30\* | 37.71 | 8.17 | 78.19 | 94.22 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w30_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w30_3rdparty_8xb32_in1k_20220120-8aa3832f.pth) |
|
||||||
|
| HRNet-W32\* | 41.23 | 8.99 | 78.44 | 94.19 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w32_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w32_3rdparty_8xb32_in1k_20220120-c394f1ab.pth) |
|
||||||
|
| HRNet-W40\* | 57.55 | 12.77 | 78.94 | 94.47 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w40_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w40_3rdparty_8xb32_in1k_20220120-9a2dbfc5.pth) |
|
||||||
|
| HRNet-W44\* | 67.06 | 14.96 | 78.88 | 94.37 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w44_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w44_3rdparty_8xb32_in1k_20220120-35d07f73.pth) |
|
||||||
|
| HRNet-W48\* | 77.47 | 17.36 | 79.32 | 94.52 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w48_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w48_3rdparty_8xb32_in1k_20220120-e555ef50.pth) |
|
||||||
|
| HRNet-W64\* | 128.06 | 29.00 | 79.46 | 94.65 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w64_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w64_3rdparty_8xb32_in1k_20220120-19126642.pth) |
|
||||||
|
| HRNet-W18 (ssld)\* | 21.30 | 4.33 | 81.06 | 95.70 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w18_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w18_3rdparty_8xb32-ssld_in1k_20220120-455f69ea.pth) |
|
||||||
|
| HRNet-W48 (ssld)\* | 77.47 | 17.36 | 83.63 | 96.79 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w48_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w48_3rdparty_8xb32-ssld_in1k_20220120-d0459c38.pth) |
|
||||||
|
|
||||||
*Models with \* are converted from other repos, others are trained by ourselves.*
|
*Models with \* are converted from other repos, others are trained by ourselves.*
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ from .conformer import Conformer
|
||||||
from .convnext import ConvNeXt
|
from .convnext import ConvNeXt
|
||||||
from .deit import DistilledVisionTransformer
|
from .deit import DistilledVisionTransformer
|
||||||
from .efficientnet import EfficientNet
|
from .efficientnet import EfficientNet
|
||||||
|
from .hrnet import HRNet
|
||||||
from .lenet import LeNet5
|
from .lenet import LeNet5
|
||||||
from .mlp_mixer import MlpMixer
|
from .mlp_mixer import MlpMixer
|
||||||
from .mobilenet_v2 import MobileNetV2
|
from .mobilenet_v2 import MobileNetV2
|
||||||
|
@ -33,5 +34,5 @@ __all__ = [
|
||||||
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
|
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
|
||||||
'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG',
|
'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG',
|
||||||
'Conformer', 'MlpMixer', 'DistilledVisionTransformer', 'PCPVT', 'SVT',
|
'Conformer', 'MlpMixer', 'DistilledVisionTransformer', 'PCPVT', 'SVT',
|
||||||
'EfficientNet', 'ConvNeXt'
|
'EfficientNet', 'ConvNeXt', 'HRNet'
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,563 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import torch.nn as nn
|
||||||
|
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||||
|
from mmcv.runner import BaseModule, ModuleList, Sequential
|
||||||
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
|
from ..builder import BACKBONES
|
||||||
|
from .resnet import BasicBlock, Bottleneck, ResLayer, get_expansion
|
||||||
|
|
||||||
|
|
||||||
|
class HRModule(BaseModule):
|
||||||
|
"""High-Resolution Module for HRNet.
|
||||||
|
|
||||||
|
In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
|
||||||
|
is in this module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_branches (int): The number of branches.
|
||||||
|
block (``BaseModule``): Convolution block module.
|
||||||
|
num_blocks (tuple): The number of blocks in each branch.
|
||||||
|
The length must be equal to ``num_branches``.
|
||||||
|
num_channels (tuple): The number of base channels in each branch.
|
||||||
|
The length must be equal to ``num_branches``.
|
||||||
|
multiscale_output (bool): Whether to output multi-level features
|
||||||
|
produced by multiple branches. If False, only the first level
|
||||||
|
feature will be output. Defaults to True.
|
||||||
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||||
|
memory while slowing down the training speed. Defaults to False.
|
||||||
|
conv_cfg (dict, optional): Dictionary to construct and config conv
|
||||||
|
layer. Defaults to None.
|
||||||
|
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||||
|
Defaults to ``dict(type='BN')``.
|
||||||
|
block_init_cfg (dict, optional): The initialization configs of every
|
||||||
|
blocks. Defaults to None.
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_branches,
|
||||||
|
block,
|
||||||
|
num_blocks,
|
||||||
|
in_channels,
|
||||||
|
num_channels,
|
||||||
|
multiscale_output=True,
|
||||||
|
with_cp=False,
|
||||||
|
conv_cfg=None,
|
||||||
|
norm_cfg=dict(type='BN'),
|
||||||
|
block_init_cfg=None,
|
||||||
|
init_cfg=None):
|
||||||
|
super(HRModule, self).__init__(init_cfg)
|
||||||
|
self.block_init_cfg = block_init_cfg
|
||||||
|
self._check_branches(num_branches, num_blocks, in_channels,
|
||||||
|
num_channels)
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.num_branches = num_branches
|
||||||
|
|
||||||
|
self.multiscale_output = multiscale_output
|
||||||
|
self.norm_cfg = norm_cfg
|
||||||
|
self.conv_cfg = conv_cfg
|
||||||
|
self.with_cp = with_cp
|
||||||
|
self.branches = self._make_branches(num_branches, block, num_blocks,
|
||||||
|
num_channels)
|
||||||
|
self.fuse_layers = self._make_fuse_layers()
|
||||||
|
self.relu = nn.ReLU(inplace=False)
|
||||||
|
|
||||||
|
def _check_branches(self, num_branches, num_blocks, in_channels,
|
||||||
|
num_channels):
|
||||||
|
if num_branches != len(num_blocks):
|
||||||
|
error_msg = f'NUM_BRANCHES({num_branches}) ' \
|
||||||
|
f'!= NUM_BLOCKS({len(num_blocks)})'
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
if num_branches != len(num_channels):
|
||||||
|
error_msg = f'NUM_BRANCHES({num_branches}) ' \
|
||||||
|
f'!= NUM_CHANNELS({len(num_channels)})'
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
if num_branches != len(in_channels):
|
||||||
|
error_msg = f'NUM_BRANCHES({num_branches}) ' \
|
||||||
|
f'!= NUM_INCHANNELS({len(in_channels)})'
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
||||||
|
branches = []
|
||||||
|
|
||||||
|
for i in range(num_branches):
|
||||||
|
out_channels = num_channels[i] * get_expansion(block)
|
||||||
|
branches.append(
|
||||||
|
ResLayer(
|
||||||
|
block=block,
|
||||||
|
num_blocks=num_blocks[i],
|
||||||
|
in_channels=self.in_channels[i],
|
||||||
|
out_channels=out_channels,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
with_cp=self.with_cp,
|
||||||
|
init_cfg=self.block_init_cfg,
|
||||||
|
))
|
||||||
|
|
||||||
|
return ModuleList(branches)
|
||||||
|
|
||||||
|
def _make_fuse_layers(self):
|
||||||
|
if self.num_branches == 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
num_branches = self.num_branches
|
||||||
|
in_channels = self.in_channels
|
||||||
|
fuse_layers = []
|
||||||
|
num_out_branches = num_branches if self.multiscale_output else 1
|
||||||
|
for i in range(num_out_branches):
|
||||||
|
fuse_layer = []
|
||||||
|
for j in range(num_branches):
|
||||||
|
if j > i:
|
||||||
|
# Upsample the feature maps of smaller scales.
|
||||||
|
fuse_layer.append(
|
||||||
|
nn.Sequential(
|
||||||
|
build_conv_layer(
|
||||||
|
self.conv_cfg,
|
||||||
|
in_channels[j],
|
||||||
|
in_channels[i],
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=False),
|
||||||
|
build_norm_layer(self.norm_cfg, in_channels[i])[1],
|
||||||
|
nn.Upsample(
|
||||||
|
scale_factor=2**(j - i), mode='nearest')))
|
||||||
|
elif j == i:
|
||||||
|
# Keep the feature map with the same scale.
|
||||||
|
fuse_layer.append(None)
|
||||||
|
else:
|
||||||
|
# Downsample the feature maps of larger scales.
|
||||||
|
conv_downsamples = []
|
||||||
|
for k in range(i - j):
|
||||||
|
# Use stacked convolution layers to downsample.
|
||||||
|
if k == i - j - 1:
|
||||||
|
conv_downsamples.append(
|
||||||
|
nn.Sequential(
|
||||||
|
build_conv_layer(
|
||||||
|
self.conv_cfg,
|
||||||
|
in_channels[j],
|
||||||
|
in_channels[i],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
bias=False),
|
||||||
|
build_norm_layer(self.norm_cfg,
|
||||||
|
in_channels[i])[1]))
|
||||||
|
else:
|
||||||
|
conv_downsamples.append(
|
||||||
|
nn.Sequential(
|
||||||
|
build_conv_layer(
|
||||||
|
self.conv_cfg,
|
||||||
|
in_channels[j],
|
||||||
|
in_channels[j],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
bias=False),
|
||||||
|
build_norm_layer(self.norm_cfg,
|
||||||
|
in_channels[j])[1],
|
||||||
|
nn.ReLU(inplace=False)))
|
||||||
|
fuse_layer.append(nn.Sequential(*conv_downsamples))
|
||||||
|
fuse_layers.append(nn.ModuleList(fuse_layer))
|
||||||
|
|
||||||
|
return nn.ModuleList(fuse_layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward function."""
|
||||||
|
if self.num_branches == 1:
|
||||||
|
return [self.branches[0](x[0])]
|
||||||
|
|
||||||
|
for i in range(self.num_branches):
|
||||||
|
x[i] = self.branches[i](x[i])
|
||||||
|
|
||||||
|
x_fuse = []
|
||||||
|
for i in range(len(self.fuse_layers)):
|
||||||
|
y = 0
|
||||||
|
for j in range(self.num_branches):
|
||||||
|
if i == j:
|
||||||
|
y += x[j]
|
||||||
|
else:
|
||||||
|
y += self.fuse_layers[i][j](x[j])
|
||||||
|
x_fuse.append(self.relu(y))
|
||||||
|
return x_fuse
|
||||||
|
|
||||||
|
|
||||||
|
@BACKBONES.register_module()
|
||||||
|
class HRNet(BaseModule):
|
||||||
|
"""HRNet backbone.
|
||||||
|
|
||||||
|
`High-Resolution Representations for Labeling Pixels and Regions
|
||||||
|
<https://arxiv.org/abs/1904.04514>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arch (str): The preset HRNet architecture, includes 'w18', 'w30',
|
||||||
|
'w32', 'w40', 'w44', 'w48', 'w64'. It will only be used if
|
||||||
|
extra is ``None``. Defaults to 'w32'.
|
||||||
|
extra (dict, optional): Detailed configuration for each stage of HRNet.
|
||||||
|
There must be 4 stages, the configuration for each stage must have
|
||||||
|
5 keys:
|
||||||
|
|
||||||
|
- num_modules (int): The number of HRModule in this stage.
|
||||||
|
- num_branches (int): The number of branches in the HRModule.
|
||||||
|
- block (str): The type of convolution block. Please choose between
|
||||||
|
'BOTTLENECK' and 'BASIC'.
|
||||||
|
- num_blocks (tuple): The number of blocks in each branch.
|
||||||
|
The length must be equal to num_branches.
|
||||||
|
- num_channels (tuple): The number of base channels in each branch.
|
||||||
|
The length must be equal to num_branches.
|
||||||
|
|
||||||
|
Defaults to None.
|
||||||
|
in_channels (int): Number of input image channels. Defaults to 3.
|
||||||
|
conv_cfg (dict, optional): Dictionary to construct and config conv
|
||||||
|
layer. Defaults to None.
|
||||||
|
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||||
|
Defaults to ``dict(type='BN')``.
|
||||||
|
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||||
|
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||||
|
and its variants only. Defaults to False.
|
||||||
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||||
|
memory while slowing down the training speed. Defaults to False.
|
||||||
|
zero_init_residual (bool): Whether to use zero init for last norm layer
|
||||||
|
in resblocks to let them behave as identity. Defaults to False.
|
||||||
|
multiscale_output (bool): Whether to output multi-level features
|
||||||
|
produced by multiple branches. If False, only the first level
|
||||||
|
feature will be output. Defaults to True.
|
||||||
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> import torch
|
||||||
|
>>> from mmcls.models import HRNet
|
||||||
|
>>> extra = dict(
|
||||||
|
>>> stage1=dict(
|
||||||
|
>>> num_modules=1,
|
||||||
|
>>> num_branches=1,
|
||||||
|
>>> block='BOTTLENECK',
|
||||||
|
>>> num_blocks=(4, ),
|
||||||
|
>>> num_channels=(64, )),
|
||||||
|
>>> stage2=dict(
|
||||||
|
>>> num_modules=1,
|
||||||
|
>>> num_branches=2,
|
||||||
|
>>> block='BASIC',
|
||||||
|
>>> num_blocks=(4, 4),
|
||||||
|
>>> num_channels=(32, 64)),
|
||||||
|
>>> stage3=dict(
|
||||||
|
>>> num_modules=4,
|
||||||
|
>>> num_branches=3,
|
||||||
|
>>> block='BASIC',
|
||||||
|
>>> num_blocks=(4, 4, 4),
|
||||||
|
>>> num_channels=(32, 64, 128)),
|
||||||
|
>>> stage4=dict(
|
||||||
|
>>> num_modules=3,
|
||||||
|
>>> num_branches=4,
|
||||||
|
>>> block='BASIC',
|
||||||
|
>>> num_blocks=(4, 4, 4, 4),
|
||||||
|
>>> num_channels=(32, 64, 128, 256)))
|
||||||
|
>>> self = HRNet(extra, in_channels=1)
|
||||||
|
>>> self.eval()
|
||||||
|
>>> inputs = torch.rand(1, 1, 32, 32)
|
||||||
|
>>> level_outputs = self.forward(inputs)
|
||||||
|
>>> for level_out in level_outputs:
|
||||||
|
... print(tuple(level_out.shape))
|
||||||
|
(1, 32, 8, 8)
|
||||||
|
(1, 64, 4, 4)
|
||||||
|
(1, 128, 2, 2)
|
||||||
|
(1, 256, 1, 1)
|
||||||
|
"""
|
||||||
|
|
||||||
|
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
|
||||||
|
arch_zoo = {
|
||||||
|
# num_modules, num_branches, block, num_blocks, num_channels
|
||||||
|
'w18': [[1, 1, 'BOTTLENECK', (4, ), (64, )],
|
||||||
|
[1, 2, 'BASIC', (4, 4), (18, 36)],
|
||||||
|
[4, 3, 'BASIC', (4, 4, 4), (18, 36, 72)],
|
||||||
|
[3, 4, 'BASIC', (4, 4, 4, 4), (18, 36, 72, 144)]],
|
||||||
|
'w30': [[1, 1, 'BOTTLENECK', (4, ), (64, )],
|
||||||
|
[1, 2, 'BASIC', (4, 4), (30, 60)],
|
||||||
|
[4, 3, 'BASIC', (4, 4, 4), (30, 60, 120)],
|
||||||
|
[3, 4, 'BASIC', (4, 4, 4, 4), (30, 60, 120, 240)]],
|
||||||
|
'w32': [[1, 1, 'BOTTLENECK', (4, ), (64, )],
|
||||||
|
[1, 2, 'BASIC', (4, 4), (32, 64)],
|
||||||
|
[4, 3, 'BASIC', (4, 4, 4), (32, 64, 128)],
|
||||||
|
[3, 4, 'BASIC', (4, 4, 4, 4), (32, 64, 128, 256)]],
|
||||||
|
'w40': [[1, 1, 'BOTTLENECK', (4, ), (64, )],
|
||||||
|
[1, 2, 'BASIC', (4, 4), (40, 80)],
|
||||||
|
[4, 3, 'BASIC', (4, 4, 4), (40, 80, 160)],
|
||||||
|
[3, 4, 'BASIC', (4, 4, 4, 4), (40, 80, 160, 320)]],
|
||||||
|
'w44': [[1, 1, 'BOTTLENECK', (4, ), (64, )],
|
||||||
|
[1, 2, 'BASIC', (4, 4), (44, 88)],
|
||||||
|
[4, 3, 'BASIC', (4, 4, 4), (44, 88, 176)],
|
||||||
|
[3, 4, 'BASIC', (4, 4, 4, 4), (44, 88, 176, 352)]],
|
||||||
|
'w48': [[1, 1, 'BOTTLENECK', (4, ), (64, )],
|
||||||
|
[1, 2, 'BASIC', (4, 4), (48, 96)],
|
||||||
|
[4, 3, 'BASIC', (4, 4, 4), (48, 96, 192)],
|
||||||
|
[3, 4, 'BASIC', (4, 4, 4, 4), (48, 96, 192, 384)]],
|
||||||
|
'w64': [[1, 1, 'BOTTLENECK', (4, ), (64, )],
|
||||||
|
[1, 2, 'BASIC', (4, 4), (64, 128)],
|
||||||
|
[4, 3, 'BASIC', (4, 4, 4), (64, 128, 256)],
|
||||||
|
[3, 4, 'BASIC', (4, 4, 4, 4), (64, 128, 256, 512)]],
|
||||||
|
} # yapf:disable
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
arch='w32',
|
||||||
|
extra=None,
|
||||||
|
in_channels=3,
|
||||||
|
conv_cfg=None,
|
||||||
|
norm_cfg=dict(type='BN'),
|
||||||
|
norm_eval=False,
|
||||||
|
with_cp=False,
|
||||||
|
zero_init_residual=False,
|
||||||
|
multiscale_output=True,
|
||||||
|
init_cfg=[
|
||||||
|
dict(type='Kaiming', layer='Conv2d'),
|
||||||
|
dict(
|
||||||
|
type='Constant',
|
||||||
|
val=1,
|
||||||
|
layer=['_BatchNorm', 'GroupNorm'])
|
||||||
|
]):
|
||||||
|
super(HRNet, self).__init__(init_cfg)
|
||||||
|
|
||||||
|
extra = self.parse_arch(arch, extra)
|
||||||
|
|
||||||
|
# Assert configurations of 4 stages are in extra
|
||||||
|
for i in range(1, 5):
|
||||||
|
assert f'stage{i}' in extra, f'Missing stage{i} config in "extra".'
|
||||||
|
# Assert whether the length of `num_blocks` and `num_channels` are
|
||||||
|
# equal to `num_branches`
|
||||||
|
cfg = extra[f'stage{i}']
|
||||||
|
assert len(cfg['num_blocks']) == cfg['num_branches'] and \
|
||||||
|
len(cfg['num_channels']) == cfg['num_branches']
|
||||||
|
|
||||||
|
self.extra = extra
|
||||||
|
self.conv_cfg = conv_cfg
|
||||||
|
self.norm_cfg = norm_cfg
|
||||||
|
self.norm_eval = norm_eval
|
||||||
|
self.with_cp = with_cp
|
||||||
|
self.zero_init_residual = zero_init_residual
|
||||||
|
|
||||||
|
# -------------------- stem net --------------------
|
||||||
|
self.conv1 = build_conv_layer(
|
||||||
|
self.conv_cfg,
|
||||||
|
in_channels,
|
||||||
|
out_channels=64,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
bias=False)
|
||||||
|
|
||||||
|
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
|
||||||
|
self.add_module(self.norm1_name, norm1)
|
||||||
|
|
||||||
|
self.conv2 = build_conv_layer(
|
||||||
|
self.conv_cfg,
|
||||||
|
in_channels=64,
|
||||||
|
out_channels=64,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
bias=False)
|
||||||
|
|
||||||
|
self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
|
||||||
|
self.add_module(self.norm2_name, norm2)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
# -------------------- stage 1 --------------------
|
||||||
|
self.stage1_cfg = self.extra['stage1']
|
||||||
|
base_channels = self.stage1_cfg['num_channels']
|
||||||
|
block_type = self.stage1_cfg['block']
|
||||||
|
num_blocks = self.stage1_cfg['num_blocks']
|
||||||
|
|
||||||
|
block = self.blocks_dict[block_type]
|
||||||
|
num_channels = [
|
||||||
|
channel * get_expansion(block) for channel in base_channels
|
||||||
|
]
|
||||||
|
# To align with the original code, use layer1 instead of stage1 here.
|
||||||
|
self.layer1 = ResLayer(
|
||||||
|
block,
|
||||||
|
in_channels=64,
|
||||||
|
out_channels=num_channels[0],
|
||||||
|
num_blocks=num_blocks[0])
|
||||||
|
pre_num_channels = num_channels
|
||||||
|
|
||||||
|
# -------------------- stage 2~4 --------------------
|
||||||
|
for i in range(2, 5):
|
||||||
|
stage_cfg = self.extra[f'stage{i}']
|
||||||
|
base_channels = stage_cfg['num_channels']
|
||||||
|
block = self.blocks_dict[stage_cfg['block']]
|
||||||
|
multiscale_output_ = multiscale_output if i == 4 else True
|
||||||
|
|
||||||
|
num_channels = [
|
||||||
|
channel * get_expansion(block) for channel in base_channels
|
||||||
|
]
|
||||||
|
# The transition layer from layer1 to stage2
|
||||||
|
transition = self._make_transition_layer(pre_num_channels,
|
||||||
|
num_channels)
|
||||||
|
self.add_module(f'transition{i-1}', transition)
|
||||||
|
stage = self._make_stage(
|
||||||
|
stage_cfg, num_channels, multiscale_output=multiscale_output_)
|
||||||
|
self.add_module(f'stage{i}', stage)
|
||||||
|
|
||||||
|
pre_num_channels = num_channels
|
||||||
|
|
||||||
|
@property
|
||||||
|
def norm1(self):
|
||||||
|
"""nn.Module: the normalization layer named "norm1" """
|
||||||
|
return getattr(self, self.norm1_name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def norm2(self):
|
||||||
|
"""nn.Module: the normalization layer named "norm2" """
|
||||||
|
return getattr(self, self.norm2_name)
|
||||||
|
|
||||||
|
def _make_transition_layer(self, num_channels_pre_layer,
|
||||||
|
num_channels_cur_layer):
|
||||||
|
num_branches_cur = len(num_channels_cur_layer)
|
||||||
|
num_branches_pre = len(num_channels_pre_layer)
|
||||||
|
|
||||||
|
transition_layers = []
|
||||||
|
for i in range(num_branches_cur):
|
||||||
|
if i < num_branches_pre:
|
||||||
|
# For existing scale branches,
|
||||||
|
# add conv block when the channels are not the same.
|
||||||
|
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
||||||
|
transition_layers.append(
|
||||||
|
nn.Sequential(
|
||||||
|
build_conv_layer(
|
||||||
|
self.conv_cfg,
|
||||||
|
num_channels_pre_layer[i],
|
||||||
|
num_channels_cur_layer[i],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias=False),
|
||||||
|
build_norm_layer(self.norm_cfg,
|
||||||
|
num_channels_cur_layer[i])[1],
|
||||||
|
nn.ReLU(inplace=True)))
|
||||||
|
else:
|
||||||
|
transition_layers.append(nn.Identity())
|
||||||
|
else:
|
||||||
|
# For new scale branches, add stacked downsample conv blocks.
|
||||||
|
# For example, num_branches_pre = 2, for the 4th branch, add
|
||||||
|
# stacked two downsample conv blocks.
|
||||||
|
conv_downsamples = []
|
||||||
|
for j in range(i + 1 - num_branches_pre):
|
||||||
|
in_channels = num_channels_pre_layer[-1]
|
||||||
|
out_channels = num_channels_cur_layer[i] \
|
||||||
|
if j == i - num_branches_pre else in_channels
|
||||||
|
conv_downsamples.append(
|
||||||
|
nn.Sequential(
|
||||||
|
build_conv_layer(
|
||||||
|
self.conv_cfg,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
bias=False),
|
||||||
|
build_norm_layer(self.norm_cfg, out_channels)[1],
|
||||||
|
nn.ReLU(inplace=True)))
|
||||||
|
transition_layers.append(nn.Sequential(*conv_downsamples))
|
||||||
|
|
||||||
|
return nn.ModuleList(transition_layers)
|
||||||
|
|
||||||
|
def _make_stage(self, layer_config, in_channels, multiscale_output=True):
|
||||||
|
num_modules = layer_config['num_modules']
|
||||||
|
num_branches = layer_config['num_branches']
|
||||||
|
num_blocks = layer_config['num_blocks']
|
||||||
|
num_channels = layer_config['num_channels']
|
||||||
|
block = self.blocks_dict[layer_config['block']]
|
||||||
|
|
||||||
|
hr_modules = []
|
||||||
|
block_init_cfg = None
|
||||||
|
if self.zero_init_residual:
|
||||||
|
if block is BasicBlock:
|
||||||
|
block_init_cfg = dict(
|
||||||
|
type='Constant', val=0, override=dict(name='norm2'))
|
||||||
|
elif block is Bottleneck:
|
||||||
|
block_init_cfg = dict(
|
||||||
|
type='Constant', val=0, override=dict(name='norm3'))
|
||||||
|
|
||||||
|
for i in range(num_modules):
|
||||||
|
# multi_scale_output is only used for the last module
|
||||||
|
if not multiscale_output and i == num_modules - 1:
|
||||||
|
reset_multiscale_output = False
|
||||||
|
else:
|
||||||
|
reset_multiscale_output = True
|
||||||
|
|
||||||
|
hr_modules.append(
|
||||||
|
HRModule(
|
||||||
|
num_branches,
|
||||||
|
block,
|
||||||
|
num_blocks,
|
||||||
|
in_channels,
|
||||||
|
num_channels,
|
||||||
|
reset_multiscale_output,
|
||||||
|
with_cp=self.with_cp,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
block_init_cfg=block_init_cfg))
|
||||||
|
|
||||||
|
return Sequential(*hr_modules)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward function."""
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.layer1(x)
|
||||||
|
|
||||||
|
x_list = [x]
|
||||||
|
|
||||||
|
for i in range(2, 5):
|
||||||
|
# Apply transition
|
||||||
|
transition = getattr(self, f'transition{i-1}')
|
||||||
|
inputs = []
|
||||||
|
for j, layer in enumerate(transition):
|
||||||
|
if j < len(x_list):
|
||||||
|
inputs.append(layer(x_list[j]))
|
||||||
|
else:
|
||||||
|
inputs.append(layer(x_list[-1]))
|
||||||
|
# Forward HRModule
|
||||||
|
stage = getattr(self, f'stage{i}')
|
||||||
|
x_list = stage(inputs)
|
||||||
|
|
||||||
|
return tuple(x_list)
|
||||||
|
|
||||||
|
def train(self, mode=True):
|
||||||
|
"""Convert the model into training mode will keeping the normalization
|
||||||
|
layer freezed."""
|
||||||
|
super(HRNet, self).train(mode)
|
||||||
|
if mode and self.norm_eval:
|
||||||
|
for m in self.modules():
|
||||||
|
# trick: eval have effect on BatchNorm only
|
||||||
|
if isinstance(m, _BatchNorm):
|
||||||
|
m.eval()
|
||||||
|
|
||||||
|
def parse_arch(self, arch, extra=None):
|
||||||
|
if extra is not None:
|
||||||
|
return extra
|
||||||
|
|
||||||
|
assert arch in self.arch_zoo, \
|
||||||
|
('Invalid arch, please choose arch from '
|
||||||
|
f'{list(self.arch_zoo.keys())}, or specify `extra` '
|
||||||
|
'argument directly.')
|
||||||
|
|
||||||
|
extra = dict()
|
||||||
|
for i, stage_setting in enumerate(self.arch_zoo[arch], start=1):
|
||||||
|
extra[f'stage{i}'] = dict(
|
||||||
|
num_modules=stage_setting[0],
|
||||||
|
num_branches=stage_setting[1],
|
||||||
|
block=stage_setting[2],
|
||||||
|
num_blocks=stage_setting[3],
|
||||||
|
num_channels=stage_setting[4],
|
||||||
|
)
|
||||||
|
|
||||||
|
return extra
|
|
@ -5,6 +5,7 @@ import torch.utils.checkpoint as cp
|
||||||
from mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
|
from mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
|
||||||
constant_init)
|
constant_init)
|
||||||
from mmcv.cnn.bricks import DropPath
|
from mmcv.cnn.bricks import DropPath
|
||||||
|
from mmcv.runner import BaseModule
|
||||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||||
|
|
||||||
from ..builder import BACKBONES
|
from ..builder import BACKBONES
|
||||||
|
@ -13,7 +14,7 @@ from .base_backbone import BaseBackbone
|
||||||
eps = 1.0e-5
|
eps = 1.0e-5
|
||||||
|
|
||||||
|
|
||||||
class BasicBlock(nn.Module):
|
class BasicBlock(BaseModule):
|
||||||
"""BasicBlock for ResNet.
|
"""BasicBlock for ResNet.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -47,8 +48,9 @@ class BasicBlock(nn.Module):
|
||||||
with_cp=False,
|
with_cp=False,
|
||||||
conv_cfg=None,
|
conv_cfg=None,
|
||||||
norm_cfg=dict(type='BN'),
|
norm_cfg=dict(type='BN'),
|
||||||
drop_path_rate=0.0):
|
drop_path_rate=0.0,
|
||||||
super(BasicBlock, self).__init__()
|
init_cfg=None):
|
||||||
|
super(BasicBlock, self).__init__(init_cfg=init_cfg)
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.expansion = expansion
|
self.expansion = expansion
|
||||||
|
@ -130,7 +132,7 @@ class BasicBlock(nn.Module):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Bottleneck(nn.Module):
|
class Bottleneck(BaseModule):
|
||||||
"""Bottleneck block for ResNet.
|
"""Bottleneck block for ResNet.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -164,8 +166,9 @@ class Bottleneck(nn.Module):
|
||||||
with_cp=False,
|
with_cp=False,
|
||||||
conv_cfg=None,
|
conv_cfg=None,
|
||||||
norm_cfg=dict(type='BN'),
|
norm_cfg=dict(type='BN'),
|
||||||
drop_path_rate=0.0):
|
drop_path_rate=0.0,
|
||||||
super(Bottleneck, self).__init__()
|
init_cfg=None):
|
||||||
|
super(Bottleneck, self).__init__(init_cfg=init_cfg)
|
||||||
assert style in ['pytorch', 'caffe']
|
assert style in ['pytorch', 'caffe']
|
||||||
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .gap import GlobalAveragePooling
|
from .gap import GlobalAveragePooling
|
||||||
|
from .hr_fuse import HRFuseScales
|
||||||
|
|
||||||
__all__ = ['GlobalAveragePooling']
|
__all__ = ['GlobalAveragePooling', 'HRFuseScales']
|
||||||
|
|
|
@ -0,0 +1,83 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import torch.nn as nn
|
||||||
|
from mmcv.cnn.bricks import ConvModule
|
||||||
|
from mmcv.runner import BaseModule
|
||||||
|
|
||||||
|
from ..backbones.resnet import Bottleneck, ResLayer
|
||||||
|
from ..builder import NECKS
|
||||||
|
|
||||||
|
|
||||||
|
@NECKS.register_module()
|
||||||
|
class HRFuseScales(BaseModule):
|
||||||
|
"""Fuse feature map of multiple scales in HRNet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (list[int]): The input channels of all scales.
|
||||||
|
out_channels (int): The channels of fused feature map.
|
||||||
|
Defaults to 2048.
|
||||||
|
norm_cfg (dict): dictionary to construct norm layers.
|
||||||
|
Defaults to ``dict(type='BN', momentum=0.1)``.
|
||||||
|
init_cfg (dict | list[dict], optional): Initialization config dict.
|
||||||
|
Defaults to ``dict(type='Normal', layer='Linear', std=0.01))``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels=2048,
|
||||||
|
norm_cfg=dict(type='BN', momentum=0.1),
|
||||||
|
init_cfg=dict(type='Normal', layer='Linear', std=0.01)):
|
||||||
|
super(HRFuseScales, self).__init__(init_cfg=init_cfg)
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.norm_cfg = norm_cfg
|
||||||
|
|
||||||
|
block_type = Bottleneck
|
||||||
|
out_channels = [128, 256, 512, 1024]
|
||||||
|
|
||||||
|
# Increase the channels on each resolution
|
||||||
|
# from C, 2C, 4C, 8C to 128, 256, 512, 1024
|
||||||
|
increase_layers = []
|
||||||
|
for i in range(len(in_channels)):
|
||||||
|
increase_layers.append(
|
||||||
|
ResLayer(
|
||||||
|
block_type,
|
||||||
|
in_channels=in_channels[i],
|
||||||
|
out_channels=out_channels[i],
|
||||||
|
num_blocks=1,
|
||||||
|
stride=1,
|
||||||
|
))
|
||||||
|
self.increase_layers = nn.ModuleList(increase_layers)
|
||||||
|
|
||||||
|
# Downsample feature maps in each scale.
|
||||||
|
downsample_layers = []
|
||||||
|
for i in range(len(in_channels) - 1):
|
||||||
|
downsample_layers.append(
|
||||||
|
ConvModule(
|
||||||
|
in_channels=out_channels[i],
|
||||||
|
out_channels=out_channels[i + 1],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
bias=False,
|
||||||
|
))
|
||||||
|
self.downsample_layers = nn.ModuleList(downsample_layers)
|
||||||
|
|
||||||
|
# The final conv block before final classifier linear layer.
|
||||||
|
self.final_layer = ConvModule(
|
||||||
|
in_channels=out_channels[3],
|
||||||
|
out_channels=self.out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
assert isinstance(x, tuple) and len(x) == len(self.in_channels)
|
||||||
|
|
||||||
|
feat = self.increase_layers[0](x[0])
|
||||||
|
for i in range(len(self.downsample_layers)):
|
||||||
|
feat = self.downsample_layers[i](feat) + \
|
||||||
|
self.increase_layers[i + 1](x[i + 1])
|
||||||
|
|
||||||
|
return (self.final_layer(feat), )
|
|
@ -19,3 +19,4 @@ Import:
|
||||||
- configs/twins/metafile.yml
|
- configs/twins/metafile.yml
|
||||||
- configs/efficientnet/metafile.yml
|
- configs/efficientnet/metafile.yml
|
||||||
- configs/convnext/metafile.yml
|
- configs/convnext/metafile.yml
|
||||||
|
- configs/hrnet/metafile.yml
|
||||||
|
|
|
@ -0,0 +1,93 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch.nn.modules import GroupNorm
|
||||||
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
|
from mmcls.models.backbones import HRNet
|
||||||
|
|
||||||
|
|
||||||
|
def is_norm(modules):
|
||||||
|
"""Check if is one of the norms."""
|
||||||
|
if isinstance(modules, (GroupNorm, _BatchNorm)):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def check_norm_state(modules, train_state):
|
||||||
|
"""Check if norm layer is in correct train state."""
|
||||||
|
for mod in modules:
|
||||||
|
if isinstance(mod, _BatchNorm):
|
||||||
|
if mod.training != train_state:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('base_channels', [18, 30, 32, 40, 44, 48, 64])
|
||||||
|
def test_hrnet_arch_zoo(base_channels):
|
||||||
|
|
||||||
|
cfg_ori = dict(arch=f'w{base_channels}')
|
||||||
|
|
||||||
|
# Test HRNet model with input size of 224
|
||||||
|
model = HRNet(**cfg_ori)
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
assert check_norm_state(model.modules(), True)
|
||||||
|
|
||||||
|
imgs = torch.randn(3, 3, 224, 224)
|
||||||
|
outs = model(imgs)
|
||||||
|
out_channels = base_channels
|
||||||
|
out_size = 56
|
||||||
|
assert isinstance(outs, tuple)
|
||||||
|
for out in outs:
|
||||||
|
assert out.shape == (3, out_channels, out_size, out_size)
|
||||||
|
out_channels = out_channels * 2
|
||||||
|
out_size = out_size // 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_hrnet_custom_arch():
|
||||||
|
|
||||||
|
cfg_ori = dict(
|
||||||
|
extra=dict(
|
||||||
|
stage1=dict(
|
||||||
|
num_modules=1,
|
||||||
|
num_branches=1,
|
||||||
|
block='BOTTLENECK',
|
||||||
|
num_blocks=(4, ),
|
||||||
|
num_channels=(64, )),
|
||||||
|
stage2=dict(
|
||||||
|
num_modules=1,
|
||||||
|
num_branches=2,
|
||||||
|
block='BASIC',
|
||||||
|
num_blocks=(4, 4),
|
||||||
|
num_channels=(32, 64)),
|
||||||
|
stage3=dict(
|
||||||
|
num_modules=4,
|
||||||
|
num_branches=3,
|
||||||
|
block='BOTTLENECK',
|
||||||
|
num_blocks=(4, 4, 2),
|
||||||
|
num_channels=(32, 64, 128)),
|
||||||
|
stage4=dict(
|
||||||
|
num_modules=3,
|
||||||
|
num_branches=4,
|
||||||
|
block='BASIC',
|
||||||
|
num_blocks=(4, 3, 4, 4),
|
||||||
|
num_channels=(32, 64, 152, 256)),
|
||||||
|
), )
|
||||||
|
|
||||||
|
# Test HRNet model with input size of 224
|
||||||
|
model = HRNet(**cfg_ori)
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
assert check_norm_state(model.modules(), True)
|
||||||
|
|
||||||
|
imgs = torch.randn(3, 3, 224, 224)
|
||||||
|
outs = model(imgs)
|
||||||
|
out_channels = (32, 64, 152, 256)
|
||||||
|
out_size = 56
|
||||||
|
assert isinstance(outs, tuple)
|
||||||
|
for out, out_channel in zip(outs, out_channels):
|
||||||
|
assert out.shape == (3, out_channel, out_size, out_size)
|
||||||
|
out_size = out_size // 2
|
|
@ -2,7 +2,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmcls.models.necks import GlobalAveragePooling
|
from mmcls.models.necks import GlobalAveragePooling, HRFuseScales
|
||||||
|
|
||||||
|
|
||||||
def test_gap_neck():
|
def test_gap_neck():
|
||||||
|
@ -37,3 +37,24 @@ def test_gap_neck():
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
# dim must in [1, 2, 3]
|
# dim must in [1, 2, 3]
|
||||||
GlobalAveragePooling(dim='other')
|
GlobalAveragePooling(dim='other')
|
||||||
|
|
||||||
|
|
||||||
|
def test_hr_fuse_scales():
|
||||||
|
|
||||||
|
in_channels = (18, 32, 64, 128)
|
||||||
|
neck = HRFuseScales(in_channels=in_channels, out_channels=1024)
|
||||||
|
|
||||||
|
feat_size = 56
|
||||||
|
inputs = []
|
||||||
|
for in_channel in in_channels:
|
||||||
|
input_tensor = torch.rand(3, in_channel, feat_size, feat_size)
|
||||||
|
inputs.append(input_tensor)
|
||||||
|
feat_size = feat_size // 2
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
neck(inputs)
|
||||||
|
|
||||||
|
outs = neck(tuple(inputs))
|
||||||
|
assert isinstance(outs, tuple)
|
||||||
|
assert len(outs) == 1
|
||||||
|
assert outs[0].shape == (3, 1024, 7, 7)
|
||||||
|
|
Loading…
Reference in New Issue