add autoargument
parent
32ce68374e
commit
4ccfca291b
|
@ -0,0 +1,84 @@
|
||||||
|
mode: 'train'
|
||||||
|
ARCHITECTURE:
|
||||||
|
name: "EfficientNetB0"
|
||||||
|
drop_connect_rate: 0.1
|
||||||
|
padding_type : "SAME"
|
||||||
|
pretrained_model: ""
|
||||||
|
model_save_dir: "./output/"
|
||||||
|
classes_num: 1000
|
||||||
|
total_images: 1281167
|
||||||
|
save_interval: 1
|
||||||
|
validate: True
|
||||||
|
valid_interval: 1
|
||||||
|
epochs: 360
|
||||||
|
topk: 5
|
||||||
|
image_shape: [3, 224, 224]
|
||||||
|
use_ema: True
|
||||||
|
ema_decay: 0.9999
|
||||||
|
use_aa: True
|
||||||
|
ls_epsilon: 0.1
|
||||||
|
|
||||||
|
LEARNING_RATE:
|
||||||
|
function: 'ExponentialWarmup'
|
||||||
|
params:
|
||||||
|
lr: 0.032
|
||||||
|
|
||||||
|
OPTIMIZER:
|
||||||
|
function: 'RMSProp'
|
||||||
|
params:
|
||||||
|
momentum: 0.9
|
||||||
|
rho: 0.9
|
||||||
|
epsilon: 0.001
|
||||||
|
regularizer:
|
||||||
|
function: 'L2'
|
||||||
|
factor: 0.00001
|
||||||
|
|
||||||
|
TRAIN:
|
||||||
|
batch_size: 512
|
||||||
|
num_workers: 4
|
||||||
|
file_list: "./dataset/ILSVRC2012/train_list.txt"
|
||||||
|
data_dir: "./dataset/ILSVRC2012/"
|
||||||
|
shuffle_seed: 0
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
to_np: Fals
|
||||||
|
channel_first: False
|
||||||
|
- RandCropImage:
|
||||||
|
size: 224
|
||||||
|
- RandFlipImage:
|
||||||
|
flip_code: 1
|
||||||
|
- AA:
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1./255.
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
- ToCHWImage:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
VALID:
|
||||||
|
batch_size: 128
|
||||||
|
num_workers: 4
|
||||||
|
file_list: "./dataset/ILSVRC2012/val_list.txt"
|
||||||
|
data_dir: "./dataset/ILSVRC2012/"
|
||||||
|
shuffle_seed: 0
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
to_np: False
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
interpolation: 2
|
||||||
|
resize_short: 256
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
- ToCHWImage:
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,7 @@ import random
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from autoargument import ImageNetPolicy
|
||||||
|
|
||||||
class OperatorParamError(ValueError):
|
class OperatorParamError(ValueError):
|
||||||
""" OperatorParamError
|
""" OperatorParamError
|
||||||
|
@ -171,6 +172,18 @@ class RandFlipImage(object):
|
||||||
else:
|
else:
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
class AA(object):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.policy = ImageNetPolicy()
|
||||||
|
|
||||||
|
def __call__(self,img):
|
||||||
|
from PIL import Image
|
||||||
|
img = np.ascontiguousarray(img)
|
||||||
|
img = Image.fromarray(img)
|
||||||
|
img = self.policy(img)
|
||||||
|
img = np.asarray(img)
|
||||||
|
|
||||||
|
|
||||||
class NormalizeImage(object):
|
class NormalizeImage(object):
|
||||||
""" normalize image such as substract mean, divide std
|
""" normalize image such as substract mean, divide std
|
||||||
|
|
|
@ -145,6 +145,59 @@ class CosineWarmup(object):
|
||||||
return learning_rate
|
return learning_rate
|
||||||
|
|
||||||
|
|
||||||
|
class ExponentialWarmup(object):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Exponential learning rate decay with warmup
|
||||||
|
[0, warmup_epoch): linear warmup
|
||||||
|
[warmup_epoch, epochs): Exponential decay
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lr(float): initial learning rate
|
||||||
|
step_each_epoch(int): steps each epoch
|
||||||
|
decay_epochs(float): decay epochs
|
||||||
|
decay_rate(float): decay rate
|
||||||
|
warmup_epoch(int): epoch num of warmup
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, lr, step_each_epoch, decay_epochs=2.4, decay_rate=0.97, warmup_epoch=5, **kwargs):
|
||||||
|
super(CosineWarmup, self).__init__()
|
||||||
|
self.lr = lr
|
||||||
|
self.step_each_epoch = step_each_epoch
|
||||||
|
self.decay_epochs = decay_epochs * self.step_each_epoch
|
||||||
|
self.decay_rate = decay_rate
|
||||||
|
self.warmup_epoch = fluid.layers.fill_constant(
|
||||||
|
shape=[1],
|
||||||
|
value=float(warmup_epoch),
|
||||||
|
dtype='float32',
|
||||||
|
force_cpu=True)
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
global_step = _decay_step_counter()
|
||||||
|
learning_rate = fluid.layers.tensor.create_global_var(
|
||||||
|
shape=[1],
|
||||||
|
value=0.0,
|
||||||
|
dtype='float32',
|
||||||
|
persistable=True,
|
||||||
|
name="learning_rate")
|
||||||
|
|
||||||
|
epoch = ops.floor(global_step / self.step_each_epoch)
|
||||||
|
with fluid.layers.control_flow.Switch() as switch:
|
||||||
|
with switch.case(epoch < self.warmup_epoch):
|
||||||
|
decayed_lr = self.lr * \
|
||||||
|
(global_step / (self.step_each_epoch * self.warmup_epoch))
|
||||||
|
fluid.layers.tensor.assign(
|
||||||
|
input=decayed_lr, output=learning_rate)
|
||||||
|
with switch.default():
|
||||||
|
rest_step = global_step - self.warmup_epoch * self.step_each_epoch
|
||||||
|
div_res = ops.floor(rest_step / self.decay_epochs)
|
||||||
|
|
||||||
|
decayed_lr = self.lr*(self.decay_rate**div_res)
|
||||||
|
fluid.layers.tensor.assign(
|
||||||
|
input=decayed_lr, output=learning_rate)
|
||||||
|
|
||||||
|
return learning_rate
|
||||||
|
|
||||||
class LearningRateBuilder():
|
class LearningRateBuilder():
|
||||||
"""
|
"""
|
||||||
Build learning rate variable
|
Build learning rate variable
|
||||||
|
|
Loading…
Reference in New Issue