commit
1ef7dbc93b
|
@ -24,6 +24,7 @@
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/PaddlePaddle/PaddleSlim.git
|
git clone https://github.com/PaddlePaddle/PaddleSlim.git
|
||||||
cd PaddleSlim
|
cd PaddleSlim
|
||||||
|
git checkout develop
|
||||||
python3 setup.py install
|
python3 setup.py install
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -48,14 +49,14 @@ python3 setup.py install
|
||||||
|
|
||||||
进入PaddleOCR根目录,通过以下命令对模型进行敏感度分析训练:
|
进入PaddleOCR根目录,通过以下命令对模型进行敏感度分析训练:
|
||||||
```bash
|
```bash
|
||||||
python3.7 deploy/slim/prune/sensitivity_anal.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model="your trained model"
|
python3.7 deploy/slim/prune/sensitivity_anal.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model="your trained model" Global.save_model_dir=./output/prune_model/
|
||||||
```
|
```
|
||||||
|
|
||||||
### 4. 导出模型、预测部署
|
### 4. 导出模型、预测部署
|
||||||
|
|
||||||
在得到裁剪训练保存的模型后,我们可以将其导出为inference_model:
|
在得到裁剪训练保存的模型后,我们可以将其导出为inference_model:
|
||||||
```bash
|
```bash
|
||||||
pytho3.7 deploy/slim/prune/export_prune_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model=./output/det_db/best_accuracy Global.save_inference_dir=inference_model
|
pytho3.7 deploy/slim/prune/export_prune_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model=./output/det_db/best_accuracy Global.save_inference_dir=./prune/prune_inference_model
|
||||||
```
|
```
|
||||||
|
|
||||||
inference model的预测和部署参考:
|
inference model的预测和部署参考:
|
||||||
|
|
|
@ -54,7 +54,7 @@ Enter the PaddleOCR root directory,perform sensitivity analysis on the model w
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
||||||
python3.7 deploy/slim/prune/sensitivity_anal.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model="your trained model"
|
python3.7 deploy/slim/prune/sensitivity_anal.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model="your trained model" Global.save_model_dir=./output/prune_model/
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -63,7 +63,7 @@ python3.7 deploy/slim/prune/sensitivity_anal.py -c configs/det/ch_ppocr_v2.0/ch_
|
||||||
|
|
||||||
We can export the pruned model as inference_model for deployment:
|
We can export the pruned model as inference_model for deployment:
|
||||||
```bash
|
```bash
|
||||||
python deploy/slim/prune/export_prune_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model=./output/det_db/best_accuracy Global.save_inference_dir=inference_model
|
python deploy/slim/prune/export_prune_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model=./output/det_db/best_accuracy Global.save_inference_dir=./prune/prune_inference_model
|
||||||
```
|
```
|
||||||
|
|
||||||
Reference for prediction and deployment of inference model:
|
Reference for prediction and deployment of inference model:
|
||||||
|
|
|
@ -112,10 +112,6 @@ def main(config, device, logger, vdl_writer):
|
||||||
config['Architecture']["Head"]['out_channels'] = char_num
|
config['Architecture']["Head"]['out_channels'] = char_num
|
||||||
model = build_model(config['Architecture'])
|
model = build_model(config['Architecture'])
|
||||||
|
|
||||||
# prepare to quant
|
|
||||||
quanter = QAT(config=quant_config, act_preprocess=PACT)
|
|
||||||
quanter.quantize(model)
|
|
||||||
|
|
||||||
if config['Global']['distributed']:
|
if config['Global']['distributed']:
|
||||||
model = paddle.DataParallel(model)
|
model = paddle.DataParallel(model)
|
||||||
|
|
||||||
|
@ -136,31 +132,15 @@ def main(config, device, logger, vdl_writer):
|
||||||
|
|
||||||
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
|
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
|
||||||
format(len(train_dataloader), len(valid_dataloader)))
|
format(len(train_dataloader), len(valid_dataloader)))
|
||||||
|
quanter = QAT(config=quant_config, act_preprocess=PACT)
|
||||||
|
quanter.quantize(model)
|
||||||
|
|
||||||
# start train
|
# start train
|
||||||
program.train(config, train_dataloader, valid_dataloader, device, model,
|
program.train(config, train_dataloader, valid_dataloader, device, model,
|
||||||
loss_class, optimizer, lr_scheduler, post_process_class,
|
loss_class, optimizer, lr_scheduler, post_process_class,
|
||||||
eval_class, pre_best_model_dict, logger, vdl_writer)
|
eval_class, pre_best_model_dict, logger, vdl_writer)
|
||||||
|
|
||||||
|
|
||||||
def test_reader(config, device, logger):
|
|
||||||
loader = build_dataloader(config, 'Train', device, logger)
|
|
||||||
import time
|
|
||||||
starttime = time.time()
|
|
||||||
count = 0
|
|
||||||
try:
|
|
||||||
for data in loader():
|
|
||||||
count += 1
|
|
||||||
if count % 1 == 0:
|
|
||||||
batch_time = time.time() - starttime
|
|
||||||
starttime = time.time()
|
|
||||||
logger.info("reader: {}, {}, {}".format(
|
|
||||||
count, len(data[0]), batch_time))
|
|
||||||
except Exception as e:
|
|
||||||
logger.info(e)
|
|
||||||
logger.info("finish reader: {}, Success!".format(count))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
config, device, logger, vdl_writer = program.preprocess(is_train=True)
|
config, device, logger, vdl_writer = program.preprocess(is_train=True)
|
||||||
main(config, device, logger, vdl_writer)
|
main(config, device, logger, vdl_writer)
|
||||||
# test_reader(config, device, logger)
|
|
||||||
|
|
Loading…
Reference in New Issue