fix slim bugs
parent
eafcc86457
commit
3703f63f04
|
@ -1,7 +1,7 @@
|
|||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
# pretrained_model: ./output/ResNet50_vd/epoch_29
|
||||
pretrained_model: ./output/ResNet50_vd/best_model
|
||||
pretrained_model: null
|
||||
output_dir: ./output/
|
||||
device: gpu
|
||||
|
@ -15,19 +15,16 @@ Global:
|
|||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: ./inference
|
||||
|
||||
# for paddleslim
|
||||
# for quantalization or prune model
|
||||
Slim:
|
||||
# for quantalization
|
||||
# quant:
|
||||
# name: pact
|
||||
## for prune
|
||||
prune:
|
||||
name: fpgm
|
||||
pruned_ratio: 0.3
|
||||
name: fpgm
|
||||
pruned_ratio: 0.3
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: MobileNetV3_large_x1_0
|
||||
name: ResNet50_vd
|
||||
class_num: 1000
|
||||
|
||||
# loss function config for traing/eval process
|
||||
|
@ -58,7 +55,7 @@ DataLoader:
|
|||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: ./dataset/ILSVRC2012/
|
||||
cls_label_path: ./dataset/ILSVRC2012/train.txt
|
||||
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
|
@ -89,7 +86,7 @@ DataLoader:
|
|||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: ./dataset/ILSVRC2012/
|
||||
cls_label_path: ./dataset/ILSVRC2012/val.txt
|
||||
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
|
@ -0,0 +1,137 @@
|
|||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: ./output/ResNet50_vd/best_model
|
||||
pretrained_model: null
|
||||
output_dir: ./output/
|
||||
device: gpu
|
||||
save_interval: 1
|
||||
eval_during_train: True
|
||||
eval_interval: 1
|
||||
epochs: 30
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: ./inference
|
||||
|
||||
# for quantalization or prune model
|
||||
Slim:
|
||||
## for quantalization
|
||||
quant:
|
||||
name: pact
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: ResNet50_vd
|
||||
class_num: 1000
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- MixCELoss:
|
||||
weight: 1.0
|
||||
epsilon: 0.1
|
||||
Eval:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.1
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
coeff: 0.00007
|
||||
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: ./dataset/ILSVRC2012/
|
||||
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- RandCropImage:
|
||||
size: 224
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
batch_transform_ops:
|
||||
- MixupOperator:
|
||||
alpha: 0.2
|
||||
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: ./dataset/ILSVRC2012/
|
||||
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Infer:
|
||||
infer_imgs: docs/images/whl/demo.jpg
|
||||
batch_size: 10
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
PostProcess:
|
||||
name: Topk
|
||||
topk: 5
|
||||
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
|
||||
|
||||
Metric:
|
||||
Train:
|
||||
Eval:
|
||||
- TopkAcc:
|
||||
topk: [1, 5]
|
|
@ -0,0 +1,106 @@
|
|||
|
||||
## 介绍
|
||||
复杂的模型有利于提高模型的性能,但也导致模型中存在一定冗余,模型量化将全精度缩减到定点数减少这种冗余,达到减少模型计算复杂度,提高模型推理性能的目的。
|
||||
模型量化可以在基本不损失模型的精度的情况下,将FP32精度的模型参数转换为Int8精度,减小模型参数大小并加速计算,使用量化后的模型在移动端等部署时更具备速度优势。
|
||||
|
||||
本教程将介绍如何使用飞桨模型压缩库PaddleSlim做PaddleClas模型的压缩。
|
||||
[PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim) 集成了模型剪枝、量化(包括量化训练和离线量化)、蒸馏和神经网络搜索等多种业界常用且领先的模型压缩功能,如果您感兴趣,可以关注并了解。
|
||||
|
||||
在开始本教程之前,建议先了解[PaddleClas模型的训练方法](../../../docs/zh_CN/tutorials/quick_start.md)以及[PaddleSlim](https://paddleslim.readthedocs.io/zh_CN/latest/index.html)
|
||||
|
||||
|
||||
## 快速开始
|
||||
量化多适用于轻量模型在移动端的部署,当训练出一个模型后,如果希望进一步的压缩模型大小并加速预测,可使用量化的方法压缩模型。
|
||||
|
||||
模型量化主要包括五个步骤:
|
||||
1. 安装 PaddleSlim
|
||||
2. 准备训练好的模型
|
||||
3. 量化训练
|
||||
4. 导出量化推理模型
|
||||
5. 量化模型预测部署
|
||||
|
||||
### 1. 安装PaddleSlim
|
||||
|
||||
* 可以通过pip install的方式进行安装。
|
||||
|
||||
```bash
|
||||
pip3.7 install paddleslim==2.0.0
|
||||
```
|
||||
|
||||
* 如果获取PaddleSlim的最新特性,可以从源码安装。
|
||||
|
||||
```bash
|
||||
git clone https://github.com/PaddlePaddle/PaddleSlim.git
|
||||
cd Paddleslim
|
||||
python3.7 setup.py install
|
||||
```
|
||||
|
||||
### 2. 准备训练好的模型
|
||||
|
||||
PaddleClas提供了一系列训练好的[模型](../../../docs/zh_CN/models/models_intro.md),如果待量化的模型不在列表中,需要按照[常规训练](../../../docs/zh_CN/tutorials/getting_started.md)方法得到训练好的模型。
|
||||
|
||||
### 3. 量化训练
|
||||
量化训练包括离线量化训练和在线量化训练,在线量化训练效果更好,需加载预训练模型,在定义好量化策略后即可对模型进行量化。
|
||||
|
||||
|
||||
量化训练的代码位于`deploy/slim/quant/quant.py` 中,训练指令如下:
|
||||
|
||||
* CPU/单机单卡启动
|
||||
|
||||
```bash
|
||||
python3.7 deploy/slim/quant/quant.py \
|
||||
-c configs/MobileNetV3/MobileNetV3_large_x1_0.yaml \
|
||||
-o pretrained_model="./MobileNetV3_large_x1_0_pretrained"
|
||||
```
|
||||
|
||||
* 单机单卡/单机多卡/多机多卡启动
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||
python3.7 -m paddle.distributed.launch \
|
||||
--gpus="0,1,2,3,4,5,6,7" \
|
||||
deploy/slim/quant/quant.py \
|
||||
-c configs/MobileNetV3/MobileNetV3_large_x1_0.yaml \
|
||||
-o pretrained_model="./MobileNetV3_large_x1_0_pretrained"
|
||||
```
|
||||
|
||||
|
||||
* 下面是量化`MobileNetV3_large_x1_0`模型的训练示例脚本。
|
||||
|
||||
```bash
|
||||
# 下载预训练模型
|
||||
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x1_0_pretrained.pdparams
|
||||
# 启动训练,这里如果因为显存限制,batch size无法设置过大,可以将batch size和learning rate同比例缩小。
|
||||
python3.7 -m paddle.distributed.launch \
|
||||
--gpus="0,1,2,3,4,5,6,7" \
|
||||
deploy/slim/quant/quant.py \
|
||||
-c configs/MobileNetV3/MobileNetV3_large_x1_0.yaml \
|
||||
-o pretrained_model="./MobileNetV3_large_x1_0_pretrained"
|
||||
-o LEARNING_RATE.params.lr=0.13 \
|
||||
-o epochs=100
|
||||
```
|
||||
|
||||
### 4. 导出模型
|
||||
|
||||
在得到量化训练保存的模型后,可以将其导出为inference model,用于预测部署:
|
||||
|
||||
```bash
|
||||
python3.7 deploy/slim/quant/export_model.py \
|
||||
-m MobileNetV3_large_x1_0 \
|
||||
-p output/MobileNetV3_large_x1_0/best_model/ppcls \
|
||||
-o ./MobileNetV3_large_x1_0_infer/ \
|
||||
--img_size=224 \
|
||||
--class_dim=1000
|
||||
```
|
||||
|
||||
|
||||
### 5. 量化模型部署
|
||||
|
||||
上述步骤导出的量化模型,参数精度仍然是FP32,但是参数的数值范围是int8,导出的模型可以通过PaddleLite的opt模型转换工具完成模型转换。
|
||||
量化模型部署的可参考 [移动端模型部署](../../lite/readme.md)
|
||||
|
||||
|
||||
## 量化训练超参数建议
|
||||
|
||||
* 量化训练时,建议加载常规训练得到的预训练模型,加速量化训练收敛。
|
||||
* 量化训练时,建议初始学习率修改为常规训练的`1/20~1/10`,同时将训练epoch数修改为常规训练的`1/5~1/2`,学习率策略方面,加上Warmup,其他配置信息不建议修改。
|
|
@ -0,0 +1,112 @@
|
|||
|
||||
## Introduction
|
||||
|
||||
Generally, a more complex model would achive better performance in the task, but it also leads to some redundancy in the model.
|
||||
Quantization is a technique that reduces this redundancy by reducing the full precision data to a fixed number,
|
||||
so as to reduce model calculation complexity and improve model inference performance.
|
||||
|
||||
This example uses PaddleSlim provided [APIs of Quantization](https://paddlepaddle.github.io/PaddleSlim/api/quantization_api/) to compress the PaddleClas models.
|
||||
|
||||
It is recommended that you could understand following pages before reading this example:
|
||||
- [The training strategy of PaddleClas models](../../../docs/en/tutorials/quick_start_en.md)
|
||||
- [PaddleSlim Document](https://paddlepaddle.github.io/PaddleSlim/api/quantization_api/)
|
||||
|
||||
## Quick Start
|
||||
Quantization is mostly suitable for the deployment of lightweight models on mobile terminals.
|
||||
After training, if you want to further compress the model size and accelerate the prediction, you can use quantization methods to compress the model according to the following steps.
|
||||
|
||||
1. Install PaddleSlim
|
||||
2. Prepare trained model
|
||||
3. Quantization-Aware Training
|
||||
4. Export inference model
|
||||
5. Deploy quantization inference model
|
||||
|
||||
|
||||
### 1. Install PaddleSlim
|
||||
|
||||
* Install by pip.
|
||||
|
||||
```bash
|
||||
pip3.7 install paddleslim==2.0.0
|
||||
```
|
||||
|
||||
* Install from source code to get the lastest features.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/PaddlePaddle/PaddleSlim.git
|
||||
cd Paddleslim
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
|
||||
### 2. Download Pretrain Model
|
||||
PaddleClas provides a series of trained [models](../../../docs/en/models/models_intro_en.md).
|
||||
If the model to be quantified is not in the list, you need to follow the [Regular Training](../../../docs/en/tutorials/getting_started_en.md) method to get the trained model.
|
||||
|
||||
|
||||
### 3. Quant-Aware Training
|
||||
Quantization training includes offline quantization training and online quantization training.
|
||||
Online quantization training is more effective. It is necessary to load the pre-trained model.
|
||||
After the quantization strategy is defined, the model can be quantified.
|
||||
|
||||
The code for quantization training is located in `deploy/slim/quant/quant.py`. The training command is as follow:
|
||||
|
||||
* CPU/Single GPU training
|
||||
|
||||
```bash
|
||||
python3.7 deploy/slim/quant/quant.py \
|
||||
-c configs/MobileNetV3/MobileNetV3_large_x1_0.yaml \
|
||||
-o pretrained_model="./MobileNetV3_large_x1_0_pretrained"
|
||||
```
|
||||
|
||||
* Distributed training
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||
python3.7 -m paddle.distributed.launch \
|
||||
--gpus="0,1,2,3,4,5,6,7" \
|
||||
deploy/slim/quant/quant.py \
|
||||
-c configs/MobileNetV3/MobileNetV3_large_x1_0.yaml \
|
||||
-o pretrained_model="./MobileNetV3_large_x1_0_pretrained"
|
||||
```
|
||||
|
||||
* The command of quantizing `MobileNetV3_large_x1_0` model is as follow:
|
||||
|
||||
```bash
|
||||
# download pre-trained model
|
||||
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x1_0_pretrained.pdparams
|
||||
|
||||
# run training
|
||||
python3.7 -m paddle.distributed.launch \
|
||||
--gpus="0,1,2,3,4,5,6,7" \
|
||||
deploy/slim/quant/quant.py \
|
||||
-c configs/MobileNetV3/MobileNetV3_large_x1_0.yaml \
|
||||
-o pretrained_model="./MobileNetV3_large_x1_0_pretrained"
|
||||
-o LEARNING_RATE.params.lr=0.13 \
|
||||
-o epochs=100
|
||||
```
|
||||
|
||||
|
||||
### 4. Export inference model
|
||||
|
||||
After getting the model quantization aware trained, we can export it as inference model for predictive deployment:
|
||||
|
||||
```bash
|
||||
python3.7 deploy/slim/quant/export_model.py \
|
||||
-m MobileNetV3_large_x1_0 \
|
||||
-p output/MobileNetV3_large_x1_0/best_model/ppcls \
|
||||
-o ./MobileNetV3_large_x1_0_infer/ \
|
||||
--img_size=224 \
|
||||
--class_dim=1000
|
||||
```
|
||||
|
||||
### 5. Deploy
|
||||
The type of quantized model's parameters derived from the above steps is still FP32, but the numerical range of the parameters is int8.
|
||||
The derived model can be converted through the `opt tool` of PaddleLite.
|
||||
|
||||
For quantitative model deployment, please refer to [Mobile terminal model deployment](../../lite/readme_en.md)
|
||||
|
||||
## Notes:
|
||||
|
||||
* In quantitative training, it is suggested to load the pre-trained model obtained from conventional training to accelerate the convergence of quantitative training.
|
||||
* In quantitative training, it is suggested that the initial learning rate should be changed to `1 / 20 ~ 1 / 10` of the conventional training, and the training epoch number should be changed to `1 / 5 ~ 1 / 2` of the conventional training. In terms of learning rate strategy, it's better to train with warmup, other configuration information is not recommended to be changed.
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddleslim
|
||||
from paddle.jit import to_static
|
||||
from paddleslim.analysis import dygraph_flops as flops
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
|
||||
from paddleslim.dygraph.quant import QAT
|
||||
|
||||
from ppcls.data import build_dataloader
|
||||
from ppcls.utils import config as conf
|
||||
from ppcls.utils.logger import init_logger
|
||||
|
||||
|
||||
def main():
|
||||
args = conf.parse_args()
|
||||
config = conf.get_config(args.config, overrides=args.override, show=False)
|
||||
|
||||
assert os.path.exists(
|
||||
os.path.join(config["Global"]["save_inference_dir"],
|
||||
'inference.pdmodel')) and os.path.exists(
|
||||
os.path.join(config["Global"]["save_inference_dir"],
|
||||
'inference.pdiparams'))
|
||||
config["DataLoader"]["Train"]["sampler"]["batch_size"] = 1
|
||||
config["DataLoader"]["Train"]["loader"]["num_workers"] = 0
|
||||
init_logger()
|
||||
device = paddle.set_device("cpu")
|
||||
train_dataloader = build_dataloader(config["DataLoader"], "Train", device,
|
||||
False)
|
||||
|
||||
def sample_generator(loader):
|
||||
def __reader__():
|
||||
for indx, data in enumerate(loader):
|
||||
images = np.array(data[0])
|
||||
yield images
|
||||
|
||||
return __reader__
|
||||
|
||||
paddle.enable_static()
|
||||
place = paddle.CPUPlace()
|
||||
exe = paddle.static.Executor(place)
|
||||
paddleslim.quant.quant_post_static(
|
||||
executor=exe,
|
||||
model_dir=config["Global"]["save_inference_dir"],
|
||||
model_filename='inference.pdmodel',
|
||||
params_filename='inference.pdiparams',
|
||||
quantize_model_path=os.path.join(
|
||||
config["Global"]["save_inference_dir"], "quant_post_static_model"),
|
||||
sample_generator=sample_generator(train_dataloader),
|
||||
batch_nums=5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -18,9 +18,11 @@ import os
|
|||
import sys
|
||||
|
||||
import paddle
|
||||
import numpy as np
|
||||
import paddleslim
|
||||
from paddle.jit import to_static
|
||||
from paddleslim.analysis import dygraph_flops as flops
|
||||
import argparse
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
|
||||
|
@ -29,6 +31,7 @@ from paddleslim.dygraph.quant import QAT
|
|||
from ppcls.engine.trainer import Trainer
|
||||
from ppcls.utils import config, logger
|
||||
from ppcls.utils.save_load import load_dygraph_pretrain
|
||||
from ppcls.data import build_dataloader
|
||||
|
||||
quant_config = {
|
||||
# weight preprocess type, default is None and no preprocessing is performed.
|
||||
|
@ -79,7 +82,7 @@ class Trainer_slim(Trainer):
|
|||
else:
|
||||
logger.info("FLOPs before pruning: {}GFLOPs".format(
|
||||
flops(self.model, [1] + self.config["Global"][
|
||||
"image_shape"]) / 1000000))
|
||||
"image_shape"]) / 1e9))
|
||||
self.model.eval()
|
||||
|
||||
if prune_config["name"].lower() == "fpgm":
|
||||
|
@ -96,11 +99,6 @@ class Trainer_slim(Trainer):
|
|||
if self.quanter is None and self.pruner is None:
|
||||
logger.info("Training without slim")
|
||||
|
||||
def train(self):
|
||||
super().train()
|
||||
if self.config["Global"].get("save_inference_dir", None):
|
||||
self.export_inference_model()
|
||||
|
||||
def export_inference_model(self):
|
||||
if os.path.exists(
|
||||
os.path.join(self.output_dir, self.config["Arch"]["name"],
|
||||
|
@ -153,7 +151,7 @@ class Trainer_slim(Trainer):
|
|||
|
||||
logger.info("FLOPs after pruning: {}GFLOPs; pruned ratio: {}".format(
|
||||
flops(self.model, [1] + self.config["Global"]["image_shape"]) /
|
||||
1000000, plan.pruned_flops))
|
||||
1e9, plan.pruned_flops))
|
||||
|
||||
for param in self.model.parameters():
|
||||
if "conv2d" in param.name:
|
||||
|
@ -162,9 +160,46 @@ class Trainer_slim(Trainer):
|
|||
self.model.train()
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
"generic-image-rec slim script, for train, eval and export inference model"
|
||||
)
|
||||
parser.add_argument(
|
||||
'-c',
|
||||
'--config',
|
||||
type=str,
|
||||
default='configs/config.yaml',
|
||||
help='config file path')
|
||||
parser.add_argument(
|
||||
'-o',
|
||||
'--override',
|
||||
action='append',
|
||||
default=[],
|
||||
help='config options to be overridden')
|
||||
parser.add_argument(
|
||||
'-m',
|
||||
'--mode',
|
||||
type=str,
|
||||
default='train',
|
||||
choices=['train', 'eval', 'infer', 'export'],
|
||||
help='the different function')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = config.parse_args()
|
||||
args = parse_args()
|
||||
config = config.get_config(
|
||||
args.config, overrides=args.override, show=False)
|
||||
trainer = Trainer_slim(config, mode="train")
|
||||
trainer.train()
|
||||
if args.mode == 'train':
|
||||
trainer = Trainer_slim(config, mode="train")
|
||||
trainer.train()
|
||||
elif args.mode == 'eval':
|
||||
trainer = Trainer_slim(config, mode="eval")
|
||||
trainer.eval()
|
||||
elif args.mode == 'infer':
|
||||
trainer = Trainer_slim(config, mode="infer")
|
||||
trainer.infer()
|
||||
else:
|
||||
trainer = Trainer_slim(config, mode="train")
|
||||
trainer.export_inference_model()
|
Loading…
Reference in New Issue