2021-08-22 23:10:23 +08:00
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
2021-08-24 15:07:17 +08:00
import os
2023-06-12 14:34:23 +08:00
import shutil
2023-03-14 16:16:40 +08:00
import platform
2021-08-22 23:10:23 +08:00
import paddle
import paddle . distributed as dist
from visualdl import LogWriter
2021-08-24 15:07:17 +08:00
from paddle import nn
2021-09-14 12:06:37 +08:00
import numpy as np
import random
2021-08-22 23:10:23 +08:00
2023-03-14 16:16:40 +08:00
from ppcls . utils . misc import AverageMeter
2021-08-22 23:10:23 +08:00
from ppcls . utils import logger
from ppcls . utils . logger import init_logger
from ppcls . utils . config import print_config
2023-03-14 16:16:40 +08:00
from ppcls . data import build_dataloader
2021-12-10 11:14:14 +08:00
from ppcls . arch import build_model , RecModel , DistillationModel , TheseusLayer
2023-03-14 16:16:40 +08:00
from ppcls . arch import apply_to_static
2023-03-14 16:16:40 +08:00
from ppcls . loss import build_loss
from ppcls . metric import build_metrics
from ppcls . optimizer import build_optimizer
2023-05-23 17:16:12 +08:00
from ppcls . utils . amp import AutoCast , build_scaler
2023-03-14 16:16:40 +08:00
from ppcls . utils . ema import ExponentialMovingAverage
2023-10-30 13:44:16 +08:00
from ppcls . utils . save_load import load_dygraph_pretrain
2023-03-14 16:16:40 +08:00
from ppcls . utils . save_load import init_model
2023-08-29 14:32:07 +08:00
from ppcls . utils import save_load , save_predict_result
2021-08-22 23:10:23 +08:00
from ppcls . data . utils . get_image_list import get_image_list
from ppcls . data . postprocess import build_postprocess
from ppcls . data import create_operators
2023-03-14 16:16:40 +08:00
from ppcls . engine import train as train_method
2023-03-14 16:16:40 +08:00
from ppcls . engine . train . utils import type_name
2021-08-24 11:02:55 +08:00
from ppcls . engine import evaluation
2021-08-22 23:10:23 +08:00
from ppcls . arch . gears . identity_head import IdentityHead
2021-08-24 11:02:55 +08:00
class Engine ( object ) :
2021-08-22 23:10:23 +08:00
def __init__ ( self , config , mode = " train " ) :
2021-08-24 11:02:55 +08:00
assert mode in [ " train " , " eval " , " infer " , " export " ]
2021-08-22 23:10:23 +08:00
self . mode = mode
self . config = config
2023-03-14 16:16:40 +08:00
self . eval_mode = self . config [ " Global " ] . get ( " eval_mode " ,
" classification " )
self . train_mode = self . config [ " Global " ] . get ( " train_mode " , None )
2023-03-14 16:16:40 +08:00
if " Head " in self . config [ " Arch " ] or self . config [ " Arch " ] . get ( " is_rec " ,
False ) :
self . is_rec = True
else :
self . is_rec = False
2023-03-14 16:16:40 +08:00
# set seed
seed = self . config [ " Global " ] . get ( " seed " , False )
if seed or seed == 0 :
assert isinstance ( seed , int ) , " The ' seed ' must be a integer! "
paddle . seed ( seed )
np . random . seed ( seed )
random . seed ( seed )
# init logger
self . output_dir = self . config [ ' Global ' ] [ ' output_dir ' ]
2023-06-12 14:34:23 +08:00
log_file = os . path . join ( self . output_dir , f " { mode } .log " )
2023-03-14 16:16:40 +08:00
init_logger ( log_file = log_file )
print_config ( config )
2023-03-14 16:16:40 +08:00
# init train_func and eval_func
2023-03-14 16:16:40 +08:00
assert self . eval_mode in [
" classification " , " retrieval " , " adaface "
] , logger . error ( " Invalid eval mode: {} " . format ( self . eval_mode ) )
2023-03-14 16:16:40 +08:00
if self . train_mode is None :
self . train_epoch_func = train_method . train_epoch
else :
self . train_epoch_func = getattr ( train_method ,
" train_epoch_ " + self . train_mode )
self . eval_func = getattr ( evaluation , self . eval_mode + " _eval " )
2023-03-14 16:16:40 +08:00
2023-03-14 16:16:40 +08:00
self . use_dali = self . config [ ' Global ' ] . get ( " use_dali " , False )
# for visualdl
self . vdl_writer = None
if self . config [ ' Global ' ] [
' use_visualdl ' ] and mode == " train " and dist . get_rank ( ) == 0 :
2023-06-13 23:01:48 +08:00
vdl_writer_path = self . output_dir
2023-03-14 16:16:40 +08:00
if not os . path . exists ( vdl_writer_path ) :
os . makedirs ( vdl_writer_path )
self . vdl_writer = LogWriter ( logdir = vdl_writer_path )
2021-08-22 23:10:23 +08:00
# set device
2023-03-14 16:16:40 +08:00
assert self . config [ " Global " ] [
2023-06-27 10:58:32 +08:00
" device " ] in [ " cpu " , " gpu " , " xpu " , " npu " , " mlu " , " ascend " , " intel_gpu " , " mps " ]
2023-03-14 16:16:40 +08:00
self . device = paddle . set_device ( self . config [ " Global " ] [ " device " ] )
logger . info ( ' train with paddle {} and device {} ' . format (
paddle . __version__ , self . device ) )
2021-08-22 23:10:23 +08:00
2023-03-14 16:16:40 +08:00
# gradient accumulation
self . update_freq = self . config [ " Global " ] . get ( " update_freq " , 1 )
2023-03-14 16:16:40 +08:00
if " class_num " in config [ " Global " ] :
global_class_num = config [ " Global " ] [ " class_num " ]
if " class_num " not in config [ " Arch " ] :
config [ " Arch " ] [ " class_num " ] = global_class_num
msg = f " The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to { global_class_num } . "
else :
msg = " The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored. "
logger . warning ( msg )
#TODO(gaotingquan): support rec
class_num = config [ " Arch " ] . get ( " class_num " , None )
self . config [ " DataLoader " ] . update ( { " class_num " : class_num } )
self . config [ " DataLoader " ] . update ( {
" epochs " : self . config [ " Global " ] [ " epochs " ]
} )
2023-03-14 16:16:40 +08:00
# build dataloader
2023-03-14 16:16:40 +08:00
if self . mode == ' train ' :
self . train_dataloader = build_dataloader (
self . config [ " DataLoader " ] , " Train " , self . device , self . use_dali )
if self . config [ " DataLoader " ] . get ( ' UnLabelTrain ' , None ) is not None :
self . unlabel_train_dataloader = build_dataloader (
self . config [ " DataLoader " ] , " UnLabelTrain " , self . device ,
self . use_dali )
else :
self . unlabel_train_dataloader = None
self . iter_per_epoch = len (
self . train_dataloader ) - 1 if platform . system (
) == " Windows " else len ( self . train_dataloader )
if self . config [ " Global " ] . get ( " iter_per_epoch " , None ) :
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
self . iter_per_epoch = self . config [ " Global " ] . get (
" iter_per_epoch " )
2023-05-16 16:26:39 +08:00
if self . iter_per_epoch < self . update_freq :
logger . warning (
" The arg Global.update_freq greater than iter_per_epoch and has been set to 1. This may be caused by too few of batches. "
)
self . update_freq = 1
2023-03-14 16:16:40 +08:00
self . iter_per_epoch = self . iter_per_epoch / / self . update_freq * self . update_freq
if self . mode == " eval " or ( self . mode == " train " and
self . config [ " Global " ] [ " eval_during_train " ] ) :
if self . eval_mode in [ " classification " , " adaface " ] :
self . eval_dataloader = build_dataloader (
self . config [ " DataLoader " ] , " Eval " , self . device ,
self . use_dali )
elif self . eval_mode == " retrieval " :
self . gallery_query_dataloader = None
if len ( self . config [ " DataLoader " ] [ " Eval " ] . keys ( ) ) == 1 :
key = list ( self . config [ " DataLoader " ] [ " Eval " ] . keys ( ) ) [ 0 ]
self . gallery_query_dataloader = build_dataloader (
self . config [ " DataLoader " ] [ " Eval " ] , key , self . device ,
self . use_dali )
else :
self . gallery_dataloader = build_dataloader (
self . config [ " DataLoader " ] [ " Eval " ] , " Gallery " ,
self . device , self . use_dali )
self . query_dataloader = build_dataloader (
self . config [ " DataLoader " ] [ " Eval " ] , " Query " ,
self . device , self . use_dali )
2023-03-14 16:16:40 +08:00
# build loss
2023-03-14 16:16:40 +08:00
if self . mode == " train " :
label_loss_info = self . config [ " Loss " ] [ " Train " ]
self . train_loss_func = build_loss ( label_loss_info )
unlabel_loss_info = self . config . get ( " UnLabelLoss " , { } ) . get ( " Train " ,
None )
self . unlabel_train_loss_func = build_loss ( unlabel_loss_info )
if self . mode == " eval " or ( self . mode == " train " and
self . config [ " Global " ] [ " eval_during_train " ] ) :
loss_config = self . config . get ( " Loss " , None )
if loss_config is not None :
loss_config = loss_config . get ( " Eval " )
if loss_config is not None :
self . eval_loss_func = build_loss ( loss_config )
else :
self . eval_loss_func = None
else :
self . eval_loss_func = None
2023-03-14 16:16:40 +08:00
# build metric
2023-03-14 16:16:40 +08:00
if self . mode == ' train ' and " Metric " in self . config and " Train " in self . config [
" Metric " ] and self . config [ " Metric " ] [ " Train " ] :
metric_config = self . config [ " Metric " ] [ " Train " ]
if hasattr ( self . train_dataloader , " collate_fn "
) and self . train_dataloader . collate_fn is not None :
for m_idx , m in enumerate ( metric_config ) :
if " TopkAcc " in m :
msg = f " Unable to calculate accuracy when using \" batch_transform_ops \" . The metric \" { m } \" has been removed. "
logger . warning ( msg )
metric_config . pop ( m_idx )
self . train_metric_func = build_metrics ( metric_config )
else :
self . train_metric_func = None
if self . mode == " eval " or ( self . mode == " train " and
self . config [ " Global " ] [ " eval_during_train " ] ) :
if self . eval_mode == " classification " :
if " Metric " in self . config and " Eval " in self . config [ " Metric " ] :
self . eval_metric_func = build_metrics ( self . config [ " Metric " ]
[ " Eval " ] )
else :
self . eval_metric_func = None
elif self . eval_mode == " retrieval " :
if " Metric " in self . config and " Eval " in self . config [ " Metric " ] :
metric_config = self . config [ " Metric " ] [ " Eval " ]
else :
metric_config = [ { " name " : " Recallk " , " topk " : ( 1 , 5 ) } ]
self . eval_metric_func = build_metrics ( metric_config )
else :
self . eval_metric_func = None
2023-03-14 16:16:40 +08:00
2021-08-22 23:10:23 +08:00
# build model
2022-05-23 19:04:27 +08:00
self . model = build_model ( self . config , self . mode )
2023-03-14 16:16:40 +08:00
# set @to_static for benchmark, skip this by default.
apply_to_static ( self . config , self . model )
2021-10-18 18:07:14 +08:00
2023-03-14 16:16:40 +08:00
# load_pretrain
2023-03-14 16:16:40 +08:00
if self . config [ " Global " ] [ " pretrained_model " ] is not None :
2023-10-30 13:44:16 +08:00
load_dygraph_pretrain (
[ self . model , getattr ( self , ' train_loss_func ' , None ) ] ,
self . config [ " Global " ] [ " pretrained_model " ] )
2023-03-14 16:16:40 +08:00
2023-03-14 16:16:40 +08:00
# build optimizer
2023-03-14 16:16:40 +08:00
if self . mode == ' train ' :
self . optimizer , self . lr_sch = build_optimizer (
self . config [ " Optimizer " ] , self . config [ " Global " ] [ " epochs " ] ,
self . iter_per_epoch / / self . update_freq ,
[ self . model , self . train_loss_func ] )
2023-05-23 11:01:25 +08:00
# amp
self . _init_amp ( )
2023-03-14 16:16:40 +08:00
# build EMA model
self . ema = " EMA " in self . config and self . mode == " train "
if self . ema :
self . model_ema = ExponentialMovingAverage (
self . model , self . config [ ' EMA ' ] . get ( " decay " , 0.9999 ) )
# check the gpu num
world_size = dist . get_world_size ( )
self . config [ " Global " ] [ " distributed " ] = world_size != 1
if self . mode == " train " :
std_gpu_num = 8 if isinstance (
self . config [ " Optimizer " ] ,
dict ) and self . config [ " Optimizer " ] [ " name " ] == " AdamW " else 4
if world_size != std_gpu_num :
msg = f " The training strategy provided by PaddleClas is based on { std_gpu_num } gpus. But the number of gpu is { world_size } in current training. Please modify the stategy (learning rate, batch size and so on) if use this config to train. "
logger . warning ( msg )
2022-04-28 22:19:01 +08:00
# for distributed
2023-03-14 16:16:40 +08:00
if self . config [ " Global " ] [ " distributed " ] :
dist . init_parallel_env ( )
self . model = paddle . DataParallel ( self . model )
if self . mode == ' train ' and len ( self . train_loss_func . parameters (
) ) > 0 :
self . train_loss_func = paddle . DataParallel (
self . train_loss_func )
# set different seed in different GPU manually in distributed environment
if seed is None :
logger . warning (
" The random seed cannot be None in a distributed environment. Global.seed has been set to 42 by default "
)
self . config [ " Global " ] [ " seed " ] = seed = 42
logger . info (
f " Set random seed to ( { int ( seed ) } + $PADDLE_TRAINER_ID) for different trainer "
)
paddle . seed ( int ( seed ) + dist . get_rank ( ) )
np . random . seed ( int ( seed ) + dist . get_rank ( ) )
random . seed ( int ( seed ) + dist . get_rank ( ) )
# build postprocess for infer
if self . mode == ' infer ' :
self . preprocess_func = create_operators ( self . config [ " Infer " ] [
" transforms " ] )
self . postprocess_func = build_postprocess ( self . config [ " Infer " ] [
" PostProcess " ] )
2021-08-22 23:10:23 +08:00
2023-03-14 16:16:40 +08:00
def train ( self ) :
assert self . mode == " train "
print_batch_step = self . config [ ' Global ' ] [ ' print_batch_step ' ]
save_interval = self . config [ " Global " ] [ " save_interval " ]
best_metric = {
" metric " : - 1.0 ,
" epoch " : 0 ,
}
2023-03-14 16:16:40 +08:00
ema_module = None
2023-03-14 16:16:40 +08:00
if self . ema :
best_metric_ema = 0.0
ema_module = self . model_ema . module
2023-03-14 16:16:40 +08:00
# key:
# val: metrics list word
self . output_info = dict ( )
self . time_info = {
" batch_cost " : AverageMeter (
" batch_cost " , ' .5f ' , postfix = " s, " ) ,
" reader_cost " : AverageMeter (
" reader_cost " , " .5f " , postfix = " s, " ) ,
}
# global iter counter
self . global_step = 0
2023-03-14 16:16:40 +08:00
if self . config . Global . checkpoints is not None :
metric_info = init_model ( self . config . Global , self . model ,
self . optimizer , self . train_loss_func ,
ema_module )
if metric_info is not None :
best_metric . update ( metric_info )
2023-03-06 10:21:27 +08:00
if hasattr ( self . train_dataloader . batch_sampler , " set_epoch " ) :
self . train_dataloader . batch_sampler . set_epoch ( best_metric [
" epoch " ] )
2023-03-14 16:16:40 +08:00
for epoch_id in range ( best_metric [ " epoch " ] + 1 ,
self . config [ " Global " ] [ " epochs " ] + 1 ) :
acc = 0.0
2023-03-14 16:16:40 +08:00
# for one epoch train
self . train_epoch_func ( self , epoch_id , print_batch_step )
2023-03-14 16:16:40 +08:00
if self . use_dali :
self . train_dataloader . reset ( )
2023-03-14 16:16:40 +08:00
metric_msg = " , " . join (
[ self . output_info [ key ] . avg_info for key in self . output_info ] )
2023-03-14 16:16:40 +08:00
logger . info ( " [Train][Epoch {} / {} ][Avg] {} " . format (
epoch_id , self . config [ " Global " ] [ " epochs " ] , metric_msg ) )
2023-03-14 16:16:40 +08:00
self . output_info . clear ( )
2023-03-14 16:16:40 +08:00
# eval model and save model if possible
start_eval_epoch = self . config [ " Global " ] . get ( " start_eval_epoch " ,
0 ) - 1
2023-03-14 16:16:40 +08:00
if self . config [ " Global " ] [
" eval_during_train " ] and epoch_id % self . config [ " Global " ] [
2023-03-14 16:16:40 +08:00
" eval_interval " ] == 0 and epoch_id > start_eval_epoch :
2023-03-14 16:16:40 +08:00
acc = self . eval ( epoch_id )
# step lr (by epoch) according to given metric, such as acc
for i in range ( len ( self . lr_sch ) ) :
if getattr ( self . lr_sch [ i ] , " by_epoch " , False ) and \
type_name ( self . lr_sch [ i ] ) == " ReduceOnPlateau " :
self . lr_sch [ i ] . step ( acc )
if acc > best_metric [ " metric " ] :
best_metric [ " metric " ] = acc
best_metric [ " epoch " ] = epoch_id
2023-03-14 16:16:40 +08:00
save_load . save_model (
self . model ,
self . optimizer ,
2023-03-14 16:16:40 +08:00
best_metric ,
2023-03-14 16:16:40 +08:00
self . output_dir ,
ema = ema_module ,
model_name = self . config [ " Arch " ] [ " name " ] ,
2023-03-14 16:16:40 +08:00
prefix = " best_model " ,
2023-03-14 16:16:40 +08:00
loss = self . train_loss_func ,
2023-03-14 16:16:40 +08:00
save_student_model = True )
logger . info ( " [Eval][Epoch {} ][best metric: {} ] " . format (
epoch_id , best_metric [ " metric " ] ) )
logger . scaler (
name = " eval_acc " ,
value = acc ,
step = epoch_id ,
writer = self . vdl_writer )
self . model . train ( )
2023-03-14 16:16:40 +08:00
if self . ema :
ori_model , self . model = self . model , ema_module
2023-03-14 16:16:40 +08:00
acc_ema = self . eval ( epoch_id )
self . model = ori_model
2023-03-14 16:16:40 +08:00
ema_module . eval ( )
2023-03-14 16:16:40 +08:00
if acc_ema > best_metric_ema :
best_metric_ema = acc_ema
2023-03-14 16:16:40 +08:00
save_load . save_model (
self . model ,
self . optimizer ,
{ " metric " : acc_ema ,
" epoch " : epoch_id } ,
self . output_dir ,
ema = ema_module ,
model_name = self . config [ " Arch " ] [ " name " ] ,
prefix = " best_model_ema " ,
loss = self . train_loss_func )
2023-03-14 16:16:40 +08:00
logger . info ( " [Eval][Epoch {} ][best metric ema: {} ] " . format (
epoch_id , best_metric_ema ) )
logger . scaler (
name = " eval_acc_ema " ,
value = acc_ema ,
step = epoch_id ,
writer = self . vdl_writer )
# save model
if save_interval > 0 and epoch_id % save_interval == 0 :
2023-03-14 16:16:40 +08:00
save_load . save_model (
self . model ,
self . optimizer , { " metric " : acc ,
" epoch " : epoch_id } ,
self . output_dir ,
ema = ema_module ,
model_name = self . config [ " Arch " ] [ " name " ] ,
prefix = " epoch_ {} " . format ( epoch_id ) ,
loss = self . train_loss_func )
2023-03-14 16:16:40 +08:00
# save the latest model
2023-03-14 16:16:40 +08:00
save_load . save_model (
self . model ,
self . optimizer , { " metric " : acc ,
" epoch " : epoch_id } ,
self . output_dir ,
ema = ema_module ,
model_name = self . config [ " Arch " ] [ " name " ] ,
prefix = " latest " ,
loss = self . train_loss_func )
2023-03-14 16:16:40 +08:00
if self . vdl_writer is not None :
self . vdl_writer . close ( )
@paddle.no_grad ( )
def eval ( self , epoch_id = 0 ) :
assert self . mode in [ " train " , " eval " ]
self . model . eval ( )
eval_result = self . eval_func ( self , epoch_id )
self . model . train ( )
return eval_result
2021-08-22 23:10:23 +08:00
@paddle.no_grad ( )
def infer ( self ) :
assert self . mode == " infer " and self . eval_mode == " classification "
2023-03-17 17:49:11 +08:00
results = [ ]
2022-01-19 14:26:01 +08:00
total_trainer = dist . get_world_size ( )
local_rank = dist . get_rank ( )
2023-09-01 17:49:33 +08:00
infer_imgs = self . config [ " Infer " ] [ " infer_imgs " ]
infer_list = self . config [ " Infer " ] . get ( " infer_list " , None )
image_list = get_image_list ( infer_imgs , infer_list = infer_list )
2021-08-22 23:10:23 +08:00
# data split
image_list = image_list [ local_rank : : total_trainer ]
batch_size = self . config [ " Infer " ] [ " batch_size " ]
self . model . eval ( )
batch_data = [ ]
image_file_list = [ ]
2023-09-01 17:49:33 +08:00
save_path = self . config [ " Infer " ] . get ( " save_dir " , None )
2021-08-22 23:10:23 +08:00
for idx , image_file in enumerate ( image_list ) :
with open ( image_file , ' rb ' ) as f :
x = f . read ( )
for process in self . preprocess_func :
x = process ( x )
batch_data . append ( x )
image_file_list . append ( image_file )
if len ( batch_data ) > = batch_size or idx == len ( image_list ) - 1 :
batch_tensor = paddle . to_tensor ( batch_data )
2023-03-14 16:16:40 +08:00
2023-05-23 17:16:12 +08:00
with self . auto_cast ( is_eval = True ) :
2023-03-14 16:16:40 +08:00
out = self . model ( batch_tensor )
2022-05-17 19:34:50 +08:00
2021-08-22 23:10:23 +08:00
if isinstance ( out , list ) :
out = out [ 0 ]
2022-06-09 14:52:50 +08:00
if isinstance ( out , dict ) and " Student " in out :
out = out [ " Student " ]
2022-01-05 19:25:36 +08:00
if isinstance ( out , dict ) and " logits " in out :
out = out [ " logits " ]
if isinstance ( out , dict ) and " output " in out :
2021-09-15 11:35:49 +08:00
out = out [ " output " ]
2023-03-17 17:46:11 +08:00
2023-04-17 20:07:53 +08:00
result = self . postprocess_func ( out , image_file_list )
2023-09-01 17:49:33 +08:00
if not save_path :
logger . info ( result )
2023-04-17 20:07:53 +08:00
results . extend ( result )
2021-08-22 23:10:23 +08:00
batch_data . clear ( )
image_file_list . clear ( )
2023-08-29 14:32:07 +08:00
if save_path :
save_predict_result ( save_path , results )
2023-03-17 17:46:11 +08:00
return results
2021-08-22 23:10:23 +08:00
def export ( self ) :
assert self . mode == " export "
2022-05-25 16:13:38 +08:00
use_multilabel = self . config [ " Global " ] . get (
" use_multilabel " ,
2022-08-22 14:40:26 +08:00
False ) or " ATTRMetric " in self . config [ " Metric " ] [ " Eval " ] [ 0 ]
2021-09-26 15:05:13 +08:00
model = ExportModel ( self . config [ " Arch " ] , self . model , use_multilabel )
2022-12-16 17:20:26 +08:00
if self . config [ " Global " ] [ " pretrained_model " ] is not None :
2023-10-30 13:44:16 +08:00
load_dygraph_pretrain (
model . base_model ,
self . config [ " Global " ] [ " pretrained_model " ] )
2021-08-22 23:10:23 +08:00
model . eval ( )
2022-05-13 23:41:08 +08:00
2023-01-06 14:40:48 +08:00
# for re-parameterization nets
2022-06-28 13:58:14 +08:00
for layer in self . model . sublayers ( ) :
2023-01-06 14:40:48 +08:00
if hasattr ( layer , " re_parameterize " ) and not getattr ( layer ,
" is_repped " ) :
layer . re_parameterize ( )
2022-05-13 23:41:08 +08:00
2021-08-27 17:32:37 +08:00
save_path = os . path . join ( self . config [ " Global " ] [ " save_inference_dir " ] ,
" inference " )
2022-05-23 19:04:27 +08:00
model = paddle . jit . to_static (
model ,
input_spec = [
paddle . static . InputSpec (
shape = [ None ] + self . config [ " Global " ] [ " image_shape " ] ,
dtype = ' float32 ' )
] )
if hasattr ( model . base_model ,
" quanter " ) and model . base_model . quanter is not None :
model . base_model . quanter . save_quantized_model ( model ,
save_path + " _int8 " )
2021-08-27 17:32:37 +08:00
else :
paddle . jit . save ( model , save_path )
2023-06-12 14:34:23 +08:00
if self . config [ " Global " ] . get ( " export_for_fd " , False ) :
src_path = self . config [ " Global " ] [ " infer_config_path " ]
dst_path = os . path . join ( self . config [ " Global " ] [ " save_inference_dir " ] , ' inference.yml ' )
shutil . copy ( src_path , dst_path )
2022-05-17 17:06:14 +08:00
logger . info (
f " Export succeeded! The inference model exported has been saved in \" { self . config [ ' Global ' ] [ ' save_inference_dir ' ] } \" . "
)
2021-08-22 23:10:23 +08:00
2023-05-23 11:01:25 +08:00
def _init_amp ( self ) :
2023-05-26 11:33:04 +08:00
if self . mode == " export " :
return
2023-05-23 17:16:12 +08:00
amp_config = self . config . get ( " AMP " , None )
2023-05-25 14:26:33 +08:00
use_amp = True if amp_config and amp_config . get ( " use_amp " ,
True ) else False
2023-05-23 11:01:25 +08:00
2023-05-23 17:16:12 +08:00
if not use_amp :
self . auto_cast = AutoCast ( use_amp )
self . scaler = build_scaler ( use_amp )
else :
2023-05-23 11:01:25 +08:00
AMP_RELATED_FLAGS_SETTING = { ' FLAGS_max_inplace_grad_add ' : 8 , }
if paddle . is_compiled_with_cuda ( ) :
AMP_RELATED_FLAGS_SETTING . update ( {
' FLAGS_cudnn_batchnorm_spatial_persistent ' : 1
} )
paddle . set_flags ( AMP_RELATED_FLAGS_SETTING )
2023-05-23 17:16:12 +08:00
use_promote = amp_config . get ( " use_promote " , False )
amp_level = amp_config . get ( " level " , " O1 " )
if amp_level not in [ " O1 " , " O2 " ] :
2023-05-23 11:01:25 +08:00
msg = " [Parameter Error]: The optimize level of AMP only support ' O1 ' and ' O2 ' . The level has been set ' O1 ' . "
logger . warning ( msg )
2023-05-23 17:16:12 +08:00
amp_level = amp_config [ " level " ] = " O1 "
2023-05-23 11:01:25 +08:00
2023-05-23 17:16:12 +08:00
amp_eval = self . config [ " AMP " ] . get ( " use_fp16_test " , False )
2023-05-23 11:01:25 +08:00
# TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2
if self . mode == " train " and self . config [ " Global " ] . get (
" eval_during_train " ,
2023-05-23 17:16:12 +08:00
True ) and amp_level == " O2 " and amp_eval == False :
2023-05-23 11:01:25 +08:00
msg = " PaddlePaddle only support FP16 evaluation when training with AMP O2 now. "
logger . warning ( msg )
self . config [ " AMP " ] [ " use_fp16_test " ] = True
2023-05-23 17:16:12 +08:00
amp_eval = True
self . auto_cast = AutoCast (
use_amp ,
amp_level = amp_level ,
use_promote = use_promote ,
amp_eval = amp_eval )
scale_loss = amp_config . get ( " scale_loss " , 1.0 )
use_dynamic_loss_scaling = amp_config . get (
" use_dynamic_loss_scaling " , False )
self . scaler = build_scaler (
use_amp ,
scale_loss = scale_loss ,
use_dynamic_loss_scaling = use_dynamic_loss_scaling )
2023-05-23 11:01:25 +08:00
if self . mode == " train " :
self . model , self . optimizer = paddle . amp . decorate (
models = self . model ,
optimizers = self . optimizer ,
2023-05-23 17:16:12 +08:00
level = amp_level ,
2023-05-23 11:01:25 +08:00
save_dtype = ' float32 ' )
2023-05-29 16:19:20 +08:00
elif amp_eval :
2023-05-23 11:01:25 +08:00
self . model = paddle . amp . decorate (
2023-05-23 17:16:12 +08:00
models = self . model , level = amp_level , save_dtype = ' float32 ' )
2023-05-23 11:01:25 +08:00
if self . mode == " train " and len ( self . train_loss_func . parameters (
) ) > 0 :
self . train_loss_func = paddle . amp . decorate (
models = self . train_loss_func ,
level = self . amp_level ,
save_dtype = ' float32 ' )
2021-08-22 23:10:23 +08:00
2021-12-10 11:14:14 +08:00
class ExportModel ( TheseusLayer ) :
2021-08-22 23:10:23 +08:00
"""
ExportModel : add softmax onto the model
"""
2021-09-26 15:05:13 +08:00
def __init__ ( self , config , model , use_multilabel ) :
2021-08-22 23:10:23 +08:00
super ( ) . __init__ ( )
self . base_model = model
# we should choose a final model to export
if isinstance ( self . base_model , DistillationModel ) :
self . infer_model_name = config [ " infer_model_name " ]
else :
self . infer_model_name = None
self . infer_output_key = config . get ( " infer_output_key " , None )
if self . infer_output_key == " features " and isinstance ( self . base_model ,
RecModel ) :
self . base_model . head = IdentityHead ( )
2021-09-26 15:05:13 +08:00
if use_multilabel :
self . out_act = nn . Sigmoid ( )
2021-08-22 23:10:23 +08:00
else :
2021-09-26 15:05:13 +08:00
if config . get ( " infer_add_softmax " , True ) :
self . out_act = nn . Softmax ( axis = - 1 )
else :
self . out_act = None
2021-08-22 23:10:23 +08:00
def eval ( self ) :
self . training = False
for layer in self . sublayers ( ) :
layer . training = False
layer . eval ( )
def forward ( self , x ) :
x = self . base_model ( x )
if isinstance ( x , list ) :
x = x [ 0 ]
if self . infer_model_name is not None :
x = x [ self . infer_model_name ]
if self . infer_output_key is not None :
x = x [ self . infer_output_key ]
2021-09-26 15:05:13 +08:00
if self . out_act is not None :
2022-02-28 19:11:50 +08:00
if isinstance ( x , dict ) :
x = x [ " logits " ]
2021-09-26 15:05:13 +08:00
x = self . out_act ( x )
2021-08-22 23:10:23 +08:00
return x