[Feature] Support HRNet and add pre-trained models (#660)

* Support HRNet

* Add HRNet configs

* Fix a bug in backward

* Add configs and update docs.

* Not use bias in conv before batch norm

* Defaults to use `norm_eval=False`

* Add unit tests and support out_channels in HRFuseScales

* Update checkpoint path

* Update docstring.

* Remove incorrect files

* Improve according to comments
pull/679/head
Ma Zerun 2022-01-28 10:54:14 +08:00 committed by GitHub
parent dc456a0c2c
commit 5de480ea9e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 1138 additions and 11 deletions

View File

@ -80,7 +80,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea
- [x] [Twins](https://github.com/open-mmlab/mmclassification/tree/master/configs/twins)
- [x] [EfficientNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientnet)
- [x] [ConvNeXt](https://github.com/open-mmlab/mmclassification/tree/master/configs/convnext)
- [ ] HRNet
- [x] [HRNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hrnet)
</details>

View File

@ -79,7 +79,7 @@ MMClassification 是一款基于 PyTorch 的开源图像分类工具箱,是 [O
- [x] [Twins](https://github.com/open-mmlab/mmclassification/tree/master/configs/twins)
- [x] [EfficientNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientnet)
- [x] [ConvNeXt](https://github.com/open-mmlab/mmclassification/tree/master/configs/convnext)
- [ ] HRNet
- [x] [HRNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hrnet)
</details>

View File

@ -0,0 +1,15 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='HRNet', arch='w18'),
neck=[
dict(type='HRFuseScales', in_channels=(18, 36, 72, 144)),
dict(type='GlobalAveragePooling'),
],
head=dict(
type='LinearClsHead',
in_channels=2048,
num_classes=1000,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,15 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='HRNet', arch='w30'),
neck=[
dict(type='HRFuseScales', in_channels=(30, 60, 120, 240)),
dict(type='GlobalAveragePooling'),
],
head=dict(
type='LinearClsHead',
in_channels=2048,
num_classes=1000,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,15 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='HRNet', arch='w32'),
neck=[
dict(type='HRFuseScales', in_channels=(32, 64, 128, 256)),
dict(type='GlobalAveragePooling'),
],
head=dict(
type='LinearClsHead',
in_channels=2048,
num_classes=1000,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,15 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='HRNet', arch='w40'),
neck=[
dict(type='HRFuseScales', in_channels=(40, 80, 160, 320)),
dict(type='GlobalAveragePooling'),
],
head=dict(
type='LinearClsHead',
in_channels=2048,
num_classes=1000,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,15 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='HRNet', arch='w44'),
neck=[
dict(type='HRFuseScales', in_channels=(44, 88, 176, 352)),
dict(type='GlobalAveragePooling'),
],
head=dict(
type='LinearClsHead',
in_channels=2048,
num_classes=1000,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,15 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='HRNet', arch='w48'),
neck=[
dict(type='HRFuseScales', in_channels=(48, 96, 192, 384)),
dict(type='GlobalAveragePooling'),
],
head=dict(
type='LinearClsHead',
in_channels=2048,
num_classes=1000,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,15 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='HRNet', arch='w64'),
neck=[
dict(type='HRFuseScales', in_channels=(64, 128, 256, 512)),
dict(type='GlobalAveragePooling'),
],
head=dict(
type='LinearClsHead',
in_channels=2048,
num_classes=1000,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,43 @@
# HRNet
> [Deep High-Resolution Representation Learning for Visual Recognition](https://arxiv.org/abs/1908.07919v2)
<!-- [ALGORITHM] -->
## Abstract
High-resolution representations are essential for position-sensitive vision problems, such as human pose estimation, semantic segmentation, and object detection. Existing state-of-the-art frameworks first encode the input image as a low-resolution representation through a subnetwork that is formed by connecting high-to-low resolution convolutions *in series* (e.g., ResNet, VGGNet), and then recover the high-resolution representation from the encoded low-resolution representation. Instead, our proposed network, named as High-Resolution Network (HRNet), maintains high-resolution representations through the whole process. There are two key characteristics: (i) Connect the high-to-low resolution convolution streams *in parallel*; (ii) Repeatedly exchange the information across resolutions. The benefit is that the resulting representation is semantically richer and spatially more precise. We show the superiority of the proposed HRNet in a wide range of applications, including human pose estimation, semantic segmentation, and object detection, suggesting that the HRNet is a stronger backbone for computer vision problems.
<div align=center>
<img src="https://user-images.githubusercontent.com/26739999/149920446-cbe05670-989d-4fe6-accc-df20ae2984eb.png" width="100%"/>
</div>
## Results and models
## ImageNet-1k
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
| HRNet-W18\* | 21.30 | 4.33 | 76.75 | 93.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w18_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w18_3rdparty_8xb32_in1k_20220120-0c10b180.pth) |
| HRNet-W30\* | 37.71 | 8.17 | 78.19 | 94.22 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w30_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w30_3rdparty_8xb32_in1k_20220120-8aa3832f.pth) |
| HRNet-W32\* | 41.23 | 8.99 | 78.44 | 94.19 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w32_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w32_3rdparty_8xb32_in1k_20220120-c394f1ab.pth) |
| HRNet-W40\* | 57.55 | 12.77 | 78.94 | 94.47 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w40_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w40_3rdparty_8xb32_in1k_20220120-9a2dbfc5.pth) |
| HRNet-W44\* | 67.06 | 14.96 | 78.88 | 94.37 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w44_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w44_3rdparty_8xb32_in1k_20220120-35d07f73.pth) |
| HRNet-W48\* | 77.47 | 17.36 | 79.32 | 94.52 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w48_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w48_3rdparty_8xb32_in1k_20220120-e555ef50.pth) |
| HRNet-W64\* | 128.06 | 29.00 | 79.46 | 94.65 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w64_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w64_3rdparty_8xb32_in1k_20220120-19126642.pth) |
| HRNet-W18 (ssld)\* | 21.30 | 4.33 | 81.06 | 95.70 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w18_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w18_3rdparty_8xb32-ssld_in1k_20220120-455f69ea.pth) |
| HRNet-W48 (ssld)\* | 77.47 | 17.36 | 83.63 | 96.79 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w48_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w48_3rdparty_8xb32-ssld_in1k_20220120-d0459c38.pth) |
*Models with \* are converted from the [official repo](https://github.com/HRNet/HRNet-Image-Classification). The config files of these models are only for inference. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
## Citation
```
@article{WangSCJDZLMTWLX19,
title={Deep High-Resolution Representation Learning for Visual Recognition},
author={Jingdong Wang and Ke Sun and Tianheng Cheng and
Borui Jiang and Chaorui Deng and Yang Zhao and Dong Liu and Yadong Mu and
Mingkui Tan and Xinggang Wang and Wenyu Liu and Bin Xiao},
journal = {TPAMI}
year={2019}
}
```

View File

@ -0,0 +1,6 @@
_base_ = [
'../_base_/models/hrnet/hrnet-w18.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs256_coslr.py',
'../_base_/default_runtime.py'
]

View File

@ -0,0 +1,6 @@
_base_ = [
'../_base_/models/hrnet/hrnet-w30.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs256_coslr.py',
'../_base_/default_runtime.py'
]

View File

@ -0,0 +1,6 @@
_base_ = [
'../_base_/models/hrnet/hrnet-w32.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs256_coslr.py',
'../_base_/default_runtime.py'
]

View File

@ -0,0 +1,6 @@
_base_ = [
'../_base_/models/hrnet/hrnet-w40.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs256_coslr.py',
'../_base_/default_runtime.py'
]

View File

@ -0,0 +1,6 @@
_base_ = [
'../_base_/models/hrnet/hrnet-w44.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs256_coslr.py',
'../_base_/default_runtime.py'
]

View File

@ -0,0 +1,6 @@
_base_ = [
'../_base_/models/hrnet/hrnet-w48.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs256_coslr.py',
'../_base_/default_runtime.py'
]

View File

@ -0,0 +1,6 @@
_base_ = [
'../_base_/models/hrnet/hrnet-w64.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/schedules/imagenet_bs256_coslr.py',
'../_base_/default_runtime.py'
]

View File

@ -0,0 +1,162 @@
Collections:
- Name: HRNet
Metadata:
Training Data: ImageNet-1k
Architecture:
- Batch Normalization
- Convolution
- ReLU
- Residual Connection
Paper:
URL: https://arxiv.org/abs/1908.07919v2
Title: "Deep High-Resolution Representation Learning for Visual Recognition"
README: configs/hrnet/README.md
Code:
URL: https://github.com/open-mmlab/mmclassification/blob/v0.20.0/mmcls/models/backbones/hrnet.py
Version: v0.20.0
Models:
- Name: hrnet-w18_3rdparty_8xb32_in1k
Metadata:
FLOPs: 4330397932
Parameters: 21295164
In Collection: HRNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 76.75
Top 5 Accuracy: 93.44
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w18_3rdparty_8xb32_in1k_20220120-0c10b180.pth
Config: configs/hrnet/hrnet-w18_4xb32_in1k.py
Converted From:
Weights: https://1drv.ms/u/s!Aus8VCZ_C_33cMkPimlmClRvmpw
Code: https://github.com/HRNet/HRNet-Image-Classification
- Name: hrnet-w30_3rdparty_8xb32_in1k
Metadata:
FLOPs: 8168305684
Parameters: 37708380
In Collection: HRNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 78.19
Top 5 Accuracy: 94.22
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w30_3rdparty_8xb32_in1k_20220120-8aa3832f.pth
Config: configs/hrnet/hrnet-w30_4xb32_in1k.py
Converted From:
Weights: https://1drv.ms/u/s!Aus8VCZ_C_33cQoACCEfrzcSaVI
Code: https://github.com/HRNet/HRNet-Image-Classification
- Name: hrnet-w32_3rdparty_8xb32_in1k
Metadata:
FLOPs: 8986267584
Parameters: 41228840
In Collection: HRNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 78.44
Top 5 Accuracy: 94.19
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w32_3rdparty_8xb32_in1k_20220120-c394f1ab.pth
Config: configs/hrnet/hrnet-w32_4xb32_in1k.py
Converted From:
Weights: https://1drv.ms/u/s!Aus8VCZ_C_33dYBMemi9xOUFR0w
Code: https://github.com/HRNet/HRNet-Image-Classification
- Name: hrnet-w40_3rdparty_8xb32_in1k
Metadata:
FLOPs: 12767574064
Parameters: 57553320
In Collection: HRNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 78.94
Top 5 Accuracy: 94.47
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w40_3rdparty_8xb32_in1k_20220120-9a2dbfc5.pth
Config: configs/hrnet/hrnet-w40_4xb32_in1k.py
Converted From:
Weights: https://1drv.ms/u/s!Aus8VCZ_C_33ck0gvo5jfoWBOPo
Code: https://github.com/HRNet/HRNet-Image-Classification
- Name: hrnet-w44_3rdparty_8xb32_in1k
Metadata:
FLOPs: 14963902632
Parameters: 67061144
In Collection: HRNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 78.88
Top 5 Accuracy: 94.37
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w44_3rdparty_8xb32_in1k_20220120-35d07f73.pth
Config: configs/hrnet/hrnet-w44_4xb32_in1k.py
Converted From:
Weights: https://1drv.ms/u/s!Aus8VCZ_C_33czZQ0woUb980gRs
Code: https://github.com/HRNet/HRNet-Image-Classification
- Name: hrnet-w48_3rdparty_8xb32_in1k
Metadata:
FLOPs: 17364014752
Parameters: 77466024
In Collection: HRNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 79.32
Top 5 Accuracy: 94.52
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w48_3rdparty_8xb32_in1k_20220120-e555ef50.pth
Config: configs/hrnet/hrnet-w48_4xb32_in1k.py
Converted From:
Weights: https://1drv.ms/u/s!Aus8VCZ_C_33dKvqI6pBZlifgJk
Code: https://github.com/HRNet/HRNet-Image-Classification
- Name: hrnet-w64_3rdparty_8xb32_in1k
Metadata:
FLOPs: 29002298752
Parameters: 128056104
In Collection: HRNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 79.46
Top 5 Accuracy: 94.65
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w64_3rdparty_8xb32_in1k_20220120-19126642.pth
Config: configs/hrnet/hrnet-w64_4xb32_in1k.py
Converted From:
Weights: https://1drv.ms/u/s!Aus8VCZ_C_33gQbJsUPTIj3rQu99
Code: https://github.com/HRNet/HRNet-Image-Classification
- Name: hrnet-w18_3rdparty_8xb32-ssld_in1k
Metadata:
FLOPs: 4330397932
Parameters: 21295164
In Collection: HRNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 81.06
Top 5 Accuracy: 95.7
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w18_3rdparty_8xb32-ssld_in1k_20220120-455f69ea.pth
Config: configs/hrnet/hrnet-w18_4xb32_in1k.py
Converted From:
Weights: https://github.com/HRNet/HRNet-Image-Classification/releases/download/PretrainedWeights/HRNet_W18_C_ssld_pretrained.pth
Code: https://github.com/HRNet/HRNet-Image-Classification
- Name: hrnet-w48_3rdparty_8xb32-ssld_in1k
Metadata:
FLOPs: 17364014752
Parameters: 77466024
In Collection: HRNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.63
Top 5 Accuracy: 96.79
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w48_3rdparty_8xb32-ssld_in1k_20220120-d0459c38.pth
Config: configs/hrnet/hrnet-w48_4xb32_in1k.py
Converted From:
Weights: https://github.com/HRNet/HRNet-Image-Classification/releases/download/PretrainedWeights/HRNet_W48_C_ssld_pretrained.pth
Code: https://github.com/HRNet/HRNet-Image-Classification

View File

@ -119,6 +119,15 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
| ConvNeXt-L\* | 197.77 | 34.37 | 84.30 | 96.89 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convnext/convnext-large_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_3rdparty_64xb64_in1k_20220124-f8a0ded0.pth) |
| ConvNeXt-L\* | 197.77 | 34.37 | 86.61 | 98.04 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convnext/convnext-large_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-large_in21k-pre-3rdparty_64xb64_in1k_20220124-2412403d.pth) |
| ConvNeXt-XL\* | 350.20 | 60.93 | 86.97 | 98.20 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/convnext/convnext-xlarge_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/convnext/convnext-xlarge_in21k-pre-3rdparty_64xb64_in1k_20220124-76b6863d.pth) |
| HRNet-W18\* | 21.30 | 4.33 | 76.75 | 93.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w18_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w18_3rdparty_8xb32_in1k_20220120-0c10b180.pth) |
| HRNet-W30\* | 37.71 | 8.17 | 78.19 | 94.22 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w30_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w30_3rdparty_8xb32_in1k_20220120-8aa3832f.pth) |
| HRNet-W32\* | 41.23 | 8.99 | 78.44 | 94.19 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w32_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w32_3rdparty_8xb32_in1k_20220120-c394f1ab.pth) |
| HRNet-W40\* | 57.55 | 12.77 | 78.94 | 94.47 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w40_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w40_3rdparty_8xb32_in1k_20220120-9a2dbfc5.pth) |
| HRNet-W44\* | 67.06 | 14.96 | 78.88 | 94.37 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w44_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w44_3rdparty_8xb32_in1k_20220120-35d07f73.pth) |
| HRNet-W48\* | 77.47 | 17.36 | 79.32 | 94.52 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w48_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w48_3rdparty_8xb32_in1k_20220120-e555ef50.pth) |
| HRNet-W64\* | 128.06 | 29.00 | 79.46 | 94.65 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w64_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w64_3rdparty_8xb32_in1k_20220120-19126642.pth) |
| HRNet-W18 (ssld)\* | 21.30 | 4.33 | 81.06 | 95.70 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w18_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w18_3rdparty_8xb32-ssld_in1k_20220120-455f69ea.pth) |
| HRNet-W48 (ssld)\* | 77.47 | 17.36 | 83.63 | 96.79 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/hrnet/hrnet-w48_4xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/hrnet/hrnet-w48_3rdparty_8xb32-ssld_in1k_20220120-d0459c38.pth) |
*Models with \* are converted from other repos, others are trained by ourselves.*

View File

@ -4,6 +4,7 @@ from .conformer import Conformer
from .convnext import ConvNeXt
from .deit import DistilledVisionTransformer
from .efficientnet import EfficientNet
from .hrnet import HRNet
from .lenet import LeNet5
from .mlp_mixer import MlpMixer
from .mobilenet_v2 import MobileNetV2
@ -33,5 +34,5 @@ __all__ = [
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG',
'Conformer', 'MlpMixer', 'DistilledVisionTransformer', 'PCPVT', 'SVT',
'EfficientNet', 'ConvNeXt'
'EfficientNet', 'ConvNeXt', 'HRNet'
]

View File

@ -0,0 +1,563 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule, ModuleList, Sequential
from torch.nn.modules.batchnorm import _BatchNorm
from ..builder import BACKBONES
from .resnet import BasicBlock, Bottleneck, ResLayer, get_expansion
class HRModule(BaseModule):
"""High-Resolution Module for HRNet.
In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
is in this module.
Args:
num_branches (int): The number of branches.
block (``BaseModule``): Convolution block module.
num_blocks (tuple): The number of blocks in each branch.
The length must be equal to ``num_branches``.
num_channels (tuple): The number of base channels in each branch.
The length must be equal to ``num_branches``.
multiscale_output (bool): Whether to output multi-level features
produced by multiple branches. If False, only the first level
feature will be output. Defaults to True.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
conv_cfg (dict, optional): Dictionary to construct and config conv
layer. Defaults to None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Defaults to ``dict(type='BN')``.
block_init_cfg (dict, optional): The initialization configs of every
blocks. Defaults to None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
num_branches,
block,
num_blocks,
in_channels,
num_channels,
multiscale_output=True,
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
block_init_cfg=None,
init_cfg=None):
super(HRModule, self).__init__(init_cfg)
self.block_init_cfg = block_init_cfg
self._check_branches(num_branches, num_blocks, in_channels,
num_channels)
self.in_channels = in_channels
self.num_branches = num_branches
self.multiscale_output = multiscale_output
self.norm_cfg = norm_cfg
self.conv_cfg = conv_cfg
self.with_cp = with_cp
self.branches = self._make_branches(num_branches, block, num_blocks,
num_channels)
self.fuse_layers = self._make_fuse_layers()
self.relu = nn.ReLU(inplace=False)
def _check_branches(self, num_branches, num_blocks, in_channels,
num_channels):
if num_branches != len(num_blocks):
error_msg = f'NUM_BRANCHES({num_branches}) ' \
f'!= NUM_BLOCKS({len(num_blocks)})'
raise ValueError(error_msg)
if num_branches != len(num_channels):
error_msg = f'NUM_BRANCHES({num_branches}) ' \
f'!= NUM_CHANNELS({len(num_channels)})'
raise ValueError(error_msg)
if num_branches != len(in_channels):
error_msg = f'NUM_BRANCHES({num_branches}) ' \
f'!= NUM_INCHANNELS({len(in_channels)})'
raise ValueError(error_msg)
def _make_branches(self, num_branches, block, num_blocks, num_channels):
branches = []
for i in range(num_branches):
out_channels = num_channels[i] * get_expansion(block)
branches.append(
ResLayer(
block=block,
num_blocks=num_blocks[i],
in_channels=self.in_channels[i],
out_channels=out_channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
with_cp=self.with_cp,
init_cfg=self.block_init_cfg,
))
return ModuleList(branches)
def _make_fuse_layers(self):
if self.num_branches == 1:
return None
num_branches = self.num_branches
in_channels = self.in_channels
fuse_layers = []
num_out_branches = num_branches if self.multiscale_output else 1
for i in range(num_out_branches):
fuse_layer = []
for j in range(num_branches):
if j > i:
# Upsample the feature maps of smaller scales.
fuse_layer.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[i],
kernel_size=1,
stride=1,
padding=0,
bias=False),
build_norm_layer(self.norm_cfg, in_channels[i])[1],
nn.Upsample(
scale_factor=2**(j - i), mode='nearest')))
elif j == i:
# Keep the feature map with the same scale.
fuse_layer.append(None)
else:
# Downsample the feature maps of larger scales.
conv_downsamples = []
for k in range(i - j):
# Use stacked convolution layers to downsample.
if k == i - j - 1:
conv_downsamples.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[i],
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg,
in_channels[i])[1]))
else:
conv_downsamples.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[j],
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg,
in_channels[j])[1],
nn.ReLU(inplace=False)))
fuse_layer.append(nn.Sequential(*conv_downsamples))
fuse_layers.append(nn.ModuleList(fuse_layer))
return nn.ModuleList(fuse_layers)
def forward(self, x):
"""Forward function."""
if self.num_branches == 1:
return [self.branches[0](x[0])]
for i in range(self.num_branches):
x[i] = self.branches[i](x[i])
x_fuse = []
for i in range(len(self.fuse_layers)):
y = 0
for j in range(self.num_branches):
if i == j:
y += x[j]
else:
y += self.fuse_layers[i][j](x[j])
x_fuse.append(self.relu(y))
return x_fuse
@BACKBONES.register_module()
class HRNet(BaseModule):
"""HRNet backbone.
`High-Resolution Representations for Labeling Pixels and Regions
<https://arxiv.org/abs/1904.04514>`_.
Args:
arch (str): The preset HRNet architecture, includes 'w18', 'w30',
'w32', 'w40', 'w44', 'w48', 'w64'. It will only be used if
extra is ``None``. Defaults to 'w32'.
extra (dict, optional): Detailed configuration for each stage of HRNet.
There must be 4 stages, the configuration for each stage must have
5 keys:
- num_modules (int): The number of HRModule in this stage.
- num_branches (int): The number of branches in the HRModule.
- block (str): The type of convolution block. Please choose between
'BOTTLENECK' and 'BASIC'.
- num_blocks (tuple): The number of blocks in each branch.
The length must be equal to num_branches.
- num_channels (tuple): The number of base channels in each branch.
The length must be equal to num_branches.
Defaults to None.
in_channels (int): Number of input image channels. Defaults to 3.
conv_cfg (dict, optional): Dictionary to construct and config conv
layer. Defaults to None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Defaults to ``dict(type='BN')``.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Defaults to False.
multiscale_output (bool): Whether to output multi-level features
produced by multiple branches. If False, only the first level
feature will be output. Defaults to True.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
Example:
>>> import torch
>>> from mmcls.models import HRNet
>>> extra = dict(
>>> stage1=dict(
>>> num_modules=1,
>>> num_branches=1,
>>> block='BOTTLENECK',
>>> num_blocks=(4, ),
>>> num_channels=(64, )),
>>> stage2=dict(
>>> num_modules=1,
>>> num_branches=2,
>>> block='BASIC',
>>> num_blocks=(4, 4),
>>> num_channels=(32, 64)),
>>> stage3=dict(
>>> num_modules=4,
>>> num_branches=3,
>>> block='BASIC',
>>> num_blocks=(4, 4, 4),
>>> num_channels=(32, 64, 128)),
>>> stage4=dict(
>>> num_modules=3,
>>> num_branches=4,
>>> block='BASIC',
>>> num_blocks=(4, 4, 4, 4),
>>> num_channels=(32, 64, 128, 256)))
>>> self = HRNet(extra, in_channels=1)
>>> self.eval()
>>> inputs = torch.rand(1, 1, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 32, 8, 8)
(1, 64, 4, 4)
(1, 128, 2, 2)
(1, 256, 1, 1)
"""
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
arch_zoo = {
# num_modules, num_branches, block, num_blocks, num_channels
'w18': [[1, 1, 'BOTTLENECK', (4, ), (64, )],
[1, 2, 'BASIC', (4, 4), (18, 36)],
[4, 3, 'BASIC', (4, 4, 4), (18, 36, 72)],
[3, 4, 'BASIC', (4, 4, 4, 4), (18, 36, 72, 144)]],
'w30': [[1, 1, 'BOTTLENECK', (4, ), (64, )],
[1, 2, 'BASIC', (4, 4), (30, 60)],
[4, 3, 'BASIC', (4, 4, 4), (30, 60, 120)],
[3, 4, 'BASIC', (4, 4, 4, 4), (30, 60, 120, 240)]],
'w32': [[1, 1, 'BOTTLENECK', (4, ), (64, )],
[1, 2, 'BASIC', (4, 4), (32, 64)],
[4, 3, 'BASIC', (4, 4, 4), (32, 64, 128)],
[3, 4, 'BASIC', (4, 4, 4, 4), (32, 64, 128, 256)]],
'w40': [[1, 1, 'BOTTLENECK', (4, ), (64, )],
[1, 2, 'BASIC', (4, 4), (40, 80)],
[4, 3, 'BASIC', (4, 4, 4), (40, 80, 160)],
[3, 4, 'BASIC', (4, 4, 4, 4), (40, 80, 160, 320)]],
'w44': [[1, 1, 'BOTTLENECK', (4, ), (64, )],
[1, 2, 'BASIC', (4, 4), (44, 88)],
[4, 3, 'BASIC', (4, 4, 4), (44, 88, 176)],
[3, 4, 'BASIC', (4, 4, 4, 4), (44, 88, 176, 352)]],
'w48': [[1, 1, 'BOTTLENECK', (4, ), (64, )],
[1, 2, 'BASIC', (4, 4), (48, 96)],
[4, 3, 'BASIC', (4, 4, 4), (48, 96, 192)],
[3, 4, 'BASIC', (4, 4, 4, 4), (48, 96, 192, 384)]],
'w64': [[1, 1, 'BOTTLENECK', (4, ), (64, )],
[1, 2, 'BASIC', (4, 4), (64, 128)],
[4, 3, 'BASIC', (4, 4, 4), (64, 128, 256)],
[3, 4, 'BASIC', (4, 4, 4, 4), (64, 128, 256, 512)]],
} # yapf:disable
def __init__(self,
arch='w32',
extra=None,
in_channels=3,
conv_cfg=None,
norm_cfg=dict(type='BN'),
norm_eval=False,
with_cp=False,
zero_init_residual=False,
multiscale_output=True,
init_cfg=[
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]):
super(HRNet, self).__init__(init_cfg)
extra = self.parse_arch(arch, extra)
# Assert configurations of 4 stages are in extra
for i in range(1, 5):
assert f'stage{i}' in extra, f'Missing stage{i} config in "extra".'
# Assert whether the length of `num_blocks` and `num_channels` are
# equal to `num_branches`
cfg = extra[f'stage{i}']
assert len(cfg['num_blocks']) == cfg['num_branches'] and \
len(cfg['num_channels']) == cfg['num_branches']
self.extra = extra
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.zero_init_residual = zero_init_residual
# -------------------- stem net --------------------
self.conv1 = build_conv_layer(
self.conv_cfg,
in_channels,
out_channels=64,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(
self.conv_cfg,
in_channels=64,
out_channels=64,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
self.add_module(self.norm2_name, norm2)
self.relu = nn.ReLU(inplace=True)
# -------------------- stage 1 --------------------
self.stage1_cfg = self.extra['stage1']
base_channels = self.stage1_cfg['num_channels']
block_type = self.stage1_cfg['block']
num_blocks = self.stage1_cfg['num_blocks']
block = self.blocks_dict[block_type]
num_channels = [
channel * get_expansion(block) for channel in base_channels
]
# To align with the original code, use layer1 instead of stage1 here.
self.layer1 = ResLayer(
block,
in_channels=64,
out_channels=num_channels[0],
num_blocks=num_blocks[0])
pre_num_channels = num_channels
# -------------------- stage 2~4 --------------------
for i in range(2, 5):
stage_cfg = self.extra[f'stage{i}']
base_channels = stage_cfg['num_channels']
block = self.blocks_dict[stage_cfg['block']]
multiscale_output_ = multiscale_output if i == 4 else True
num_channels = [
channel * get_expansion(block) for channel in base_channels
]
# The transition layer from layer1 to stage2
transition = self._make_transition_layer(pre_num_channels,
num_channels)
self.add_module(f'transition{i-1}', transition)
stage = self._make_stage(
stage_cfg, num_channels, multiscale_output=multiscale_output_)
self.add_module(f'stage{i}', stage)
pre_num_channels = num_channels
@property
def norm1(self):
"""nn.Module: the normalization layer named "norm1" """
return getattr(self, self.norm1_name)
@property
def norm2(self):
"""nn.Module: the normalization layer named "norm2" """
return getattr(self, self.norm2_name)
def _make_transition_layer(self, num_channels_pre_layer,
num_channels_cur_layer):
num_branches_cur = len(num_channels_cur_layer)
num_branches_pre = len(num_channels_pre_layer)
transition_layers = []
for i in range(num_branches_cur):
if i < num_branches_pre:
# For existing scale branches,
# add conv block when the channels are not the same.
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
transition_layers.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
num_channels_pre_layer[i],
num_channels_cur_layer[i],
kernel_size=3,
stride=1,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg,
num_channels_cur_layer[i])[1],
nn.ReLU(inplace=True)))
else:
transition_layers.append(nn.Identity())
else:
# For new scale branches, add stacked downsample conv blocks.
# For example, num_branches_pre = 2, for the 4th branch, add
# stacked two downsample conv blocks.
conv_downsamples = []
for j in range(i + 1 - num_branches_pre):
in_channels = num_channels_pre_layer[-1]
out_channels = num_channels_cur_layer[i] \
if j == i - num_branches_pre else in_channels
conv_downsamples.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels,
out_channels,
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, out_channels)[1],
nn.ReLU(inplace=True)))
transition_layers.append(nn.Sequential(*conv_downsamples))
return nn.ModuleList(transition_layers)
def _make_stage(self, layer_config, in_channels, multiscale_output=True):
num_modules = layer_config['num_modules']
num_branches = layer_config['num_branches']
num_blocks = layer_config['num_blocks']
num_channels = layer_config['num_channels']
block = self.blocks_dict[layer_config['block']]
hr_modules = []
block_init_cfg = None
if self.zero_init_residual:
if block is BasicBlock:
block_init_cfg = dict(
type='Constant', val=0, override=dict(name='norm2'))
elif block is Bottleneck:
block_init_cfg = dict(
type='Constant', val=0, override=dict(name='norm3'))
for i in range(num_modules):
# multi_scale_output is only used for the last module
if not multiscale_output and i == num_modules - 1:
reset_multiscale_output = False
else:
reset_multiscale_output = True
hr_modules.append(
HRModule(
num_branches,
block,
num_blocks,
in_channels,
num_channels,
reset_multiscale_output,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg,
block_init_cfg=block_init_cfg))
return Sequential(*hr_modules)
def forward(self, x):
"""Forward function."""
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.relu(x)
x = self.layer1(x)
x_list = [x]
for i in range(2, 5):
# Apply transition
transition = getattr(self, f'transition{i-1}')
inputs = []
for j, layer in enumerate(transition):
if j < len(x_list):
inputs.append(layer(x_list[j]))
else:
inputs.append(layer(x_list[-1]))
# Forward HRModule
stage = getattr(self, f'stage{i}')
x_list = stage(inputs)
return tuple(x_list)
def train(self, mode=True):
"""Convert the model into training mode will keeping the normalization
layer freezed."""
super(HRNet, self).train(mode)
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
def parse_arch(self, arch, extra=None):
if extra is not None:
return extra
assert arch in self.arch_zoo, \
('Invalid arch, please choose arch from '
f'{list(self.arch_zoo.keys())}, or specify `extra` '
'argument directly.')
extra = dict()
for i, stage_setting in enumerate(self.arch_zoo[arch], start=1):
extra[f'stage{i}'] = dict(
num_modules=stage_setting[0],
num_branches=stage_setting[1],
block=stage_setting[2],
num_blocks=stage_setting[3],
num_channels=stage_setting[4],
)
return extra

View File

@ -5,6 +5,7 @@ import torch.utils.checkpoint as cp
from mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
constant_init)
from mmcv.cnn.bricks import DropPath
from mmcv.runner import BaseModule
from mmcv.utils.parrots_wrapper import _BatchNorm
from ..builder import BACKBONES
@ -13,7 +14,7 @@ from .base_backbone import BaseBackbone
eps = 1.0e-5
class BasicBlock(nn.Module):
class BasicBlock(BaseModule):
"""BasicBlock for ResNet.
Args:
@ -47,8 +48,9 @@ class BasicBlock(nn.Module):
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
drop_path_rate=0.0):
super(BasicBlock, self).__init__()
drop_path_rate=0.0,
init_cfg=None):
super(BasicBlock, self).__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
self.expansion = expansion
@ -130,7 +132,7 @@ class BasicBlock(nn.Module):
return out
class Bottleneck(nn.Module):
class Bottleneck(BaseModule):
"""Bottleneck block for ResNet.
Args:
@ -164,8 +166,9 @@ class Bottleneck(nn.Module):
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
drop_path_rate=0.0):
super(Bottleneck, self).__init__()
drop_path_rate=0.0,
init_cfg=None):
super(Bottleneck, self).__init__(init_cfg=init_cfg)
assert style in ['pytorch', 'caffe']
self.in_channels = in_channels

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .gap import GlobalAveragePooling
from .hr_fuse import HRFuseScales
__all__ = ['GlobalAveragePooling']
__all__ = ['GlobalAveragePooling', 'HRFuseScales']

View File

@ -0,0 +1,83 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn.bricks import ConvModule
from mmcv.runner import BaseModule
from ..backbones.resnet import Bottleneck, ResLayer
from ..builder import NECKS
@NECKS.register_module()
class HRFuseScales(BaseModule):
"""Fuse feature map of multiple scales in HRNet.
Args:
in_channels (list[int]): The input channels of all scales.
out_channels (int): The channels of fused feature map.
Defaults to 2048.
norm_cfg (dict): dictionary to construct norm layers.
Defaults to ``dict(type='BN', momentum=0.1)``.
init_cfg (dict | list[dict], optional): Initialization config dict.
Defaults to ``dict(type='Normal', layer='Linear', std=0.01))``.
"""
def __init__(self,
in_channels,
out_channels=2048,
norm_cfg=dict(type='BN', momentum=0.1),
init_cfg=dict(type='Normal', layer='Linear', std=0.01)):
super(HRFuseScales, self).__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
self.norm_cfg = norm_cfg
block_type = Bottleneck
out_channels = [128, 256, 512, 1024]
# Increase the channels on each resolution
# from C, 2C, 4C, 8C to 128, 256, 512, 1024
increase_layers = []
for i in range(len(in_channels)):
increase_layers.append(
ResLayer(
block_type,
in_channels=in_channels[i],
out_channels=out_channels[i],
num_blocks=1,
stride=1,
))
self.increase_layers = nn.ModuleList(increase_layers)
# Downsample feature maps in each scale.
downsample_layers = []
for i in range(len(in_channels) - 1):
downsample_layers.append(
ConvModule(
in_channels=out_channels[i],
out_channels=out_channels[i + 1],
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
bias=False,
))
self.downsample_layers = nn.ModuleList(downsample_layers)
# The final conv block before final classifier linear layer.
self.final_layer = ConvModule(
in_channels=out_channels[3],
out_channels=self.out_channels,
kernel_size=1,
norm_cfg=self.norm_cfg,
bias=False,
)
def forward(self, x):
assert isinstance(x, tuple) and len(x) == len(self.in_channels)
feat = self.increase_layers[0](x[0])
for i in range(len(self.downsample_layers)):
feat = self.downsample_layers[i](feat) + \
self.increase_layers[i + 1](x[i + 1])
return (self.final_layer(feat), )

View File

@ -19,3 +19,4 @@ Import:
- configs/twins/metafile.yml
- configs/efficientnet/metafile.yml
- configs/convnext/metafile.yml
- configs/hrnet/metafile.yml

View File

@ -0,0 +1,93 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.backbones import HRNet
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (GroupNorm, _BatchNorm)):
return True
return False
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
@pytest.mark.parametrize('base_channels', [18, 30, 32, 40, 44, 48, 64])
def test_hrnet_arch_zoo(base_channels):
cfg_ori = dict(arch=f'w{base_channels}')
# Test HRNet model with input size of 224
model = HRNet(**cfg_ori)
model.init_weights()
model.train()
assert check_norm_state(model.modules(), True)
imgs = torch.randn(3, 3, 224, 224)
outs = model(imgs)
out_channels = base_channels
out_size = 56
assert isinstance(outs, tuple)
for out in outs:
assert out.shape == (3, out_channels, out_size, out_size)
out_channels = out_channels * 2
out_size = out_size // 2
def test_hrnet_custom_arch():
cfg_ori = dict(
extra=dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(32, 64)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BOTTLENECK',
num_blocks=(4, 4, 2),
num_channels=(32, 64, 128)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 3, 4, 4),
num_channels=(32, 64, 152, 256)),
), )
# Test HRNet model with input size of 224
model = HRNet(**cfg_ori)
model.init_weights()
model.train()
assert check_norm_state(model.modules(), True)
imgs = torch.randn(3, 3, 224, 224)
outs = model(imgs)
out_channels = (32, 64, 152, 256)
out_size = 56
assert isinstance(outs, tuple)
for out, out_channel in zip(outs, out_channels):
assert out.shape == (3, out_channel, out_size, out_size)
out_size = out_size // 2

View File

@ -2,7 +2,7 @@
import pytest
import torch
from mmcls.models.necks import GlobalAveragePooling
from mmcls.models.necks import GlobalAveragePooling, HRFuseScales
def test_gap_neck():
@ -37,3 +37,24 @@ def test_gap_neck():
with pytest.raises(AssertionError):
# dim must in [1, 2, 3]
GlobalAveragePooling(dim='other')
def test_hr_fuse_scales():
in_channels = (18, 32, 64, 128)
neck = HRFuseScales(in_channels=in_channels, out_channels=1024)
feat_size = 56
inputs = []
for in_channel in in_channels:
input_tensor = torch.rand(3, in_channel, feat_size, feat_size)
inputs.append(input_tensor)
feat_size = feat_size // 2
with pytest.raises(AssertionError):
neck(inputs)
outs = neck(tuple(inputs))
assert isinstance(outs, tuple)
assert len(outs) == 1
assert outs[0].shape == (3, 1024, 7, 7)