2021-08-22 23:10:23 +08:00
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
2021-08-24 15:07:17 +08:00
import os
2021-08-22 23:10:23 +08:00
import platform
import paddle
import paddle . distributed as dist
from visualdl import LogWriter
2021-08-24 15:07:17 +08:00
from paddle import nn
2021-09-14 12:06:37 +08:00
import numpy as np
import random
2021-08-22 23:10:23 +08:00
from ppcls . utils . check import check_gpu
from ppcls . utils . misc import AverageMeter
from ppcls . utils import logger
from ppcls . utils . logger import init_logger
from ppcls . utils . config import print_config
from ppcls . data import build_dataloader
2021-12-10 11:14:14 +08:00
from ppcls . arch import build_model , RecModel , DistillationModel , TheseusLayer
2021-08-22 23:10:23 +08:00
from ppcls . arch import apply_to_static
from ppcls . loss import build_loss
from ppcls . metric import build_metrics
from ppcls . optimizer import build_optimizer
from ppcls . utils . save_load import load_dygraph_pretrain , load_dygraph_pretrain_from_url
from ppcls . utils . save_load import init_model
from ppcls . utils import save_load
from ppcls . data . utils . get_image_list import get_image_list
from ppcls . data . postprocess import build_postprocess
from ppcls . data import create_operators
2021-08-24 11:02:55 +08:00
from ppcls . engine . train import train_epoch
from ppcls . engine import evaluation
2021-08-22 23:10:23 +08:00
from ppcls . arch . gears . identity_head import IdentityHead
2021-08-24 11:02:55 +08:00
class Engine ( object ) :
2021-08-22 23:10:23 +08:00
def __init__ ( self , config , mode = " train " ) :
2021-08-24 11:02:55 +08:00
assert mode in [ " train " , " eval " , " infer " , " export " ]
2021-08-22 23:10:23 +08:00
self . mode = mode
self . config = config
self . eval_mode = self . config [ " Global " ] . get ( " eval_mode " ,
" classification " )
2022-01-05 19:25:36 +08:00
if " Head " in self . config [ " Arch " ] or self . config [ " Arch " ] . get ( " is_rec " ,
False ) :
2021-09-02 15:42:22 +08:00
self . is_rec = True
else :
self . is_rec = False
2021-09-14 12:06:37 +08:00
# set seed
seed = self . config [ " Global " ] . get ( " seed " , False )
2021-11-01 14:16:24 +08:00
if seed or seed == 0 :
2021-09-14 12:06:37 +08:00
assert isinstance ( seed , int ) , " The ' seed ' must be a integer! "
paddle . seed ( seed )
np . random . seed ( seed )
random . seed ( seed )
2021-08-22 23:10:23 +08:00
# init logger
self . output_dir = self . config [ ' Global ' ] [ ' output_dir ' ]
log_file = os . path . join ( self . output_dir , self . config [ " Arch " ] [ " name " ] ,
f " { mode } .log " )
2022-03-30 16:31:35 +08:00
init_logger ( log_file = log_file )
2021-08-22 23:10:23 +08:00
print_config ( config )
# init train_func and eval_func
2021-08-24 15:07:17 +08:00
assert self . eval_mode in [ " classification " , " retrieval " ] , logger . error (
" Invalid eval mode: {} " . format ( self . eval_mode ) )
2021-08-24 11:02:55 +08:00
self . train_epoch_func = train_epoch
self . eval_func = getattr ( evaluation , self . eval_mode + " _eval " )
2021-08-22 23:10:23 +08:00
self . use_dali = self . config [ ' Global ' ] . get ( " use_dali " , False )
# for visualdl
self . vdl_writer = None
2022-01-19 13:57:21 +08:00
if self . config [ ' Global ' ] [
' use_visualdl ' ] and mode == " train " and dist . get_rank ( ) == 0 :
2021-08-22 23:10:23 +08:00
vdl_writer_path = os . path . join ( self . output_dir , " vdl " )
if not os . path . exists ( vdl_writer_path ) :
os . makedirs ( vdl_writer_path )
self . vdl_writer = LogWriter ( logdir = vdl_writer_path )
# set device
2022-03-30 16:31:35 +08:00
assert self . config [ " Global " ] [
" device " ] in [ " cpu " , " gpu " , " xpu " , " npu " , " mlu " ]
2021-08-22 23:10:23 +08:00
self . device = paddle . set_device ( self . config [ " Global " ] [ " device " ] )
logger . info ( ' train with paddle {} and device {} ' . format (
paddle . __version__ , self . device ) )
2022-01-19 16:29:20 +08:00
if " class_num " in config [ " Global " ] :
global_class_num = config [ " Global " ] [ " class_num " ]
if " class_num " not in config [ " Arch " ] :
config [ " Arch " ] [ " class_num " ] = global_class_num
msg = f " The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to { global_class_num } . "
else :
msg = " The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored. "
logger . warning ( msg )
2021-10-08 22:22:41 +08:00
#TODO(gaotingquan): support rec
2021-09-30 18:16:57 +08:00
class_num = config [ " Arch " ] . get ( " class_num " , None )
self . config [ " DataLoader " ] . update ( { " class_num " : class_num } )
2021-08-22 23:10:23 +08:00
# build dataloader
if self . mode == ' train ' :
self . train_dataloader = build_dataloader (
self . config [ " DataLoader " ] , " Train " , self . device , self . use_dali )
2021-10-21 10:39:27 +08:00
if self . mode == " eval " or ( self . mode == " train " and
self . config [ " Global " ] [ " eval_during_train " ] ) :
2021-08-22 23:10:23 +08:00
if self . eval_mode == " classification " :
self . eval_dataloader = build_dataloader (
self . config [ " DataLoader " ] , " Eval " , self . device ,
self . use_dali )
elif self . eval_mode == " retrieval " :
2021-08-27 14:25:27 +08:00
self . gallery_query_dataloader = None
if len ( self . config [ " DataLoader " ] [ " Eval " ] . keys ( ) ) == 1 :
key = list ( self . config [ " DataLoader " ] [ " Eval " ] . keys ( ) ) [ 0 ]
self . gallery_query_dataloader = build_dataloader (
self . config [ " DataLoader " ] [ " Eval " ] , key , self . device ,
self . use_dali )
else :
self . gallery_dataloader = build_dataloader (
self . config [ " DataLoader " ] [ " Eval " ] , " Gallery " ,
self . device , self . use_dali )
self . query_dataloader = build_dataloader (
self . config [ " DataLoader " ] [ " Eval " ] , " Query " ,
self . device , self . use_dali )
2021-08-22 23:10:23 +08:00
# build loss
if self . mode == " train " :
loss_info = self . config [ " Loss " ] [ " Train " ]
self . train_loss_func = build_loss ( loss_info )
2021-10-21 10:39:27 +08:00
if self . mode == " eval " or ( self . mode == " train " and
self . config [ " Global " ] [ " eval_during_train " ] ) :
2021-08-22 23:10:23 +08:00
loss_config = self . config . get ( " Loss " , None )
if loss_config is not None :
loss_config = loss_config . get ( " Eval " )
if loss_config is not None :
self . eval_loss_func = build_loss ( loss_config )
else :
self . eval_loss_func = None
else :
self . eval_loss_func = None
# build metric
if self . mode == ' train ' :
metric_config = self . config . get ( " Metric " )
if metric_config is not None :
metric_config = metric_config . get ( " Train " )
if metric_config is not None :
2022-03-07 12:46:13 +08:00
if hasattr (
self . train_dataloader , " collate_fn "
) and self . train_dataloader . collate_fn is not None :
2021-12-24 18:46:58 +08:00
for m_idx , m in enumerate ( metric_config ) :
if " TopkAcc " in m :
msg = f " ' TopkAcc ' metric can not be used when setting ' batch_transform_ops ' in config. The ' TopkAcc ' metric has been removed. "
logger . warning ( msg )
break
metric_config . pop ( m_idx )
2021-08-22 23:10:23 +08:00
self . train_metric_func = build_metrics ( metric_config )
else :
self . train_metric_func = None
else :
self . train_metric_func = None
2021-10-21 10:39:27 +08:00
if self . mode == " eval " or ( self . mode == " train " and
self . config [ " Global " ] [ " eval_during_train " ] ) :
2021-08-22 23:10:23 +08:00
metric_config = self . config . get ( " Metric " )
if self . eval_mode == " classification " :
if metric_config is not None :
metric_config = metric_config . get ( " Eval " )
if metric_config is not None :
self . eval_metric_func = build_metrics ( metric_config )
elif self . eval_mode == " retrieval " :
if metric_config is None :
metric_config = [ { " name " : " Recallk " , " topk " : ( 1 , 5 ) } ]
else :
metric_config = metric_config [ " Eval " ]
self . eval_metric_func = build_metrics ( metric_config )
else :
self . eval_metric_func = None
# build model
2021-12-09 14:51:40 +08:00
self . model = build_model ( self . config )
2021-08-22 23:10:23 +08:00
# set @to_static for benchmark, skip this by default.
apply_to_static ( self . config , self . model )
2021-10-18 18:07:14 +08:00
2021-08-22 23:10:23 +08:00
# load_pretrain
if self . config [ " Global " ] [ " pretrained_model " ] is not None :
if self . config [ " Global " ] [ " pretrained_model " ] . startswith ( " http " ) :
load_dygraph_pretrain_from_url (
2022-04-19 19:54:48 +08:00
[ self . model , getattr ( self , ' train_loss_func ' , None ) ] ,
2022-04-19 14:26:42 +08:00
self . config [ " Global " ] [ " pretrained_model " ] )
2021-08-22 23:10:23 +08:00
else :
load_dygraph_pretrain (
2022-04-19 19:54:48 +08:00
[ self . model , getattr ( self , ' train_loss_func ' , None ) ] ,
2022-04-19 14:26:42 +08:00
self . config [ " Global " ] [ " pretrained_model " ] )
2021-08-22 23:10:23 +08:00
# build optimizer
if self . mode == ' train ' :
self . optimizer , self . lr_sch = build_optimizer (
2022-04-21 16:31:28 +08:00
self . config [ " Optimizer " ] , self . config [ " Global " ] [ " epochs " ] ,
2022-04-19 14:26:42 +08:00
len ( self . train_dataloader ) ,
[ self . model , self . train_loss_func ] )
2022-01-11 16:48:03 +08:00
2022-04-29 18:21:09 +08:00
# AMP training and evaluating
self . amp = " AMP " in self . config and self . config [ " AMP " ] is not None
self . amp_eval = False
2022-04-25 20:40:29 +08:00
# for amp
2022-01-07 10:10:46 +08:00
if self . amp :
2022-04-29 18:21:09 +08:00
AMP_RELATED_FLAGS_SETTING = { ' FLAGS_max_inplace_grad_add ' : 8 , }
if paddle . is_compiled_with_cuda ( ) :
AMP_RELATED_FLAGS_SETTING . update ( {
' FLAGS_cudnn_batchnorm_spatial_persistent ' : 1
} )
paddle . fluid . set_flags ( AMP_RELATED_FLAGS_SETTING )
self . scale_loss = self . config [ " AMP " ] . get ( " scale_loss " , 1.0 )
self . use_dynamic_loss_scaling = self . config [ " AMP " ] . get (
" use_dynamic_loss_scaling " , False )
2022-01-07 10:10:46 +08:00
self . scaler = paddle . amp . GradScaler (
init_loss_scaling = self . scale_loss ,
use_dynamic_loss_scaling = self . use_dynamic_loss_scaling )
2022-04-29 18:21:09 +08:00
self . amp_level = self . config [ ' AMP ' ] . get ( " level " , " O1 " )
if self . amp_level not in [ " O1 " , " O2 " ] :
2022-01-11 16:48:03 +08:00
msg = " [Parameter Error]: The optimize level of AMP only support ' O1 ' and ' O2 ' . The level has been set ' O1 ' . "
logger . warning ( msg )
self . config [ ' AMP ' ] [ " level " ] = " O1 "
2022-04-29 18:21:09 +08:00
self . amp_level = " O1 "
self . amp_eval = self . config [ " AMP " ] . get ( " use_fp16_test " , False )
# TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2
2022-05-17 19:34:50 +08:00
if self . mode == " train " and self . config [ " Global " ] . get (
2022-04-29 18:21:09 +08:00
" eval_during_train " ,
True ) and self . amp_level == " O2 " and self . amp_eval == False :
msg = " PaddlePaddle only support FP16 evaluation when training with AMP O2 now. "
logger . warning ( msg )
self . config [ " AMP " ] [ " use_fp16_test " ] = True
self . amp_eval = True
2022-05-06 16:46:41 +08:00
# TODO(gaotingquan): to compatible with different versions of Paddle
paddle_version = paddle . __version__ [ : 3 ]
2022-04-29 18:21:09 +08:00
# paddle version < 2.3.0 and not develop
2022-05-06 16:46:41 +08:00
if paddle_version not in [ " 2.3 " , " 0.0 " ] :
2022-04-29 18:21:09 +08:00
if self . mode == " train " :
self . model , self . optimizer = paddle . amp . decorate (
models = self . model ,
optimizers = self . optimizer ,
level = self . amp_level ,
save_dtype = ' float32 ' )
elif self . amp_eval :
if self . amp_level == " O2 " :
msg = " The PaddlePaddle that installed not support FP16 evaluation in AMP O2. Please use PaddlePaddle version >= 2.3.0. Use FP32 evaluation instead and please notice the Eval Dataset output_fp16 should be ' False ' . "
logger . warning ( msg )
self . amp_eval = False
else :
self . model , self . optimizer = paddle . amp . decorate (
models = self . model ,
level = self . amp_level ,
save_dtype = ' float32 ' )
# paddle version >= 2.3.0 or develop
else :
2022-05-17 19:34:50 +08:00
if self . mode == " train " or self . amp_eval :
self . model = paddle . amp . decorate (
models = self . model ,
level = self . amp_level ,
save_dtype = ' float32 ' )
2022-04-29 18:21:09 +08:00
2022-04-25 20:40:29 +08:00
if self . mode == " train " and len ( self . train_loss_func . parameters (
) ) > 0 :
2022-04-19 14:26:42 +08:00
self . train_loss_func = paddle . amp . decorate (
models = self . train_loss_func ,
2022-04-29 18:21:09 +08:00
level = self . amp_level ,
2022-04-19 14:26:42 +08:00
save_dtype = ' float32 ' )
2021-08-22 23:10:23 +08:00
2022-04-28 22:19:01 +08:00
# check the gpu num
2022-01-19 14:26:01 +08:00
world_size = dist . get_world_size ( )
self . config [ " Global " ] [ " distributed " ] = world_size != 1
2022-04-28 22:19:01 +08:00
if self . mode == " train " :
2022-05-05 20:28:59 +08:00
std_gpu_num = 8 if isinstance (
self . config [ " Optimizer " ] ,
dict ) and self . config [ " Optimizer " ] [ " name " ] == " AdamW " else 4
2022-04-28 22:19:01 +08:00
if world_size != std_gpu_num :
msg = f " The training strategy provided by PaddleClas is based on { std_gpu_num } gpus. But the number of gpu is { world_size } in current training. Please modify the stategy (learning rate, batch size and so on) if use this config to train. "
logger . warning ( msg )
# for distributed
2021-08-22 23:10:23 +08:00
if self . config [ " Global " ] [ " distributed " ] :
dist . init_parallel_env ( )
self . model = paddle . DataParallel ( self . model )
2022-04-19 19:54:48 +08:00
if self . mode == ' train ' and len ( self . train_loss_func . parameters (
) ) > 0 :
2022-04-19 14:26:42 +08:00
self . train_loss_func = paddle . DataParallel (
self . train_loss_func )
2021-08-22 23:10:23 +08:00
# build postprocess for infer
if self . mode == ' infer ' :
self . preprocess_func = create_operators ( self . config [ " Infer " ] [
" transforms " ] )
self . postprocess_func = build_postprocess ( self . config [ " Infer " ] [
" PostProcess " ] )
def train ( self ) :
assert self . mode == " train "
print_batch_step = self . config [ ' Global ' ] [ ' print_batch_step ' ]
save_interval = self . config [ " Global " ] [ " save_interval " ]
best_metric = {
2022-05-14 17:31:52 +08:00
" metric " : - 1.0 ,
2021-08-22 23:10:23 +08:00
" epoch " : 0 ,
}
# key:
# val: metrics list word
self . output_info = dict ( )
self . time_info = {
" batch_cost " : AverageMeter (
" batch_cost " , ' .5f ' , postfix = " s, " ) ,
" reader_cost " : AverageMeter (
" reader_cost " , " .5f " , postfix = " s, " ) ,
}
# global iter counter
self . global_step = 0
2022-04-19 14:26:42 +08:00
if self . config . Global . checkpoints is not None :
metric_info = init_model ( self . config . Global , self . model ,
self . optimizer , self . train_loss_func )
2021-08-22 23:10:23 +08:00
if metric_info is not None :
best_metric . update ( metric_info )
self . max_iter = len ( self . train_dataloader ) - 1 if platform . system (
) == " Windows " else len ( self . train_dataloader )
2022-04-21 00:17:54 +08:00
2021-08-22 23:10:23 +08:00
for epoch_id in range ( best_metric [ " epoch " ] + 1 ,
self . config [ " Global " ] [ " epochs " ] + 1 ) :
acc = 0.0
# for one epoch train
2021-08-24 11:02:55 +08:00
self . train_epoch_func ( self , epoch_id , print_batch_step )
2021-08-22 23:10:23 +08:00
if self . use_dali :
self . train_dataloader . reset ( )
2022-05-17 21:24:24 +08:00
metric_msg = " , " . join (
[ self . output_info [ key ] . avg_info for key in self . output_info ] )
2021-08-22 23:10:23 +08:00
logger . info ( " [Train][Epoch {} / {} ][Avg] {} " . format (
epoch_id , self . config [ " Global " ] [ " epochs " ] , metric_msg ) )
self . output_info . clear ( )
# eval model and save model if possible
2022-05-17 21:24:24 +08:00
start_eval_epoch = self . config [ " Global " ] . get ( " start_eval_epoch " ,
0 ) - 1
2021-08-22 23:10:23 +08:00
if self . config [ " Global " ] [
" eval_during_train " ] and epoch_id % self . config [ " Global " ] [
2022-05-14 17:31:52 +08:00
" eval_interval " ] == 0 and epoch_id > start_eval_epoch :
2021-08-22 23:10:23 +08:00
acc = self . eval ( epoch_id )
if acc > best_metric [ " metric " ] :
best_metric [ " metric " ] = acc
best_metric [ " epoch " ] = epoch_id
save_load . save_model (
self . model ,
self . optimizer ,
best_metric ,
self . output_dir ,
model_name = self . config [ " Arch " ] [ " name " ] ,
2022-04-19 14:26:42 +08:00
prefix = " best_model " ,
2022-05-17 21:24:24 +08:00
loss = self . train_loss_func ,
save_student_model = True )
2021-08-22 23:10:23 +08:00
logger . info ( " [Eval][Epoch {} ][best metric: {} ] " . format (
epoch_id , best_metric [ " metric " ] ) )
logger . scaler (
name = " eval_acc " ,
value = acc ,
step = epoch_id ,
writer = self . vdl_writer )
self . model . train ( )
# save model
if epoch_id % save_interval == 0 :
save_load . save_model (
self . model ,
self . optimizer , { " metric " : acc ,
" epoch " : epoch_id } ,
self . output_dir ,
model_name = self . config [ " Arch " ] [ " name " ] ,
2022-04-19 14:26:42 +08:00
prefix = " epoch_ {} " . format ( epoch_id ) ,
loss = self . train_loss_func )
2021-12-27 14:30:01 +08:00
# save the latest model
save_load . save_model (
self . model ,
self . optimizer , { " metric " : acc ,
" epoch " : epoch_id } ,
self . output_dir ,
model_name = self . config [ " Arch " ] [ " name " ] ,
2022-04-19 14:26:42 +08:00
prefix = " latest " ,
loss = self . train_loss_func )
2021-08-22 23:10:23 +08:00
if self . vdl_writer is not None :
self . vdl_writer . close ( )
@paddle.no_grad ( )
def eval ( self , epoch_id = 0 ) :
assert self . mode in [ " train " , " eval " ]
self . model . eval ( )
2021-08-24 11:02:55 +08:00
eval_result = self . eval_func ( self , epoch_id )
2021-08-22 23:10:23 +08:00
self . model . train ( )
return eval_result
@paddle.no_grad ( )
def infer ( self ) :
assert self . mode == " infer " and self . eval_mode == " classification "
2022-01-19 14:26:01 +08:00
total_trainer = dist . get_world_size ( )
local_rank = dist . get_rank ( )
2021-08-22 23:10:23 +08:00
image_list = get_image_list ( self . config [ " Infer " ] [ " infer_imgs " ] )
# data split
image_list = image_list [ local_rank : : total_trainer ]
batch_size = self . config [ " Infer " ] [ " batch_size " ]
self . model . eval ( )
batch_data = [ ]
image_file_list = [ ]
for idx , image_file in enumerate ( image_list ) :
with open ( image_file , ' rb ' ) as f :
x = f . read ( )
for process in self . preprocess_func :
x = process ( x )
batch_data . append ( x )
image_file_list . append ( image_file )
if len ( batch_data ) > = batch_size or idx == len ( image_list ) - 1 :
batch_tensor = paddle . to_tensor ( batch_data )
2022-05-17 19:34:50 +08:00
if self . amp and self . amp_eval :
with paddle . amp . auto_cast (
custom_black_list = {
" flatten_contiguous_range " , " greater_than "
} ,
level = self . amp_level ) :
out = self . model ( batch_tensor )
else :
out = self . model ( batch_tensor )
2021-08-22 23:10:23 +08:00
if isinstance ( out , list ) :
out = out [ 0 ]
2022-01-05 19:25:36 +08:00
if isinstance ( out , dict ) and " logits " in out :
out = out [ " logits " ]
if isinstance ( out , dict ) and " output " in out :
2021-09-15 11:35:49 +08:00
out = out [ " output " ]
2021-08-22 23:10:23 +08:00
result = self . postprocess_func ( out , image_file_list )
print ( result )
batch_data . clear ( )
image_file_list . clear ( )
def export ( self ) :
assert self . mode == " export "
2021-09-26 15:05:13 +08:00
use_multilabel = self . config [ " Global " ] . get ( " use_multilabel " , False )
model = ExportModel ( self . config [ " Arch " ] , self . model , use_multilabel )
2021-08-22 23:10:23 +08:00
if self . config [ " Global " ] [ " pretrained_model " ] is not None :
load_dygraph_pretrain ( model . base_model ,
self . config [ " Global " ] [ " pretrained_model " ] )
model . eval ( )
2022-05-13 23:41:08 +08:00
# for rep nets
for layer in self . model . sublayers ( ) :
if hasattr ( layer , " rep " ) :
layer . rep ( )
2021-08-27 17:32:37 +08:00
save_path = os . path . join ( self . config [ " Global " ] [ " save_inference_dir " ] ,
" inference " )
2021-12-09 20:08:57 +08:00
if model . quanter :
2021-12-09 14:51:40 +08:00
model . quanter . save_quantized_model (
2021-12-03 15:39:31 +08:00
model . base_model ,
2021-08-27 17:32:37 +08:00
save_path ,
input_spec = [
paddle . static . InputSpec (
shape = [ None ] + self . config [ " Global " ] [ " image_shape " ] ,
dtype = ' float32 ' )
] )
else :
model = paddle . jit . to_static (
model ,
input_spec = [
paddle . static . InputSpec (
shape = [ None ] + self . config [ " Global " ] [ " image_shape " ] ,
dtype = ' float32 ' )
] )
paddle . jit . save ( model , save_path )
2022-05-17 17:06:14 +08:00
logger . info (
f " Export succeeded! The inference model exported has been saved in \" { self . config [ ' Global ' ] [ ' save_inference_dir ' ] } \" . "
)
2021-08-22 23:10:23 +08:00
2021-12-10 11:14:14 +08:00
class ExportModel ( TheseusLayer ) :
2021-08-22 23:10:23 +08:00
"""
ExportModel : add softmax onto the model
"""
2021-09-26 15:05:13 +08:00
def __init__ ( self , config , model , use_multilabel ) :
2021-08-22 23:10:23 +08:00
super ( ) . __init__ ( )
self . base_model = model
# we should choose a final model to export
if isinstance ( self . base_model , DistillationModel ) :
self . infer_model_name = config [ " infer_model_name " ]
else :
self . infer_model_name = None
self . infer_output_key = config . get ( " infer_output_key " , None )
if self . infer_output_key == " features " and isinstance ( self . base_model ,
RecModel ) :
self . base_model . head = IdentityHead ( )
2021-09-26 15:05:13 +08:00
if use_multilabel :
self . out_act = nn . Sigmoid ( )
2021-08-22 23:10:23 +08:00
else :
2021-09-26 15:05:13 +08:00
if config . get ( " infer_add_softmax " , True ) :
self . out_act = nn . Softmax ( axis = - 1 )
else :
self . out_act = None
2021-08-22 23:10:23 +08:00
def eval ( self ) :
self . training = False
for layer in self . sublayers ( ) :
layer . training = False
layer . eval ( )
def forward ( self , x ) :
x = self . base_model ( x )
if isinstance ( x , list ) :
x = x [ 0 ]
if self . infer_model_name is not None :
x = x [ self . infer_model_name ]
if self . infer_output_key is not None :
x = x [ self . infer_output_key ]
2021-09-26 15:05:13 +08:00
if self . out_act is not None :
2022-02-28 19:11:50 +08:00
if isinstance ( x , dict ) :
x = x [ " logits " ]
2021-09-26 15:05:13 +08:00
x = self . out_act ( x )
2021-08-22 23:10:23 +08:00
return x