commit
eef1e879c8
|
@ -18,3 +18,32 @@
|
|||
>>
|
||||
* Q: 评估和预测时,已经指定了预训练模型所在文件夹的地址,但是仍然无法导入参数,这么为什么呢?
|
||||
* A: 加载预训练模型时,需要指定预训练模型的前缀,例如预训练模型参数所在的文件夹为`output/ResNet50_vd/19`,预训练模型参数的名称为`output/ResNet50_vd/19/ppcls.pdparams`,则`pretrained_model`参数需要指定为`output/ResNet50_vd/19/ppcls`,PaddleClas会自动补齐`.pdparams`的后缀。
|
||||
|
||||
|
||||
>>
|
||||
* Q: 在评测`EfficientNetB0_small`模型时,为什么最终的精度始终比官网的低0.3%左右?
|
||||
* A: `EfficientNet`系列的网络在进行resize的时候,是使用`cubic插值方式`(resize参数的interpolation值设置为2),而其他模型默认情况下为None,因此在训练和评估的时候需要显式地指定resiz的interpolation值。具体地,可以参考以下配置中预处理过程中ResizeImage的参数。
|
||||
```
|
||||
VALID:
|
||||
batch_size: 16
|
||||
num_workers: 4
|
||||
file_list: "./dataset/ILSVRC2012/val_list.txt"
|
||||
data_dir: "./dataset/ILSVRC2012/"
|
||||
shuffle_seed: 0
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
to_np: False
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
interpolation: 2
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
```
|
||||
|
|
|
@ -22,7 +22,6 @@ from __future__ import unicode_literals
|
|||
import six
|
||||
import math
|
||||
import random
|
||||
import functools
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
@ -38,8 +37,8 @@ class DecodeImage(object):
|
|||
|
||||
def __init__(self, to_rgb=True, to_np=False, channel_first=False):
|
||||
self.to_rgb = to_rgb
|
||||
self.to_np = to_np #to numpy
|
||||
self.channel_first = channel_first #only enabled when to_np is True
|
||||
self.to_np = to_np # to numpy
|
||||
self.channel_first = channel_first # only enabled when to_np is True
|
||||
|
||||
def __call__(self, img):
|
||||
if six.PY2:
|
||||
|
@ -64,7 +63,8 @@ class DecodeImage(object):
|
|||
class ResizeImage(object):
|
||||
""" resize image """
|
||||
|
||||
def __init__(self, size=None, resize_short=None):
|
||||
def __init__(self, size=None, resize_short=None, interpolation=-1):
|
||||
self.interpolation = interpolation if interpolation >= 0 else None
|
||||
if resize_short is not None and resize_short > 0:
|
||||
self.resize_short = resize_short
|
||||
self.w = None
|
||||
|
@ -86,8 +86,10 @@ class ResizeImage(object):
|
|||
else:
|
||||
w = self.w
|
||||
h = self.h
|
||||
|
||||
return cv2.resize(img, (w, h))
|
||||
if self.interpolation is None:
|
||||
return cv2.resize(img, (w, h))
|
||||
else:
|
||||
return cv2.resize(img, (w, h), interpolation=self.interpolation)
|
||||
|
||||
|
||||
class CropImage(object):
|
||||
|
@ -138,8 +140,7 @@ class RandCropImage(object):
|
|||
scale_max = min(scale[1], bound)
|
||||
scale_min = min(scale[0], bound)
|
||||
|
||||
target_area = img_w * img_h * random.uniform(\
|
||||
scale_min, scale_max)
|
||||
target_area = img_w * img_h * random.uniform(scale_min, scale_max)
|
||||
target_size = math.sqrt(target_area)
|
||||
w = int(target_size * w)
|
||||
h = int(target_size * h)
|
||||
|
@ -176,7 +177,8 @@ class NormalizeImage(object):
|
|||
"""
|
||||
|
||||
def __init__(self, scale=None, mean=None, std=None, order='chw'):
|
||||
if isinstance(scale, str): scale = eval(scale)
|
||||
if isinstance(scale, str):
|
||||
scale = eval(scale)
|
||||
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
|
||||
mean = mean if mean is not None else [0.485, 0.456, 0.406]
|
||||
std = std if std is not None else [0.229, 0.224, 0.225]
|
||||
|
|
|
@ -36,7 +36,7 @@ from .densenet import DenseNet121, DenseNet161, DenseNet169, DenseNet201, DenseN
|
|||
from .squeezenet import SqueezeNet1_0, SqueezeNet1_1
|
||||
from .darknet import DarkNet53
|
||||
from .resnext101_wsl import ResNeXt101_32x8d_wsl, ResNeXt101_32x16d_wsl, ResNeXt101_32x32d_wsl, ResNeXt101_32x48d_wsl, Fix_ResNeXt101_32x48d_wsl
|
||||
from .efficientnet import EfficientNet, EfficientNetB0, EfficientNetB1, EfficientNetB2, EfficientNetB3, EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7
|
||||
from .efficientnet import EfficientNet, EfficientNetB0, EfficientNetB0_small, EfficientNetB1, EfficientNetB2, EfficientNetB3, EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7
|
||||
from .res2net import Res2Net50_48w_2s, Res2Net50_26w_4s, Res2Net50_14w_8s, Res2Net50_26w_6s, Res2Net50_26w_8s, Res2Net101_26w_4s, Res2Net152_26w_4s
|
||||
from .res2net_vd import Res2Net50_vd_48w_2s, Res2Net50_vd_26w_4s, Res2Net50_vd_14w_8s, Res2Net50_vd_26w_6s, Res2Net50_vd_26w_8s, Res2Net101_vd_26w_4s, Res2Net152_vd_26w_4s, Res2Net200_vd_26w_4s
|
||||
from .hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W18_C, SE_HRNet_W30_C, SE_HRNet_W32_C, SE_HRNet_W40_C, SE_HRNet_W44_C, SE_HRNet_W48_C, SE_HRNet_W60_C, SE_HRNet_W64_C
|
||||
|
|
|
@ -1,16 +1,16 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -192,15 +192,17 @@ class EfficientNet():
|
|||
if is_test:
|
||||
return inputs
|
||||
keep_prob = 1.0 - prob
|
||||
random_tensor = keep_prob + fluid.layers.uniform_random_batch_size_like(
|
||||
inputs, [-1, 1, 1, 1], min=0., max=1.)
|
||||
random_tensor = keep_prob + \
|
||||
fluid.layers.uniform_random_batch_size_like(
|
||||
inputs, [-1, 1, 1, 1], min=0., max=1.)
|
||||
binary_tensor = fluid.layers.floor(random_tensor)
|
||||
output = inputs / keep_prob * binary_tensor
|
||||
return output
|
||||
|
||||
def _expand_conv_norm(self, inputs, block_args, is_test, name=None):
|
||||
# Expansion phase
|
||||
oup = block_args.input_filters * block_args.expand_ratio # number of output channels
|
||||
oup = block_args.input_filters * \
|
||||
block_args.expand_ratio # number of output channels
|
||||
|
||||
if block_args.expand_ratio != 1:
|
||||
conv = self.conv_bn_layer(
|
||||
|
@ -222,7 +224,8 @@ class EfficientNet():
|
|||
s = block_args.stride
|
||||
if isinstance(s, list) or isinstance(s, tuple):
|
||||
s = s[0]
|
||||
oup = block_args.input_filters * block_args.expand_ratio # number of output channels
|
||||
oup = block_args.input_filters * \
|
||||
block_args.expand_ratio # number of output channels
|
||||
|
||||
conv = self.conv_bn_layer(
|
||||
inputs,
|
||||
|
@ -285,7 +288,7 @@ class EfficientNet():
|
|||
name=conv_name,
|
||||
use_bias=use_bias)
|
||||
|
||||
if use_bn == False:
|
||||
if use_bn is False:
|
||||
return conv
|
||||
else:
|
||||
bn_name = name + bn_name
|
||||
|
@ -325,7 +328,8 @@ class EfficientNet():
|
|||
drop_connect_rate=None,
|
||||
name=None):
|
||||
# Expansion and Depthwise Convolution
|
||||
oup = block_args.input_filters * block_args.expand_ratio # number of output channels
|
||||
oup = block_args.input_filters * \
|
||||
block_args.expand_ratio # number of output channels
|
||||
has_se = self.use_se and (block_args.se_ratio is not None) and (
|
||||
0 < block_args.se_ratio <= 1)
|
||||
id_skip = block_args.id_skip # skip connection and drop connect
|
||||
|
@ -346,8 +350,11 @@ class EfficientNet():
|
|||
conv = self._project_conv_norm(conv, block_args, is_test, name)
|
||||
|
||||
# Skip connection and drop connect
|
||||
input_filters, output_filters = block_args.input_filters, block_args.output_filters
|
||||
if id_skip and block_args.stride == 1 and input_filters == output_filters:
|
||||
input_filters = block_args.input_filters
|
||||
output_filters = block_args.output_filters
|
||||
if id_skip and \
|
||||
block_args.stride == 1 and \
|
||||
input_filters == output_filters:
|
||||
if drop_connect_rate:
|
||||
conv = self._drop_connect(conv, drop_connect_rate,
|
||||
self.is_test)
|
||||
|
@ -412,7 +419,8 @@ class EfficientNet():
|
|||
num_repeat=round_repeats(block_args.num_repeat,
|
||||
self._global_params))
|
||||
|
||||
# The first block needs to take care of stride and filter size increase.
|
||||
# The first block needs to take care of stride,
|
||||
# and filter size increase.
|
||||
drop_connect_rate = self._global_params.drop_connect_rate
|
||||
if drop_connect_rate:
|
||||
drop_connect_rate *= float(idx) / block_size
|
||||
|
@ -440,7 +448,9 @@ class EfficientNet():
|
|||
|
||||
|
||||
class BlockDecoder(object):
|
||||
""" Block Decoder for readability, straight from the official TensorFlow repository """
|
||||
"""
|
||||
Block Decoder, straight from the official TensorFlow repository.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _decode_block_string(block_string):
|
||||
|
@ -456,9 +466,10 @@ class BlockDecoder(object):
|
|||
options[key] = value
|
||||
|
||||
# Check stride
|
||||
assert (
|
||||
('s' in options and len(options['s']) == 1) or
|
||||
(len(options['s']) == 2 and options['s'][0] == options['s'][1]))
|
||||
cond_1 = ('s' in options and len(options['s']) == 1)
|
||||
cond_2 = ((len(options['s']) == 2)
|
||||
and (options['s'][0] == options['s'][1]))
|
||||
assert (cond_1 or cond_2)
|
||||
|
||||
return BlockArgs(
|
||||
kernel_size=int(options['k']),
|
||||
|
@ -487,10 +498,11 @@ class BlockDecoder(object):
|
|||
@staticmethod
|
||||
def decode(string_list):
|
||||
"""
|
||||
Decodes a list of string notations to specify blocks inside the network.
|
||||
Decode a list of string notations to specify blocks in the network.
|
||||
|
||||
:param string_list: a list of strings, each string is a notation of block
|
||||
:return: a list of BlockArgs namedtuples of block args
|
||||
string_list: list of strings, each string is a notation of block
|
||||
return
|
||||
list of BlockArgs namedtuples of block args
|
||||
"""
|
||||
assert isinstance(string_list, list)
|
||||
blocks_args = []
|
||||
|
@ -525,6 +537,19 @@ def EfficientNetB0(is_test=False,
|
|||
return model
|
||||
|
||||
|
||||
def EfficientNetB0_small(is_test=False,
|
||||
padding_type='DYNAMIC',
|
||||
override_params=None,
|
||||
use_se=False):
|
||||
model = EfficientNet(
|
||||
name='b0',
|
||||
is_test=is_test,
|
||||
padding_type=padding_type,
|
||||
override_params=override_params,
|
||||
use_se=use_se)
|
||||
return model
|
||||
|
||||
|
||||
def EfficientNetB1(is_test=False,
|
||||
padding_type='SAME',
|
||||
override_params=None,
|
||||
|
|
Loading…
Reference in New Issue