add slim support

pull/1093/head
dongshuilong 2021-07-26 12:31:18 +00:00
parent 4452565bdf
commit eafcc86457
2 changed files with 83 additions and 20 deletions

View File

@ -18,12 +18,12 @@ Global:
# for paddleslim
Slim:
# for quantalization
quant:
name: pact
# quant:
# name: pact
## for prune
#prune:
# name: fpgm
# prune_ratio: 0.3
prune:
name: fpgm
pruned_ratio: 0.3
# model architecture
Arch:
@ -58,7 +58,7 @@ DataLoader:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
cls_label_path: ./dataset/ILSVRC2012/train.txt
transform_ops:
- DecodeImage:
to_rgb: True
@ -89,7 +89,7 @@ DataLoader:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
cls_label_path: ./dataset/ILSVRC2012/val.txt
transform_ops:
- DecodeImage:
to_rgb: True

View File

@ -18,10 +18,14 @@ import os
import sys
import paddle
import paddleslim
from paddle.jit import to_static
from paddleslim.analysis import dygraph_flops as flops
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
from paddleslim.dygraph.quant import QAT
from ppcls.engine.trainer import Trainer
from ppcls.utils import config, logger
from ppcls.utils.save_load import load_dygraph_pretrain
@ -53,9 +57,12 @@ quant_config = {
class Trainer_slim(Trainer):
def __init__(self, config, mode="train"):
super().__init__(config, mode)
# self.pact = self.config["Slim"].get("pact", False)
self.pact = True
if self.pact:
pact = self.config["Slim"].get("quant", False)
self.pact = pact.get("name", False) if pact else pact
if self.pact and str(self.pact.lower()) != 'pact':
raise RuntimeError("The quantization only support 'PACT'!")
if pact:
quant_config["activation_preprocess_type"] = "PACT"
self.quanter = QAT(config=quant_config)
self.quanter.quantize(self.model)
@ -64,6 +71,31 @@ class Trainer_slim(Trainer):
else:
self.quanter = None
prune_config = self.config["Slim"].get("prune", False)
if prune_config:
if prune_config["name"].lower() not in ["fpgm", "l1_norm"]:
raise RuntimeError(
"The prune methods only support 'fpgm' and 'l1_norm'")
else:
logger.info("FLOPs before pruning: {}GFLOPs".format(
flops(self.model, [1] + self.config["Global"][
"image_shape"]) / 1000000))
self.model.eval()
if prune_config["name"].lower() == "fpgm":
self.model.eval()
self.pruner = paddleslim.dygraph.FPGMFilterPruner(
self.model, [1] + self.config["Global"]["image_shape"])
else:
self.pruner = paddleslim.dygraph.L1NormFilterPruner(
self.model, [1] + self.config["Global"]["image_shape"])
self.prune_model()
else:
self.pruner = None
if self.quanter is None and self.pruner is None:
logger.info("Training without slim")
def train(self):
super().train()
if self.config["Global"].get("save_inference_dir", None):
@ -86,17 +118,48 @@ class Trainer_slim(Trainer):
raise RuntimeError(
"The best_model or pretraine_model should exist to generate inference model"
)
save_path = os.path.join(self.config["Global"]["save_inference_dir"],
"inference")
if self.quanter:
self.quanter.save_quantized_model(
self.model,
save_path,
input_spec=[
paddle.static.InputSpec(
shape=[None] + config["Global"]["image_shape"],
dtype='float32')
])
else:
model = to_static(
self.model,
input_spec=[
paddle.static.InputSpec(
shape=[None] + self.config["Global"]["image_shape"],
dtype='float32',
name="image")
])
paddle.jit.save(model, save_path)
assert self.quanter
self.quanter.save_quantized_model(
self.model,
os.path.join(self.config["Global"]["save_inference_dir"],
"inference"),
input_spec=[
paddle.static.InputSpec(
shape=[None] + config["Global"]["image_shape"],
dtype='float32')
])
def prune_model(self):
params = []
for sublayer in self.model.sublayers():
for param in sublayer.parameters(include_sublayers=False):
if isinstance(sublayer, paddle.nn.Conv2D):
params.append(param.name)
ratios = {}
for param in params:
ratios[param] = self.config["Slim"]["prune"]["pruned_ratio"]
plan = self.pruner.prune_vars(ratios, [0])
logger.info("FLOPs after pruning: {}GFLOPs; pruned ratio: {}".format(
flops(self.model, [1] + self.config["Global"]["image_shape"]) /
1000000, plan.pruned_flops))
for param in self.model.parameters():
if "conv2d" in param.name:
logger.info("{}\t{}".format(param.name, param.shape))
self.model.train()
if __name__ == "__main__":