Merge pull request #40 from littletomatodonkey/add_dis_finetune
add finetune inferface based on distillation modelspull/42/head
commit
2ee646eba2
|
@ -0,0 +1,74 @@
|
|||
mode: 'train'
|
||||
ARCHITECTURE:
|
||||
name: 'ResNet50_vd'
|
||||
params:
|
||||
lr_mult_list: [0.1, 0.1, 0.2, 0.2, 0.3]
|
||||
pretrained_model: "./pretrained/ResNet50_vd_ssld_pretrained"
|
||||
model_save_dir: "./output/"
|
||||
classes_num: 102
|
||||
total_images: 1020
|
||||
save_interval: 1
|
||||
validate: True
|
||||
valid_interval: 1
|
||||
epochs: 40
|
||||
topk: 5
|
||||
image_shape: [3, 224, 224]
|
||||
|
||||
ls_epsilon: 0.1
|
||||
|
||||
LEARNING_RATE:
|
||||
function: 'Cosine'
|
||||
params:
|
||||
lr: 0.00375
|
||||
|
||||
OPTIMIZER:
|
||||
function: 'Momentum'
|
||||
params:
|
||||
momentum: 0.9
|
||||
regularizer:
|
||||
function: 'L2'
|
||||
factor: 0.000001
|
||||
|
||||
TRAIN:
|
||||
batch_size: 32
|
||||
num_workers: 4
|
||||
file_list: "./dataset/flowers102/train_list.txt"
|
||||
data_dir: "./dataset/flowers102/"
|
||||
shuffle_seed: 0
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
to_np: False
|
||||
channel_first: False
|
||||
- RandCropImage:
|
||||
size: 224
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
|
||||
VALID:
|
||||
batch_size: 32
|
||||
num_workers: 4
|
||||
file_list: "./dataset/flowers102/val_list.txt"
|
||||
data_dir: "./dataset/flowers102/"
|
||||
shuffle_seed: 0
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
to_np: False
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
|
@ -7,4 +7,9 @@
|
|||
|
||||
>>
|
||||
* Q: 多卡评估时,为什么每张卡输出的精度指标不相同?
|
||||
* A: 目前PaddleClas基于fleet api使用多卡,在多卡评估时,每张卡都是单独读取各自part的数据,不同卡中计算的图片是不同的,因此最终指标也会有微量差异,如果希望得到准确的评估指标,可以使用单卡评估。
|
||||
* A: 目前PaddleClas基于fleet api使用多卡,在多卡评估时,每张卡都是单独读取各自part的数据,不同卡中计算的图片是不同的,因此最终指标也会有微量差异,如果希望得到准确的评估指标,可以使用单卡评估。
|
||||
|
||||
|
||||
>>
|
||||
* Q: 在配置文件的`TRAIN`字段中配置了`mix`的参数,为什么`mixup`的数据增广预处理没有生效呢?
|
||||
* A: 使用mixup时,数据预处理部分与模型输入部分均需要修改,因此还需要在配置文件中显式地配置`use_mix: True`,才能使得`mixup`生效。
|
|
@ -29,9 +29,20 @@ __all__ = [
|
|||
|
||||
|
||||
class ResNet():
|
||||
def __init__(self, layers=50, is_3x3=False):
|
||||
def __init__(self,
|
||||
layers=50,
|
||||
is_3x3=False,
|
||||
postfix_name="",
|
||||
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0]):
|
||||
self.layers = layers
|
||||
self.is_3x3 = is_3x3
|
||||
self.postfix_name = "" if postfix_name is None else postfix_name
|
||||
self.lr_mult_list = lr_mult_list
|
||||
assert len(
|
||||
self.lr_mult_list
|
||||
) == 5, "lr_mult_list length in ResNet must be 5 but got {}!!".format(
|
||||
len(self.lr_mult_list))
|
||||
self.curr_stage = 0
|
||||
|
||||
def net(self, input, class_dim=1000):
|
||||
is_3x3 = self.is_3x3
|
||||
|
@ -90,6 +101,7 @@ class ResNet():
|
|||
|
||||
if layers >= 50:
|
||||
for block in range(len(depth)):
|
||||
self.curr_stage += 1
|
||||
for i in range(depth[block]):
|
||||
if layers in [101, 152, 200] and block == 2:
|
||||
if i == 0:
|
||||
|
@ -106,6 +118,7 @@ class ResNet():
|
|||
name=conv_name)
|
||||
else:
|
||||
for block in range(len(depth)):
|
||||
self.curr_stage += 1
|
||||
for i in range(depth[block]):
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
conv = self.basic_block(
|
||||
|
@ -123,9 +136,9 @@ class ResNet():
|
|||
input=pool,
|
||||
size=class_dim,
|
||||
param_attr=fluid.param_attr.ParamAttr(
|
||||
name="fc_0.w_0",
|
||||
name="fc_0.w_0" + self.postfix_name,
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv)),
|
||||
bias_attr=ParamAttr(name="fc_0.b_0"))
|
||||
bias_attr=ParamAttr(name="fc_0.b_0" + self.postfix_name))
|
||||
|
||||
return out
|
||||
|
||||
|
@ -137,6 +150,7 @@ class ResNet():
|
|||
groups=1,
|
||||
act=None,
|
||||
name=None):
|
||||
lr_mult = self.lr_mult_list[self.curr_stage]
|
||||
conv = fluid.layers.conv2d(
|
||||
input=input,
|
||||
num_filters=num_filters,
|
||||
|
@ -145,7 +159,7 @@ class ResNet():
|
|||
padding=(filter_size - 1) // 2,
|
||||
groups=groups,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name=name + "_weights"),
|
||||
param_attr=ParamAttr(name=name + "_weights" + self.postfix_name),
|
||||
bias_attr=False)
|
||||
if name == "conv1":
|
||||
bn_name = "bn_" + name
|
||||
|
@ -154,10 +168,10 @@ class ResNet():
|
|||
return fluid.layers.batch_norm(
|
||||
input=conv,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name=bn_name + '_scale'),
|
||||
bias_attr=ParamAttr(bn_name + '_offset'),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance')
|
||||
param_attr=ParamAttr(name=bn_name + '_scale' + self.postfix_name),
|
||||
bias_attr=ParamAttr(bn_name + '_offset' + self.postfix_name),
|
||||
moving_mean_name=bn_name + '_mean' + self.postfix_name,
|
||||
moving_variance_name=bn_name + '_variance' + self.postfix_name)
|
||||
|
||||
def conv_bn_layer_new(self,
|
||||
input,
|
||||
|
@ -167,6 +181,7 @@ class ResNet():
|
|||
groups=1,
|
||||
act=None,
|
||||
name=None):
|
||||
lr_mult = self.lr_mult_list[self.curr_stage]
|
||||
pool = fluid.layers.pool2d(
|
||||
input=input,
|
||||
pool_size=2,
|
||||
|
@ -183,7 +198,9 @@ class ResNet():
|
|||
padding=(filter_size - 1) // 2,
|
||||
groups=groups,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name=name + "_weights"),
|
||||
param_attr=ParamAttr(
|
||||
name=name + "_weights" + self.postfix_name,
|
||||
learning_rate=lr_mult),
|
||||
bias_attr=False)
|
||||
if name == "conv1":
|
||||
bn_name = "bn_" + name
|
||||
|
@ -192,10 +209,14 @@ class ResNet():
|
|||
return fluid.layers.batch_norm(
|
||||
input=conv,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name=bn_name + '_scale'),
|
||||
bias_attr=ParamAttr(bn_name + '_offset'),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance')
|
||||
param_attr=ParamAttr(
|
||||
name=bn_name + '_scale' + self.postfix_name,
|
||||
learning_rate=lr_mult),
|
||||
bias_attr=ParamAttr(
|
||||
bn_name + '_offset' + self.postfix_name,
|
||||
learning_rate=lr_mult),
|
||||
moving_mean_name=bn_name + '_mean' + self.postfix_name,
|
||||
moving_variance_name=bn_name + '_variance' + self.postfix_name)
|
||||
|
||||
def shortcut(self, input, ch_out, stride, name, if_first=False):
|
||||
ch_in = input.shape[1]
|
||||
|
@ -273,8 +294,8 @@ def ResNet34_vd():
|
|||
return model
|
||||
|
||||
|
||||
def ResNet50_vd():
|
||||
model = ResNet(layers=50, is_3x3=True)
|
||||
def ResNet50_vd(**args):
|
||||
model = ResNet(layers=50, is_3x3=True, **args)
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -59,15 +59,18 @@ def check_architecture(architecture):
|
|||
"""
|
||||
check architecture and recommend similar architectures
|
||||
"""
|
||||
assert isinstance(architecture, str), \
|
||||
("the type of architecture({}) should be str". format(architecture))
|
||||
similar_names = similar_architectures(architecture, get_architectures())
|
||||
assert isinstance(architecture, dict), \
|
||||
("the type of architecture({}) should be dict". format(architecture))
|
||||
assert "name" in architecture, \
|
||||
("name must be in the architecture keys, just contains: {}". format(architecture.keys()))
|
||||
|
||||
similar_names = similar_architectures(architecture["name"],
|
||||
get_architectures())
|
||||
model_list = ', '.join(similar_names)
|
||||
err = "{} is not exist! Maybe you want: [{}]" \
|
||||
"".format(architecture, model_list)
|
||||
|
||||
"".format(architecture["name"], model_list)
|
||||
try:
|
||||
assert architecture in similar_names
|
||||
assert architecture["name"] in similar_names
|
||||
except AssertionError:
|
||||
logger.error(err)
|
||||
sys.exit(1)
|
||||
|
@ -80,7 +83,7 @@ def check_mix(architecture, use_mix=False):
|
|||
err = "Cannot use mix processing in GoogLeNet, " \
|
||||
"please set use_mix = False."
|
||||
try:
|
||||
if architecture == "GoogLeNet": assert use_mix == False
|
||||
if architecture["name"] == "GoogLeNet": assert use_mix == False
|
||||
except AssertionError:
|
||||
logger.error(err)
|
||||
sys.exit(1)
|
||||
|
|
|
@ -19,7 +19,7 @@ from ppcls.utils import logger
|
|||
|
||||
__all__ = ['get_config']
|
||||
|
||||
CONFIG_SECS = ['TRAIN', 'VALID', 'OPTIMIZER', 'LEARNING_RATE']
|
||||
CONFIG_SECS = ['ARCHITECTURE', 'TRAIN', 'VALID', 'OPTIMIZER', 'LEARNING_RATE']
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
|
@ -110,7 +110,7 @@ def check_config(config):
|
|||
mode = config.get('mode', 'train')
|
||||
check.check_gpu()
|
||||
|
||||
architecture = config.get('architecture')
|
||||
architecture = config.get('ARCHITECTURE')
|
||||
check.check_architecture(architecture)
|
||||
|
||||
use_mix = config.get('use_mix')
|
||||
|
|
|
@ -106,7 +106,7 @@ def load_params(exe, prog, path, ignore_params=[]):
|
|||
fluid.io.set_program_state(prog, state)
|
||||
|
||||
|
||||
def init_model(config, program, exe, prefix="ppcls"):
|
||||
def init_model(config, program, exe, prefix=""):
|
||||
"""
|
||||
load model from checkpoint or pretrained_model
|
||||
"""
|
||||
|
|
|
@ -88,19 +88,21 @@ def create_dataloader(feeds):
|
|||
return dataloader
|
||||
|
||||
|
||||
def create_model(name, image, classes_num):
|
||||
def create_model(architecture, image, classes_num):
|
||||
"""
|
||||
Create a model
|
||||
|
||||
Args:
|
||||
name(str): model name, such as ResNet50
|
||||
architecture(dict): architecture information, name(such as ResNet50) is needed
|
||||
image(variable): model input variable
|
||||
classes_num(int): num of classes
|
||||
|
||||
Returns:
|
||||
out(variable): model output variable
|
||||
"""
|
||||
model = architectures.__dict__[name]()
|
||||
name = architecture["name"]
|
||||
params = architecture.get("params", {})
|
||||
model = architectures.__dict__[name](**params)
|
||||
out = model.net(input=image, class_dim=classes_num)
|
||||
return out
|
||||
|
||||
|
@ -122,7 +124,7 @@ def create_loss(out,
|
|||
Args:
|
||||
out(variable): model output variable
|
||||
feeds(dict): dict of model input variables
|
||||
architecture(str): model name, such as ResNet50
|
||||
architecture(dict): architecture information, name(such as ResNet50) is needed
|
||||
classes_num(int): num of classes
|
||||
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
|
||||
mix(bool): whether to use mix(include mixup, cutmix, fmix)
|
||||
|
@ -130,7 +132,7 @@ def create_loss(out,
|
|||
Returns:
|
||||
loss(variable): loss variable
|
||||
"""
|
||||
if architecture == "GoogLeNet":
|
||||
if architecture["name"] == "GoogLeNet":
|
||||
assert len(out) == 3, "GoogLeNet should have 3 outputs"
|
||||
loss = GoogLeNetLoss(class_dim=classes_num, epsilon=epsilon)
|
||||
target = feeds['label']
|
||||
|
@ -188,7 +190,7 @@ def create_fetchs(out,
|
|||
Args:
|
||||
out(variable): model output variable
|
||||
feeds(dict): dict of model input variables(included label)
|
||||
architecture(str): model name, such as ResNet50
|
||||
architecture(dict): architecture information, name(such as ResNet50) is needed
|
||||
topk(int): usually top5
|
||||
classes_num(int): num of classes
|
||||
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
|
||||
|
@ -293,12 +295,12 @@ def build(config, main_prog, startup_prog, is_train=True):
|
|||
use_mix = config.get('use_mix') and is_train
|
||||
feeds = create_feeds(config.image_shape, mix=use_mix)
|
||||
dataloader = create_dataloader(feeds.values())
|
||||
out = create_model(config.architecture, feeds['image'],
|
||||
out = create_model(config.ARCHITECTURE, feeds['image'],
|
||||
config.classes_num)
|
||||
fetchs = create_fetchs(
|
||||
out,
|
||||
feeds,
|
||||
config.architecture,
|
||||
config.ARCHITECTURE,
|
||||
config.topk,
|
||||
config.classes_num,
|
||||
epsilon=config.get('ls_epsilon'),
|
||||
|
|
|
@ -96,7 +96,7 @@ def main(args):
|
|||
|
||||
if epoch_id % config.save_interval == 0:
|
||||
model_path = os.path.join(config.model_save_dir,
|
||||
config.architecture)
|
||||
config.ARCHITECTURE["name"])
|
||||
save_model(train_prog, model_path, epoch_id)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue