mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
add quick start demo
This commit is contained in:
parent
5736d85b96
commit
42158c1867
70
configs/quick_start/MobileNetV3_large_x1_0_finetune.yaml
Normal file
70
configs/quick_start/MobileNetV3_large_x1_0_finetune.yaml
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
mode: 'train'
|
||||||
|
ARCHITECTURE:
|
||||||
|
name: 'MobileNetV3_large_x1_0'
|
||||||
|
pretrained_model: "./pretrained/MobileNetV3_large_x1_0_pretrained"
|
||||||
|
model_save_dir: "./output/"
|
||||||
|
classes_num: 102
|
||||||
|
total_images: 1020
|
||||||
|
save_interval: 1
|
||||||
|
validate: True
|
||||||
|
valid_interval: 1
|
||||||
|
epochs: 20
|
||||||
|
topk: 5
|
||||||
|
image_shape: [3, 224, 224]
|
||||||
|
|
||||||
|
LEARNING_RATE:
|
||||||
|
function: 'Cosine'
|
||||||
|
params:
|
||||||
|
lr: 0.00375
|
||||||
|
|
||||||
|
OPTIMIZER:
|
||||||
|
function: 'Momentum'
|
||||||
|
params:
|
||||||
|
momentum: 0.9
|
||||||
|
regularizer:
|
||||||
|
function: 'L2'
|
||||||
|
factor: 0.000001
|
||||||
|
|
||||||
|
TRAIN:
|
||||||
|
batch_size: 32
|
||||||
|
num_workers: 4
|
||||||
|
file_list: "./dataset/flowers102/train_list.txt"
|
||||||
|
data_dir: "./dataset/flowers102/"
|
||||||
|
shuffle_seed: 0
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
to_np: False
|
||||||
|
channel_first: False
|
||||||
|
- RandCropImage:
|
||||||
|
size: 224
|
||||||
|
- RandFlipImage:
|
||||||
|
flip_code: 1
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1./255.
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
- ToCHWImage:
|
||||||
|
|
||||||
|
VALID:
|
||||||
|
batch_size: 20
|
||||||
|
num_workers: 4
|
||||||
|
file_list: "./dataset/flowers102/val_list.txt"
|
||||||
|
data_dir: "./dataset/flowers102/"
|
||||||
|
shuffle_seed: 0
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
to_np: False
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
resize_short: 256
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
- ToCHWImage:
|
75
configs/quick_start/R50_vd_distill_MV3_large_x1_0.yaml
Normal file
75
configs/quick_start/R50_vd_distill_MV3_large_x1_0.yaml
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
mode: 'train'
|
||||||
|
ARCHITECTURE:
|
||||||
|
name: 'ResNet50_vd_distill_MobileNetV3_large_x1_0'
|
||||||
|
|
||||||
|
pretrained_model:
|
||||||
|
- "./pretrain/flowers102_R50_vd_final/ppcls"
|
||||||
|
- "./pretrained/MobileNetV3_large_x1_0_pretrained/"
|
||||||
|
model_save_dir: "./output/"
|
||||||
|
classes_num: 102
|
||||||
|
total_images: 7169
|
||||||
|
save_interval: 1
|
||||||
|
validate: True
|
||||||
|
valid_interval: 1
|
||||||
|
epochs: 20
|
||||||
|
topk: 5
|
||||||
|
image_shape: [3, 224, 224]
|
||||||
|
|
||||||
|
use_distillation: True
|
||||||
|
|
||||||
|
LEARNING_RATE:
|
||||||
|
function: 'Cosine'
|
||||||
|
params:
|
||||||
|
lr: 0.0125
|
||||||
|
|
||||||
|
OPTIMIZER:
|
||||||
|
function: 'Momentum'
|
||||||
|
params:
|
||||||
|
momentum: 0.9
|
||||||
|
regularizer:
|
||||||
|
function: 'L2'
|
||||||
|
factor: 0.00007
|
||||||
|
|
||||||
|
TRAIN:
|
||||||
|
batch_size: 32
|
||||||
|
num_workers: 4
|
||||||
|
file_list: "./dataset/flowers102/train_test_list.txt"
|
||||||
|
data_dir: "./dataset/flowers102/"
|
||||||
|
shuffle_seed: 0
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
to_np: False
|
||||||
|
channel_first: False
|
||||||
|
- RandCropImage:
|
||||||
|
size: 224
|
||||||
|
- RandFlipImage:
|
||||||
|
flip_code: 1
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1./255.
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
- ToCHWImage:
|
||||||
|
|
||||||
|
VALID:
|
||||||
|
batch_size: 20
|
||||||
|
num_workers: 4
|
||||||
|
file_list: "./dataset/flowers102/val_list.txt"
|
||||||
|
data_dir: "./dataset/flowers102/"
|
||||||
|
shuffle_seed: 0
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
to_np: False
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
resize_short: 256
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
- ToCHWImage:
|
70
configs/quick_start/ResNet50_vd.yaml
Normal file
70
configs/quick_start/ResNet50_vd.yaml
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
mode: 'train'
|
||||||
|
ARCHITECTURE:
|
||||||
|
name: 'ResNet50_vd'
|
||||||
|
pretrained_model: ""
|
||||||
|
model_save_dir: "./output/"
|
||||||
|
classes_num: 102
|
||||||
|
total_images: 1020
|
||||||
|
save_interval: 1
|
||||||
|
validate: True
|
||||||
|
valid_interval: 1
|
||||||
|
epochs: 20
|
||||||
|
topk: 5
|
||||||
|
image_shape: [3, 224, 224]
|
||||||
|
|
||||||
|
LEARNING_RATE:
|
||||||
|
function: 'Cosine'
|
||||||
|
params:
|
||||||
|
lr: 0.0125
|
||||||
|
|
||||||
|
OPTIMIZER:
|
||||||
|
function: 'Momentum'
|
||||||
|
params:
|
||||||
|
momentum: 0.9
|
||||||
|
regularizer:
|
||||||
|
function: 'L2'
|
||||||
|
factor: 0.00001
|
||||||
|
|
||||||
|
TRAIN:
|
||||||
|
batch_size: 32
|
||||||
|
num_workers: 4
|
||||||
|
file_list: "./dataset/flowers102/train_list.txt"
|
||||||
|
data_dir: "./dataset/flowers102/"
|
||||||
|
shuffle_seed: 0
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
to_np: False
|
||||||
|
channel_first: False
|
||||||
|
- RandCropImage:
|
||||||
|
size: 224
|
||||||
|
- RandFlipImage:
|
||||||
|
flip_code: 1
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1./255.
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
- ToCHWImage:
|
||||||
|
|
||||||
|
VALID:
|
||||||
|
batch_size: 20
|
||||||
|
num_workers: 4
|
||||||
|
file_list: "./dataset/flowers102/val_list.txt"
|
||||||
|
data_dir: "./dataset/flowers102/"
|
||||||
|
shuffle_seed: 0
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
to_np: False
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
resize_short: 256
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
- ToCHWImage:
|
70
configs/quick_start/ResNet50_vd_finetune.yaml
Normal file
70
configs/quick_start/ResNet50_vd_finetune.yaml
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
mode: 'train'
|
||||||
|
ARCHITECTURE:
|
||||||
|
name: 'ResNet50_vd'
|
||||||
|
pretrained_model: "./pretrained/ResNet50_vd_pretrained"
|
||||||
|
model_save_dir: "./output/"
|
||||||
|
classes_num: 102
|
||||||
|
total_images: 1020
|
||||||
|
save_interval: 1
|
||||||
|
validate: True
|
||||||
|
valid_interval: 1
|
||||||
|
epochs: 20
|
||||||
|
topk: 5
|
||||||
|
image_shape: [3, 224, 224]
|
||||||
|
|
||||||
|
LEARNING_RATE:
|
||||||
|
function: 'Cosine'
|
||||||
|
params:
|
||||||
|
lr: 0.00375
|
||||||
|
|
||||||
|
OPTIMIZER:
|
||||||
|
function: 'Momentum'
|
||||||
|
params:
|
||||||
|
momentum: 0.9
|
||||||
|
regularizer:
|
||||||
|
function: 'L2'
|
||||||
|
factor: 0.000001
|
||||||
|
|
||||||
|
TRAIN:
|
||||||
|
batch_size: 32
|
||||||
|
num_workers: 4
|
||||||
|
file_list: "./dataset/flowers102/train_list.txt"
|
||||||
|
data_dir: "./dataset/flowers102/"
|
||||||
|
shuffle_seed: 0
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
to_np: False
|
||||||
|
channel_first: False
|
||||||
|
- RandCropImage:
|
||||||
|
size: 224
|
||||||
|
- RandFlipImage:
|
||||||
|
flip_code: 1
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1./255.
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
- ToCHWImage:
|
||||||
|
|
||||||
|
VALID:
|
||||||
|
batch_size: 20
|
||||||
|
num_workers: 4
|
||||||
|
file_list: "./dataset/flowers102/val_list.txt"
|
||||||
|
data_dir: "./dataset/flowers102/"
|
||||||
|
shuffle_seed: 0
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
to_np: False
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
resize_short: 256
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
- ToCHWImage:
|
72
configs/quick_start/ResNet50_vd_ssld_finetune.yaml
Normal file
72
configs/quick_start/ResNet50_vd_ssld_finetune.yaml
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
mode: 'train'
|
||||||
|
ARCHITECTURE:
|
||||||
|
name: 'ResNet50_vd'
|
||||||
|
params:
|
||||||
|
lr_mult_list: [0.1, 0.1, 0.2, 0.2, 0.3]
|
||||||
|
pretrained_model: "./pretrained/ResNet50_vd_ssld_pretrained"
|
||||||
|
model_save_dir: "./output/"
|
||||||
|
classes_num: 102
|
||||||
|
total_images: 1020
|
||||||
|
save_interval: 1
|
||||||
|
validate: True
|
||||||
|
valid_interval: 1
|
||||||
|
epochs: 20
|
||||||
|
topk: 5
|
||||||
|
image_shape: [3, 224, 224]
|
||||||
|
|
||||||
|
LEARNING_RATE:
|
||||||
|
function: 'Cosine'
|
||||||
|
params:
|
||||||
|
lr: 0.00375
|
||||||
|
|
||||||
|
OPTIMIZER:
|
||||||
|
function: 'Momentum'
|
||||||
|
params:
|
||||||
|
momentum: 0.9
|
||||||
|
regularizer:
|
||||||
|
function: 'L2'
|
||||||
|
factor: 0.000001
|
||||||
|
|
||||||
|
TRAIN:
|
||||||
|
batch_size: 32
|
||||||
|
num_workers: 4
|
||||||
|
file_list: "./dataset/flowers102/train_list.txt"
|
||||||
|
data_dir: "./dataset/flowers102/"
|
||||||
|
shuffle_seed: 0
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
to_np: False
|
||||||
|
channel_first: False
|
||||||
|
- RandCropImage:
|
||||||
|
size: 224
|
||||||
|
- RandFlipImage:
|
||||||
|
flip_code: 1
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1./255.
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
- ToCHWImage:
|
||||||
|
|
||||||
|
VALID:
|
||||||
|
batch_size: 20
|
||||||
|
num_workers: 4
|
||||||
|
file_list: "./dataset/flowers102/val_list.txt"
|
||||||
|
data_dir: "./dataset/flowers102/"
|
||||||
|
shuffle_seed: 0
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
to_np: False
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
resize_short: 256
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
- ToCHWImage:
|
@ -0,0 +1,74 @@
|
|||||||
|
mode: 'train'
|
||||||
|
ARCHITECTURE:
|
||||||
|
name: 'ResNet50_vd'
|
||||||
|
params:
|
||||||
|
lr_mult_list: [0.1, 0.1, 0.2, 0.2, 0.3]
|
||||||
|
pretrained_model: "./pretrained/ResNet50_vd_ssld_pretrained"
|
||||||
|
model_save_dir: "./output/"
|
||||||
|
classes_num: 102
|
||||||
|
total_images: 1020
|
||||||
|
save_interval: 1
|
||||||
|
validate: True
|
||||||
|
valid_interval: 1
|
||||||
|
epochs: 20
|
||||||
|
topk: 5
|
||||||
|
image_shape: [3, 224, 224]
|
||||||
|
|
||||||
|
LEARNING_RATE:
|
||||||
|
function: 'Cosine'
|
||||||
|
params:
|
||||||
|
lr: 0.00375
|
||||||
|
|
||||||
|
OPTIMIZER:
|
||||||
|
function: 'Momentum'
|
||||||
|
params:
|
||||||
|
momentum: 0.9
|
||||||
|
regularizer:
|
||||||
|
function: 'L2'
|
||||||
|
factor: 0.000001
|
||||||
|
|
||||||
|
TRAIN:
|
||||||
|
batch_size: 32
|
||||||
|
num_workers: 4
|
||||||
|
file_list: "./dataset/flowers102/train_list.txt"
|
||||||
|
data_dir: "./dataset/flowers102/"
|
||||||
|
shuffle_seed: 0
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
to_np: False
|
||||||
|
channel_first: False
|
||||||
|
- RandCropImage:
|
||||||
|
size: 224
|
||||||
|
- RandFlipImage:
|
||||||
|
flip_code: 1
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1./255.
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
- RandomErasing:
|
||||||
|
EPSILON: 0.5
|
||||||
|
- ToCHWImage:
|
||||||
|
|
||||||
|
VALID:
|
||||||
|
batch_size: 20
|
||||||
|
num_workers: 4
|
||||||
|
file_list: "./dataset/flowers102/val_list.txt"
|
||||||
|
data_dir: "./dataset/flowers102/"
|
||||||
|
shuffle_seed: 0
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
to_np: False
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
resize_short: 256
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
- ToCHWImage:
|
@ -44,4 +44,4 @@ from .darts_gs import DARTS_GS_6M, DARTS_GS_4M
|
|||||||
from .resnet_acnet import ResNet18_ACNet, ResNet34_ACNet, ResNet50_ACNet, ResNet101_ACNet, ResNet152_ACNet
|
from .resnet_acnet import ResNet18_ACNet, ResNet34_ACNet, ResNet50_ACNet, ResNet101_ACNet, ResNet152_ACNet
|
||||||
|
|
||||||
# distillation model
|
# distillation model
|
||||||
from .distillation_models import ResNet50_vd_distill_MobileNetV3_x1_0, ResNeXt101_32x16d_wsl_distill_ResNet50_vd
|
from .distillation_models import ResNet50_vd_distill_MobileNetV3_large_x1_0, ResNeXt101_32x16d_wsl_distill_ResNet50_vd
|
||||||
|
@ -27,12 +27,12 @@ from .mobilenet_v3 import MobileNetV3_large_x1_0
|
|||||||
from .resnext101_wsl import ResNeXt101_32x16d_wsl
|
from .resnext101_wsl import ResNeXt101_32x16d_wsl
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ResNet50_vd_distill_MobileNetV3_x1_0',
|
'ResNet50_vd_distill_MobileNetV3_large_x1_0',
|
||||||
'ResNeXt101_32x16d_wsl_distill_ResNet50_vd'
|
'ResNeXt101_32x16d_wsl_distill_ResNet50_vd'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class ResNet50_vd_distill_MobileNetV3_x1_0():
|
class ResNet50_vd_distill_MobileNetV3_large_x1_0():
|
||||||
def net(self, input, class_dim=1000):
|
def net(self, input, class_dim=1000):
|
||||||
# student
|
# student
|
||||||
student = MobileNetV3_large_x1_0()
|
student = MobileNetV3_large_x1_0()
|
||||||
|
@ -118,7 +118,10 @@ def init_model(config, program, exe):
|
|||||||
|
|
||||||
pretrained_model = config.get('pretrained_model')
|
pretrained_model = config.get('pretrained_model')
|
||||||
if pretrained_model:
|
if pretrained_model:
|
||||||
load_params(exe, program, pretrained_model)
|
if not isinstance(pretrained_model, list):
|
||||||
|
pretrained_model = [pretrained_model]
|
||||||
|
for pretrain in pretrained_model:
|
||||||
|
load_params(exe, program, pretrain)
|
||||||
logger.info("Finish initing model from {}".format(pretrained_model))
|
logger.info("Finish initing model from {}".format(pretrained_model))
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user