Add PPHGNet code
parent
50c1302bce
commit
8a760fb85f
|
@ -0,0 +1,24 @@
|
||||||
|
# PP-HGNet 系列
|
||||||
|
---
|
||||||
|
## 目录
|
||||||
|
|
||||||
|
* [1. 概述](#1)
|
||||||
|
* [2. 精度、FLOPs 和参数量](#2)
|
||||||
|
|
||||||
|
<a name='1'></a>
|
||||||
|
|
||||||
|
## 1. 概述
|
||||||
|
|
||||||
|
PP-HGNet是百度自研的一个在 GPU 端上高性能的网络,该网络在 VOVNet 的基础上融合了 ResNet_vd、PPLCNet 的优点,使用了可学习的下采样层,组合成了一个在 GPU 设备上速度快、精度高的网络,超越其他 GPU 端 SOTA 模型。
|
||||||
|
|
||||||
|
<a name='2'></a>
|
||||||
|
|
||||||
|
## 2.精度、FLOPs 和参数量
|
||||||
|
|
||||||
|
| Models | Top1 | Top5 | FLOPs<br>(G) | Params<br/>(M) |
|
||||||
|
|:--:|:--:|:--:|:--:|:--:|
|
||||||
|
| PPHGNet_tiny | 79.83 | 95.04 | 4.54 | 14.75 |
|
||||||
|
| PPHGNet_tiny_ssld | 81.95 | 96.12 | 4.54 | 14.75 |
|
||||||
|
| PPHGNet_small | 81.51 | 95.82 | 8.53 | 24.38 |
|
||||||
|
|
||||||
|
关于 Inference speed 等信息,敬请期待。
|
|
@ -23,6 +23,7 @@ from ppcls.arch.backbone.legendary_models.inception_v3 import InceptionV3
|
||||||
from ppcls.arch.backbone.legendary_models.hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W64_C
|
from ppcls.arch.backbone.legendary_models.hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W64_C
|
||||||
from ppcls.arch.backbone.legendary_models.pp_lcnet import PPLCNet_x0_25, PPLCNet_x0_35, PPLCNet_x0_5, PPLCNet_x0_75, PPLCNet_x1_0, PPLCNet_x1_5, PPLCNet_x2_0, PPLCNet_x2_5
|
from ppcls.arch.backbone.legendary_models.pp_lcnet import PPLCNet_x0_25, PPLCNet_x0_35, PPLCNet_x0_5, PPLCNet_x0_75, PPLCNet_x1_0, PPLCNet_x1_5, PPLCNet_x2_0, PPLCNet_x2_5
|
||||||
from ppcls.arch.backbone.legendary_models.esnet import ESNet_x0_25, ESNet_x0_5, ESNet_x0_75, ESNet_x1_0
|
from ppcls.arch.backbone.legendary_models.esnet import ESNet_x0_25, ESNet_x0_5, ESNet_x0_75, ESNet_x1_0
|
||||||
|
from ppcls.arch.backbone.legendary_models.pp_hgnet import PPHGNet_tiny, PPHGNet_small, PPHGNet_base
|
||||||
|
|
||||||
from ppcls.arch.backbone.model_zoo.resnet_vc import ResNet50_vc
|
from ppcls.arch.backbone.model_zoo.resnet_vc import ResNet50_vc
|
||||||
from ppcls.arch.backbone.model_zoo.resnext import ResNeXt50_32x4d, ResNeXt50_64x4d, ResNeXt101_32x4d, ResNeXt101_64x4d, ResNeXt152_32x4d, ResNeXt152_64x4d
|
from ppcls.arch.backbone.model_zoo.resnext import ResNeXt50_32x4d, ResNeXt50_64x4d, ResNeXt101_32x4d, ResNeXt101_64x4d, ResNeXt152_32x4d, ResNeXt152_64x4d
|
||||||
|
|
|
@ -0,0 +1,372 @@
|
||||||
|
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
from paddle.nn.initializer import KaimingNormal, Constant
|
||||||
|
from paddle.nn import Conv2D, BatchNorm2D, ReLU, AdaptiveAvgPool2D, MaxPool2D
|
||||||
|
from paddle.regularizer import L2Decay
|
||||||
|
from paddle import ParamAttr
|
||||||
|
|
||||||
|
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
|
||||||
|
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
|
||||||
|
|
||||||
|
MODEL_URLS = {
|
||||||
|
"PPHGNet_tiny":
|
||||||
|
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_tiny_pretrained.pdparams",
|
||||||
|
"PPHGNet_small":
|
||||||
|
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_small_pretrained.pdparams"
|
||||||
|
}
|
||||||
|
|
||||||
|
__all__ = list(MODEL_URLS.keys())
|
||||||
|
|
||||||
|
kaiming_normal_ = KaimingNormal()
|
||||||
|
zeros_ = Constant(value=0.)
|
||||||
|
ones_ = Constant(value=1.)
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBNAct(TheseusLayer):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
groups=1,
|
||||||
|
use_act=True):
|
||||||
|
super().__init__()
|
||||||
|
self.use_act = use_act
|
||||||
|
self.conv = Conv2D(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding=(kernel_size - 1) // 2,
|
||||||
|
groups=groups,
|
||||||
|
bias_attr=False)
|
||||||
|
self.bn = BatchNorm2D(
|
||||||
|
out_channels,
|
||||||
|
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
|
||||||
|
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
|
||||||
|
if self.use_act:
|
||||||
|
self.act = ReLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
if self.use_act:
|
||||||
|
x = self.act(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ESEModule(TheseusLayer):
|
||||||
|
def __init__(self, channels):
|
||||||
|
super().__init__()
|
||||||
|
self.avg_pool = AdaptiveAvgPool2D(1)
|
||||||
|
self.conv = Conv2D(
|
||||||
|
in_channels=channels,
|
||||||
|
out_channels=channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
x = self.avg_pool(x)
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.sigmoid(x)
|
||||||
|
return paddle.multiply(x=identity, y=x)
|
||||||
|
|
||||||
|
|
||||||
|
class _HG_Block(TheseusLayer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
mid_channels,
|
||||||
|
out_channels,
|
||||||
|
layer_num,
|
||||||
|
identity=False, ):
|
||||||
|
super().__init__()
|
||||||
|
self.identity = identity
|
||||||
|
|
||||||
|
self.layers = nn.LayerList()
|
||||||
|
self.layers.append(
|
||||||
|
ConvBNAct(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=mid_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1))
|
||||||
|
for _ in range(layer_num - 1):
|
||||||
|
self.layers.append(
|
||||||
|
ConvBNAct(
|
||||||
|
in_channels=mid_channels,
|
||||||
|
out_channels=mid_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1))
|
||||||
|
|
||||||
|
# feature aggregation
|
||||||
|
total_channels = in_channels + layer_num * mid_channels
|
||||||
|
self.aggregation_conv = ConvBNAct(
|
||||||
|
in_channels=total_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1)
|
||||||
|
self.att = ESEModule(out_channels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
output = []
|
||||||
|
output.append(x)
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x)
|
||||||
|
output.append(x)
|
||||||
|
x = paddle.concat(output, axis=1)
|
||||||
|
x = self.aggregation_conv(x)
|
||||||
|
x = self.att(x)
|
||||||
|
if self.identity:
|
||||||
|
x += identity
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class _HG_Stage(TheseusLayer):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
mid_channels,
|
||||||
|
out_channels,
|
||||||
|
block_num,
|
||||||
|
layer_num,
|
||||||
|
downsample=True):
|
||||||
|
super().__init__()
|
||||||
|
self.downsample = downsample
|
||||||
|
if downsample:
|
||||||
|
self.downsample = ConvBNAct(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
groups=in_channels,
|
||||||
|
use_act=False)
|
||||||
|
|
||||||
|
blocks_list = []
|
||||||
|
blocks_list.append(
|
||||||
|
_HG_Block(
|
||||||
|
in_channels,
|
||||||
|
mid_channels,
|
||||||
|
out_channels,
|
||||||
|
layer_num,
|
||||||
|
identity=False))
|
||||||
|
for _ in range(block_num - 1):
|
||||||
|
blocks_list.append(
|
||||||
|
_HG_Block(
|
||||||
|
out_channels,
|
||||||
|
mid_channels,
|
||||||
|
out_channels,
|
||||||
|
layer_num,
|
||||||
|
identity=True))
|
||||||
|
self.blocks = nn.Sequential(*blocks_list)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.downsample:
|
||||||
|
x = self.downsample(x)
|
||||||
|
x = self.blocks(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PPHGNet(TheseusLayer):
|
||||||
|
"""
|
||||||
|
PPHGNet
|
||||||
|
Args:
|
||||||
|
stem_channels: list. Stem channel list of PPHGNet.
|
||||||
|
stage_config: dict. The configuration of each stage of PPHGNet. such as the number of channels, stride, etc.
|
||||||
|
layer_num: int. Number of layers of HG_Block.
|
||||||
|
use_last_conv: boolean. Whether to use a 1x1 convolutional layer before the classification layer.
|
||||||
|
class_expand: int=2048. Number of channels for the last 1x1 convolutional layer.
|
||||||
|
dropout_prob: float. Parameters of dropout, 0.0 means dropout is not used.
|
||||||
|
class_num: int=1000. The number of classes.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific PPHGNet model depends on args.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
stem_channels,
|
||||||
|
stage_config,
|
||||||
|
layer_num,
|
||||||
|
use_last_conv=True,
|
||||||
|
class_expand=2048,
|
||||||
|
dropout_prob=0.0,
|
||||||
|
class_num=1000):
|
||||||
|
super().__init__()
|
||||||
|
self.use_last_conv = use_last_conv
|
||||||
|
self.class_expand = class_expand
|
||||||
|
|
||||||
|
# stem
|
||||||
|
stem_channels.insert(0, 3)
|
||||||
|
self.stem = nn.Sequential(* [
|
||||||
|
ConvBNAct(
|
||||||
|
in_channels=stem_channels[i],
|
||||||
|
out_channels=stem_channels[i + 1],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2 if i == 0 else 1) for i in range(
|
||||||
|
len(stem_channels) - 1)
|
||||||
|
])
|
||||||
|
self.pool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
|
||||||
|
|
||||||
|
# stages
|
||||||
|
self.stages = nn.LayerList()
|
||||||
|
for k in stage_config:
|
||||||
|
in_channels, mid_channels, out_channels, block_num, downsample = stage_config[
|
||||||
|
k]
|
||||||
|
self.stages.append(
|
||||||
|
_HG_Stage(in_channels, mid_channels, out_channels, block_num,
|
||||||
|
layer_num, downsample))
|
||||||
|
|
||||||
|
self.avg_pool = AdaptiveAvgPool2D(1)
|
||||||
|
if self.use_last_conv:
|
||||||
|
self.last_conv = Conv2D(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=self.class_expand,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias_attr=False)
|
||||||
|
self.act = nn.ReLU()
|
||||||
|
self.dropout = nn.Dropout(
|
||||||
|
p=dropout_prob, mode="downscale_in_infer")
|
||||||
|
|
||||||
|
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
|
||||||
|
self.fc = nn.Linear(self.class_expand
|
||||||
|
if self.use_last_conv else out_channels, class_num)
|
||||||
|
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
for m in self.sublayers():
|
||||||
|
if isinstance(m, nn.Conv2D):
|
||||||
|
kaiming_normal_(m.weight)
|
||||||
|
elif isinstance(m, (nn.BatchNorm2D)):
|
||||||
|
ones_(m.weight)
|
||||||
|
zeros_(m.bias)
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
zeros_(m.bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.stem(x)
|
||||||
|
x = self.pool(x)
|
||||||
|
|
||||||
|
for stage in self.stages:
|
||||||
|
x = stage(x)
|
||||||
|
|
||||||
|
x = self.avg_pool(x)
|
||||||
|
if self.use_last_conv:
|
||||||
|
x = self.last_conv(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.flatten(x)
|
||||||
|
x = self.fc(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _load_pretrained(pretrained, model, model_url, use_ssld):
|
||||||
|
if pretrained is False:
|
||||||
|
pass
|
||||||
|
elif pretrained is True:
|
||||||
|
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
|
||||||
|
elif isinstance(pretrained, str):
|
||||||
|
load_dygraph_pretrain(model, pretrained)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"pretrained type is not available. Please use `string` or `boolean` type."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def PPHGNet_tiny(pretrained=False, use_ssld=False, **kwargs):
|
||||||
|
"""
|
||||||
|
PPHGNet_tiny
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
If str, means the path of the pretrained model.
|
||||||
|
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `PPHGNet_tiny` model depends on args.
|
||||||
|
"""
|
||||||
|
stage_config = {
|
||||||
|
# in_channels, mid_channels, out_channels, blocks, downsample
|
||||||
|
"stage1": [96, 96, 224, 1, False],
|
||||||
|
"stage2": [224, 128, 448, 1, True],
|
||||||
|
"stage3": [448, 160, 512, 2, True],
|
||||||
|
"stage4": [512, 192, 768, 1, True],
|
||||||
|
}
|
||||||
|
|
||||||
|
model = PPHGNet(
|
||||||
|
stem_channels=[48, 48, 96],
|
||||||
|
stage_config=stage_config,
|
||||||
|
layer_num=5,
|
||||||
|
**kwargs)
|
||||||
|
_load_pretrained(pretrained, model, MODEL_URLS["PPHGNet_tiny"], use_ssld)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def PPHGNet_small(pretrained=False, use_ssld=False, **kwargs):
|
||||||
|
"""
|
||||||
|
PPHGNet_small
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
If str, means the path of the pretrained model.
|
||||||
|
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `PPHGNet_small` model depends on args.
|
||||||
|
"""
|
||||||
|
stage_config = {
|
||||||
|
# in_channels, mid_channels, out_channels, blocks, downsample
|
||||||
|
"stage1": [128, 128, 256, 1, False],
|
||||||
|
"stage2": [256, 160, 512, 1, True],
|
||||||
|
"stage3": [512, 192, 768, 2, True],
|
||||||
|
"stage4": [768, 224, 1024, 1, True],
|
||||||
|
}
|
||||||
|
|
||||||
|
model = PPHGNet(
|
||||||
|
stem_channels=[64, 64, 128],
|
||||||
|
stage_config=stage_config,
|
||||||
|
layer_num=6,
|
||||||
|
**kwargs)
|
||||||
|
_load_pretrained(pretrained, model, MODEL_URLS["PPHGNet_small"], use_ssld)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def PPHGNet_base(pretrained=False, use_ssld=False, **kwargs):
|
||||||
|
"""
|
||||||
|
PPHGNet_base
|
||||||
|
Args:
|
||||||
|
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
|
||||||
|
If str, means the path of the pretrained model.
|
||||||
|
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
|
||||||
|
Returns:
|
||||||
|
model: nn.Layer. Specific `PPHGNet_base` model depends on args.
|
||||||
|
"""
|
||||||
|
stage_config = {
|
||||||
|
# in_channels, mid_channels, out_channels, blocks, downsample
|
||||||
|
"stage1": [160, 192, 320, 1, False],
|
||||||
|
"stage2": [320, 224, 640, 2, True],
|
||||||
|
"stage3": [640, 256, 960, 3, True],
|
||||||
|
"stage4": [960, 288, 1280, 2, True],
|
||||||
|
}
|
||||||
|
|
||||||
|
model = PPHGNet(
|
||||||
|
stem_channels=[96, 96, 160],
|
||||||
|
stage_config=stage_config,
|
||||||
|
layer_num=7,
|
||||||
|
dropout_prob=0.2,
|
||||||
|
**kwargs)
|
||||||
|
_load_pretrained(pretrained, model, MODEL_URLS["PPHGNet_base"], use_ssld)
|
||||||
|
return model
|
|
@ -0,0 +1,164 @@
|
||||||
|
# global configs
|
||||||
|
Global:
|
||||||
|
checkpoints: null
|
||||||
|
pretrained_model: null
|
||||||
|
output_dir: ./output/
|
||||||
|
device: gpu
|
||||||
|
save_interval: 1
|
||||||
|
eval_during_train: True
|
||||||
|
eval_interval: 1
|
||||||
|
epochs: 600
|
||||||
|
print_batch_step: 10
|
||||||
|
use_visualdl: False
|
||||||
|
# used for static mode and model export
|
||||||
|
image_shape: [3, 224, 224]
|
||||||
|
save_inference_dir: ./inference
|
||||||
|
# training model under @to_static
|
||||||
|
to_static: False
|
||||||
|
use_dali: False
|
||||||
|
|
||||||
|
# mixed precision training
|
||||||
|
AMP:
|
||||||
|
scale_loss: 128.0
|
||||||
|
use_dynamic_loss_scaling: True
|
||||||
|
# O1: mixed fp16
|
||||||
|
level: O1
|
||||||
|
|
||||||
|
# model architecture
|
||||||
|
Arch:
|
||||||
|
name: PPHGNet_small
|
||||||
|
class_num: 1000
|
||||||
|
|
||||||
|
# loss function config for traing/eval process
|
||||||
|
Loss:
|
||||||
|
Train:
|
||||||
|
- CELoss:
|
||||||
|
weight: 1.0
|
||||||
|
epsilon: 0.1
|
||||||
|
Eval:
|
||||||
|
- CELoss:
|
||||||
|
weight: 1.0
|
||||||
|
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Momentum
|
||||||
|
momentum: 0.9
|
||||||
|
lr:
|
||||||
|
name: Cosine
|
||||||
|
learning_rate: 0.5
|
||||||
|
warmup_epoch: 5
|
||||||
|
regularizer:
|
||||||
|
name: 'L2'
|
||||||
|
coeff: 0.00004
|
||||||
|
|
||||||
|
|
||||||
|
# data loader for train and eval
|
||||||
|
DataLoader:
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: ImageNetDataset
|
||||||
|
image_root: ./dataset/ILSVRC2012/
|
||||||
|
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- RandCropImage:
|
||||||
|
size: 224
|
||||||
|
interpolation: bicubic
|
||||||
|
backend: pil
|
||||||
|
- RandFlipImage:
|
||||||
|
flip_code: 1
|
||||||
|
- TimmAutoAugment:
|
||||||
|
config_str: rand-m7-mstd0.5-inc1
|
||||||
|
interpolation: bicubic
|
||||||
|
img_size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
- RandomErasing:
|
||||||
|
EPSILON: 0.25
|
||||||
|
sl: 0.02
|
||||||
|
sh: 1.0/3.0
|
||||||
|
r1: 0.3
|
||||||
|
attempt: 10
|
||||||
|
use_log_aspect: True
|
||||||
|
mode: pixel
|
||||||
|
batch_transform_ops:
|
||||||
|
- OpSampler:
|
||||||
|
MixupOperator:
|
||||||
|
alpha: 0.2
|
||||||
|
prob: 0.5
|
||||||
|
CutmixOperator:
|
||||||
|
alpha: 1.0
|
||||||
|
prob: 0.5
|
||||||
|
|
||||||
|
sampler:
|
||||||
|
name: DistributedBatchSampler
|
||||||
|
batch_size: 128
|
||||||
|
drop_last: False
|
||||||
|
shuffle: True
|
||||||
|
loader:
|
||||||
|
num_workers: 16
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: ImageNetDataset
|
||||||
|
image_root: ./dataset/ILSVRC2012/
|
||||||
|
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
resize_short: 236
|
||||||
|
interpolation: bicubic
|
||||||
|
backend: pil
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
sampler:
|
||||||
|
name: DistributedBatchSampler
|
||||||
|
batch_size: 128
|
||||||
|
drop_last: False
|
||||||
|
shuffle: False
|
||||||
|
loader:
|
||||||
|
num_workers: 16
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Infer:
|
||||||
|
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
|
||||||
|
batch_size: 10
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
resize_short: 236
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
- ToCHWImage:
|
||||||
|
PostProcess:
|
||||||
|
name: Topk
|
||||||
|
topk: 5
|
||||||
|
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
Train:
|
||||||
|
- TopkAcc:
|
||||||
|
topk: [1, 5]
|
||||||
|
Eval:
|
||||||
|
- TopkAcc:
|
||||||
|
topk: [1, 5]
|
|
@ -0,0 +1,164 @@
|
||||||
|
# global configs
|
||||||
|
Global:
|
||||||
|
checkpoints: null
|
||||||
|
pretrained_model: null
|
||||||
|
output_dir: ./output/
|
||||||
|
device: gpu
|
||||||
|
save_interval: 1
|
||||||
|
eval_during_train: True
|
||||||
|
eval_interval: 1
|
||||||
|
epochs: 600
|
||||||
|
print_batch_step: 10
|
||||||
|
use_visualdl: False
|
||||||
|
# used for static mode and model export
|
||||||
|
image_shape: [3, 224, 224]
|
||||||
|
save_inference_dir: ./inference
|
||||||
|
# training model under @to_static
|
||||||
|
to_static: False
|
||||||
|
use_dali: False
|
||||||
|
|
||||||
|
# mixed precision training
|
||||||
|
AMP:
|
||||||
|
scale_loss: 128.0
|
||||||
|
use_dynamic_loss_scaling: True
|
||||||
|
# O1: mixed fp16
|
||||||
|
level: O1
|
||||||
|
|
||||||
|
# model architecture
|
||||||
|
Arch:
|
||||||
|
name: PPHGNet_tiny
|
||||||
|
class_num: 1000
|
||||||
|
|
||||||
|
# loss function config for traing/eval process
|
||||||
|
Loss:
|
||||||
|
Train:
|
||||||
|
- CELoss:
|
||||||
|
weight: 1.0
|
||||||
|
epsilon: 0.1
|
||||||
|
Eval:
|
||||||
|
- CELoss:
|
||||||
|
weight: 1.0
|
||||||
|
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Momentum
|
||||||
|
momentum: 0.9
|
||||||
|
lr:
|
||||||
|
name: Cosine
|
||||||
|
learning_rate: 0.5
|
||||||
|
warmup_epoch: 5
|
||||||
|
regularizer:
|
||||||
|
name: 'L2'
|
||||||
|
coeff: 0.00004
|
||||||
|
|
||||||
|
|
||||||
|
# data loader for train and eval
|
||||||
|
DataLoader:
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: ImageNetDataset
|
||||||
|
image_root: ./dataset/ILSVRC2012/
|
||||||
|
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- RandCropImage:
|
||||||
|
size: 224
|
||||||
|
interpolation: bicubic
|
||||||
|
backend: pil
|
||||||
|
- RandFlipImage:
|
||||||
|
flip_code: 1
|
||||||
|
- TimmAutoAugment:
|
||||||
|
config_str: rand-m7-mstd0.5-inc1
|
||||||
|
interpolation: bicubic
|
||||||
|
img_size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
- RandomErasing:
|
||||||
|
EPSILON: 0.25
|
||||||
|
sl: 0.02
|
||||||
|
sh: 1.0/3.0
|
||||||
|
r1: 0.3
|
||||||
|
attempt: 10
|
||||||
|
use_log_aspect: True
|
||||||
|
mode: pixel
|
||||||
|
batch_transform_ops:
|
||||||
|
- OpSampler:
|
||||||
|
MixupOperator:
|
||||||
|
alpha: 0.2
|
||||||
|
prob: 0.5
|
||||||
|
CutmixOperator:
|
||||||
|
alpha: 1.0
|
||||||
|
prob: 0.5
|
||||||
|
|
||||||
|
sampler:
|
||||||
|
name: DistributedBatchSampler
|
||||||
|
batch_size: 128
|
||||||
|
drop_last: False
|
||||||
|
shuffle: True
|
||||||
|
loader:
|
||||||
|
num_workers: 16
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: ImageNetDataset
|
||||||
|
image_root: ./dataset/ILSVRC2012/
|
||||||
|
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
resize_short: 232
|
||||||
|
interpolation: bicubic
|
||||||
|
backend: pil
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
sampler:
|
||||||
|
name: DistributedBatchSampler
|
||||||
|
batch_size: 128
|
||||||
|
drop_last: False
|
||||||
|
shuffle: False
|
||||||
|
loader:
|
||||||
|
num_workers: 16
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Infer:
|
||||||
|
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
|
||||||
|
batch_size: 10
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
resize_short: 232
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
- ToCHWImage:
|
||||||
|
PostProcess:
|
||||||
|
name: Topk
|
||||||
|
topk: 5
|
||||||
|
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
Train:
|
||||||
|
- TopkAcc:
|
||||||
|
topk: [1, 5]
|
||||||
|
Eval:
|
||||||
|
- TopkAcc:
|
||||||
|
topk: [1, 5]
|
|
@ -0,0 +1,53 @@
|
||||||
|
===========================train_params===========================
|
||||||
|
model_name:PPHGNet_small
|
||||||
|
python:python3.7
|
||||||
|
gpu_list:0|0,1
|
||||||
|
-o Global.device:gpu
|
||||||
|
-o Global.auto_cast:null
|
||||||
|
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120
|
||||||
|
-o Global.output_dir:./output/
|
||||||
|
-o DataLoader.Train.sampler.batch_size:8
|
||||||
|
-o Global.pretrained_model:null
|
||||||
|
train_model_name:latest
|
||||||
|
train_infer_img_dir:./dataset/ILSVRC2012/val
|
||||||
|
null:null
|
||||||
|
##
|
||||||
|
trainer:norm_train
|
||||||
|
norm_train:tools/train.py -c ppcls/configs/ImageNet/PPHGNet/PPHGNet_small.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False
|
||||||
|
pact_train:null
|
||||||
|
fpgm_train:null
|
||||||
|
distill_train:null
|
||||||
|
null:null
|
||||||
|
null:null
|
||||||
|
##
|
||||||
|
===========================eval_params===========================
|
||||||
|
eval:tools/eval.py -c ppcls/configs/ImageNet/PPHGNet/PPHGNet_small.yaml
|
||||||
|
null:null
|
||||||
|
##
|
||||||
|
===========================infer_params==========================
|
||||||
|
-o Global.save_inference_dir:./inference
|
||||||
|
-o Global.pretrained_model:
|
||||||
|
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/PPHGNet/PPHGNet_small.yaml
|
||||||
|
quant_export:null
|
||||||
|
fpgm_export:null
|
||||||
|
distill_export:null
|
||||||
|
kl_quant:null
|
||||||
|
export2:null
|
||||||
|
pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_small_pretrained.pdparams
|
||||||
|
infer_model:../inference/
|
||||||
|
infer_export:True
|
||||||
|
infer_quant:Fasle
|
||||||
|
inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=236
|
||||||
|
-o Global.use_gpu:True|False
|
||||||
|
-o Global.enable_mkldnn:True|False
|
||||||
|
-o Global.cpu_num_threads:1|6
|
||||||
|
-o Global.batch_size:1|16
|
||||||
|
-o Global.use_tensorrt:True|False
|
||||||
|
-o Global.use_fp16:True|False
|
||||||
|
-o Global.inference_model_dir:../inference
|
||||||
|
-o Global.infer_imgs:../dataset/ILSVRC2012/val
|
||||||
|
-o Global.save_log_path:null
|
||||||
|
-o Global.benchmark:True
|
||||||
|
null:null
|
||||||
|
===========================infer_benchmark_params==========================
|
||||||
|
random_infer_input:[{float32,[3,224,224]}]
|
|
@ -0,0 +1,53 @@
|
||||||
|
===========================train_params===========================
|
||||||
|
model_name:PPHGNet_tiny
|
||||||
|
python:python3.7
|
||||||
|
gpu_list:0|0,1
|
||||||
|
-o Global.device:gpu
|
||||||
|
-o Global.auto_cast:null
|
||||||
|
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120
|
||||||
|
-o Global.output_dir:./output/
|
||||||
|
-o DataLoader.Train.sampler.batch_size:8
|
||||||
|
-o Global.pretrained_model:null
|
||||||
|
train_model_name:latest
|
||||||
|
train_infer_img_dir:./dataset/ILSVRC2012/val
|
||||||
|
null:null
|
||||||
|
##
|
||||||
|
trainer:norm_train
|
||||||
|
norm_train:tools/train.py -c ppcls/configs/ImageNet/PPHGNet/PPHGNet_tiny.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False
|
||||||
|
pact_train:null
|
||||||
|
fpgm_train:null
|
||||||
|
distill_train:null
|
||||||
|
null:null
|
||||||
|
null:null
|
||||||
|
##
|
||||||
|
===========================eval_params===========================
|
||||||
|
eval:tools/eval.py -c ppcls/configs/ImageNet/PPHGNet/PPHGNet_tiny.yaml
|
||||||
|
null:null
|
||||||
|
##
|
||||||
|
===========================infer_params==========================
|
||||||
|
-o Global.save_inference_dir:./inference
|
||||||
|
-o Global.pretrained_model:
|
||||||
|
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/PPHGNet/PPHGNet_tiny.yaml
|
||||||
|
quant_export:null
|
||||||
|
fpgm_export:null
|
||||||
|
distill_export:null
|
||||||
|
kl_quant:null
|
||||||
|
export2:null
|
||||||
|
pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_tiny_pretrained.pdparams
|
||||||
|
infer_model:../inference/
|
||||||
|
infer_export:True
|
||||||
|
infer_quant:Fasle
|
||||||
|
inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=232
|
||||||
|
-o Global.use_gpu:True|False
|
||||||
|
-o Global.enable_mkldnn:True|False
|
||||||
|
-o Global.cpu_num_threads:1|6
|
||||||
|
-o Global.batch_size:1|16
|
||||||
|
-o Global.use_tensorrt:True|False
|
||||||
|
-o Global.use_fp16:True|False
|
||||||
|
-o Global.inference_model_dir:../inference
|
||||||
|
-o Global.infer_imgs:../dataset/ILSVRC2012/val
|
||||||
|
-o Global.save_log_path:null
|
||||||
|
-o Global.benchmark:True
|
||||||
|
null:null
|
||||||
|
===========================infer_benchmark_params==========================
|
||||||
|
random_infer_input:[{float32,[3,224,224]}]
|
Loading…
Reference in New Issue