fix: adapt to release 2.3

pull/1396/head
gaotingquan 2021-11-04 09:45:44 +00:00
parent 677f6aea7e
commit 8ac51f7416
4 changed files with 440 additions and 203 deletions

View File

@ -6,43 +6,45 @@
## 二、准备工作
首先需要选定研究的模型本文设定ResNet50作为研究模型resnet.py从[模型库](../../../ppcls/arch/architecture/)拷贝到当前目录下,并下载预训练模型[预训练模型](../../zh_CN/models/models_intro), 复制resnet50的模型链接使用下列命令下载并解压预训练模型
首先需要选定研究的模型本文设定ResNet50作为研究模型模型组网代码[resnet.py](../../../ppcls/arch/backbone/legendary_models/resnet.py)拷贝到[目录](../../../ppcls/utils/feature_maps_visualization/)下,并下载[ResNet50预训练模型](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_pretrained.pdparams),或使用以下命令下载
```bash
wget The Link for Pretrained Model
tar -xf Downloaded Pretrained Model
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_pretrained.pdparams
```
以resnet50为例
```bash
wget https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar
tar -xf ResNet50_pretrained.tar
```
其他模型网络结构代码及预训练模型请自行下载:[模型库](../../../ppcls/arch/backbone/)[预训练模型](../models/models_intro.md)。
## 三、修改模型
找到我们所需要的特征图位置设置self.fm将其fetch出来本文以resnet50中的stem层之后的特征图为例。
fm_vis.py中修改模型的名字。
ResNet50的forward函数中指定要可视化的特征图
在ResNet50的__init__函数中定义self.fm
```python
self.fm = None
def forward(self, x):
with paddle.static.amp.fp16_guard():
if self.data_format == "NHWC":
x = paddle.transpose(x, [0, 2, 3, 1])
x.stop_gradient = True
x = self.stem(x)
fm = x
x = self.max_pool(x)
x = self.blocks(x)
x = self.avg_pool(x)
x = self.flatten(x)
x = self.fc(x)
return x, fm
```
在ResNet50的forward函数中指定特征图
然后修改代码[fm_vis.py](../../../ppcls/utils/feature_maps_visualization/fm_vis.py),引入 `ResNet50`,实例化 `net` 对象:
```python
def forward(self, inputs):
y = self.conv(inputs)
self.fm = y
y = self.pool2d_max(y)
for bottleneck_block in self.bottleneck_block_list:
y = bottleneck_block(y)
y = self.avg_pool(y)
y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_output])
y = self.out(y)
return y, self.fm
from resnet import ResNet50
net = ResNet50()
```
执行函数
最后执行函数
```bash
python tools/feature_maps_visualization/fm_vis.py -i the image you want to test \
-c channel_num -p pretrained model \
@ -51,9 +53,10 @@ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test
--save_path where to save \
--use_gpu whether to use gpu
```
参数说明:
+ `-i`:待预测的图片文件路径,如 `./test.jpeg`
+ `-c`:特征图维度,如 `./resnet50_vd/model`
+ `-c`:特征图维度,如 `5`
+ `-p`:权重文件路径,如 `./ResNet50_pretrained/`
+ `--interpolation`: 图像插值方式, 默认值 1
+ `--save_path`:保存路径,如:`./tools/`
@ -63,7 +66,7 @@ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test
* 输入图片:
![](../../../docs/images/feature_maps/feature_visualization_input.jpg)
![](../../images/feature_maps/feature_visualization_input.jpg)
* 运行下面的特征图可视化脚本
@ -75,10 +78,9 @@ python tools/feature_maps_visualization/fm_vis.py \
--show=True \
--interpolation=1 \
--save_path="./output.png" \
--use_gpu=False \
--load_static_weights=True
--use_gpu=False
```
* 输出特征图保存为`output.png`,如下所示。
![](../../../docs/images/feature_maps/feature_visualization_output.jpg)
![](../../images/feature_maps/feature_visualization_output.jpg)

View File

@ -1,2 +0,0 @@
wget https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar
tar -xf ResNet50_pretrained.tar

View File

@ -19,7 +19,7 @@ import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../..')))
import paddle
from paddle.distributed import ParallelEnv
@ -33,18 +33,13 @@ def parse_args():
return v.lower() in ("true", "t", "1")
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--image_file", type=str)
parser.add_argument("-i", "--image_file", required=True, type=str)
parser.add_argument("-c", "--channel_num", type=int)
parser.add_argument("-p", "--pretrained_model", type=str)
parser.add_argument("--show", type=str2bool, default=False)
parser.add_argument("--interpolation", type=int, default=1)
parser.add_argument("--save_path", type=str, default=None)
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument(
"--load_static_weights",
type=str2bool,
default=False,
help='Whether to load the pretrained weights saved in static mode')
return parser.parse_args()
@ -79,7 +74,7 @@ def main():
place = paddle.set_device(place)
net = ResNet50()
load_dygraph_pretrain(net, args.pretrained_model, args.load_static_weights)
load_dygraph_pretrain(net, args.pretrained_model)
img = cv2.imread(args.image_file, cv2.IMREAD_COLOR)
data = preprocess(img, operators)

View File

@ -1,4 +1,4 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,126 +12,204 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import, division, print_function
import numpy as np
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import Conv2D, BatchNorm, Linear
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import Uniform
import math
__all__ = ["ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"]
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = {
"ResNet18":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_pretrained.pdparams",
"ResNet18_vd":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_vd_pretrained.pdparams",
"ResNet34":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet34_pretrained.pdparams",
"ResNet34_vd":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet34_vd_pretrained.pdparams",
"ResNet50":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet50_pretrained.pdparams",
"ResNet50_vd":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet50_vd_pretrained.pdparams",
"ResNet101":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet101_pretrained.pdparams",
"ResNet101_vd":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet101_vd_pretrained.pdparams",
"ResNet152":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet152_pretrained.pdparams",
"ResNet152_vd":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet152_vd_pretrained.pdparams",
"ResNet200_vd":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet200_vd_pretrained.pdparams",
}
__all__ = MODEL_URLS.keys()
'''
ResNet config: dict.
key: depth of ResNet.
values: config's dict of specific model.
keys:
block_type: Two different blocks in ResNet, BasicBlock and BottleneckBlock are optional.
block_depth: The number of blocks in different stages in ResNet.
num_channels: The number of channels to enter the next stage.
'''
NET_CONFIG = {
"18": {
"block_type": "BasicBlock",
"block_depth": [2, 2, 2, 2],
"num_channels": [64, 64, 128, 256]
},
"34": {
"block_type": "BasicBlock",
"block_depth": [3, 4, 6, 3],
"num_channels": [64, 64, 128, 256]
},
"50": {
"block_type": "BottleneckBlock",
"block_depth": [3, 4, 6, 3],
"num_channels": [64, 256, 512, 1024]
},
"101": {
"block_type": "BottleneckBlock",
"block_depth": [3, 4, 23, 3],
"num_channels": [64, 256, 512, 1024]
},
"152": {
"block_type": "BottleneckBlock",
"block_depth": [3, 8, 36, 3],
"num_channels": [64, 256, 512, 1024]
},
"200": {
"block_type": "BottleneckBlock",
"block_depth": [3, 12, 48, 3],
"num_channels": [64, 256, 512, 1024]
},
}
class ConvBNLayer(nn.Layer):
class ConvBNLayer(TheseusLayer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
is_vd_mode=False,
act=None,
name=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
lr_mult=1.0,
data_format="NCHW"):
super().__init__()
self.is_vd_mode = is_vd_mode
self.act = act
self.avg_pool = AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.conv = Conv2D(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = BatchNorm(
weight_attr=ParamAttr(learning_rate=lr_mult),
bias_attr=False,
data_format=data_format)
self.bn = BatchNorm(
num_filters,
act=act,
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(bn_name + "_offset"),
moving_mean_name=bn_name + "_mean",
moving_variance_name=bn_name + "_variance")
param_attr=ParamAttr(learning_rate=lr_mult),
bias_attr=ParamAttr(learning_rate=lr_mult),
data_layout=data_format)
self.relu = nn.ReLU()
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
return y
def forward(self, x):
if self.is_vd_mode:
x = self.avg_pool(x)
x = self.conv(x)
x = self.bn(x)
if self.act:
x = self.relu(x)
return x
class BottleneckBlock(nn.Layer):
class BottleneckBlock(TheseusLayer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
name=None):
super(BottleneckBlock, self).__init__()
if_first=False,
lr_mult=1.0,
data_format="NCHW"):
super().__init__()
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act="relu",
name=name + "_branch2a")
lr_mult=lr_mult,
data_format=data_format)
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
act="relu",
name=name + "_branch2b")
lr_mult=lr_mult,
data_format=data_format)
self.conv2 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name + "_branch2c")
lr_mult=lr_mult,
data_format=data_format)
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters * 4,
filter_size=1,
stride=stride,
name=name + "_branch1")
stride=stride if if_first else 1,
is_vd_mode=False if if_first else True,
lr_mult=lr_mult,
data_format=data_format)
self.relu = nn.ReLU()
self.shortcut = shortcut
self._num_channels_out = num_filters * 4
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
def forward(self, x):
identity = x
x = self.conv0(x)
x = self.conv1(x)
x = self.conv2(x)
if self.shortcut:
short = inputs
short = identity
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv2)
y = F.relu(y)
return y
short = self.short(identity)
x = paddle.add(x=x, y=short)
x = self.relu(x)
return x
class BasicBlock(nn.Layer):
class BasicBlock(TheseusLayer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
name=None):
super(BasicBlock, self).__init__()
if_first=False,
lr_mult=1.0,
data_format="NCHW"):
super().__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
num_channels=num_channels,
@ -139,155 +217,319 @@ class BasicBlock(nn.Layer):
filter_size=3,
stride=stride,
act="relu",
name=name + "_branch2a")
lr_mult=lr_mult,
data_format=data_format)
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
act=None,
name=name + "_branch2b")
lr_mult=lr_mult,
data_format=data_format)
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
stride=stride,
name=name + "_branch1")
stride=stride if if_first else 1,
is_vd_mode=False if if_first else True,
lr_mult=lr_mult,
data_format=data_format)
self.shortcut = shortcut
self.relu = nn.ReLU()
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
def forward(self, x):
identity = x
x = self.conv0(x)
x = self.conv1(x)
if self.shortcut:
short = inputs
short = identity
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv1)
y = F.relu(y)
return y
short = self.short(identity)
x = paddle.add(x=x, y=short)
x = self.relu(x)
return x
class ResNet(nn.Layer):
def __init__(self, layers=50, class_dim=1000):
super(ResNet, self).__init__()
class ResNet(TheseusLayer):
"""
ResNet
Args:
config: dict. config of ResNet.
version: str="vb". Different version of ResNet, version vd can perform better.
class_num: int=1000. The number of classes.
lr_mult_list: list. Control the learning rate of different stages.
Returns:
model: nn.Layer. Specific ResNet model depends on args.
"""
self.layers = layers
supported_layers = [18, 34, 50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(
supported_layers, layers)
def __init__(self,
config,
version="vb",
class_num=1000,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
data_format="NCHW",
input_image_channel=3,
return_patterns=None):
super().__init__()
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_channels = [64, 256, 512,
1024] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512]
self.cfg = config
self.lr_mult_list = lr_mult_list
self.is_vd_mode = version == "vd"
self.class_num = class_num
self.num_filters = [64, 128, 256, 512]
self.block_depth = self.cfg["block_depth"]
self.block_type = self.cfg["block_type"]
self.num_channels = self.cfg["num_channels"]
self.channels_mult = 1 if self.num_channels[-1] == 256 else 4
self.feature_map = None
assert isinstance(self.lr_mult_list, (
list, tuple
)), "lr_mult_list should be in (list, tuple) but got {}".format(
type(self.lr_mult_list))
assert len(self.lr_mult_list
) == 5, "lr_mult_list length should be 5 but got {}".format(
len(self.lr_mult_list))
self.conv = ConvBNLayer(
num_channels=3,
num_filters=64,
filter_size=7,
stride=2,
act="relu",
name="conv1")
self.pool2d_max = MaxPool2D(kernel_size=3, stride=2, padding=1)
self.stem_cfg = {
#num_channels, num_filters, filter_size, stride
"vb": [[input_image_channel, 64, 7, 2]],
"vd":
[[input_image_channel, 32, 3, 2], [32, 32, 3, 1], [32, 64, 3, 1]]
}
self.block_list = []
if layers >= 50:
for block in range(len(depth)):
shortcut = False
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = self.add_sublayer(
conv_name,
BottleneckBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
name=conv_name))
self.block_list.append(bottleneck_block)
shortcut = True
else:
for block in range(len(depth)):
shortcut = False
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
conv_name,
BasicBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block],
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
name=conv_name))
self.block_list.append(basic_block)
shortcut = True
self.stem = nn.Sequential(* [
ConvBNLayer(
num_channels=in_c,
num_filters=out_c,
filter_size=k,
stride=s,
act="relu",
lr_mult=self.lr_mult_list[0],
data_format=data_format)
for in_c, out_c, k, s in self.stem_cfg[version]
])
self.pool2d_avg = AdaptiveAvgPool2D(1)
self.max_pool = MaxPool2D(
kernel_size=3, stride=2, padding=1, data_format=data_format)
block_list = []
for block_idx in range(len(self.block_depth)):
shortcut = False
for i in range(self.block_depth[block_idx]):
block_list.append(globals()[self.block_type](
num_channels=self.num_channels[block_idx] if i == 0 else
self.num_filters[block_idx] * self.channels_mult,
num_filters=self.num_filters[block_idx],
stride=2 if i == 0 and block_idx != 0 else 1,
shortcut=shortcut,
if_first=block_idx == i == 0 if version == "vd" else True,
lr_mult=self.lr_mult_list[block_idx + 1],
data_format=data_format))
shortcut = True
self.blocks = nn.Sequential(*block_list)
self.pool2d_avg_channels = num_channels[-1] * 2
self.avg_pool = AdaptiveAvgPool2D(1, data_format=data_format)
self.flatten = nn.Flatten()
self.avg_pool_channels = self.num_channels[-1] * 2
stdv = 1.0 / math.sqrt(self.avg_pool_channels * 1.0)
self.fc = Linear(
self.avg_pool_channels,
self.class_num,
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
stdv = 1.0 / math.sqrt(self.pool2d_avg_channels * 1.0)
self.data_format = data_format
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
self.out = Linear(
self.pool2d_avg_channels,
class_dim,
weight_attr=ParamAttr(
initializer=Uniform(-stdv, stdv), name="fc_0.w_0"),
bias_attr=ParamAttr(name="fc_0.b_0"))
def forward(self, inputs):
y = self.conv(inputs)
y = self.pool2d_max(y)
self.feature_map = y
for block in self.block_list:
y = block(y)
y = self.pool2d_avg(y)
y = paddle.reshape(y, shape=[-1, self.pool2d_avg_channels])
y = self.out(y)
return y, self.feature_map
def forward(self, x):
with paddle.static.amp.fp16_guard():
if self.data_format == "NHWC":
x = paddle.transpose(x, [0, 2, 3, 1])
x.stop_gradient = True
x = self.stem(x)
fm = x
x = self.max_pool(x)
x = self.blocks(x)
x = self.avg_pool(x)
x = self.flatten(x)
x = self.fc(x)
return x, fm
def ResNet18(**args):
model = ResNet(layers=18, **args)
def _load_pretrained(pretrained, model, model_url, use_ssld):
if pretrained is False:
pass
elif pretrained is True:
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
elif isinstance(pretrained, str):
load_dygraph_pretrain(model, pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def ResNet18(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet18
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet18` model depends on args.
"""
model = ResNet(config=NET_CONFIG["18"], version="vb", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet18"], use_ssld)
return model
def ResNet34(**args):
model = ResNet(layers=34, **args)
def ResNet18_vd(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet18_vd
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet18_vd` model depends on args.
"""
model = ResNet(config=NET_CONFIG["18"], version="vd", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet18_vd"], use_ssld)
return model
def ResNet50(**args):
model = ResNet(layers=50, **args)
def ResNet34(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet34
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet34` model depends on args.
"""
model = ResNet(config=NET_CONFIG["34"], version="vb", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet34"], use_ssld)
return model
def ResNet101(**args):
model = ResNet(layers=101, **args)
def ResNet34_vd(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet34_vd
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet34_vd` model depends on args.
"""
model = ResNet(config=NET_CONFIG["34"], version="vd", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet34_vd"], use_ssld)
return model
def ResNet152(**args):
model = ResNet(layers=152, **args)
def ResNet50(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet50
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet50` model depends on args.
"""
model = ResNet(config=NET_CONFIG["50"], version="vb", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet50"], use_ssld)
return model
def ResNet50_vd(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet50_vd
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet50_vd` model depends on args.
"""
model = ResNet(config=NET_CONFIG["50"], version="vd", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet50_vd"], use_ssld)
return model
def ResNet101(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet101
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet101` model depends on args.
"""
model = ResNet(config=NET_CONFIG["101"], version="vb", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet101"], use_ssld)
return model
def ResNet101_vd(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet101_vd
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet101_vd` model depends on args.
"""
model = ResNet(config=NET_CONFIG["101"], version="vd", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet101_vd"], use_ssld)
return model
def ResNet152(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet152
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet152` model depends on args.
"""
model = ResNet(config=NET_CONFIG["152"], version="vb", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet152"], use_ssld)
return model
def ResNet152_vd(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet152_vd
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet152_vd` model depends on args.
"""
model = ResNet(config=NET_CONFIG["152"], version="vd", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet152_vd"], use_ssld)
return model
def ResNet200_vd(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet200_vd
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet200_vd` model depends on args.
"""
model = ResNet(config=NET_CONFIG["200"], version="vd", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet200_vd"], use_ssld)
return model