475 lines
16 KiB
Python
475 lines
16 KiB
Python
|
import paddle
|
||
|
import numpy as np
|
||
|
import os
|
||
|
import paddle.nn as nn
|
||
|
import paddle.distributed as dist
|
||
|
dist.get_world_size()
|
||
|
dist.init_parallel_env()
|
||
|
|
||
|
from loss import build_loss, LossDistill, DMLLoss, KLJSLoss
|
||
|
from optimizer import create_optimizer
|
||
|
from data_loader import build_dataloader
|
||
|
from metric import create_metric
|
||
|
from mv3 import MobileNetV3_large_x0_5, distillmv3_large_x0_5, build_model
|
||
|
from config import preprocess
|
||
|
import time
|
||
|
|
||
|
from paddleslim.dygraph.quant import QAT
|
||
|
from slim.slim_quant import PACT, quant_config
|
||
|
from slim.slim_fpgm import prune_model
|
||
|
from utils import load_model
|
||
|
|
||
|
|
||
|
def _mkdir_if_not_exist(path, logger):
|
||
|
"""
|
||
|
mkdir if not exists, ignore the exception when multiprocess mkdir together
|
||
|
"""
|
||
|
if not os.path.exists(path):
|
||
|
try:
|
||
|
os.makedirs(path)
|
||
|
except OSError as e:
|
||
|
if e.errno == errno.EEXIST and os.path.isdir(path):
|
||
|
logger.warning(
|
||
|
'be happy if some process has already created {}'.format(
|
||
|
path))
|
||
|
else:
|
||
|
raise OSError('Failed to mkdir {}'.format(path))
|
||
|
|
||
|
|
||
|
def save_model(model,
|
||
|
optimizer,
|
||
|
model_path,
|
||
|
logger,
|
||
|
is_best=False,
|
||
|
prefix='ppocr',
|
||
|
**kwargs):
|
||
|
"""
|
||
|
save model to the target path
|
||
|
"""
|
||
|
_mkdir_if_not_exist(model_path, logger)
|
||
|
model_prefix = os.path.join(model_path, prefix)
|
||
|
paddle.save(model.state_dict(), model_prefix + '.pdparams')
|
||
|
if type(optimizer) is list:
|
||
|
paddle.save(optimizer[0].state_dict(), model_prefix + '.pdopt')
|
||
|
paddle.save(optimizer[1].state_dict(), model_prefix + "_1" + '.pdopt')
|
||
|
|
||
|
else:
|
||
|
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
|
||
|
|
||
|
# # save metric and config
|
||
|
# with open(model_prefix + '.states', 'wb') as f:
|
||
|
# pickle.dump(kwargs, f, protocol=2)
|
||
|
if is_best:
|
||
|
logger.info('save best model is to {}'.format(model_prefix))
|
||
|
else:
|
||
|
logger.info("save model in {}".format(model_prefix))
|
||
|
|
||
|
|
||
|
def amp_scaler(config):
|
||
|
if 'AMP' in config and config['AMP']['use_amp'] is True:
|
||
|
AMP_RELATED_FLAGS_SETTING = {
|
||
|
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
|
||
|
'FLAGS_max_inplace_grad_add': 8,
|
||
|
}
|
||
|
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
|
||
|
scale_loss = config["AMP"].get("scale_loss", 1.0)
|
||
|
use_dynamic_loss_scaling = config["AMP"].get("use_dynamic_loss_scaling",
|
||
|
False)
|
||
|
scaler = paddle.amp.GradScaler(
|
||
|
init_loss_scaling=scale_loss,
|
||
|
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
|
||
|
return scaler
|
||
|
else:
|
||
|
return None
|
||
|
|
||
|
|
||
|
def set_seed(seed):
|
||
|
paddle.seed(seed)
|
||
|
np.random.seed(seed)
|
||
|
|
||
|
|
||
|
def train(config, scaler=None):
|
||
|
EPOCH = config['epoch']
|
||
|
topk = config['topk']
|
||
|
|
||
|
batch_size = config['TRAIN']['batch_size']
|
||
|
num_workers = config['TRAIN']['num_workers']
|
||
|
train_loader = build_dataloader(
|
||
|
'train', batch_size=batch_size, num_workers=num_workers)
|
||
|
|
||
|
# build metric
|
||
|
metric_func = create_metric
|
||
|
|
||
|
# build model
|
||
|
# model = MobileNetV3_large_x0_5(class_dim=100)
|
||
|
model = build_model(config)
|
||
|
|
||
|
# build_optimizer
|
||
|
optimizer, lr_scheduler = create_optimizer(
|
||
|
config, parameter_list=model.parameters())
|
||
|
|
||
|
# load model
|
||
|
pre_best_model_dict = load_model(config, model, optimizer)
|
||
|
if len(pre_best_model_dict) > 0:
|
||
|
pre_str = 'The metric of loaded metric as follows {}'.format(', '.join(
|
||
|
['{}: {}'.format(k, v) for k, v in pre_best_model_dict.items()]))
|
||
|
logger.info(pre_str)
|
||
|
|
||
|
# about slim prune and quant
|
||
|
if "quant_train" in config and config['quant_train'] is True:
|
||
|
quanter = QAT(config=quant_config, act_preprocess=PACT)
|
||
|
quanter.quantize(model)
|
||
|
elif "prune_train" in config and config['prune_train'] is True:
|
||
|
model = prune_model(model, [1, 3, 32, 32], 0.1)
|
||
|
else:
|
||
|
pass
|
||
|
|
||
|
# distribution
|
||
|
model.train()
|
||
|
model = paddle.DataParallel(model)
|
||
|
# build loss function
|
||
|
loss_func = build_loss(config)
|
||
|
|
||
|
data_num = len(train_loader)
|
||
|
|
||
|
best_acc = {}
|
||
|
for epoch in range(EPOCH):
|
||
|
st = time.time()
|
||
|
for idx, data in enumerate(train_loader):
|
||
|
img_batch, label = data
|
||
|
img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
|
||
|
label = paddle.unsqueeze(label, -1)
|
||
|
|
||
|
if scaler is not None:
|
||
|
with paddle.amp.auto_cast():
|
||
|
outs = model(img_batch)
|
||
|
else:
|
||
|
outs = model(img_batch)
|
||
|
|
||
|
# cal metric
|
||
|
acc = metric_func(outs, label)
|
||
|
|
||
|
# cal loss
|
||
|
avg_loss = loss_func(outs, label)
|
||
|
|
||
|
if scaler is None:
|
||
|
# backward
|
||
|
avg_loss.backward()
|
||
|
optimizer.step()
|
||
|
optimizer.clear_grad()
|
||
|
else:
|
||
|
scaled_avg_loss = scaler.scale(avg_loss)
|
||
|
scaled_avg_loss.backward()
|
||
|
scaler.minimize(optimizer, scaled_avg_loss)
|
||
|
|
||
|
if not isinstance(lr_scheduler, float):
|
||
|
lr_scheduler.step()
|
||
|
|
||
|
if idx % 10 == 0:
|
||
|
et = time.time()
|
||
|
strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
|
||
|
strs += f"loss: {avg_loss.numpy()[0]}"
|
||
|
strs += f", acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}"
|
||
|
strs += f", batch_time: {round(et-st, 4)} s"
|
||
|
logger.info(strs)
|
||
|
st = time.time()
|
||
|
|
||
|
if epoch % 10 == 0:
|
||
|
acc = eval(config, model)
|
||
|
if len(best_acc) < 1 or acc['top5'].numpy()[0] > best_acc['top5']:
|
||
|
best_acc = acc
|
||
|
best_acc['epoch'] = epoch
|
||
|
is_best = True
|
||
|
else:
|
||
|
is_best = False
|
||
|
logger.info(
|
||
|
f"The best acc: acc_topk1: {best_acc['top1'].numpy()[0]}, acc_top5: {best_acc['top5'].numpy()[0]}, best_epoch: {best_acc['epoch']}"
|
||
|
)
|
||
|
save_model(
|
||
|
model,
|
||
|
optimizer,
|
||
|
config['save_model_dir'],
|
||
|
logger,
|
||
|
is_best,
|
||
|
prefix="cls")
|
||
|
|
||
|
|
||
|
def train_distill(config, scaler=None):
|
||
|
EPOCH = config['epoch']
|
||
|
topk = config['topk']
|
||
|
|
||
|
batch_size = config['TRAIN']['batch_size']
|
||
|
num_workers = config['TRAIN']['num_workers']
|
||
|
train_loader = build_dataloader(
|
||
|
'train', batch_size=batch_size, num_workers=num_workers)
|
||
|
|
||
|
# build metric
|
||
|
metric_func = create_metric
|
||
|
|
||
|
# model = distillmv3_large_x0_5(class_dim=100)
|
||
|
model = build_model(config)
|
||
|
|
||
|
# pact quant train
|
||
|
if "quant_train" in config and config['quant_train'] is True:
|
||
|
quanter = QAT(config=quant_config, act_preprocess=PACT)
|
||
|
quanter.quantize(model)
|
||
|
elif "prune_train" in config and config['prune_train'] is True:
|
||
|
model = prune_model(model, [1, 3, 32, 32], 0.1)
|
||
|
else:
|
||
|
pass
|
||
|
|
||
|
# build_optimizer
|
||
|
optimizer, lr_scheduler = create_optimizer(
|
||
|
config, parameter_list=model.parameters())
|
||
|
|
||
|
# load model
|
||
|
pre_best_model_dict = load_model(config, model, optimizer)
|
||
|
if len(pre_best_model_dict) > 0:
|
||
|
pre_str = 'The metric of loaded metric as follows {}'.format(', '.join(
|
||
|
['{}: {}'.format(k, v) for k, v in pre_best_model_dict.items()]))
|
||
|
logger.info(pre_str)
|
||
|
|
||
|
model.train()
|
||
|
model = paddle.DataParallel(model)
|
||
|
|
||
|
# build loss function
|
||
|
loss_func_distill = LossDistill(model_name_list=['student', 'student1'])
|
||
|
loss_func_dml = DMLLoss(model_name_pairs=['student', 'student1'])
|
||
|
loss_func_js = KLJSLoss(mode='js')
|
||
|
|
||
|
data_num = len(train_loader)
|
||
|
|
||
|
best_acc = {}
|
||
|
for epoch in range(EPOCH):
|
||
|
st = time.time()
|
||
|
for idx, data in enumerate(train_loader):
|
||
|
img_batch, label = data
|
||
|
img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
|
||
|
label = paddle.unsqueeze(label, -1)
|
||
|
if scaler is not None:
|
||
|
with paddle.amp.auto_cast():
|
||
|
outs = model(img_batch)
|
||
|
else:
|
||
|
outs = model(img_batch)
|
||
|
|
||
|
# cal metric
|
||
|
acc = metric_func(outs['student'], label)
|
||
|
|
||
|
# cal loss
|
||
|
avg_loss = loss_func_distill(outs, label)['student'] + \
|
||
|
loss_func_distill(outs, label)['student1'] + \
|
||
|
loss_func_dml(outs, label)['student_student1']
|
||
|
|
||
|
# backward
|
||
|
if scaler is None:
|
||
|
avg_loss.backward()
|
||
|
optimizer.step()
|
||
|
optimizer.clear_grad()
|
||
|
else:
|
||
|
scaled_avg_loss = scaler.scale(avg_loss)
|
||
|
scaled_avg_loss.backward()
|
||
|
scaler.minimize(optimizer, scaled_avg_loss)
|
||
|
|
||
|
if not isinstance(lr_scheduler, float):
|
||
|
lr_scheduler.step()
|
||
|
|
||
|
if idx % 10 == 0:
|
||
|
et = time.time()
|
||
|
strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
|
||
|
strs += f"loss: {avg_loss.numpy()[0]}"
|
||
|
strs += f", acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}"
|
||
|
strs += f", batch_time: {round(et-st, 4)} s"
|
||
|
logger.info(strs)
|
||
|
st = time.time()
|
||
|
|
||
|
if epoch % 10 == 0:
|
||
|
acc = eval(config, model._layers.student)
|
||
|
if len(best_acc) < 1 or acc['top5'].numpy()[0] > best_acc['top5']:
|
||
|
best_acc = acc
|
||
|
best_acc['epoch'] = epoch
|
||
|
is_best = True
|
||
|
else:
|
||
|
is_best = False
|
||
|
logger.info(
|
||
|
f"The best acc: acc_topk1: {best_acc['top1'].numpy()[0]}, acc_top5: {best_acc['top5'].numpy()[0]}, best_epoch: {best_acc['epoch']}"
|
||
|
)
|
||
|
|
||
|
save_model(
|
||
|
model,
|
||
|
optimizer,
|
||
|
config['save_model_dir'],
|
||
|
logger,
|
||
|
is_best,
|
||
|
prefix="cls_distill")
|
||
|
|
||
|
|
||
|
def train_distill_multiopt(config, scaler=None):
|
||
|
EPOCH = config['epoch']
|
||
|
topk = config['topk']
|
||
|
|
||
|
batch_size = config['TRAIN']['batch_size']
|
||
|
num_workers = config['TRAIN']['num_workers']
|
||
|
train_loader = build_dataloader(
|
||
|
'train', batch_size=batch_size, num_workers=num_workers)
|
||
|
|
||
|
# build metric
|
||
|
metric_func = create_metric
|
||
|
|
||
|
# model = distillmv3_large_x0_5(class_dim=100)
|
||
|
model = build_model(config)
|
||
|
|
||
|
# build_optimizer
|
||
|
optimizer, lr_scheduler = create_optimizer(
|
||
|
config, parameter_list=model.student.parameters())
|
||
|
optimizer1, lr_scheduler1 = create_optimizer(
|
||
|
config, parameter_list=model.student1.parameters())
|
||
|
|
||
|
# load model
|
||
|
pre_best_model_dict = load_model(config, model, optimizer)
|
||
|
if len(pre_best_model_dict) > 0:
|
||
|
pre_str = 'The metric of loaded metric as follows {}'.format(', '.join(
|
||
|
['{}: {}'.format(k, v) for k, v in pre_best_model_dict.items()]))
|
||
|
logger.info(pre_str)
|
||
|
|
||
|
# quant train
|
||
|
if "quant_train" in config and config['quant_train'] is True:
|
||
|
quanter = QAT(config=quant_config, act_preprocess=PACT)
|
||
|
quanter.quantize(model)
|
||
|
elif "prune_train" in config and config['prune_train'] is True:
|
||
|
model = prune_model(model, [1, 3, 32, 32], 0.1)
|
||
|
else:
|
||
|
pass
|
||
|
|
||
|
model.train()
|
||
|
|
||
|
model = paddle.DataParallel(model)
|
||
|
|
||
|
# build loss function
|
||
|
loss_func_distill = LossDistill(model_name_list=['student', 'student1'])
|
||
|
loss_func_dml = DMLLoss(model_name_pairs=['student', 'student1'])
|
||
|
loss_func_js = KLJSLoss(mode='js')
|
||
|
|
||
|
data_num = len(train_loader)
|
||
|
best_acc = {}
|
||
|
for epoch in range(EPOCH):
|
||
|
st = time.time()
|
||
|
for idx, data in enumerate(train_loader):
|
||
|
img_batch, label = data
|
||
|
img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
|
||
|
label = paddle.unsqueeze(label, -1)
|
||
|
|
||
|
if scaler is not None:
|
||
|
with paddle.amp.auto_cast():
|
||
|
outs = model(img_batch)
|
||
|
else:
|
||
|
outs = model(img_batch)
|
||
|
|
||
|
# cal metric
|
||
|
acc = metric_func(outs['student'], label)
|
||
|
|
||
|
# cal loss
|
||
|
avg_loss = loss_func_distill(outs,
|
||
|
label)['student'] + loss_func_dml(
|
||
|
outs, label)['student_student1']
|
||
|
avg_loss1 = loss_func_distill(outs,
|
||
|
label)['student1'] + loss_func_dml(
|
||
|
outs, label)['student_student1']
|
||
|
|
||
|
if scaler is None:
|
||
|
# backward
|
||
|
avg_loss.backward(retain_graph=True)
|
||
|
optimizer.step()
|
||
|
optimizer.clear_grad()
|
||
|
|
||
|
avg_loss1.backward()
|
||
|
optimizer1.step()
|
||
|
optimizer1.clear_grad()
|
||
|
else:
|
||
|
scaled_avg_loss = scaler.scale(avg_loss)
|
||
|
scaled_avg_loss.backward()
|
||
|
scaler.minimize(optimizer, scaled_avg_loss)
|
||
|
|
||
|
scaled_avg_loss = scaler.scale(avg_loss1)
|
||
|
scaled_avg_loss.backward()
|
||
|
scaler.minimize(optimizer1, scaled_avg_loss)
|
||
|
|
||
|
if not isinstance(lr_scheduler, float):
|
||
|
lr_scheduler.step()
|
||
|
if not isinstance(lr_scheduler1, float):
|
||
|
lr_scheduler1.step()
|
||
|
|
||
|
if idx % 10 == 0:
|
||
|
et = time.time()
|
||
|
strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
|
||
|
strs += f"loss: {avg_loss.numpy()[0]}, loss1: {avg_loss1.numpy()[0]}"
|
||
|
strs += f", acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}"
|
||
|
strs += f", batch_time: {round(et-st, 4)} s"
|
||
|
logger.info(strs)
|
||
|
st = time.time()
|
||
|
|
||
|
if epoch % 10 == 0:
|
||
|
acc = eval(config, model._layers.student)
|
||
|
if len(best_acc) < 1 or acc['top5'].numpy()[0] > best_acc['top5']:
|
||
|
best_acc = acc
|
||
|
best_acc['epoch'] = epoch
|
||
|
is_best = True
|
||
|
else:
|
||
|
is_best = False
|
||
|
logger.info(
|
||
|
f"The best acc: acc_topk1: {best_acc['top1'].numpy()[0]}, acc_top5: {best_acc['top5'].numpy()[0]}, best_epoch: {best_acc['epoch']}"
|
||
|
)
|
||
|
save_model(
|
||
|
model, [optimizer, optimizer1],
|
||
|
config['save_model_dir'],
|
||
|
logger,
|
||
|
is_best,
|
||
|
prefix="cls_distill_multiopt")
|
||
|
|
||
|
|
||
|
def eval(config, model):
|
||
|
batch_size = config['VALID']['batch_size']
|
||
|
num_workers = config['VALID']['num_workers']
|
||
|
valid_loader = build_dataloader(
|
||
|
'test', batch_size=batch_size, num_workers=num_workers)
|
||
|
|
||
|
# build metric
|
||
|
metric_func = create_metric
|
||
|
|
||
|
outs = []
|
||
|
labels = []
|
||
|
for idx, data in enumerate(valid_loader):
|
||
|
img_batch, label = data
|
||
|
img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
|
||
|
label = paddle.unsqueeze(label, -1)
|
||
|
out = model(img_batch)
|
||
|
|
||
|
outs.append(out)
|
||
|
labels.append(label)
|
||
|
|
||
|
outs = paddle.concat(outs, axis=0)
|
||
|
labels = paddle.concat(labels, axis=0)
|
||
|
acc = metric_func(outs, labels)
|
||
|
|
||
|
strs = f"The metric are as follows: acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}"
|
||
|
logger.info(strs)
|
||
|
return acc
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
|
||
|
config, logger = preprocess(is_train=False)
|
||
|
|
||
|
# AMP scaler
|
||
|
scaler = amp_scaler(config)
|
||
|
|
||
|
model_type = config['model_type']
|
||
|
|
||
|
if model_type == "cls":
|
||
|
train(config)
|
||
|
elif model_type == "cls_distill":
|
||
|
train_distill(config)
|
||
|
elif model_type == "cls_distill_multiopt":
|
||
|
train_distill_multiopt(config)
|
||
|
else:
|
||
|
raise ValueError("model_type should be one of ['']")
|