Add ESNet code and pretrained models
parent
507a9213f6
commit
d868a63764
|
@ -82,6 +82,7 @@
|
|||
| MobileNetV3_small_x1_0_ssld | 0.713 | 0.682 | 0.031 | 6.546 | 0.123 | 2.94 | 12 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x1_0_ssld_pretrained.pdparams) |
|
||||
| GhostNet_x1_3_ssld | 0.794 | 0.757 | 0.037 | 19.983 | 0.44 | 7.3 | 29 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GhostNet_x1_3_ssld_pretrained.pdparams) |
|
||||
|
||||
|
||||
<a name="Intel-CPU端知识蒸馏模型"></a>
|
||||
#### Intel CPU端知识蒸馏模型
|
||||
|
||||
|
@ -180,6 +181,10 @@ ResNet及其Vd系列模型的精度、速度指标如下表所示,更多关于
|
|||
| GhostNet_<br>x1_0 | 0.7402 | 0.9165 | 13.5587 | 0.294 | 5.2 | 20 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GhostNet_x1_0_pretrained.pdparams) |
|
||||
| GhostNet_<br>x1_3 | 0.7579 | 0.9254 | 19.9825 | 0.44 | 7.3 | 29 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GhostNet_x1_3_pretrained.pdparams) |
|
||||
| GhostNet_<br>x1_3_ssld | 0.7938 | 0.9449 | 19.9825 | 0.44 | 7.3 | 29 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GhostNet_x1_3_ssld_pretrained.pdparams) |
|
||||
| ESNet_x0_25 | 62.48 | 83.46 || 0.031 | 2.83 | 11 |[下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ESNet_x0_25_pretrained.pdparams) |
|
||||
| ESNet_x0_5 | 68.82 | 88.04 || 0.067 | 3.25 | 13 |[下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ESNet_x0_5_pretrained.pdparams) |
|
||||
| ESNet_x0_75 | 72.24 | 90.45 || 0.124 | 3.87 | 15 |[下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ESNet_x0_75_pretrained.pdparams) |
|
||||
| ESNet_x1_0 | 73.92 | 91.40 || 0.197 | 4.64 | 18 |[下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ESNet_x1_0_pretrained.pdparams) |
|
||||
|
||||
|
||||
<a name="SEResNeXt与Res2Net系列"></a>
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
# ESNet系列
|
||||
|
||||
## 概述
|
||||
|
||||
ESNet(Enhanced ShuffleNet)是百度自研的一个轻量级网络,该网络在ShuffleNetV2的基础上融合了MobileNetV3、GhostNet、PPLCNet的优点,组合成了一个在ARM设备上速度更快、精度更高的网络,由于其出色的表现,所以在PaddleDetection推出的[PP-PicoDet](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.3/configs/picodet)使用了该模型做backbone,配合更强的目标检测算法,最终的指标一举刷新了目标检测模型在ARM设备上的SOTA指标。
|
||||
|
||||
## 精度、FLOPs和参数量
|
||||
|
||||
| Models | Top1 | Top5 | FLOPs<br>(M) | Params<br/>(M) |
|
||||
|:--:|:--:|:--:|:--:|:--:|
|
||||
| ESNet_x0_25 | 62.48 | 83.46 | 30.9 | 2.83 |
|
||||
| ESNet_x0_5 | 68.82 | 88.04 | 67.3 | 3.25 |
|
||||
| ESNet_x0_75 | 72.24 | 90.45 | 123.7 | 3.87 |
|
||||
| ESNet_x1_0 | 73.92 | 91.40 | 197.3 | 4.64 |
|
||||
|
||||
关于Inference speed等信息,敬请期待。
|
|
@ -22,6 +22,7 @@ from ppcls.arch.backbone.legendary_models.vgg import VGG11, VGG13, VGG16, VGG19
|
|||
from ppcls.arch.backbone.legendary_models.inception_v3 import InceptionV3
|
||||
from ppcls.arch.backbone.legendary_models.hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W64_C
|
||||
from ppcls.arch.backbone.legendary_models.pp_lcnet import PPLCNet_x0_25, PPLCNet_x0_35, PPLCNet_x0_5, PPLCNet_x0_75, PPLCNet_x1_0, PPLCNet_x1_5, PPLCNet_x2_0, PPLCNet_x2_5
|
||||
from ppcls.arch.backbone.legendary_models.esnet import ESNet_x0_25, ESNet_x0_5, ESNet_x0_75, ESNet_x1_0
|
||||
|
||||
from ppcls.arch.backbone.model_zoo.resnet_vc import ResNet50_vc
|
||||
from ppcls.arch.backbone.model_zoo.resnext import ResNeXt50_32x4d, ResNeXt50_64x4d, ResNeXt101_32x4d, ResNeXt101_64x4d, ResNeXt152_32x4d, ResNeXt152_64x4d
|
||||
|
|
|
@ -0,0 +1,359 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import math
|
||||
import paddle
|
||||
from paddle import ParamAttr, reshape, transpose, concat, split
|
||||
import paddle.nn as nn
|
||||
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
|
||||
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
|
||||
from paddle.nn.initializer import KaimingNormal
|
||||
from paddle.regularizer import L2Decay
|
||||
|
||||
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
|
||||
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
|
||||
|
||||
|
||||
MODEL_URLS = {
|
||||
"ESNet_x0_25":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ESNet_x0_25_pretrained.pdparams",
|
||||
"ESNet_x0_5":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ESNet_x0_5_pretrained.pdparams",
|
||||
"ESNet_x0_75":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ESNet_x0_75_pretrained.pdparams",
|
||||
"ESNet_x1_0":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ESNet_x1_0_pretrained.pdparams",
|
||||
}
|
||||
|
||||
__all__ = list(MODEL_URLS.keys())
|
||||
|
||||
|
||||
def channel_shuffle(x, groups):
|
||||
batch_size, num_channels, height, width = x.shape[0:4]
|
||||
channels_per_group = num_channels // groups
|
||||
x = reshape(
|
||||
x=x, shape=[batch_size, groups, channels_per_group, height, width])
|
||||
x = transpose(x=x, perm=[0, 2, 1, 3, 4])
|
||||
x = reshape(x=x, shape=[batch_size, num_channels, height, width])
|
||||
return x
|
||||
|
||||
|
||||
def make_divisible(v, divisor=8, min_value=None):
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class ConvBNLayer(TheseusLayer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
if_act=True):
|
||||
super().__init__()
|
||||
self.conv = Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(initializer=KaimingNormal()),
|
||||
bias_attr=False)
|
||||
|
||||
self.bn = BatchNorm(
|
||||
out_channels,
|
||||
param_attr=ParamAttr(regularizer=L2Decay(0.0)),
|
||||
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
|
||||
self.if_act = if_act
|
||||
self.hardswish = nn.Hardswish()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
if self.if_act:
|
||||
x = self.hardswish(x)
|
||||
return x
|
||||
|
||||
|
||||
class SEModule(TheseusLayer):
|
||||
def __init__(self, channel, reduction=4):
|
||||
super().__init__()
|
||||
self.avg_pool = AdaptiveAvgPool2D(1)
|
||||
self.conv1 = Conv2D(
|
||||
in_channels=channel,
|
||||
out_channels=channel // reduction,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv2 = Conv2D(
|
||||
in_channels=channel // reduction,
|
||||
out_channels=channel,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.hardsigmoid = nn.Hardsigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.avg_pool(x)
|
||||
x = self.conv1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv2(x)
|
||||
x = self.hardsigmoid(x)
|
||||
x = paddle.multiply(x=identity, y=x)
|
||||
return x
|
||||
|
||||
|
||||
class ESBlock1(TheseusLayer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels):
|
||||
super().__init__()
|
||||
self.pw_1_1 = ConvBNLayer(
|
||||
in_channels=in_channels // 2,
|
||||
out_channels=out_channels // 2,
|
||||
kernel_size=1,
|
||||
stride=1)
|
||||
self.dw_1 = ConvBNLayer(
|
||||
in_channels=out_channels // 2,
|
||||
out_channels=out_channels // 2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
groups=out_channels // 2,
|
||||
if_act=False)
|
||||
self.se = SEModule(out_channels)
|
||||
|
||||
self.pw_1_2 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels // 2,
|
||||
kernel_size=1,
|
||||
stride=1)
|
||||
|
||||
def forward(self, x):
|
||||
x1, x2 = split(
|
||||
x,
|
||||
num_or_sections=[x.shape[1] // 2, x.shape[1] // 2],
|
||||
axis=1)
|
||||
x2 = self.pw_1_1(x2)
|
||||
x3 = self.dw_1(x2)
|
||||
x3 = concat([x2, x3], axis=1)
|
||||
x3 = self.se(x3)
|
||||
x3 = self.pw_1_2(x3)
|
||||
x = concat([x1, x3], axis=1)
|
||||
return channel_shuffle(x, 2)
|
||||
|
||||
|
||||
class ESBlock2(TheseusLayer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels):
|
||||
super().__init__()
|
||||
|
||||
# branch1
|
||||
self.dw_1 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
groups=in_channels,
|
||||
if_act=False)
|
||||
self.pw_1 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels // 2,
|
||||
kernel_size=1,
|
||||
stride=1)
|
||||
# branch2
|
||||
self.pw_2_1 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels // 2,
|
||||
kernel_size=1)
|
||||
self.dw_2 = ConvBNLayer(
|
||||
in_channels=out_channels // 2,
|
||||
out_channels=out_channels // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
groups=out_channels // 2,
|
||||
if_act=False)
|
||||
self.se = SEModule(out_channels // 2)
|
||||
self.pw_2_2 = ConvBNLayer(
|
||||
in_channels=out_channels // 2,
|
||||
out_channels=out_channels // 2,
|
||||
kernel_size=1)
|
||||
self.concat_dw = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
groups=out_channels)
|
||||
self.concat_pw = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.dw_1(x)
|
||||
x1 = self.pw_1(x1)
|
||||
x2 = self.pw_2_1(x)
|
||||
x2 = self.dw_2(x2)
|
||||
x2 = self.se(x2)
|
||||
x2 = self.pw_2_2(x2)
|
||||
x = concat([x1, x2], axis=1)
|
||||
x = self.concat_dw(x)
|
||||
x = self.concat_pw(x)
|
||||
return x
|
||||
|
||||
|
||||
class ESNet(TheseusLayer):
|
||||
def __init__(self, class_num=1000, scale=1.0, dropout_prob=0.2, class_expand=1280):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.class_num = class_num
|
||||
self.class_expand = class_expand
|
||||
stage_repeats = [3, 7, 3]
|
||||
stage_out_channels = [-1, 24, make_divisible(116*scale), make_divisible(232*scale), make_divisible(464*scale), 1024]
|
||||
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channels=3,
|
||||
out_channels=stage_out_channels[1],
|
||||
kernel_size=3,
|
||||
stride=2)
|
||||
self.max_pool = MaxPool2D(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
block_list = []
|
||||
for stage_id, num_repeat in enumerate(stage_repeats):
|
||||
for i in range(num_repeat):
|
||||
if i == 0:
|
||||
block = ESBlock2(
|
||||
in_channels=stage_out_channels[stage_id + 1],
|
||||
out_channels=stage_out_channels[stage_id + 2])
|
||||
else:
|
||||
block = ESBlock1(
|
||||
in_channels=stage_out_channels[stage_id + 2],
|
||||
out_channels=stage_out_channels[stage_id + 2])
|
||||
block_list.append(block)
|
||||
self.blocks = nn.Sequential(*block_list)
|
||||
|
||||
self.conv2 = ConvBNLayer(
|
||||
in_channels=stage_out_channels[-2],
|
||||
out_channels=stage_out_channels[-1],
|
||||
kernel_size=1)
|
||||
|
||||
self.avg_pool = AdaptiveAvgPool2D(1)
|
||||
|
||||
self.last_conv = Conv2D(
|
||||
in_channels=stage_out_channels[-1],
|
||||
out_channels=self.class_expand,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias_attr=False)
|
||||
self.hardswish = nn.Hardswish()
|
||||
self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer")
|
||||
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
|
||||
self.fc = Linear(self.class_expand, self.class_num)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.max_pool(x)
|
||||
x = self.blocks(x)
|
||||
x = self.conv2(x)
|
||||
x = self.avg_pool(x)
|
||||
x = self.last_conv(x)
|
||||
x = self.hardswish(x)
|
||||
x = self.dropout(x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
def _load_pretrained(pretrained, model, model_url, use_ssld):
|
||||
if pretrained is False:
|
||||
pass
|
||||
elif pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
|
||||
elif isinstance(pretrained, str):
|
||||
load_dygraph_pretrain(model, pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type."
|
||||
)
|
||||
|
||||
|
||||
def ESNet_x0_25(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ESNet_x0_25
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ESNet_x0_25` model depends on args.
|
||||
"""
|
||||
model = ESNet(scale=0.25, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ESNet_x0_25"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ESNet_x0_5(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ESNet_x0_5
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ESNet_x0_5` model depends on args.
|
||||
"""
|
||||
model = ESNet(scale=0.5, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ESNet_x0_5"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ESNet_x0_75(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ESNet_x0_75
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ESNet_x0_75` model depends on args.
|
||||
"""
|
||||
model = ESNet(scale=0.75, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ESNet_x0_75"], use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def ESNet_x1_0(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
ESNet_x1_0
|
||||
Args:
|
||||
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `ESNet_x1_0` model depends on args.
|
||||
"""
|
||||
model = ESNet(scale=1.0, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["ESNet_x1_0"], use_ssld)
|
||||
return model
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: ./output/
|
||||
device: gpu
|
||||
class_num: 1000
|
||||
save_interval: 1
|
||||
eval_during_train: True
|
||||
eval_interval: 1
|
||||
epochs: 360
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: ./inference
|
||||
# model architecture
|
||||
Arch:
|
||||
name: ESNet_x0_25
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.8
|
||||
warmup_epoch: 5
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
coeff: 0.00003
|
||||
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: ./dataset/ILSVRC2012/
|
||||
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- RandCropImage:
|
||||
size: 224
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 512
|
||||
drop_last: False
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: ./dataset/ILSVRC2012/
|
||||
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Infer:
|
||||
infer_imgs: docs/images/whl/demo.jpg
|
||||
batch_size: 10
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
PostProcess:
|
||||
name: Topk
|
||||
topk: 5
|
||||
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
|
||||
|
||||
Metric:
|
||||
Train:
|
||||
- TopkAcc:
|
||||
topk: [1, 5]
|
||||
Eval:
|
||||
- TopkAcc:
|
||||
topk: [1, 5]
|
|
@ -0,0 +1,129 @@
|
|||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: ./output/
|
||||
device: gpu
|
||||
class_num: 1000
|
||||
save_interval: 1
|
||||
eval_during_train: True
|
||||
eval_interval: 1
|
||||
epochs: 360
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: ./inference
|
||||
# model architecture
|
||||
Arch:
|
||||
name: ESNet_x0_5
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.8
|
||||
warmup_epoch: 5
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
coeff: 0.00003
|
||||
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: ./dataset/ILSVRC2012/
|
||||
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- RandCropImage:
|
||||
size: 224
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 512
|
||||
drop_last: False
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: ./dataset/ILSVRC2012/
|
||||
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Infer:
|
||||
infer_imgs: docs/images/whl/demo.jpg
|
||||
batch_size: 10
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
PostProcess:
|
||||
name: Topk
|
||||
topk: 5
|
||||
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
|
||||
|
||||
Metric:
|
||||
Train:
|
||||
- TopkAcc:
|
||||
topk: [1, 5]
|
||||
Eval:
|
||||
- TopkAcc:
|
||||
topk: [1, 5]
|
|
@ -0,0 +1,129 @@
|
|||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: ./output/
|
||||
device: gpu
|
||||
class_num: 1000
|
||||
save_interval: 1
|
||||
eval_during_train: True
|
||||
eval_interval: 1
|
||||
epochs: 360
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: ./inference
|
||||
# model architecture
|
||||
Arch:
|
||||
name: ESNet_x0_75
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.8
|
||||
warmup_epoch: 5
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
coeff: 0.00003
|
||||
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: ./dataset/ILSVRC2012/
|
||||
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- RandCropImage:
|
||||
size: 224
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 512
|
||||
drop_last: False
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: ./dataset/ILSVRC2012/
|
||||
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Infer:
|
||||
infer_imgs: docs/images/whl/demo.jpg
|
||||
batch_size: 10
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
PostProcess:
|
||||
name: Topk
|
||||
topk: 5
|
||||
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
|
||||
|
||||
Metric:
|
||||
Train:
|
||||
- TopkAcc:
|
||||
topk: [1, 5]
|
||||
Eval:
|
||||
- TopkAcc:
|
||||
topk: [1, 5]
|
|
@ -0,0 +1,129 @@
|
|||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: ./output/
|
||||
device: gpu
|
||||
class_num: 1000
|
||||
save_interval: 1
|
||||
eval_during_train: True
|
||||
eval_interval: 1
|
||||
epochs: 360
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: ./inference
|
||||
# model architecture
|
||||
Arch:
|
||||
name: ESNet_x1_0
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.8
|
||||
warmup_epoch: 5
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
coeff: 0.00003
|
||||
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: ./dataset/ILSVRC2012/
|
||||
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- RandCropImage:
|
||||
size: 224
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 512
|
||||
drop_last: False
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: ./dataset/ILSVRC2012/
|
||||
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Infer:
|
||||
infer_imgs: docs/images/whl/demo.jpg
|
||||
batch_size: 10
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
PostProcess:
|
||||
name: Topk
|
||||
topk: 5
|
||||
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
|
||||
|
||||
Metric:
|
||||
Train:
|
||||
- TopkAcc:
|
||||
topk: [1, 5]
|
||||
Eval:
|
||||
- TopkAcc:
|
||||
topk: [1, 5]
|
Loading…
Reference in New Issue