commit
5aba5f898e
|
@ -0,0 +1,87 @@
|
|||
# 特征图可视化指南
|
||||
|
||||
## 一、概述
|
||||
|
||||
特征图是输入图片在卷积网络中的特征表达,对特征图的研究可以有利于我们对于模型的理解与设计,所以基于动态图我们使用本工具来可视化特征图。
|
||||
|
||||
## 二、准备工作
|
||||
|
||||
首先需要选定研究的模型,本文设定ResNet50作为研究模型,将模型组网代码[resnet.py](../../../ppcls/arch/backbone/legendary_models/resnet.py)拷贝到[目录](../../../ppcls/utils/feature_maps_visualization/)下,并下载[ResNet50预训练模型](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_pretrained.pdparams),或使用以下命令下载。
|
||||
|
||||
```bash
|
||||
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_pretrained.pdparams
|
||||
```
|
||||
|
||||
其他模型网络结构代码及预训练模型请自行下载:[模型库](../../../ppcls/arch/backbone/),[预训练模型](../models/models_intro.md)。
|
||||
|
||||
## 三、修改模型
|
||||
|
||||
找到我们所需要的特征图位置,设置self.fm将其fetch出来,本文以resnet50中的stem层之后的特征图为例。
|
||||
|
||||
在ResNet50的forward函数中指定要可视化的特征图
|
||||
|
||||
```python
|
||||
def forward(self, x):
|
||||
with paddle.static.amp.fp16_guard():
|
||||
if self.data_format == "NHWC":
|
||||
x = paddle.transpose(x, [0, 2, 3, 1])
|
||||
x.stop_gradient = True
|
||||
x = self.stem(x)
|
||||
fm = x
|
||||
x = self.max_pool(x)
|
||||
x = self.blocks(x)
|
||||
x = self.avg_pool(x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc(x)
|
||||
return x, fm
|
||||
```
|
||||
|
||||
然后修改代码[fm_vis.py](../../../ppcls/utils/feature_maps_visualization/fm_vis.py),引入 `ResNet50`,实例化 `net` 对象:
|
||||
|
||||
```python
|
||||
from resnet import ResNet50
|
||||
net = ResNet50()
|
||||
```
|
||||
|
||||
最后执行函数
|
||||
|
||||
```bash
|
||||
python tools/feature_maps_visualization/fm_vis.py \
|
||||
-i the image you want to test \
|
||||
-c channel_num -p pretrained model \
|
||||
--show whether to show \
|
||||
--interpolation interpolation method\
|
||||
--save_path where to save \
|
||||
--use_gpu whether to use gpu
|
||||
```
|
||||
|
||||
参数说明:
|
||||
+ `-i`:待预测的图片文件路径,如 `./test.jpeg`
|
||||
+ `-c`:特征图维度,如 `5`
|
||||
+ `-p`:权重文件路径,如 `./ResNet50_pretrained`
|
||||
+ `--interpolation`: 图像插值方式, 默认值 1
|
||||
+ `--save_path`:保存路径,如:`./tools/`
|
||||
+ `--use_gpu`:是否使用 GPU 预测,默认值:True
|
||||
|
||||
## 四、结果
|
||||
|
||||
* 输入图片:
|
||||
|
||||

|
||||
|
||||
* 运行下面的特征图可视化脚本
|
||||
|
||||
```
|
||||
python tools/feature_maps_visualization/fm_vis.py \
|
||||
-i ./docs/images/feature_maps/feature_visualization_input.jpg \
|
||||
-c 5 \
|
||||
-p pretrained/ResNet50_pretrained/ \
|
||||
--show=True \
|
||||
--interpolation=1 \
|
||||
--save_path="./output.png" \
|
||||
--use_gpu=False
|
||||
```
|
||||
|
||||
* 输出特征图保存为`output.png`,如下所示。
|
||||
|
||||

|
|
@ -0,0 +1,62 @@
|
|||
# 图像分类昆仑模型介绍(持续更新中)
|
||||
|
||||
## 前言
|
||||
|
||||
* 本文档介绍了目前昆仑支持的模型以及如何在昆仑设备上训练这些模型。支持昆仑的PaddlePaddle安装参考install_kunlun(https://github.com/PaddlePaddle/FluidDoc/blob/develop/doc/paddle/install/install_Kunlun_zh.md)
|
||||
|
||||
## 昆仑训练
|
||||
* 数据来源和预训练模型参考[quick_start](../quick_start/quick_start_classification_new_user.md)。昆仑训练效果与CPU/GPU对齐。
|
||||
|
||||
### ResNet50
|
||||
* 命令:
|
||||
|
||||
```shell
|
||||
python3.7 ppcls/static/train.py \
|
||||
-c ppcls/configs/quick_start/kunlun/ResNet50_vd_finetune_kunlun.yaml \
|
||||
-o use_gpu=False \
|
||||
-o use_xpu=True \
|
||||
-o is_distributed=False
|
||||
```
|
||||
|
||||
与cpu/gpu训练的区别是加上-o use_xpu=True, 表示执行在昆仑设备上。
|
||||
|
||||
### MobileNetV3
|
||||
* 命令:
|
||||
|
||||
```shell
|
||||
python3.7 ppcls/static/train.py \
|
||||
-c ppcls/configs/quick_start/MobileNetV3_large_x1_0.yaml \
|
||||
-o use_gpu=False \
|
||||
-o use_xpu=True \
|
||||
-o is_distributed=False
|
||||
```
|
||||
|
||||
### HRNet
|
||||
* 命令:
|
||||
|
||||
```shell
|
||||
python3.7 ppcls/static/train.py \
|
||||
-c ppcls/configs/quick_start/kunlun/HRNet_W18_C_finetune_kunlun.yaml \
|
||||
-o is_distributed=False \
|
||||
-o use_xpu=True \
|
||||
-o use_gpu=False
|
||||
```
|
||||
|
||||
|
||||
### VGG16/19
|
||||
* 命令:
|
||||
|
||||
```shell
|
||||
python3.7 ppcls/static/train.py \
|
||||
-c ppcls/configs/quick_start/VGG16_finetune_kunlun.yaml \
|
||||
-o use_gpu=False \
|
||||
-o use_xpu=True \
|
||||
-o is_distributed=False
|
||||
```
|
||||
```shell
|
||||
python3.7 ppcls/static/train.py \
|
||||
-c ppcls/configs/quick_start/VGG19_finetune_kunlun.yaml \
|
||||
-o use_gpu=False \
|
||||
-o use_xpu=True \
|
||||
-o is_distributed=False
|
||||
```
|
|
@ -0,0 +1,58 @@
|
|||
# 使用DALI加速训练
|
||||
|
||||
## 前言
|
||||
[NVIDIA数据加载库](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/index.html)(The NVIDIA Data Loading Library,DALI)是用于数据加载和预处理的开源库,用于加速深度学习训练、推理过程,它可以直接构建飞桨Paddle的DataLoader数据读取器。
|
||||
|
||||
由于深度学习程序在训练阶段依赖大量数据,这些数据需要经过加载、预处理等操作后,才能送入训练程序,而这些操作通常在CPU完成,因此限制了训练速度进一步提高,特别是在batch_size较大时,数据读取可能成为训练速度的瓶颈。DALI可以基于GPU的高并行特性实现数据加载及预处理操作,可以进一步提高训练速度。
|
||||
|
||||
## 安装DALI
|
||||
目前DALI仅支持Linux x64平台,且CUDA版本大于等于10.2。
|
||||
|
||||
* 对于CUDA 10:
|
||||
|
||||
pip install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda100
|
||||
|
||||
* 对于CUDA 11.0:
|
||||
|
||||
pip install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda110
|
||||
|
||||
关于更多DALI安装的信息,可以参考[DALI官方](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html)。
|
||||
|
||||
## 使用DALI
|
||||
PaddleClas支持在静态图训练方式中使用DALI加速,由于DALI仅支持GPU训练,因此需要设置GPU,且DALI需要占用GPU显存,需要为DALI预留显存。使用DALI训练只需在训练配置文件中设置字段`use_dali=True`,或通过以下命令启动训练即可:
|
||||
|
||||
```shell
|
||||
# 设置用于训练的GPU卡号
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
|
||||
python ppcls/static/train.py -c ppcls/configs/ImageNet/ResNet/ResNet50.yaml -o use_dali=True
|
||||
```
|
||||
|
||||
也可以使用多卡训练:
|
||||
|
||||
```shell
|
||||
# 设置用于训练的GPU卡号
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
|
||||
# 设置用于神经网络训练的显存大小,可根据具体情况设置,一般可设置为0.8或0.7,剩余显存则预留DALI使用
|
||||
export FLAGS_fraction_of_gpu_memory_to_use=0.80
|
||||
|
||||
python -m paddle.distributed.launch \
|
||||
--gpus="0,1,2,3,4,5,6,7" \
|
||||
ppcls/static/train.py \
|
||||
-c ./ppcls/configs/ImageNet/ResNet/ResNet50.yaml \
|
||||
-o use_dali=True
|
||||
```
|
||||
|
||||
## 使用FP16训练
|
||||
在上述基础上,使用FP16半精度训练,可以进一步提高速度,可以参考下面的配置与运行命令。
|
||||
|
||||
```shell
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||
export FLAGS_fraction_of_gpu_memory_to_use=0.8
|
||||
|
||||
python -m paddle.distributed.launch \
|
||||
--gpus="0,1,2,3,4,5,6,7" \
|
||||
ppcls/static/train.py \
|
||||
-c ./ppcls/configs/ImageNet/ResNet/ResNet50_fp16.yaml
|
||||
```
|
|
@ -0,0 +1,86 @@
|
|||
# 图像分类迁移学习
|
||||
|
||||
迁移学习是机器学习领域的一个重要分支,广泛应用于文本、图像等各种领域,此处我们主要介绍的是图像分类领域的迁移学习,也就是我们常说的域迁移,比如将 ImageNet 分类模型迁移到我们自己场景的图像分类任务上,如花卉分类。
|
||||
|
||||
## 一、 超参搜索
|
||||
|
||||
ImageNet 作为业界常用的图像分类数据被大家广泛使用,已经总结出一系列经验性的超参,使用这些超参往往能够得到不错的训练精度,而这些经验性的参数在迁移到自己的业务中时,有时效果不佳。有两种常用的超参搜索方法可以用于获得更好的模型超参。
|
||||
|
||||
### 1.1 网格搜索
|
||||
|
||||
网格搜索,即穷举搜索,通过查找搜索空间内所有的点,确定最优值。方法简单有效,但当搜索空间较大时,需要消耗大量的计算资源。
|
||||
|
||||
### 1.2 贝叶斯搜索
|
||||
|
||||
贝叶斯搜索,即贝叶斯优化,在搜索空间中随机选取超参数点,采用高斯过程,即根据上一个超参数点的结果,更新当前的先验信息,计算前面n个超参数点的后验概率分布,得到搜索空间中每一个超参数点的期望均值和方差,其中期望均值越大表示接近最优指标的可能性越大,方差越大表示不确定性越大。通常将选择期望均值大的超参数点称为`exporitation`,选择方差大的超参数点称为`exploration`。在贝叶斯优化中通过定义`acquisition function`权衡期望均值和方差。贝叶斯搜索认为当前选择的超参数点是处于最大值可能出现的位置。
|
||||
|
||||
------
|
||||
|
||||
基于上述两种搜索方案,我们在8个开源数据集上将固定一组参数实验以及两种搜索方案做了对比实验,参照[1]的实验方案,我们对4个超参数进行搜索,搜索空间及实验结果如下所示:
|
||||
|
||||
- 固定参数:
|
||||
|
||||
```
|
||||
初始学习率lr=0.003,l2 decay=1e-4,label smoothing=False,mixup=False
|
||||
```
|
||||
|
||||
- 超参搜索空间:
|
||||
|
||||
```
|
||||
初始学习率lr: [0.1, 0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001]
|
||||
|
||||
L2 decay: [1e-3, 3e-4, 1e-4, 3e-5, 1e-5, 3e-6, 1e-6]
|
||||
|
||||
Label smoothing: [False, True]
|
||||
|
||||
Mixup: [False, True]
|
||||
```
|
||||
|
||||
网格搜索的搜索次数为196次,而贝叶斯搜索通过设置最大迭代次数(`max_iter`)和是否重复搜索(`de_duplication`)来确定搜索次数。我们设计了系列实验,baseline为ImageNet1k校验集Top1 Acc为79.12%的ResNet50_vd预训练模型,并固定超参,在新数据集上finetune得到的模型。下表给出了固定参数、网格搜索以及贝叶斯搜索的精度与搜索次数对比。
|
||||
|
||||
- 精度与搜索次数对比:
|
||||
|
||||
| 数据集 | 固定参数 | 网格搜索 | 网格搜索次数 | 贝叶斯搜索 | 贝叶斯搜索次数|
|
||||
| ------------------ | -------- | -------- | -------- | -------- | ---------- |
|
||||
| Oxford-IIIT-Pets | 93.64% | 94.55% | 196 | 94.04% | 20 |
|
||||
| Oxford-102-Flowers | 96.08% | 97.69% | 196 | 97.49% | 20 |
|
||||
| Food101 | 87.07% | 87.52% | 196 | 87.33% | 23 |
|
||||
| SUN397 | 63.27% | 64.84% | 196 | 64.55% | 20 |
|
||||
| Caltech101 | 91.71% | 92.54% | 196 | 92.16% | 14 |
|
||||
| DTD | 76.87% | 77.53% | 196 | 77.47% | 13 |
|
||||
| Stanford Cars | 85.14% | 92.72% | 196 | 92.72% | 25 |
|
||||
| FGVC Aircraft | 80.32% | 88.45% | 196 | 88.36% | 20 |
|
||||
|
||||
|
||||
- 上述实验验证了贝叶斯搜索相比网格搜索,在减少搜索次数10倍左右条件下,精度只下降0%~0.4%。
|
||||
- 当搜索空间进一步扩大时,例如将是否进行AutoAugment,RandAugment,Cutout, Cutmix以及Dropout这些正则化策略作为选择时,贝叶斯搜索能够在获取较优精度的前提下,有效地降低搜索次数。
|
||||
|
||||
## 二、 大规模分类模型
|
||||
|
||||
在实际应用中,由于训练数据的匮乏,往往将ImageNet1k数据集训练的分类模型作为预训练模型,进行图像分类的迁移学习。为了进一步助力解决实际问题,基于ResNet50_vd, 百度开源了自研的大规模分类预训练模型,其中训练数据为10万个类别,4300万张图片。10万类预训练模型的下载地址:[**下载地址**](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_10w_pretrained.pdparams)
|
||||
|
||||
我们在6个自有采集的数据集上进行迁移学习实验,采用一组固定参数以及网格搜索方式,其中训练轮数设置为20epochs,选用ResNet50_vd模型,ImageNet预训练精度为79.12%。实验数据集参数以及模型精度的对比结果如下:
|
||||
|
||||
固定参数:
|
||||
|
||||
```
|
||||
初始学习率lr=0.001,l2 decay=1e-4,label smoothing=False,mixup=False
|
||||
```
|
||||
|
||||
| 数据集 | 数据统计 | **ImageNet预训练模型 <br />固定参数Top-1/参数搜索Top-1** | **大规模分类预训练模型<br />固定参数Top-1/参数搜索Top-1** |
|
||||
| --------------- | ----------------------------------------- | -------------------------------------------------------- | --------------------------------------------------------- |
|
||||
| 花卉 | class:102<br />train:5789<br />valid:2396 | 0.7779/0.9883 | 0.9892/0.9954 |
|
||||
| 手绘简笔画 | Class:18<br />train:1007<br />valid:432 | 0.8795/0.9196 | 0.9107/0.9219 |
|
||||
| 植物叶子 | class:6<br />train:5256<br />valid:2278 | 0.8212/0.8482 | 0.8385/0.8659 |
|
||||
| 集装箱车辆 | Class:115<br />train:4879<br />valid:2094 | 0.6230/0.9556 | 0.9524/0.9702 |
|
||||
| 椅子 | class:5<br />train:169<br />valid:78 | 0.8557/0.9688 | 0.9077/0.9792 |
|
||||
| 地质 | class:4<br />train:671<br />valid:296 | 0.5719/0.8094 | 0.6781/0.8219 |
|
||||
|
||||
- 通过上述的实验验证了当使用一组固定参数时,相比于ImageNet预训练模型,使用大规模分类模型作为预训练模型在大多数情况下能够提升模型在新的数据集上得效果,通过参数搜索可以进一步提升精度。
|
||||
|
||||
|
||||
## 参考文献
|
||||
|
||||
[1] Kornblith, Simon, Jonathon Shlens, and Quoc V. Le. "Do better imagenet models transfer better?." *Proceedings of the IEEE conference on computer vision and pattern recognition*. 2019.
|
||||
|
||||
[2] Kolesnikov, Alexander, et al. "Large Scale Learning of General Visual Representations for Transfer." *arXiv preprint arXiv:1912.11370* (2019).
|
|
@ -1,2 +0,0 @@
|
|||
wget https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar
|
||||
tar -xf ResNet50_pretrained.tar
|
|
@ -19,7 +19,7 @@ import os
|
|||
import sys
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../..')))
|
||||
|
||||
import paddle
|
||||
from paddle.distributed import ParallelEnv
|
||||
|
@ -33,18 +33,13 @@ def parse_args():
|
|||
return v.lower() in ("true", "t", "1")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-i", "--image_file", type=str)
|
||||
parser.add_argument("-i", "--image_file", required=True, type=str)
|
||||
parser.add_argument("-c", "--channel_num", type=int)
|
||||
parser.add_argument("-p", "--pretrained_model", type=str)
|
||||
parser.add_argument("--show", type=str2bool, default=False)
|
||||
parser.add_argument("--interpolation", type=int, default=1)
|
||||
parser.add_argument("--save_path", type=str, default=None)
|
||||
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
||||
parser.add_argument(
|
||||
"--load_static_weights",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help='Whether to load the pretrained weights saved in static mode')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
@ -79,7 +74,7 @@ def main():
|
|||
place = paddle.set_device(place)
|
||||
|
||||
net = ResNet50()
|
||||
load_dygraph_pretrain(net, args.pretrained_model, args.load_static_weights)
|
||||
load_dygraph_pretrain(net, args.pretrained_model)
|
||||
|
||||
img = cv2.imread(args.image_file, cv2.IMREAD_COLOR)
|
||||
data = preprocess(img, operators)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -12,126 +12,204 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
|
||||
from paddle.nn import Conv2D, BatchNorm, Linear
|
||||
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
|
||||
from paddle.nn.initializer import Uniform
|
||||
|
||||
import math
|
||||
|
||||
__all__ = ["ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"]
|
||||
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
|
||||
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
|
||||
|
||||
MODEL_URLS = {
|
||||
"ResNet18":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_pretrained.pdparams",
|
||||
"ResNet18_vd":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_vd_pretrained.pdparams",
|
||||
"ResNet34":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet34_pretrained.pdparams",
|
||||
"ResNet34_vd":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet34_vd_pretrained.pdparams",
|
||||
"ResNet50":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet50_pretrained.pdparams",
|
||||
"ResNet50_vd":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet50_vd_pretrained.pdparams",
|
||||
"ResNet101":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet101_pretrained.pdparams",
|
||||
"ResNet101_vd":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet101_vd_pretrained.pdparams",
|
||||
"ResNet152":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet152_pretrained.pdparams",
|
||||
"ResNet152_vd":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet152_vd_pretrained.pdparams",
|
||||
"ResNet200_vd":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet200_vd_pretrained.pdparams",
|
||||
}
|
||||
|
||||
__all__ = MODEL_URLS.keys()
|
||||
'''
|
||||
ResNet config: dict.
|
||||
key: depth of ResNet.
|
||||
values: config's dict of specific model.
|
||||
keys:
|
||||
block_type: Two different blocks in ResNet, BasicBlock and BottleneckBlock are optional.
|
||||
block_depth: The number of blocks in different stages in ResNet.
|
||||
num_channels: The number of channels to enter the next stage.
|
||||
'''
|
||||
NET_CONFIG = {
|
||||
"18": {
|
||||
"block_type": "BasicBlock",
|
||||
"block_depth": [2, 2, 2, 2],
|
||||
"num_channels": [64, 64, 128, 256]
|
||||
},
|
||||
"34": {
|
||||
"block_type": "BasicBlock",
|
||||
"block_depth": [3, 4, 6, 3],
|
||||
"num_channels": [64, 64, 128, 256]
|
||||
},
|
||||
"50": {
|
||||
"block_type": "BottleneckBlock",
|
||||
"block_depth": [3, 4, 6, 3],
|
||||
"num_channels": [64, 256, 512, 1024]
|
||||
},
|
||||
"101": {
|
||||
"block_type": "BottleneckBlock",
|
||||
"block_depth": [3, 4, 23, 3],
|
||||
"num_channels": [64, 256, 512, 1024]
|
||||
},
|
||||
"152": {
|
||||
"block_type": "BottleneckBlock",
|
||||
"block_depth": [3, 8, 36, 3],
|
||||
"num_channels": [64, 256, 512, 1024]
|
||||
},
|
||||
"200": {
|
||||
"block_type": "BottleneckBlock",
|
||||
"block_depth": [3, 12, 48, 3],
|
||||
"num_channels": [64, 256, 512, 1024]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
class ConvBNLayer(TheseusLayer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
num_filters,
|
||||
filter_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
is_vd_mode=False,
|
||||
act=None,
|
||||
name=None):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
self._conv = Conv2D(
|
||||
lr_mult=1.0,
|
||||
data_format="NCHW"):
|
||||
super().__init__()
|
||||
self.is_vd_mode = is_vd_mode
|
||||
self.act = act
|
||||
self.avg_pool = AvgPool2D(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||
self.conv = Conv2D(
|
||||
in_channels=num_channels,
|
||||
out_channels=num_filters,
|
||||
kernel_size=filter_size,
|
||||
stride=stride,
|
||||
padding=(filter_size - 1) // 2,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False)
|
||||
if name == "conv1":
|
||||
bn_name = "bn_" + name
|
||||
else:
|
||||
bn_name = "bn" + name[3:]
|
||||
self._batch_norm = BatchNorm(
|
||||
weight_attr=ParamAttr(learning_rate=lr_mult),
|
||||
bias_attr=False,
|
||||
data_format=data_format)
|
||||
self.bn = BatchNorm(
|
||||
num_filters,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name=bn_name + "_scale"),
|
||||
bias_attr=ParamAttr(bn_name + "_offset"),
|
||||
moving_mean_name=bn_name + "_mean",
|
||||
moving_variance_name=bn_name + "_variance")
|
||||
param_attr=ParamAttr(learning_rate=lr_mult),
|
||||
bias_attr=ParamAttr(learning_rate=lr_mult),
|
||||
data_layout=data_format)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
return y
|
||||
def forward(self, x):
|
||||
if self.is_vd_mode:
|
||||
x = self.avg_pool(x)
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
if self.act:
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class BottleneckBlock(nn.Layer):
|
||||
class BottleneckBlock(TheseusLayer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
num_filters,
|
||||
stride,
|
||||
shortcut=True,
|
||||
name=None):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
if_first=False,
|
||||
lr_mult=1.0,
|
||||
data_format="NCHW"):
|
||||
super().__init__()
|
||||
|
||||
self.conv0 = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=num_filters,
|
||||
filter_size=1,
|
||||
act="relu",
|
||||
name=name + "_branch2a")
|
||||
lr_mult=lr_mult,
|
||||
data_format=data_format)
|
||||
self.conv1 = ConvBNLayer(
|
||||
num_channels=num_filters,
|
||||
num_filters=num_filters,
|
||||
filter_size=3,
|
||||
stride=stride,
|
||||
act="relu",
|
||||
name=name + "_branch2b")
|
||||
lr_mult=lr_mult,
|
||||
data_format=data_format)
|
||||
self.conv2 = ConvBNLayer(
|
||||
num_channels=num_filters,
|
||||
num_filters=num_filters * 4,
|
||||
filter_size=1,
|
||||
act=None,
|
||||
name=name + "_branch2c")
|
||||
lr_mult=lr_mult,
|
||||
data_format=data_format)
|
||||
|
||||
if not shortcut:
|
||||
self.short = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=num_filters * 4,
|
||||
filter_size=1,
|
||||
stride=stride,
|
||||
name=name + "_branch1")
|
||||
|
||||
stride=stride if if_first else 1,
|
||||
is_vd_mode=False if if_first else True,
|
||||
lr_mult=lr_mult,
|
||||
data_format=data_format)
|
||||
self.relu = nn.ReLU()
|
||||
self.shortcut = shortcut
|
||||
|
||||
self._num_channels_out = num_filters * 4
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv0(inputs)
|
||||
conv1 = self.conv1(y)
|
||||
conv2 = self.conv2(conv1)
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.conv0(x)
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
if self.shortcut:
|
||||
short = inputs
|
||||
short = identity
|
||||
else:
|
||||
short = self.short(inputs)
|
||||
|
||||
y = paddle.add(x=short, y=conv2)
|
||||
y = F.relu(y)
|
||||
return y
|
||||
short = self.short(identity)
|
||||
x = paddle.add(x=x, y=short)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class BasicBlock(nn.Layer):
|
||||
class BasicBlock(TheseusLayer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
num_filters,
|
||||
stride,
|
||||
shortcut=True,
|
||||
name=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
if_first=False,
|
||||
lr_mult=1.0,
|
||||
data_format="NCHW"):
|
||||
super().__init__()
|
||||
|
||||
self.stride = stride
|
||||
self.conv0 = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
|
@ -139,155 +217,319 @@ class BasicBlock(nn.Layer):
|
|||
filter_size=3,
|
||||
stride=stride,
|
||||
act="relu",
|
||||
name=name + "_branch2a")
|
||||
lr_mult=lr_mult,
|
||||
data_format=data_format)
|
||||
self.conv1 = ConvBNLayer(
|
||||
num_channels=num_filters,
|
||||
num_filters=num_filters,
|
||||
filter_size=3,
|
||||
act=None,
|
||||
name=name + "_branch2b")
|
||||
|
||||
lr_mult=lr_mult,
|
||||
data_format=data_format)
|
||||
if not shortcut:
|
||||
self.short = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=num_filters,
|
||||
filter_size=1,
|
||||
stride=stride,
|
||||
name=name + "_branch1")
|
||||
|
||||
stride=stride if if_first else 1,
|
||||
is_vd_mode=False if if_first else True,
|
||||
lr_mult=lr_mult,
|
||||
data_format=data_format)
|
||||
self.shortcut = shortcut
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv0(inputs)
|
||||
conv1 = self.conv1(y)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.conv0(x)
|
||||
x = self.conv1(x)
|
||||
if self.shortcut:
|
||||
short = inputs
|
||||
short = identity
|
||||
else:
|
||||
short = self.short(inputs)
|
||||
y = paddle.add(x=short, y=conv1)
|
||||
y = F.relu(y)
|
||||
return y
|
||||
short = self.short(identity)
|
||||
x = paddle.add(x=x, y=short)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResNet(nn.Layer):
|
||||
def __init__(self, layers=50, class_dim=1000):
|
||||
super(ResNet, self).__init__()
|
||||
class ResNet(TheseusLayer):
|
||||
"""
|
||||
ResNet
|
||||
Args:
|
||||
config: dict. config of ResNet.
|
||||
version: str="vb". Different version of ResNet, version vd can perform better.
|
||||
class_num: int=1000. The number of classes.
|
||||
lr_mult_list: list. Control the learning rate of different stages.
|
||||
Returns:
|
||||
model: nn.Layer. Specific ResNet model depends on args.
|
||||
"""
|
||||
|
||||
self.layers = layers
|
||||
supported_layers = [18, 34, 50, 101, 152]
|
||||
assert layers in supported_layers, \
|
||||
"supported layers are {} but input layer is {}".format(
|
||||
supported_layers, layers)
|
||||
def __init__(self,
|
||||
config,
|
||||
version="vb",
|
||||
class_num=1000,
|
||||
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
|
||||
data_format="NCHW",
|
||||
input_image_channel=3,
|
||||
return_patterns=None):
|
||||
super().__init__()
|
||||
|
||||
if layers == 18:
|
||||
depth = [2, 2, 2, 2]
|
||||
elif layers == 34 or layers == 50:
|
||||
depth = [3, 4, 6, 3]
|
||||
elif layers == 101:
|
||||
depth = [3, 4, 23, 3]
|
||||
elif layers == 152:
|
||||
depth = [3, 8, 36, 3]
|
||||
num_channels = [64, 256, 512,
|
||||
1024] if layers >= 50 else [64, 64, 128, 256]
|
||||
num_filters = [64, 128, 256, 512]
|
||||
self.cfg = config
|
||||
self.lr_mult_list = lr_mult_list
|
||||
self.is_vd_mode = version == "vd"
|
||||
self.class_num = class_num
|
||||
self.num_filters = [64, 128, 256, 512]
|
||||
self.block_depth = self.cfg["block_depth"]
|
||||
self.block_type = self.cfg["block_type"]
|
||||
self.num_channels = self.cfg["num_channels"]
|
||||
self.channels_mult = 1 if self.num_channels[-1] == 256 else 4
|
||||
|
||||
self.feature_map = None
|
||||
assert isinstance(self.lr_mult_list, (
|
||||
list, tuple
|
||||
)), "lr_mult_list should be in (list, tuple) but got {}".format(
|
||||
type(self.lr_mult_list))
|
||||
assert len(self.lr_mult_list
|
||||
) == 5, "lr_mult_list length should be 5 but got {}".format(
|
||||
len(self.lr_mult_list))
|
||||
|
||||
self.conv = ConvBNLayer(
|
||||
num_channels=3,
|
||||
num_filters=64,
|
||||
filter_size=7,
|
||||
stride=2,
|
||||
act="relu",
|
||||
name="conv1")
|
||||
self.pool2d_max = MaxPool2D(kernel_size=3, stride=2, padding=1)
|
||||
self.stem_cfg = {
|
||||
#num_channels, num_filters, filter_size, stride
|
||||
"vb": [[input_image_channel, 64, 7, 2]],
|
||||
"vd":
|
||||
[[input_image_channel, 32, 3, 2], [32, 32, 3, 1], [32, 64, 3, 1]]
|
||||
}
|
||||
|
||||
self.block_list = []
|
||||
if layers >= 50:
|
||||
for block in range(len(depth)):
|
||||
shortcut = False
|
||||
for i in range(depth[block]):
|
||||
if layers in [101, 152] and block == 2:
|
||||
if i == 0:
|
||||
conv_name = "res" + str(block + 2) + "a"
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + "b" + str(i)
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
bottleneck_block = self.add_sublayer(
|
||||
conv_name,
|
||||
BottleneckBlock(
|
||||
num_channels=num_channels[block]
|
||||
if i == 0 else num_filters[block] * 4,
|
||||
num_filters=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
name=conv_name))
|
||||
self.block_list.append(bottleneck_block)
|
||||
shortcut = True
|
||||
else:
|
||||
for block in range(len(depth)):
|
||||
shortcut = False
|
||||
for i in range(depth[block]):
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
basic_block = self.add_sublayer(
|
||||
conv_name,
|
||||
BasicBlock(
|
||||
num_channels=num_channels[block]
|
||||
if i == 0 else num_filters[block],
|
||||
num_filters=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
name=conv_name))
|
||||
self.block_list.append(basic_block)
|
||||
shortcut = True
|
||||
self.stem = nn.Sequential(* [
|
||||
ConvBNLayer(
|
||||
num_channels=in_c,
|
||||
num_filters=out_c,
|
||||
filter_size=k,
|
||||
stride=s,
|
||||
act="relu",
|
||||
lr_mult=self.lr_mult_list[0],
|
||||
data_format=data_format)
|
||||
for in_c, out_c, k, s in self.stem_cfg[version]
|
||||
])
|
||||
|
||||
self.pool2d_avg = AdaptiveAvgPool2D(1)
|
||||
self.max_pool = MaxPool2D(
|
||||
kernel_size=3, stride=2, padding=1, data_format=data_format)
|
||||
block_list = []
|
||||
for block_idx in range(len(self.block_depth)):
|
||||
shortcut = False
|
||||
for i in range(self.block_depth[block_idx]):
|
||||
block_list.append(globals()[self.block_type](
|
||||
num_channels=self.num_channels[block_idx] if i == 0 else
|
||||
self.num_filters[block_idx] * self.channels_mult,
|
||||
num_filters=self.num_filters[block_idx],
|
||||
stride=2 if i == 0 and block_idx != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
if_first=block_idx == i == 0 if version == "vd" else True,
|
||||
lr_mult=self.lr_mult_list[block_idx + 1],
|
||||
data_format=data_format))
|
||||
shortcut = True
|
||||
self.blocks = nn.Sequential(*block_list)
|
||||
|
||||
self.pool2d_avg_channels = num_channels[-1] * 2
|
||||
self.avg_pool = AdaptiveAvgPool2D(1, data_format=data_format)
|
||||
self.flatten = nn.Flatten()
|
||||
self.avg_pool_channels = self.num_channels[-1] * 2
|
||||
stdv = 1.0 / math.sqrt(self.avg_pool_channels * 1.0)
|
||||
self.fc = Linear(
|
||||
self.avg_pool_channels,
|
||||
self.class_num,
|
||||
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
|
||||
|
||||
stdv = 1.0 / math.sqrt(self.pool2d_avg_channels * 1.0)
|
||||
self.data_format = data_format
|
||||
if return_patterns is not None:
|
||||
self.update_res(return_patterns)
|
||||
self.register_forward_post_hook(self._return_dict_hook)
|
||||
|
||||
self.out = Linear(
|
||||
self.pool2d_avg_channels,
|
||||
class_dim,
|
||||
weight_attr=ParamAttr(
|
||||
initializer=Uniform(-stdv, stdv), name="fc_0.w_0"),
|
||||
bias_attr=ParamAttr(name="fc_0.b_0"))
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv(inputs)
|
||||
y = self.pool2d_max(y)
|
||||
self.feature_map = y
|
||||
for block in self.block_list:
|
||||
y = block(y)
|
||||
y = self.pool2d_avg(y)
|
||||
y = paddle.reshape(y, shape=[-1, self.pool2d_avg_channels])
|
||||
y = self.out(y)
|
||||
return y, self.feature_map
|
||||
def forward(self, x):
|
||||
with paddle.static.amp.fp16_guard():
|
||||
if self.data_format == "NHWC":
|
||||
x = paddle.transpose(x, [0, 2, 3, 1])
|
||||
x.stop_gradient = True
|
||||
x = self.stem(x)
|
||||
fm = x
|
||||
x = self.max_pool(x)
|
||||
x = self.blocks(x)
|
||||
x = self.avg_pool(x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc(x)
|
||||
return x, fm
|
||||
|
||||
|
||||
def ResNet18(**args):
|
||||
model = ResNet(layers=18, **args)
|
||||
def _load_pretrained(pretrained, model, model_url, use_ssld):
|
||||
if pretrained is False:
|
||||
pass
|
||||
elif pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
|
||||
elif isinstance(pretrained, str):
|
||||
load_dygraph_pretrain(model, pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type."
|
||||
)
|
||||
|
||||
|
||||
def ResNet18(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ResNet18
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ResNet18` model depends on args.
|
||||
"""
|
||||
model = ResNet(config=NET_CONFIG["18"], version="vb", **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ResNet18"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ResNet34(**args):
|
||||
model = ResNet(layers=34, **args)
|
||||
def ResNet18_vd(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ResNet18_vd
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ResNet18_vd` model depends on args.
|
||||
"""
|
||||
model = ResNet(config=NET_CONFIG["18"], version="vd", **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ResNet18_vd"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ResNet50(**args):
|
||||
model = ResNet(layers=50, **args)
|
||||
def ResNet34(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ResNet34
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ResNet34` model depends on args.
|
||||
"""
|
||||
model = ResNet(config=NET_CONFIG["34"], version="vb", **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ResNet34"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ResNet101(**args):
|
||||
model = ResNet(layers=101, **args)
|
||||
def ResNet34_vd(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ResNet34_vd
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ResNet34_vd` model depends on args.
|
||||
"""
|
||||
model = ResNet(config=NET_CONFIG["34"], version="vd", **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ResNet34_vd"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ResNet152(**args):
|
||||
model = ResNet(layers=152, **args)
|
||||
def ResNet50(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ResNet50
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ResNet50` model depends on args.
|
||||
"""
|
||||
model = ResNet(config=NET_CONFIG["50"], version="vb", **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ResNet50"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ResNet50_vd(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ResNet50_vd
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ResNet50_vd` model depends on args.
|
||||
"""
|
||||
model = ResNet(config=NET_CONFIG["50"], version="vd", **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ResNet50_vd"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ResNet101(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ResNet101
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ResNet101` model depends on args.
|
||||
"""
|
||||
model = ResNet(config=NET_CONFIG["101"], version="vb", **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ResNet101"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ResNet101_vd(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ResNet101_vd
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ResNet101_vd` model depends on args.
|
||||
"""
|
||||
model = ResNet(config=NET_CONFIG["101"], version="vd", **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ResNet101_vd"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ResNet152(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ResNet152
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ResNet152` model depends on args.
|
||||
"""
|
||||
model = ResNet(config=NET_CONFIG["152"], version="vb", **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ResNet152"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ResNet152_vd(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ResNet152_vd
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ResNet152_vd` model depends on args.
|
||||
"""
|
||||
model = ResNet(config=NET_CONFIG["152"], version="vd", **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ResNet152_vd"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ResNet200_vd(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ResNet200_vd
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ResNet200_vd` model depends on args.
|
||||
"""
|
||||
model = ResNet(config=NET_CONFIG["200"], version="vd", **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ResNet200_vd"], use_ssld)
|
||||
return model
|
||||
|
|
Loading…
Reference in New Issue