feat(model): add EfficientNetV2 code and fix AttrDict BUG
parent
7a0c79655d
commit
3a8b56801d
|
@ -33,7 +33,7 @@ class AttrDict(dict):
|
|||
self[key] = value
|
||||
|
||||
def __deepcopy__(self, content):
|
||||
return copy.deepcopy(dict(self))
|
||||
return AttrDict(copy.deepcopy(dict(self)))
|
||||
|
||||
|
||||
def create_attr_dict(yaml_config):
|
||||
|
|
|
@ -38,6 +38,7 @@ from .model_zoo.dpn import DPN68, DPN92, DPN98, DPN107, DPN131
|
|||
from .model_zoo.dsnet import DSNet_tiny, DSNet_small, DSNet_base
|
||||
from .model_zoo.densenet import DenseNet121, DenseNet161, DenseNet169, DenseNet201, DenseNet264
|
||||
from .model_zoo.efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2, EfficientNetB3, EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7, EfficientNetB0_small
|
||||
from .model_zoo.efficientnet_v2 import EfficientNetV2_S
|
||||
from .model_zoo.resnest import ResNeSt50_fast_1s1x64d, ResNeSt50, ResNeSt101, ResNeSt200, ResNeSt269
|
||||
from .model_zoo.googlenet import GoogLeNet
|
||||
from .model_zoo.mobilenet_v2 import MobileNetV2_x0_25, MobileNetV2_x0_5, MobileNetV2_x0_75, MobileNetV2, MobileNetV2_x1_5, MobileNetV2_x2_0
|
||||
|
|
|
@ -0,0 +1,991 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Code was based on https://github.com/lukemelas/EfficientNet-PyTorch
|
||||
# reference: https://arxiv.org/abs/1905.11946
|
||||
|
||||
import math
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle import ParamAttr
|
||||
from paddle.nn.initializer import Constant, Normal, Uniform
|
||||
from paddle.regularizer import L2Decay
|
||||
|
||||
from ppcls.utils.config import AttrDict
|
||||
|
||||
from ....utils.save_load import (load_dygraph_pretrain,
|
||||
load_dygraph_pretrain_from_url)
|
||||
|
||||
MODEL_URLS = {
|
||||
"EfficientNetV2_S":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetV2_S_pretrained.pdparams",
|
||||
"EfficientNetV2_M":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetV2_M_pretrained.pdparams",
|
||||
"EfficientNetV2_L":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetV2_L_pretrained.pdparams",
|
||||
"EfficientNetV2_XL":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetV2_XL_pretrained.pdparams",
|
||||
}
|
||||
|
||||
__all__ = list(MODEL_URLS.keys())
|
||||
|
||||
inp_shape = {
|
||||
"efficientnetv2-s": [384, 192, 192, 96, 48, 24, 24, 12],
|
||||
"efficientnetv2-m": [384, 192, 192, 96, 48, 24, 24, 12],
|
||||
"efficientnetv2-l": [384, 192, 192, 96, 48, 24, 24, 12],
|
||||
"efficientnetv2-xl": [384, 192, 192, 96, 48, 24, 24, 12],
|
||||
}
|
||||
|
||||
|
||||
def cal_padding(img_size, stride, kernel_size):
|
||||
"""Calculate padding size."""
|
||||
if img_size % stride == 0:
|
||||
out_size = max(kernel_size - stride, 0)
|
||||
else:
|
||||
out_size = max(kernel_size - (img_size % stride), 0)
|
||||
return out_size // 2, out_size - out_size // 2
|
||||
|
||||
|
||||
class Conv2ds(nn.Layer):
|
||||
"""Customed Conv2D with tensorflow's padding style
|
||||
|
||||
Args:
|
||||
input_channels (int): input channels
|
||||
output_channels (int): output channels
|
||||
kernel_size (int): filter size
|
||||
stride (int, optional): stride. Defaults to 1.
|
||||
padding (int, optional): padding. Defaults to 0.
|
||||
groups (int, optional): groups. Defaults to None.
|
||||
act (str, optional): act. Defaults to None.
|
||||
use_bias (bool, optional): use_bias. Defaults to None.
|
||||
padding_type (str, optional): padding_type. Defaults to None.
|
||||
model_name (str, optional): model name. Defaults to None.
|
||||
cur_stage (int, optional): current stage. Defaults to None.
|
||||
|
||||
Returns:
|
||||
nn.Layer: Customed Conv2D instance
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_channels: int,
|
||||
output_channels: int,
|
||||
kernel_size: int,
|
||||
stride=1,
|
||||
padding=0,
|
||||
groups=None,
|
||||
act=None,
|
||||
use_bias=None,
|
||||
padding_type=None,
|
||||
model_name=None,
|
||||
cur_stage=None):
|
||||
super(Conv2ds, self).__init__()
|
||||
assert act in [None, "swish", "sigmoid"]
|
||||
self._act = act
|
||||
|
||||
def get_padding(kernel_size, stride=1, dilation=1):
|
||||
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
||||
return padding
|
||||
|
||||
inps = inp_shape[model_name][cur_stage]
|
||||
self.need_crop = False
|
||||
if padding_type == "SAME":
|
||||
top_padding, bottom_padding = cal_padding(inps, stride,
|
||||
kernel_size)
|
||||
left_padding, right_padding = cal_padding(inps, stride,
|
||||
kernel_size)
|
||||
height_padding = bottom_padding
|
||||
width_padding = right_padding
|
||||
if top_padding != bottom_padding or left_padding != right_padding:
|
||||
height_padding = top_padding + stride
|
||||
width_padding = left_padding + stride
|
||||
self.need_crop = True
|
||||
padding = [height_padding, width_padding]
|
||||
elif padding_type == "VALID":
|
||||
height_padding = 0
|
||||
width_padding = 0
|
||||
padding = [height_padding, width_padding]
|
||||
elif padding_type == "DYNAMIC":
|
||||
padding = get_padding(kernel_size, stride)
|
||||
else:
|
||||
padding = padding_type
|
||||
|
||||
groups = 1 if groups is None else groups
|
||||
self._conv = nn.Conv2D(
|
||||
input_channels,
|
||||
output_channels,
|
||||
kernel_size,
|
||||
groups=groups,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
weight_attr=None,
|
||||
bias_attr=use_bias
|
||||
if not use_bias else ParamAttr(regularizer=L2Decay(0.0)))
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self._conv(inputs)
|
||||
if self._act == "swish":
|
||||
x = F.swish(x)
|
||||
elif self._act == "sigmoid":
|
||||
x = F.sigmoid(x)
|
||||
|
||||
if self.need_crop:
|
||||
x = x[:, :, 1:, 1:]
|
||||
return x
|
||||
|
||||
|
||||
class BlockDecoder(object):
|
||||
"""Block Decoder for readability."""
|
||||
|
||||
def _decode_block_string(self, block_string):
|
||||
"""Gets a block through a string notation of arguments."""
|
||||
assert isinstance(block_string, str)
|
||||
ops = block_string.split('_')
|
||||
options = AttrDict()
|
||||
for op in ops:
|
||||
splits = re.split(r'(\d.*)', op)
|
||||
if len(splits) >= 2:
|
||||
key, value = splits[:2]
|
||||
options[key] = value
|
||||
|
||||
t = AttrDict(
|
||||
kernel_size=int(options['k']),
|
||||
num_repeat=int(options['r']),
|
||||
in_channels=int(options['i']),
|
||||
out_channels=int(options['o']),
|
||||
expand_ratio=int(options['e']),
|
||||
se_ratio=float(options['se']) if 'se' in options else None,
|
||||
strides=int(options['s']),
|
||||
conv_type=int(options['c']) if 'c' in options else 0, )
|
||||
return t
|
||||
|
||||
def _encode_block_string(self, block):
|
||||
"""Encodes a block to a string."""
|
||||
args = [
|
||||
'r%d' % block.num_repeat,
|
||||
'k%d' % block.kernel_size,
|
||||
's%d' % block.strides,
|
||||
'e%s' % block.expand_ratio,
|
||||
'i%d' % block.in_channels,
|
||||
'o%d' % block.out_channels,
|
||||
'c%d' % block.conv_type,
|
||||
'f%d' % block.fused_conv,
|
||||
]
|
||||
if block.se_ratio > 0 and block.se_ratio <= 1:
|
||||
args.append('se%s' % block.se_ratio)
|
||||
return '_'.join(args)
|
||||
|
||||
def decode(self, string_list):
|
||||
"""Decodes a list of string notations to specify blocks inside the network.
|
||||
|
||||
Args:
|
||||
string_list: a list of strings, each string is a notation of block.
|
||||
|
||||
Returns:
|
||||
A list of namedtuples to represent blocks arguments.
|
||||
"""
|
||||
assert isinstance(string_list, list)
|
||||
blocks_args = []
|
||||
for block_string in string_list:
|
||||
blocks_args.append(self._decode_block_string(block_string))
|
||||
return blocks_args
|
||||
|
||||
def encode(self, blocks_args):
|
||||
"""Encodes a list of Blocks to a list of strings.
|
||||
|
||||
Args:
|
||||
blocks_args: A list of namedtuples to represent blocks arguments.
|
||||
Returns:
|
||||
a list of strings, each string is a notation of block.
|
||||
"""
|
||||
block_strings = []
|
||||
for block in blocks_args:
|
||||
block_strings.append(self._encode_block_string(block))
|
||||
return block_strings
|
||||
|
||||
|
||||
#################### EfficientNet V2 configs ####################
|
||||
v2_base_block = [ # The baseline config for v2 models.
|
||||
"r1_k3_s1_e1_i32_o16_c1",
|
||||
"r2_k3_s2_e4_i16_o32_c1",
|
||||
"r2_k3_s2_e4_i32_o48_c1",
|
||||
"r3_k3_s2_e4_i48_o96_se0.25",
|
||||
"r5_k3_s1_e6_i96_o112_se0.25",
|
||||
"r8_k3_s2_e6_i112_o192_se0.25",
|
||||
]
|
||||
|
||||
v2_s_block = [ # about base * (width1.4, depth1.8)
|
||||
"r2_k3_s1_e1_i24_o24_c1",
|
||||
"r4_k3_s2_e4_i24_o48_c1",
|
||||
"r4_k3_s2_e4_i48_o64_c1",
|
||||
"r6_k3_s2_e4_i64_o128_se0.25",
|
||||
"r9_k3_s1_e6_i128_o160_se0.25",
|
||||
"r15_k3_s2_e6_i160_o256_se0.25",
|
||||
]
|
||||
|
||||
v2_m_block = [ # about base * (width1.6, depth2.2)
|
||||
"r3_k3_s1_e1_i24_o24_c1",
|
||||
"r5_k3_s2_e4_i24_o48_c1",
|
||||
"r5_k3_s2_e4_i48_o80_c1",
|
||||
"r7_k3_s2_e4_i80_o160_se0.25",
|
||||
"r14_k3_s1_e6_i160_o176_se0.25",
|
||||
"r18_k3_s2_e6_i176_o304_se0.25",
|
||||
"r5_k3_s1_e6_i304_o512_se0.25",
|
||||
]
|
||||
|
||||
v2_l_block = [ # about base * (width2.0, depth3.1)
|
||||
"r4_k3_s1_e1_i32_o32_c1",
|
||||
"r7_k3_s2_e4_i32_o64_c1",
|
||||
"r7_k3_s2_e4_i64_o96_c1",
|
||||
"r10_k3_s2_e4_i96_o192_se0.25",
|
||||
"r19_k3_s1_e6_i192_o224_se0.25",
|
||||
"r25_k3_s2_e6_i224_o384_se0.25",
|
||||
"r7_k3_s1_e6_i384_o640_se0.25",
|
||||
]
|
||||
|
||||
v2_xl_block = [ # only for 21k pretraining.
|
||||
"r4_k3_s1_e1_i32_o32_c1",
|
||||
"r8_k3_s2_e4_i32_o64_c1",
|
||||
"r8_k3_s2_e4_i64_o96_c1",
|
||||
"r16_k3_s2_e4_i96_o192_se0.25",
|
||||
"r24_k3_s1_e6_i192_o256_se0.25",
|
||||
"r32_k3_s2_e6_i256_o512_se0.25",
|
||||
"r8_k3_s1_e6_i512_o640_se0.25",
|
||||
]
|
||||
efficientnetv2_params = {
|
||||
# params: (block, width, depth, dropout)
|
||||
"efficientnetv2-s": (v2_s_block, 1.0, 1.0, 0.2),
|
||||
"efficientnetv2-m": (v2_m_block, 1.0, 1.0, 0.3),
|
||||
"efficientnetv2-l": (v2_l_block, 1.0, 1.0, 0.4),
|
||||
"efficientnetv2-xl": (v2_xl_block, 1.0, 1.0, 0.4),
|
||||
}
|
||||
|
||||
|
||||
def efficientnetv2_config(model_name: str):
|
||||
"""EfficientNetV2 model config."""
|
||||
block, width, depth, dropout = efficientnetv2_params[model_name]
|
||||
|
||||
cfg = AttrDict(model=AttrDict(
|
||||
model_name=model_name,
|
||||
blocks_args=BlockDecoder().decode(block),
|
||||
width_coefficient=width,
|
||||
depth_coefficient=depth,
|
||||
dropout_rate=dropout,
|
||||
feature_size=1280,
|
||||
bn_momentum=0.9,
|
||||
bn_epsilon=1e-3,
|
||||
depth_divisor=8,
|
||||
min_depth=8,
|
||||
act_fn="silu",
|
||||
survival_prob=0.8,
|
||||
local_pooling=False,
|
||||
conv_dropout=None,
|
||||
num_classes=1000))
|
||||
return cfg
|
||||
|
||||
|
||||
def get_model_config(model_name: str):
|
||||
"""Main entry for model name to config."""
|
||||
if model_name.startswith("efficientnetv2-"):
|
||||
return efficientnetv2_config(model_name)
|
||||
raise ValueError(f"Unknown model_name {model_name}")
|
||||
|
||||
|
||||
################################################################################
|
||||
|
||||
|
||||
def round_filters(filters,
|
||||
width_coefficient,
|
||||
depth_divisor,
|
||||
min_depth,
|
||||
skip=False):
|
||||
"""Round number of filters based on depth multiplier."""
|
||||
multiplier = width_coefficient
|
||||
divisor = depth_divisor
|
||||
min_depth = min_depth
|
||||
if skip or not multiplier:
|
||||
return filters
|
||||
|
||||
filters *= multiplier
|
||||
min_depth = min_depth or divisor
|
||||
new_filters = max(min_depth,
|
||||
int(filters + divisor / 2) // divisor * divisor)
|
||||
return int(new_filters)
|
||||
|
||||
|
||||
def round_repeats(repeats, multiplier, skip=False):
|
||||
"""Round number of filters based on depth multiplier."""
|
||||
if skip or not multiplier:
|
||||
return repeats
|
||||
return int(math.ceil(multiplier * repeats))
|
||||
|
||||
|
||||
def activation_fn(act_fn: str):
|
||||
"""Customized non-linear activation type."""
|
||||
if not act_fn:
|
||||
return nn.Silu()
|
||||
elif act_fn in ("silu", "swish"):
|
||||
return nn.Swish()
|
||||
elif act_fn == "relu":
|
||||
return nn.ReLU()
|
||||
elif act_fn == "relu6":
|
||||
return nn.ReLU6()
|
||||
elif act_fn == "elu":
|
||||
return nn.ELU()
|
||||
elif act_fn == "leaky_relu":
|
||||
return nn.LeakyReLU()
|
||||
elif act_fn == "selu":
|
||||
return nn.SELU()
|
||||
elif act_fn == "mish":
|
||||
return nn.Mish()
|
||||
else:
|
||||
raise ValueError("Unsupported act_fn {}".format(act_fn))
|
||||
|
||||
|
||||
def drop_path(x, training=False, survival_prob=1.0):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
|
||||
"""
|
||||
if not training:
|
||||
return x
|
||||
shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
|
||||
keep_prob = paddle.to_tensor(survival_prob)
|
||||
random_tensor = keep_prob + paddle.rand(shape).astype(x.dtype)
|
||||
random_tensor = paddle.floor(random_tensor) # binarize
|
||||
output = x.divide(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class SE(nn.Layer):
|
||||
"""Squeeze-and-excitation layer.
|
||||
|
||||
Args:
|
||||
local_pooling (bool): local_pooling
|
||||
act_fn (str): act_fn
|
||||
in_channels (int): in_channels
|
||||
se_channels (int): se_channels
|
||||
out_channels (int): out_channels
|
||||
cur_stage (int): cur_stage
|
||||
padding_type (str): padding_type
|
||||
model_name (str): model_name
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
local_pooling: bool,
|
||||
act_fn: str,
|
||||
in_channels: int,
|
||||
se_channels: int,
|
||||
out_channels: int,
|
||||
cur_stage: int,
|
||||
padding_type: str,
|
||||
model_name: str):
|
||||
super(SE, self).__init__()
|
||||
|
||||
self._local_pooling = local_pooling
|
||||
self._act = activation_fn(act_fn)
|
||||
|
||||
# Squeeze and Excitation layer.
|
||||
self._se_reduce = Conv2ds(
|
||||
in_channels,
|
||||
se_channels,
|
||||
1,
|
||||
stride=1,
|
||||
padding_type=padding_type,
|
||||
model_name=model_name,
|
||||
cur_stage=cur_stage)
|
||||
self._se_expand = Conv2ds(
|
||||
se_channels,
|
||||
out_channels,
|
||||
1,
|
||||
stride=1,
|
||||
padding_type=padding_type,
|
||||
model_name=model_name,
|
||||
cur_stage=cur_stage)
|
||||
|
||||
def forward(self, x):
|
||||
if self._local_pooling:
|
||||
se_tensor = F.adaptive_avg_pool2d(x, output_size=1)
|
||||
else:
|
||||
se_tensor = paddle.mean(x, axis=[2, 3], keepdim=True)
|
||||
se_tensor = self._se_expand(self._act(self._se_reduce(se_tensor)))
|
||||
return F.sigmoid(se_tensor) * x
|
||||
|
||||
|
||||
class MBConvBlock(nn.Layer):
|
||||
"""A class of MBConv: Mobile Inverted Residual Bottleneck.
|
||||
|
||||
Args:
|
||||
se_ratio (int): se_ratio
|
||||
in_channels (int): in_channels
|
||||
expand_ratio (int): expand_ratio
|
||||
kernel_size (int): kernel_size
|
||||
strides (int): strides
|
||||
out_channels (int): out_channels
|
||||
bn_momentum (float): bn_momentum
|
||||
bn_epsilon (float): bn_epsilon
|
||||
local_pooling (bool): local_pooling
|
||||
conv_dropout (float): conv_dropout
|
||||
cur_stage (int): cur_stage
|
||||
padding_type (str): padding_type
|
||||
model_name (str): model_name
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
se_ratio: int,
|
||||
in_channels: int,
|
||||
expand_ratio: int,
|
||||
kernel_size: int,
|
||||
strides: int,
|
||||
out_channels: int,
|
||||
bn_momentum: float,
|
||||
bn_epsilon: float,
|
||||
local_pooling: bool,
|
||||
conv_dropout: float,
|
||||
cur_stage: int,
|
||||
padding_type: str,
|
||||
model_name: str):
|
||||
super(MBConvBlock, self).__init__()
|
||||
|
||||
self.se_ratio = se_ratio
|
||||
self.in_channels = in_channels
|
||||
self.expand_ratio = expand_ratio
|
||||
self.kernel_size = kernel_size
|
||||
self.strides = strides
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.bn_momentum = bn_momentum
|
||||
self.bn_epsilon = bn_epsilon
|
||||
|
||||
self._local_pooling = local_pooling
|
||||
self.act_fn = None
|
||||
self.conv_dropout = conv_dropout
|
||||
|
||||
self._act = activation_fn(None)
|
||||
self._has_se = (self.se_ratio is not None and 0 < self.se_ratio <= 1)
|
||||
"""Builds block according to the arguments."""
|
||||
expand_channels = self.in_channels * self.expand_ratio
|
||||
kernel_size = self.kernel_size
|
||||
|
||||
# Expansion phase. Called if not using fused convolutions and expansion
|
||||
# phase is necessary.
|
||||
if self.expand_ratio != 1:
|
||||
self._expand_conv = Conv2ds(
|
||||
self.in_channels,
|
||||
expand_channels,
|
||||
1,
|
||||
stride=1,
|
||||
use_bias=False,
|
||||
padding_type=padding_type,
|
||||
model_name=model_name,
|
||||
cur_stage=cur_stage)
|
||||
self._norm0 = nn.BatchNorm2D(
|
||||
expand_channels,
|
||||
self.bn_momentum,
|
||||
self.bn_epsilon,
|
||||
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
|
||||
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
|
||||
|
||||
# Depth-wise convolution phase. Called if not using fused convolutions.
|
||||
self._depthwise_conv = Conv2ds(
|
||||
expand_channels,
|
||||
expand_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
stride=self.strides,
|
||||
groups=expand_channels,
|
||||
use_bias=False,
|
||||
padding_type=padding_type,
|
||||
model_name=model_name,
|
||||
cur_stage=cur_stage)
|
||||
|
||||
self._norm1 = nn.BatchNorm2D(
|
||||
expand_channels,
|
||||
self.bn_momentum,
|
||||
self.bn_epsilon,
|
||||
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
|
||||
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
|
||||
|
||||
if self._has_se:
|
||||
num_reduced_filters = max(1, int(self.in_channels * self.se_ratio))
|
||||
self._se = SE(self._local_pooling, None, expand_channels,
|
||||
num_reduced_filters, expand_channels, cur_stage,
|
||||
padding_type, model_name)
|
||||
else:
|
||||
self._se = None
|
||||
|
||||
# Output phase.
|
||||
self._project_conv = Conv2ds(
|
||||
expand_channels,
|
||||
self.out_channels,
|
||||
1,
|
||||
stride=1,
|
||||
use_bias=False,
|
||||
padding_type=padding_type,
|
||||
model_name=model_name,
|
||||
cur_stage=cur_stage)
|
||||
self._norm2 = nn.BatchNorm2D(
|
||||
self.out_channels,
|
||||
self.bn_momentum,
|
||||
self.bn_epsilon,
|
||||
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
|
||||
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
|
||||
self.drop_out = nn.Dropout(self.conv_dropout)
|
||||
|
||||
def residual(self, inputs, x, survival_prob):
|
||||
if (self.strides == 1 and self.in_channels == self.out_channels):
|
||||
# Apply only if skip connection presents.
|
||||
if survival_prob:
|
||||
x = drop_path(x, self.training, survival_prob)
|
||||
x = paddle.add(x, inputs)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, inputs, survival_prob=None):
|
||||
"""Implementation of call().
|
||||
|
||||
Args:
|
||||
inputs: the inputs tensor.
|
||||
survival_prob: float, between 0 to 1, drop connect rate.
|
||||
|
||||
Returns:
|
||||
A output tensor.
|
||||
"""
|
||||
x = inputs
|
||||
if self.expand_ratio != 1:
|
||||
x = self._act(self._norm0(self._expand_conv(x)))
|
||||
|
||||
x = self._act(self._norm1(self._depthwise_conv(x)))
|
||||
|
||||
if self.conv_dropout and self.expand_ratio > 1:
|
||||
x = self.drop_out(x)
|
||||
|
||||
if self._se:
|
||||
x = self._se(x)
|
||||
|
||||
x = self._norm2(self._project_conv(x))
|
||||
x = self.residual(inputs, x, survival_prob)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FusedMBConvBlock(MBConvBlock):
|
||||
"""Fusing the proj conv1x1 and depthwise_conv into a conv2d."""
|
||||
|
||||
def __init__(self, se_ratio, in_channels, expand_ratio, kernel_size,
|
||||
strides, out_channels, bn_momentum, bn_epsilon, local_pooling,
|
||||
conv_dropout, cur_stage, padding_type, model_name):
|
||||
"""Builds block according to the arguments."""
|
||||
super(MBConvBlock, self).__init__()
|
||||
self.se_ratio = se_ratio
|
||||
self.in_channels = in_channels
|
||||
self.expand_ratio = expand_ratio
|
||||
self.kernel_size = kernel_size
|
||||
self.strides = strides
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.bn_momentum = bn_momentum
|
||||
self.bn_epsilon = bn_epsilon
|
||||
|
||||
self._local_pooling = local_pooling
|
||||
self.act_fn = None
|
||||
self.conv_dropout = conv_dropout
|
||||
|
||||
self._act = activation_fn(None)
|
||||
self._has_se = (self.se_ratio is not None and 0 < self.se_ratio <= 1)
|
||||
|
||||
expand_channels = self.in_channels * self.expand_ratio
|
||||
kernel_size = self.kernel_size
|
||||
if self.expand_ratio != 1:
|
||||
# Expansion phase:
|
||||
self._expand_conv = Conv2ds(
|
||||
self.in_channels,
|
||||
expand_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
stride=self.strides,
|
||||
use_bias=False,
|
||||
padding_type=padding_type,
|
||||
model_name=model_name,
|
||||
cur_stage=cur_stage)
|
||||
self._norm0 = nn.BatchNorm2D(
|
||||
expand_channels,
|
||||
self.bn_momentum,
|
||||
self.bn_epsilon,
|
||||
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
|
||||
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
|
||||
|
||||
if self._has_se:
|
||||
num_reduced_filters = max(1, int(self.in_channels * self.se_ratio))
|
||||
self._se = SE(self._local_pooling, None, expand_channels,
|
||||
num_reduced_filters, expand_channels, cur_stage,
|
||||
padding_type, model_name)
|
||||
else:
|
||||
self._se = None
|
||||
|
||||
# Output phase:
|
||||
self._project_conv = Conv2ds(
|
||||
expand_channels,
|
||||
self.out_channels,
|
||||
1 if (self.expand_ratio != 1) else kernel_size,
|
||||
padding=(1 if (self.expand_ratio != 1) else kernel_size) // 2,
|
||||
stride=1 if (self.expand_ratio != 1) else self.strides,
|
||||
use_bias=False,
|
||||
padding_type=padding_type,
|
||||
model_name=model_name,
|
||||
cur_stage=cur_stage)
|
||||
self._norm1 = nn.BatchNorm2D(
|
||||
self.out_channels,
|
||||
self.bn_momentum,
|
||||
self.bn_epsilon,
|
||||
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
|
||||
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
|
||||
self.drop_out = nn.Dropout(conv_dropout)
|
||||
|
||||
def forward(self, inputs, survival_prob=None):
|
||||
"""Implementation of call().
|
||||
|
||||
Args:
|
||||
inputs: the inputs tensor.
|
||||
training: boolean, whether the model is constructed for training.
|
||||
survival_prob: float, between 0 to 1, drop connect rate.
|
||||
|
||||
Returns:
|
||||
A output tensor.
|
||||
"""
|
||||
x = inputs
|
||||
if self.expand_ratio != 1:
|
||||
x = self._act(self._norm0(self._expand_conv(x)))
|
||||
|
||||
if self.conv_dropout and self.expand_ratio > 1:
|
||||
x = self.drop_out(x)
|
||||
|
||||
if self._se:
|
||||
x = self._se(x)
|
||||
|
||||
x = self._norm1(self._project_conv(x))
|
||||
if self.expand_ratio == 1:
|
||||
x = self._act(x) # add act if no expansion.
|
||||
|
||||
x = self.residual(inputs, x, survival_prob)
|
||||
return x
|
||||
|
||||
|
||||
class Stem(nn.Layer):
|
||||
"""Stem layer at the begining of the network."""
|
||||
|
||||
def __init__(self, width_coefficient, depth_divisor, min_depth, skip,
|
||||
bn_momentum, bn_epsilon, act_fn, stem_channels, cur_stage,
|
||||
padding_type, model_name):
|
||||
super(Stem, self).__init__()
|
||||
self._conv_stem = Conv2ds(
|
||||
3,
|
||||
round_filters(stem_channels, width_coefficient, depth_divisor,
|
||||
min_depth, skip),
|
||||
3,
|
||||
padding=1,
|
||||
stride=2,
|
||||
use_bias=False,
|
||||
padding_type=padding_type,
|
||||
model_name=model_name,
|
||||
cur_stage=cur_stage)
|
||||
self._norm = nn.BatchNorm2D(
|
||||
round_filters(stem_channels, width_coefficient, depth_divisor,
|
||||
min_depth, skip),
|
||||
bn_momentum,
|
||||
bn_epsilon,
|
||||
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
|
||||
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
|
||||
self._act = activation_fn(act_fn)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self._act(self._norm(self._conv_stem(inputs)))
|
||||
|
||||
|
||||
class Head(nn.Layer):
|
||||
"""Head layer for network outputs."""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
feature_size,
|
||||
bn_momentum,
|
||||
bn_epsilon,
|
||||
act_fn,
|
||||
dropout_rate,
|
||||
local_pooling,
|
||||
width_coefficient,
|
||||
depth_divisor,
|
||||
min_depth,
|
||||
skip=False):
|
||||
super(Head, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.feature_size = feature_size
|
||||
self.bn_momentum = bn_momentum
|
||||
self.bn_epsilon = bn_epsilon
|
||||
self.dropout_rate = dropout_rate
|
||||
self._local_pooling = local_pooling
|
||||
self._conv_head = nn.Conv2D(
|
||||
in_channels,
|
||||
round_filters(self.feature_size or 1280, width_coefficient,
|
||||
depth_divisor, min_depth, skip),
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias_attr=False)
|
||||
self._norm = nn.BatchNorm2D(
|
||||
round_filters(self.feature_size or 1280, width_coefficient,
|
||||
depth_divisor, min_depth, skip),
|
||||
self.bn_momentum,
|
||||
self.bn_epsilon,
|
||||
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
|
||||
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
|
||||
self._act = activation_fn(act_fn)
|
||||
|
||||
self._avg_pooling = nn.AdaptiveAvgPool2D(output_size=1)
|
||||
|
||||
if self.dropout_rate > 0:
|
||||
self._dropout = nn.Dropout(self.dropout_rate)
|
||||
else:
|
||||
self._dropout = None
|
||||
|
||||
def forward(self, x):
|
||||
"""Call the layer."""
|
||||
outputs = self._act(self._norm(self._conv_head(x)))
|
||||
|
||||
if self._local_pooling:
|
||||
outputs = F.adaptive_avg_pool2d(outputs, output_size=1)
|
||||
if self._dropout:
|
||||
outputs = self._dropout(outputs)
|
||||
if self._fc:
|
||||
outputs = paddle.squeeze(outputs, axis=[2, 3])
|
||||
outputs = self._fc(outputs)
|
||||
else:
|
||||
outputs = self._avg_pooling(outputs)
|
||||
if self._dropout:
|
||||
outputs = self._dropout(outputs)
|
||||
return paddle.flatten(outputs, start_axis=1)
|
||||
|
||||
|
||||
class EfficientNetV2(nn.Layer):
|
||||
"""A class implements tf.keras.Model.
|
||||
|
||||
Reference: https://arxiv.org/abs/1807.11626
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_name,
|
||||
blocks_args=None,
|
||||
mconfig=None,
|
||||
include_top=True,
|
||||
class_num=1000,
|
||||
padding_type="SAME"):
|
||||
"""Initializes an `Model` instance.
|
||||
|
||||
Args:
|
||||
model_name: A string of model name.
|
||||
model_config: A dict of model configurations or a string of hparams.
|
||||
Raises:
|
||||
ValueError: when blocks_args is not specified as a list.
|
||||
"""
|
||||
super(EfficientNetV2, self).__init__()
|
||||
self.blocks_args = blocks_args
|
||||
self.mconfig = mconfig
|
||||
"""Builds a model."""
|
||||
self._blocks = nn.LayerList()
|
||||
|
||||
cur_stage = 0
|
||||
# Stem part.
|
||||
self._stem = Stem(
|
||||
self.mconfig.width_coefficient,
|
||||
self.mconfig.depth_divisor,
|
||||
self.mconfig.min_depth,
|
||||
False,
|
||||
self.mconfig.bn_momentum,
|
||||
self.mconfig.bn_epsilon,
|
||||
self.mconfig.act_fn,
|
||||
stem_channels=self.blocks_args[0].in_channels,
|
||||
cur_stage=cur_stage,
|
||||
padding_type=padding_type,
|
||||
model_name=model_name)
|
||||
cur_stage += 1
|
||||
|
||||
# Builds blocks.
|
||||
for block_args in self.blocks_args:
|
||||
assert block_args.num_repeat > 0
|
||||
# Update block input and output filters based on depth multiplier.
|
||||
in_channels = round_filters(
|
||||
block_args.in_channels, self.mconfig.width_coefficient,
|
||||
self.mconfig.depth_divisor, self.mconfig.min_depth, False)
|
||||
out_channels = round_filters(
|
||||
block_args.out_channels, self.mconfig.width_coefficient,
|
||||
self.mconfig.depth_divisor, self.mconfig.min_depth, False)
|
||||
|
||||
repeats = round_repeats(block_args.num_repeat,
|
||||
self.mconfig.depth_coefficient)
|
||||
block_args.update(
|
||||
dict(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
num_repeat=repeats))
|
||||
|
||||
# The first block needs to take care of stride and filter size increase.
|
||||
conv_block = {
|
||||
0: MBConvBlock,
|
||||
1: FusedMBConvBlock
|
||||
}[block_args.conv_type]
|
||||
self._blocks.append(
|
||||
conv_block(block_args.se_ratio, block_args.in_channels,
|
||||
block_args.expand_ratio, block_args.kernel_size,
|
||||
block_args.strides, block_args.out_channels,
|
||||
self.mconfig.bn_momentum, self.mconfig.bn_epsilon,
|
||||
self.mconfig.local_pooling, self.mconfig.
|
||||
conv_dropout, cur_stage, padding_type, model_name))
|
||||
if block_args.num_repeat > 1: # rest of blocks with the same block_arg
|
||||
block_args.in_channels = block_args.out_channels
|
||||
block_args.strides = 1
|
||||
for _ in range(block_args.num_repeat - 1):
|
||||
self._blocks.append(
|
||||
conv_block(
|
||||
block_args.se_ratio, block_args.in_channels,
|
||||
block_args.expand_ratio, block_args.kernel_size,
|
||||
block_args.strides, block_args.out_channels,
|
||||
self.mconfig.bn_momentum, self.mconfig.bn_epsilon,
|
||||
self.mconfig.local_pooling, self.mconfig.conv_dropout,
|
||||
cur_stage, padding_type, model_name))
|
||||
cur_stage += 1
|
||||
|
||||
# Head part.
|
||||
self._head = Head(
|
||||
self.blocks_args[-1].out_channels, self.mconfig.feature_size,
|
||||
self.mconfig.bn_momentum, self.mconfig.bn_epsilon,
|
||||
self.mconfig.act_fn, self.mconfig.dropout_rate,
|
||||
self.mconfig.local_pooling, self.mconfig.width_coefficient,
|
||||
self.mconfig.depth_divisor, self.mconfig.min_depth, False)
|
||||
|
||||
# top part for classification
|
||||
if include_top and class_num:
|
||||
self._fc = nn.Linear(
|
||||
self.mconfig.feature_size,
|
||||
class_num,
|
||||
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
|
||||
else:
|
||||
self._fc = None
|
||||
|
||||
# initialize weight
|
||||
def _init_weights(m):
|
||||
if isinstance(m, nn.Conv2D):
|
||||
out_filters, in_channels, kernel_height, kernel_width = m.weight.shape
|
||||
if in_channels == 1 and out_filters > in_channels:
|
||||
out_filters = in_channels
|
||||
fan_out = int(kernel_height * kernel_width * out_filters)
|
||||
Normal(mean=0.0, std=np.sqrt(2.0 / fan_out))(m.weight)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init_range = 1.0 / np.sqrt(m.weight.shape[1])
|
||||
Uniform(-init_range, init_range)(m.weight)
|
||||
Constant(0.0)(m.bias)
|
||||
|
||||
self.apply(_init_weights)
|
||||
|
||||
def forward(self, inputs):
|
||||
# Calls Stem layers
|
||||
outputs = self._stem(inputs)
|
||||
# print(f"stem: {outputs.mean().item():.10f}")
|
||||
|
||||
# Calls blocks.
|
||||
for idx, block in enumerate(self._blocks):
|
||||
survival_prob = self.mconfig.survival_prob
|
||||
if survival_prob:
|
||||
drop_rate = 1.0 - survival_prob
|
||||
survival_prob = 1.0 - drop_rate * float(idx) / len(
|
||||
self._blocks)
|
||||
outputs = block(outputs, survival_prob=survival_prob)
|
||||
|
||||
# Head to obtain the final feature.
|
||||
outputs = self._head(outputs)
|
||||
# Calls final dense layers and returns logits.
|
||||
if self._fc:
|
||||
outputs = self._fc(outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def _load_pretrained(pretrained, model, model_url, use_ssld=False):
|
||||
if pretrained is False:
|
||||
pass
|
||||
elif pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
|
||||
elif isinstance(pretrained, str):
|
||||
load_dygraph_pretrain(model, pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type."
|
||||
)
|
||||
|
||||
|
||||
def EfficientNetV2_S(include_top=True, pretrained=False, **kwargs):
|
||||
"""Get a V2 model instance.
|
||||
|
||||
Returns:
|
||||
nn.Layer: A single model instantce
|
||||
"""
|
||||
model_name = "efficientnetv2-s"
|
||||
model_config = efficientnetv2_config(model_name)
|
||||
model = EfficientNetV2(model_name, model_config.model.blocks_args,
|
||||
model_config.model, include_top, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["EfficientNetV2_S"])
|
||||
return model
|
||||
|
||||
|
||||
def EfficientNetV2_M(include_top=True, pretrained=False, **kwargs):
|
||||
"""Get a V2 model instance.
|
||||
|
||||
Returns:
|
||||
nn.Layer: A single model instantce
|
||||
"""
|
||||
model_name = "efficientnetv2-m"
|
||||
model_config = efficientnetv2_config(model_name)
|
||||
model = EfficientNetV2(model_name, model_config.model.blocks_args,
|
||||
model_config.model, include_top, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["EfficientNetV2_M"])
|
||||
return model
|
||||
|
||||
|
||||
def EfficientNetV2_L(include_top=True, pretrained=False, **kwargs):
|
||||
"""Get a V2 model instance.
|
||||
|
||||
Returns:
|
||||
nn.Layer: A single model instantce
|
||||
"""
|
||||
model_name = "efficientnetv2-l"
|
||||
model_config = efficientnetv2_config(model_name)
|
||||
model = EfficientNetV2(model_name, model_config.model.blocks_args,
|
||||
model_config.model, include_top, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["EfficientNetV2_L"])
|
||||
return model
|
||||
|
||||
|
||||
def EfficientNetV2_XL(include_top=True, pretrained=False, **kwargs):
|
||||
"""Get a V2 model instance.
|
||||
|
||||
Returns:
|
||||
nn.Layer: A single model instantce
|
||||
"""
|
||||
model_name = "efficientnetv2-xl"
|
||||
model_config = efficientnetv2_config(model_name)
|
||||
model = EfficientNetV2(model_name, model_config.model.blocks_args,
|
||||
model_config.model, include_top, **kwargs)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["EfficientNetV2_XL"])
|
||||
return model
|
|
@ -0,0 +1,142 @@
|
|||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: ./output/
|
||||
device: gpu
|
||||
save_interval: 100
|
||||
eval_during_train: True
|
||||
eval_interval: 1
|
||||
epochs: 350
|
||||
print_batch_step: 20
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 384, 384]
|
||||
save_inference_dir: ./inference
|
||||
train_mode: efficientnetv2 # progressive training
|
||||
|
||||
AMP:
|
||||
scale_loss: 65536
|
||||
use_dynamic_loss_scaling: True
|
||||
# O1: mixed fp16
|
||||
level: O1
|
||||
|
||||
EMA:
|
||||
decay: 0.9999
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: EfficientNetV2_S
|
||||
class_num: 1000
|
||||
use_sync_bn: True
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
|
||||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.65 # 8gpux128bs
|
||||
warmup_epoch: 5
|
||||
regularizer:
|
||||
name: L2
|
||||
coeff: 0.00001
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: ./dataset/ILSVRC2012/
|
||||
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- RandCropImage:
|
||||
scale: [0.05, 1.0]
|
||||
size: 224
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- RandAugmentV2:
|
||||
num_layers: 2
|
||||
magnitude: 5
|
||||
- NormalizeImage:
|
||||
scale: 1.0
|
||||
mean: [128.0, 128.0, 128.0]
|
||||
std: [128.0, 128.0, 128.0]
|
||||
order: ""
|
||||
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 128
|
||||
drop_last: True
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 8
|
||||
use_shared_memory: True
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: ./dataset/ILSVRC2012/
|
||||
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- CropImageAtRatio:
|
||||
size: 384
|
||||
pad: 32
|
||||
interpolation: bilinear
|
||||
- NormalizeImage:
|
||||
scale: 1.0
|
||||
mean: [128.0, 128.0, 128.0]
|
||||
std: [128.0, 128.0, 128.0]
|
||||
order: ""
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 128
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 8
|
||||
use_shared_memory: True
|
||||
|
||||
Infer:
|
||||
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
|
||||
batch_size: 10
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- CropImageAtRatio:
|
||||
size: 384
|
||||
pad: 32
|
||||
interpolation: bilinear
|
||||
- NormalizeImage:
|
||||
scale: 1.0
|
||||
mean: [128.0, 128.0, 128.0]
|
||||
std: [128.0, 128.0, 128.0]
|
||||
order: ""
|
||||
PostProcess:
|
||||
name: Topk
|
||||
topk: 5
|
||||
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
|
||||
|
||||
Metric:
|
||||
Train:
|
||||
- TopkAcc:
|
||||
topk: [1, 5]
|
||||
Eval:
|
||||
- TopkAcc:
|
||||
topk: [1, 5]
|
|
@ -15,6 +15,7 @@
|
|||
from ppcls.data.preprocess.ops.autoaugment import ImageNetPolicy as RawImageNetPolicy
|
||||
from ppcls.data.preprocess.ops.randaugment import RandAugment as RawRandAugment
|
||||
from ppcls.data.preprocess.ops.randaugment import RandomApply
|
||||
from ppcls.data.preprocess.ops.randaugment import RandAugmentV2 as RawRandAugmentV2
|
||||
from ppcls.data.preprocess.ops.timm_autoaugment import RawTimmAutoAugment
|
||||
from ppcls.data.preprocess.ops.cutout import Cutout
|
||||
|
||||
|
@ -25,6 +26,7 @@ from ppcls.data.preprocess.ops.grid import GridMask
|
|||
from ppcls.data.preprocess.ops.operators import DecodeImage
|
||||
from ppcls.data.preprocess.ops.operators import ResizeImage
|
||||
from ppcls.data.preprocess.ops.operators import CropImage
|
||||
from ppcls.data.preprocess.ops.operators import CropImageAtRatio
|
||||
from ppcls.data.preprocess.ops.operators import CenterCrop, Resize
|
||||
from ppcls.data.preprocess.ops.operators import RandCropImage
|
||||
from ppcls.data.preprocess.ops.operators import RandCropImageV2
|
||||
|
@ -101,6 +103,13 @@ class RandAugment(RawRandAugment):
|
|||
return img
|
||||
|
||||
|
||||
class RandAugmentV2(RawRandAugmentV2):
|
||||
""" RandAugmentV2 wrapper to auto fit different img types """
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class TimmAutoAugment(RawTimmAutoAugment):
|
||||
""" TimmAutoAugment wrapper to auto fit different img tyeps. """
|
||||
|
||||
|
|
|
@ -319,6 +319,25 @@ class CropImage(object):
|
|||
return img[h_start:h_end, w_start:w_end, :]
|
||||
|
||||
|
||||
class CropImageAtRatio(object):
|
||||
""" crop image with specified size and padding"""
|
||||
|
||||
def __init__(self, size: int, pad: int, interpolation="bilinear"):
|
||||
self.size = size
|
||||
self.ratio = size / (size + pad)
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, img):
|
||||
height, width = img.shape[:2]
|
||||
crop_size = int(self.ratio * min(height, width))
|
||||
|
||||
y = (height - crop_size) // 2
|
||||
x = (width - crop_size) // 2
|
||||
|
||||
crop_img = img[y:y + crop_size, x:x + crop_size, :]
|
||||
return F.resize(crop_img, [self.size, self.size], self.interpolation)
|
||||
|
||||
|
||||
class Padv2(object):
|
||||
def __init__(self,
|
||||
size=None,
|
||||
|
|
|
@ -15,12 +15,60 @@
|
|||
# This code is based on https://github.com/heartInsert/randaugment
|
||||
# reference: https://arxiv.org/abs/1909.13719
|
||||
|
||||
from PIL import Image, ImageEnhance, ImageOps
|
||||
import numpy as np
|
||||
import random
|
||||
from .operators import RawColorJitter
|
||||
from paddle.vision.transforms import transforms as T
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageEnhance, ImageOps
|
||||
|
||||
|
||||
def solarize_add(img, add, thresh=128, **__):
|
||||
lut = []
|
||||
for i in range(256):
|
||||
if i < thresh:
|
||||
lut.append(min(255, i + add))
|
||||
else:
|
||||
lut.append(i)
|
||||
if img.mode in ("L", "RGB"):
|
||||
if img.mode == "RGB" and len(lut) == 256:
|
||||
lut = lut + lut + lut
|
||||
return img.point(lut)
|
||||
else:
|
||||
return img
|
||||
|
||||
|
||||
def cutout(image, pad_size, replace=0):
|
||||
image_np = np.array(image)
|
||||
image_height, image_width, _ = image_np.shape
|
||||
|
||||
# Sample the center location in the image where the zero mask will be applied.
|
||||
cutout_center_height = np.random.randint(0, image_height + 1)
|
||||
cutout_center_width = np.random.randint(0, image_width + 1)
|
||||
|
||||
lower_pad = np.maximum(0, cutout_center_height - pad_size)
|
||||
upper_pad = np.maximum(0, image_height - cutout_center_height - pad_size)
|
||||
left_pad = np.maximum(0, cutout_center_width - pad_size)
|
||||
right_pad = np.maximum(0, image_width - cutout_center_width - pad_size)
|
||||
|
||||
cutout_shape = [
|
||||
image_height - (lower_pad + upper_pad),
|
||||
image_width - (left_pad + right_pad)
|
||||
]
|
||||
padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]]
|
||||
mask = np.pad(np.zeros(
|
||||
cutout_shape, dtype=image_np.dtype),
|
||||
padding_dims,
|
||||
constant_values=1)
|
||||
mask = np.expand_dims(mask, -1)
|
||||
mask = np.tile(mask, [1, 1, 3])
|
||||
image_np = np.where(
|
||||
np.equal(mask, 0),
|
||||
np.full_like(
|
||||
image_np, fill_value=replace, dtype=image_np.dtype),
|
||||
image_np)
|
||||
return Image.fromarray(image_np)
|
||||
|
||||
|
||||
class RandAugment(object):
|
||||
def __init__(self, num_layers=2, magnitude=5, fillcolor=(128, 128, 128)):
|
||||
|
@ -95,10 +143,10 @@ class RandAugment(object):
|
|||
"brightness": lambda img, magnitude:
|
||||
ImageEnhance.Brightness(img).enhance(
|
||||
1 + magnitude * rnd_ch_op([-1, 1])),
|
||||
"autocontrast": lambda img, magnitude:
|
||||
"autocontrast": lambda img, _:
|
||||
ImageOps.autocontrast(img),
|
||||
"equalize": lambda img, magnitude: ImageOps.equalize(img),
|
||||
"invert": lambda img, magnitude: ImageOps.invert(img)
|
||||
"equalize": lambda img, _: ImageOps.equalize(img),
|
||||
"invert": lambda img, _: ImageOps.invert(img)
|
||||
}
|
||||
|
||||
def __call__(self, img):
|
||||
|
@ -121,4 +169,85 @@ class RandomApply(object):
|
|||
|
||||
def __call__(self, img):
|
||||
timg = self.trans(img)
|
||||
return timg
|
||||
return timg
|
||||
|
||||
|
||||
## RandAugment_EfficientNetV2 code below ##
|
||||
class RandAugmentV2(RandAugment):
|
||||
"""Customed RandAugment for EfficientNetV2"""
|
||||
|
||||
def __init__(self, num_layers=2, magnitude=5, fillcolor=(128, 128, 128)):
|
||||
super().__init__(num_layers, magnitude, fillcolor)
|
||||
abso_level = self.magnitude / self.max_level # [5.0~10.0/10.0]=[0.5, 1.0]
|
||||
self.level_map = {
|
||||
"shearX": 0.3 * abso_level,
|
||||
"shearY": 0.3 * abso_level,
|
||||
"translateX": 100.0 * abso_level,
|
||||
"translateY": 100.0 * abso_level,
|
||||
"rotate": 30 * abso_level,
|
||||
"color": 1.8 * abso_level + 0.1,
|
||||
"posterize": int(4.0 * abso_level),
|
||||
"solarize": int(256.0 * abso_level),
|
||||
"solarize_add": int(110.0 * abso_level),
|
||||
"contrast": 1.8 * abso_level + 0.1,
|
||||
"sharpness": 1.8 * abso_level + 0.1,
|
||||
"brightness": 1.8 * abso_level + 0.1,
|
||||
"autocontrast": 0,
|
||||
"equalize": 0,
|
||||
"invert": 0,
|
||||
"cutout": int(40 * abso_level)
|
||||
}
|
||||
|
||||
def rotate_with_fill(img, magnitude):
|
||||
rot = img.convert("RGBA").rotate(magnitude)
|
||||
return Image.composite(rot,
|
||||
Image.new("RGBA", rot.size, (128, ) * 4),
|
||||
rot).convert(img.mode)
|
||||
|
||||
rnd_ch_op = random.choice
|
||||
|
||||
self.func = {
|
||||
"shearX": lambda img, magnitude: img.transform(
|
||||
img.size,
|
||||
Image.AFFINE,
|
||||
(1, magnitude * rnd_ch_op([-1, 1]), 0, 0, 1, 0),
|
||||
Image.NEAREST,
|
||||
fillcolor=fillcolor),
|
||||
"shearY": lambda img, magnitude: img.transform(
|
||||
img.size,
|
||||
Image.AFFINE,
|
||||
(1, 0, 0, magnitude * rnd_ch_op([-1, 1]), 1, 0),
|
||||
Image.NEAREST,
|
||||
fillcolor=fillcolor),
|
||||
"translateX": lambda img, magnitude: img.transform(
|
||||
img.size,
|
||||
Image.AFFINE,
|
||||
(1, 0, magnitude * rnd_ch_op([-1, 1]), 0, 1, 0),
|
||||
Image.NEAREST,
|
||||
fillcolor=fillcolor),
|
||||
"translateY": lambda img, magnitude: img.transform(
|
||||
img.size,
|
||||
Image.AFFINE,
|
||||
(1, 0, 0, 0, 1, magnitude * rnd_ch_op([-1, 1])),
|
||||
Image.NEAREST,
|
||||
fillcolor=fillcolor),
|
||||
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude * rnd_ch_op([-1, 1])),
|
||||
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(magnitude),
|
||||
"posterize": lambda img, magnitude:
|
||||
ImageOps.posterize(img, magnitude),
|
||||
"solarize": lambda img, magnitude:
|
||||
ImageOps.solarize(img, magnitude),
|
||||
"solarize_add": lambda img, magnitude:
|
||||
solarize_add(img, magnitude),
|
||||
"contrast": lambda img, magnitude:
|
||||
ImageEnhance.Contrast(img).enhance(magnitude),
|
||||
"sharpness": lambda img, magnitude:
|
||||
ImageEnhance.Sharpness(img).enhance(magnitude),
|
||||
"brightness": lambda img, magnitude:
|
||||
ImageEnhance.Brightness(img).enhance(magnitude),
|
||||
"autocontrast": lambda img, _:
|
||||
ImageOps.autocontrast(img),
|
||||
"equalize": lambda img, _: ImageOps.equalize(img),
|
||||
"invert": lambda img, _: ImageOps.invert(img),
|
||||
"cutout": lambda img, magnitude: cutout(img, magnitude, replace=fillcolor[0])
|
||||
}
|
||||
|
|
|
@ -12,5 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ppcls.engine.train.train import train_epoch
|
||||
from ppcls.engine.train.train_efficientnetv2 import train_epoch_efficientnetv2
|
||||
from ppcls.engine.train.train_fixmatch import train_epoch_fixmatch
|
||||
from ppcls.engine.train.train_fixmatch_ccssl import train_epoch_fixmatch_ccssl
|
||||
from ppcls.engine.train.train_fixmatch_ccssl import train_epoch_fixmatch_ccssl
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ppcls.data import build_dataloader
|
||||
from ppcls.utils import logger
|
||||
|
||||
from .train import train_epoch
|
||||
|
||||
|
||||
def train_epoch_efficientnetv2(engine, epoch_id, print_batch_step):
|
||||
# 1. Build training hyper-parameters for different training stage
|
||||
num_stage = 4
|
||||
ratio_list = [(i + 1) / num_stage for i in range(num_stage)]
|
||||
ram_list = np.linspace(5, 10, num_stage)
|
||||
# dropout_rate_list = np.linspace(0.0, 0.2, num_stage)
|
||||
stones = [
|
||||
int(engine.config["Global"]["epochs"] * ratio_list[i])
|
||||
for i in range(num_stage)
|
||||
]
|
||||
image_size_list = [
|
||||
int(128 + (300 - 128) * ratio_list[i]) for i in range(num_stage)
|
||||
]
|
||||
stage_id = 0
|
||||
for i in range(num_stage):
|
||||
if epoch_id > stones[i]:
|
||||
stage_id = i + 1
|
||||
|
||||
# 2. Adjust training hyper-parameters for different training stage
|
||||
if not hasattr(engine, 'last_stage') or engine.last_stage < stage_id:
|
||||
engine.config["DataLoader"]["Train"]["dataset"]["transform_ops"][1][
|
||||
"RandCropImage"]["size"] = image_size_list[stage_id]
|
||||
engine.config["DataLoader"]["Train"]["dataset"]["transform_ops"][3][
|
||||
"RandAugment"]["magnitude"] = ram_list[stage_id]
|
||||
engine.train_dataloader = build_dataloader(
|
||||
engine.config["DataLoader"],
|
||||
"Train",
|
||||
engine.device,
|
||||
engine.use_dali,
|
||||
seed=epoch_id)
|
||||
engine.train_dataloader_iter = iter(engine.train_dataloader)
|
||||
engine.last_stage = stage_id
|
||||
logger.info(
|
||||
f"Training stage: [{stage_id+1}/{num_stage}](random_aug_magnitude={ram_list[stage_id]}, train_image_size={image_size_list[stage_id]})"
|
||||
)
|
||||
|
||||
# 3. Train one epoch as usual at current stage
|
||||
train_epoch(engine, epoch_id, print_batch_step)
|
|
@ -33,7 +33,7 @@ class AttrDict(dict):
|
|||
self[key] = value
|
||||
|
||||
def __deepcopy__(self, content):
|
||||
return copy.deepcopy(dict(self))
|
||||
return AttrDict(copy.deepcopy(dict(self)))
|
||||
|
||||
|
||||
def create_attr_dict(yaml_config):
|
||||
|
|
Loading…
Reference in New Issue