From 8ac51f7416eea21b12fc2d28db46b2214d31304e Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Thu, 4 Nov 2021 09:45:44 +0000 Subject: [PATCH] fix: adapt to release 2.3 --- docs/zh_CN/others/feature_visiualization.md | 58 +- .../download_resnet50_pretrained.sh | 2 - .../feature_maps_visualization/fm_vis.py | 11 +- .../feature_maps_visualization/resnet.py | 572 +++++++++++++----- 4 files changed, 440 insertions(+), 203 deletions(-) delete mode 100644 ppcls/utils/feature_maps_visualization/download_resnet50_pretrained.sh diff --git a/docs/zh_CN/others/feature_visiualization.md b/docs/zh_CN/others/feature_visiualization.md index 8a7229e8d..65717ea80 100644 --- a/docs/zh_CN/others/feature_visiualization.md +++ b/docs/zh_CN/others/feature_visiualization.md @@ -6,43 +6,45 @@ ## 二、准备工作 -首先需要选定研究的模型,本文设定ResNet50作为研究模型,将resnet.py从[模型库](../../../ppcls/arch/architecture/)拷贝到当前目录下,并下载预训练模型[预训练模型](../../zh_CN/models/models_intro), 复制resnet50的模型链接,使用下列命令下载并解压预训练模型。 +首先需要选定研究的模型,本文设定ResNet50作为研究模型,将模型组网代码[resnet.py](../../../ppcls/arch/backbone/legendary_models/resnet.py)拷贝到[目录](../../../ppcls/utils/feature_maps_visualization/)下,并下载[ResNet50预训练模型](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_pretrained.pdparams),或使用以下命令下载。 ```bash -wget The Link for Pretrained Model -tar -xf Downloaded Pretrained Model +wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_pretrained.pdparams ``` -以resnet50为例: -```bash -wget https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar -tar -xf ResNet50_pretrained.tar -``` +其他模型网络结构代码及预训练模型请自行下载:[模型库](../../../ppcls/arch/backbone/),[预训练模型](../models/models_intro.md)。 ## 三、修改模型 找到我们所需要的特征图位置,设置self.fm将其fetch出来,本文以resnet50中的stem层之后的特征图为例。 -在fm_vis.py中修改模型的名字。 +在ResNet50的forward函数中指定要可视化的特征图 -在ResNet50的__init__函数中定义self.fm ```python -self.fm = None + def forward(self, x): + with paddle.static.amp.fp16_guard(): + if self.data_format == "NHWC": + x = paddle.transpose(x, [0, 2, 3, 1]) + x.stop_gradient = True + x = self.stem(x) + fm = x + x = self.max_pool(x) + x = self.blocks(x) + x = self.avg_pool(x) + x = self.flatten(x) + x = self.fc(x) + return x, fm ``` -在ResNet50的forward函数中指定特征图 + +然后修改代码[fm_vis.py](../../../ppcls/utils/feature_maps_visualization/fm_vis.py),引入 `ResNet50`,实例化 `net` 对象: + ```python -def forward(self, inputs): - y = self.conv(inputs) - self.fm = y - y = self.pool2d_max(y) - for bottleneck_block in self.bottleneck_block_list: - y = bottleneck_block(y) - y = self.avg_pool(y) - y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_output]) - y = self.out(y) - return y, self.fm +from resnet import ResNet50 +net = ResNet50() ``` -执行函数 + +最后执行函数 + ```bash python tools/feature_maps_visualization/fm_vis.py -i the image you want to test \ -c channel_num -p pretrained model \ @@ -51,9 +53,10 @@ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test --save_path where to save \ --use_gpu whether to use gpu ``` + 参数说明: + `-i`:待预测的图片文件路径,如 `./test.jpeg` -+ `-c`:特征图维度,如 `./resnet50_vd/model` ++ `-c`:特征图维度,如 `5` + `-p`:权重文件路径,如 `./ResNet50_pretrained/` + `--interpolation`: 图像插值方式, 默认值 1 + `--save_path`:保存路径,如:`./tools/` @@ -63,7 +66,7 @@ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test * 输入图片: -![](../../../docs/images/feature_maps/feature_visualization_input.jpg) +![](../../images/feature_maps/feature_visualization_input.jpg) * 运行下面的特征图可视化脚本 @@ -75,10 +78,9 @@ python tools/feature_maps_visualization/fm_vis.py \ --show=True \ --interpolation=1 \ --save_path="./output.png" \ - --use_gpu=False \ - --load_static_weights=True + --use_gpu=False ``` * 输出特征图保存为`output.png`,如下所示。 -![](../../../docs/images/feature_maps/feature_visualization_output.jpg) +![](../../images/feature_maps/feature_visualization_output.jpg) diff --git a/ppcls/utils/feature_maps_visualization/download_resnet50_pretrained.sh b/ppcls/utils/feature_maps_visualization/download_resnet50_pretrained.sh deleted file mode 100644 index 286c2400a..000000000 --- a/ppcls/utils/feature_maps_visualization/download_resnet50_pretrained.sh +++ /dev/null @@ -1,2 +0,0 @@ -wget https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar -tar -xf ResNet50_pretrained.tar \ No newline at end of file diff --git a/ppcls/utils/feature_maps_visualization/fm_vis.py b/ppcls/utils/feature_maps_visualization/fm_vis.py index d8ee125f0..a5368b10e 100644 --- a/ppcls/utils/feature_maps_visualization/fm_vis.py +++ b/ppcls/utils/feature_maps_visualization/fm_vis.py @@ -19,7 +19,7 @@ import os import sys __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../../..'))) import paddle from paddle.distributed import ParallelEnv @@ -33,18 +33,13 @@ def parse_args(): return v.lower() in ("true", "t", "1") parser = argparse.ArgumentParser() - parser.add_argument("-i", "--image_file", type=str) + parser.add_argument("-i", "--image_file", required=True, type=str) parser.add_argument("-c", "--channel_num", type=int) parser.add_argument("-p", "--pretrained_model", type=str) parser.add_argument("--show", type=str2bool, default=False) parser.add_argument("--interpolation", type=int, default=1) parser.add_argument("--save_path", type=str, default=None) parser.add_argument("--use_gpu", type=str2bool, default=True) - parser.add_argument( - "--load_static_weights", - type=str2bool, - default=False, - help='Whether to load the pretrained weights saved in static mode') return parser.parse_args() @@ -79,7 +74,7 @@ def main(): place = paddle.set_device(place) net = ResNet50() - load_dygraph_pretrain(net, args.pretrained_model, args.load_static_weights) + load_dygraph_pretrain(net, args.pretrained_model) img = cv2.imread(args.image_file, cv2.IMREAD_COLOR) data = preprocess(img, operators) diff --git a/ppcls/utils/feature_maps_visualization/resnet.py b/ppcls/utils/feature_maps_visualization/resnet.py index b0171f849..b75881414 100644 --- a/ppcls/utils/feature_maps_visualization/resnet.py +++ b/ppcls/utils/feature_maps_visualization/resnet.py @@ -1,4 +1,4 @@ -# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,126 +12,204 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function import numpy as np import paddle from paddle import ParamAttr import paddle.nn as nn -import paddle.nn.functional as F -from paddle.nn import Conv2D, BatchNorm, Linear, Dropout +from paddle.nn import Conv2D, BatchNorm, Linear from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D from paddle.nn.initializer import Uniform - import math -__all__ = ["ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"] +from ppcls.arch.backbone.base.theseus_layer import TheseusLayer +from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url + +MODEL_URLS = { + "ResNet18": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_pretrained.pdparams", + "ResNet18_vd": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_vd_pretrained.pdparams", + "ResNet34": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet34_pretrained.pdparams", + "ResNet34_vd": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet34_vd_pretrained.pdparams", + "ResNet50": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet50_pretrained.pdparams", + "ResNet50_vd": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet50_vd_pretrained.pdparams", + "ResNet101": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet101_pretrained.pdparams", + "ResNet101_vd": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet101_vd_pretrained.pdparams", + "ResNet152": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet152_pretrained.pdparams", + "ResNet152_vd": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet152_vd_pretrained.pdparams", + "ResNet200_vd": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet200_vd_pretrained.pdparams", +} + +__all__ = MODEL_URLS.keys() +''' +ResNet config: dict. + key: depth of ResNet. + values: config's dict of specific model. + keys: + block_type: Two different blocks in ResNet, BasicBlock and BottleneckBlock are optional. + block_depth: The number of blocks in different stages in ResNet. + num_channels: The number of channels to enter the next stage. +''' +NET_CONFIG = { + "18": { + "block_type": "BasicBlock", + "block_depth": [2, 2, 2, 2], + "num_channels": [64, 64, 128, 256] + }, + "34": { + "block_type": "BasicBlock", + "block_depth": [3, 4, 6, 3], + "num_channels": [64, 64, 128, 256] + }, + "50": { + "block_type": "BottleneckBlock", + "block_depth": [3, 4, 6, 3], + "num_channels": [64, 256, 512, 1024] + }, + "101": { + "block_type": "BottleneckBlock", + "block_depth": [3, 4, 23, 3], + "num_channels": [64, 256, 512, 1024] + }, + "152": { + "block_type": "BottleneckBlock", + "block_depth": [3, 8, 36, 3], + "num_channels": [64, 256, 512, 1024] + }, + "200": { + "block_type": "BottleneckBlock", + "block_depth": [3, 12, 48, 3], + "num_channels": [64, 256, 512, 1024] + }, +} -class ConvBNLayer(nn.Layer): +class ConvBNLayer(TheseusLayer): def __init__(self, num_channels, num_filters, filter_size, stride=1, groups=1, + is_vd_mode=False, act=None, - name=None): - super(ConvBNLayer, self).__init__() - - self._conv = Conv2D( + lr_mult=1.0, + data_format="NCHW"): + super().__init__() + self.is_vd_mode = is_vd_mode + self.act = act + self.avg_pool = AvgPool2D( + kernel_size=2, stride=2, padding=0, ceil_mode=True) + self.conv = Conv2D( in_channels=num_channels, out_channels=num_filters, kernel_size=filter_size, stride=stride, padding=(filter_size - 1) // 2, groups=groups, - weight_attr=ParamAttr(name=name + "_weights"), - bias_attr=False) - if name == "conv1": - bn_name = "bn_" + name - else: - bn_name = "bn" + name[3:] - self._batch_norm = BatchNorm( + weight_attr=ParamAttr(learning_rate=lr_mult), + bias_attr=False, + data_format=data_format) + self.bn = BatchNorm( num_filters, - act=act, - param_attr=ParamAttr(name=bn_name + "_scale"), - bias_attr=ParamAttr(bn_name + "_offset"), - moving_mean_name=bn_name + "_mean", - moving_variance_name=bn_name + "_variance") + param_attr=ParamAttr(learning_rate=lr_mult), + bias_attr=ParamAttr(learning_rate=lr_mult), + data_layout=data_format) + self.relu = nn.ReLU() - def forward(self, inputs): - y = self._conv(inputs) - y = self._batch_norm(y) - return y + def forward(self, x): + if self.is_vd_mode: + x = self.avg_pool(x) + x = self.conv(x) + x = self.bn(x) + if self.act: + x = self.relu(x) + return x -class BottleneckBlock(nn.Layer): +class BottleneckBlock(TheseusLayer): def __init__(self, num_channels, num_filters, stride, shortcut=True, - name=None): - super(BottleneckBlock, self).__init__() + if_first=False, + lr_mult=1.0, + data_format="NCHW"): + super().__init__() self.conv0 = ConvBNLayer( num_channels=num_channels, num_filters=num_filters, filter_size=1, act="relu", - name=name + "_branch2a") + lr_mult=lr_mult, + data_format=data_format) self.conv1 = ConvBNLayer( num_channels=num_filters, num_filters=num_filters, filter_size=3, stride=stride, act="relu", - name=name + "_branch2b") + lr_mult=lr_mult, + data_format=data_format) self.conv2 = ConvBNLayer( num_channels=num_filters, num_filters=num_filters * 4, filter_size=1, act=None, - name=name + "_branch2c") + lr_mult=lr_mult, + data_format=data_format) if not shortcut: self.short = ConvBNLayer( num_channels=num_channels, num_filters=num_filters * 4, filter_size=1, - stride=stride, - name=name + "_branch1") - + stride=stride if if_first else 1, + is_vd_mode=False if if_first else True, + lr_mult=lr_mult, + data_format=data_format) + self.relu = nn.ReLU() self.shortcut = shortcut - self._num_channels_out = num_filters * 4 - - def forward(self, inputs): - y = self.conv0(inputs) - conv1 = self.conv1(y) - conv2 = self.conv2(conv1) + def forward(self, x): + identity = x + x = self.conv0(x) + x = self.conv1(x) + x = self.conv2(x) if self.shortcut: - short = inputs + short = identity else: - short = self.short(inputs) - - y = paddle.add(x=short, y=conv2) - y = F.relu(y) - return y + short = self.short(identity) + x = paddle.add(x=x, y=short) + x = self.relu(x) + return x -class BasicBlock(nn.Layer): +class BasicBlock(TheseusLayer): def __init__(self, num_channels, num_filters, stride, shortcut=True, - name=None): - super(BasicBlock, self).__init__() + if_first=False, + lr_mult=1.0, + data_format="NCHW"): + super().__init__() + self.stride = stride self.conv0 = ConvBNLayer( num_channels=num_channels, @@ -139,155 +217,319 @@ class BasicBlock(nn.Layer): filter_size=3, stride=stride, act="relu", - name=name + "_branch2a") + lr_mult=lr_mult, + data_format=data_format) self.conv1 = ConvBNLayer( num_channels=num_filters, num_filters=num_filters, filter_size=3, act=None, - name=name + "_branch2b") - + lr_mult=lr_mult, + data_format=data_format) if not shortcut: self.short = ConvBNLayer( num_channels=num_channels, num_filters=num_filters, filter_size=1, - stride=stride, - name=name + "_branch1") - + stride=stride if if_first else 1, + is_vd_mode=False if if_first else True, + lr_mult=lr_mult, + data_format=data_format) self.shortcut = shortcut + self.relu = nn.ReLU() - def forward(self, inputs): - y = self.conv0(inputs) - conv1 = self.conv1(y) - + def forward(self, x): + identity = x + x = self.conv0(x) + x = self.conv1(x) if self.shortcut: - short = inputs + short = identity else: - short = self.short(inputs) - y = paddle.add(x=short, y=conv1) - y = F.relu(y) - return y + short = self.short(identity) + x = paddle.add(x=x, y=short) + x = self.relu(x) + return x -class ResNet(nn.Layer): - def __init__(self, layers=50, class_dim=1000): - super(ResNet, self).__init__() +class ResNet(TheseusLayer): + """ + ResNet + Args: + config: dict. config of ResNet. + version: str="vb". Different version of ResNet, version vd can perform better. + class_num: int=1000. The number of classes. + lr_mult_list: list. Control the learning rate of different stages. + Returns: + model: nn.Layer. Specific ResNet model depends on args. + """ - self.layers = layers - supported_layers = [18, 34, 50, 101, 152] - assert layers in supported_layers, \ - "supported layers are {} but input layer is {}".format( - supported_layers, layers) + def __init__(self, + config, + version="vb", + class_num=1000, + lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0], + data_format="NCHW", + input_image_channel=3, + return_patterns=None): + super().__init__() - if layers == 18: - depth = [2, 2, 2, 2] - elif layers == 34 or layers == 50: - depth = [3, 4, 6, 3] - elif layers == 101: - depth = [3, 4, 23, 3] - elif layers == 152: - depth = [3, 8, 36, 3] - num_channels = [64, 256, 512, - 1024] if layers >= 50 else [64, 64, 128, 256] - num_filters = [64, 128, 256, 512] + self.cfg = config + self.lr_mult_list = lr_mult_list + self.is_vd_mode = version == "vd" + self.class_num = class_num + self.num_filters = [64, 128, 256, 512] + self.block_depth = self.cfg["block_depth"] + self.block_type = self.cfg["block_type"] + self.num_channels = self.cfg["num_channels"] + self.channels_mult = 1 if self.num_channels[-1] == 256 else 4 - self.feature_map = None + assert isinstance(self.lr_mult_list, ( + list, tuple + )), "lr_mult_list should be in (list, tuple) but got {}".format( + type(self.lr_mult_list)) + assert len(self.lr_mult_list + ) == 5, "lr_mult_list length should be 5 but got {}".format( + len(self.lr_mult_list)) - self.conv = ConvBNLayer( - num_channels=3, - num_filters=64, - filter_size=7, - stride=2, - act="relu", - name="conv1") - self.pool2d_max = MaxPool2D(kernel_size=3, stride=2, padding=1) + self.stem_cfg = { + #num_channels, num_filters, filter_size, stride + "vb": [[input_image_channel, 64, 7, 2]], + "vd": + [[input_image_channel, 32, 3, 2], [32, 32, 3, 1], [32, 64, 3, 1]] + } - self.block_list = [] - if layers >= 50: - for block in range(len(depth)): - shortcut = False - for i in range(depth[block]): - if layers in [101, 152] and block == 2: - if i == 0: - conv_name = "res" + str(block + 2) + "a" - else: - conv_name = "res" + str(block + 2) + "b" + str(i) - else: - conv_name = "res" + str(block + 2) + chr(97 + i) - bottleneck_block = self.add_sublayer( - conv_name, - BottleneckBlock( - num_channels=num_channels[block] - if i == 0 else num_filters[block] * 4, - num_filters=num_filters[block], - stride=2 if i == 0 and block != 0 else 1, - shortcut=shortcut, - name=conv_name)) - self.block_list.append(bottleneck_block) - shortcut = True - else: - for block in range(len(depth)): - shortcut = False - for i in range(depth[block]): - conv_name = "res" + str(block + 2) + chr(97 + i) - basic_block = self.add_sublayer( - conv_name, - BasicBlock( - num_channels=num_channels[block] - if i == 0 else num_filters[block], - num_filters=num_filters[block], - stride=2 if i == 0 and block != 0 else 1, - shortcut=shortcut, - name=conv_name)) - self.block_list.append(basic_block) - shortcut = True + self.stem = nn.Sequential(* [ + ConvBNLayer( + num_channels=in_c, + num_filters=out_c, + filter_size=k, + stride=s, + act="relu", + lr_mult=self.lr_mult_list[0], + data_format=data_format) + for in_c, out_c, k, s in self.stem_cfg[version] + ]) - self.pool2d_avg = AdaptiveAvgPool2D(1) + self.max_pool = MaxPool2D( + kernel_size=3, stride=2, padding=1, data_format=data_format) + block_list = [] + for block_idx in range(len(self.block_depth)): + shortcut = False + for i in range(self.block_depth[block_idx]): + block_list.append(globals()[self.block_type]( + num_channels=self.num_channels[block_idx] if i == 0 else + self.num_filters[block_idx] * self.channels_mult, + num_filters=self.num_filters[block_idx], + stride=2 if i == 0 and block_idx != 0 else 1, + shortcut=shortcut, + if_first=block_idx == i == 0 if version == "vd" else True, + lr_mult=self.lr_mult_list[block_idx + 1], + data_format=data_format)) + shortcut = True + self.blocks = nn.Sequential(*block_list) - self.pool2d_avg_channels = num_channels[-1] * 2 + self.avg_pool = AdaptiveAvgPool2D(1, data_format=data_format) + self.flatten = nn.Flatten() + self.avg_pool_channels = self.num_channels[-1] * 2 + stdv = 1.0 / math.sqrt(self.avg_pool_channels * 1.0) + self.fc = Linear( + self.avg_pool_channels, + self.class_num, + weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv))) - stdv = 1.0 / math.sqrt(self.pool2d_avg_channels * 1.0) + self.data_format = data_format + if return_patterns is not None: + self.update_res(return_patterns) + self.register_forward_post_hook(self._return_dict_hook) - self.out = Linear( - self.pool2d_avg_channels, - class_dim, - weight_attr=ParamAttr( - initializer=Uniform(-stdv, stdv), name="fc_0.w_0"), - bias_attr=ParamAttr(name="fc_0.b_0")) - - def forward(self, inputs): - y = self.conv(inputs) - y = self.pool2d_max(y) - self.feature_map = y - for block in self.block_list: - y = block(y) - y = self.pool2d_avg(y) - y = paddle.reshape(y, shape=[-1, self.pool2d_avg_channels]) - y = self.out(y) - return y, self.feature_map + def forward(self, x): + with paddle.static.amp.fp16_guard(): + if self.data_format == "NHWC": + x = paddle.transpose(x, [0, 2, 3, 1]) + x.stop_gradient = True + x = self.stem(x) + fm = x + x = self.max_pool(x) + x = self.blocks(x) + x = self.avg_pool(x) + x = self.flatten(x) + x = self.fc(x) + return x, fm -def ResNet18(**args): - model = ResNet(layers=18, **args) +def _load_pretrained(pretrained, model, model_url, use_ssld): + if pretrained is False: + pass + elif pretrained is True: + load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld) + elif isinstance(pretrained, str): + load_dygraph_pretrain(model, pretrained) + else: + raise RuntimeError( + "pretrained type is not available. Please use `string` or `boolean` type." + ) + + +def ResNet18(pretrained=False, use_ssld=False, **kwargs): + """ + ResNet18 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `ResNet18` model depends on args. + """ + model = ResNet(config=NET_CONFIG["18"], version="vb", **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ResNet18"], use_ssld) return model -def ResNet34(**args): - model = ResNet(layers=34, **args) +def ResNet18_vd(pretrained=False, use_ssld=False, **kwargs): + """ + ResNet18_vd + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `ResNet18_vd` model depends on args. + """ + model = ResNet(config=NET_CONFIG["18"], version="vd", **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ResNet18_vd"], use_ssld) return model -def ResNet50(**args): - model = ResNet(layers=50, **args) +def ResNet34(pretrained=False, use_ssld=False, **kwargs): + """ + ResNet34 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `ResNet34` model depends on args. + """ + model = ResNet(config=NET_CONFIG["34"], version="vb", **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ResNet34"], use_ssld) return model -def ResNet101(**args): - model = ResNet(layers=101, **args) +def ResNet34_vd(pretrained=False, use_ssld=False, **kwargs): + """ + ResNet34_vd + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `ResNet34_vd` model depends on args. + """ + model = ResNet(config=NET_CONFIG["34"], version="vd", **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ResNet34_vd"], use_ssld) return model -def ResNet152(**args): - model = ResNet(layers=152, **args) +def ResNet50(pretrained=False, use_ssld=False, **kwargs): + """ + ResNet50 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `ResNet50` model depends on args. + """ + model = ResNet(config=NET_CONFIG["50"], version="vb", **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ResNet50"], use_ssld) + return model + + +def ResNet50_vd(pretrained=False, use_ssld=False, **kwargs): + """ + ResNet50_vd + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `ResNet50_vd` model depends on args. + """ + model = ResNet(config=NET_CONFIG["50"], version="vd", **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ResNet50_vd"], use_ssld) + return model + + +def ResNet101(pretrained=False, use_ssld=False, **kwargs): + """ + ResNet101 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `ResNet101` model depends on args. + """ + model = ResNet(config=NET_CONFIG["101"], version="vb", **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ResNet101"], use_ssld) + return model + + +def ResNet101_vd(pretrained=False, use_ssld=False, **kwargs): + """ + ResNet101_vd + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `ResNet101_vd` model depends on args. + """ + model = ResNet(config=NET_CONFIG["101"], version="vd", **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ResNet101_vd"], use_ssld) + return model + + +def ResNet152(pretrained=False, use_ssld=False, **kwargs): + """ + ResNet152 + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `ResNet152` model depends on args. + """ + model = ResNet(config=NET_CONFIG["152"], version="vb", **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ResNet152"], use_ssld) + return model + + +def ResNet152_vd(pretrained=False, use_ssld=False, **kwargs): + """ + ResNet152_vd + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `ResNet152_vd` model depends on args. + """ + model = ResNet(config=NET_CONFIG["152"], version="vd", **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ResNet152_vd"], use_ssld) + return model + + +def ResNet200_vd(pretrained=False, use_ssld=False, **kwargs): + """ + ResNet200_vd + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `ResNet200_vd` model depends on args. + """ + model = ResNet(config=NET_CONFIG["200"], version="vd", **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["ResNet200_vd"], use_ssld) return model