add slim support
parent
4452565bdf
commit
eafcc86457
|
@ -18,12 +18,12 @@ Global:
|
|||
# for paddleslim
|
||||
Slim:
|
||||
# for quantalization
|
||||
quant:
|
||||
name: pact
|
||||
# quant:
|
||||
# name: pact
|
||||
## for prune
|
||||
#prune:
|
||||
# name: fpgm
|
||||
# prune_ratio: 0.3
|
||||
prune:
|
||||
name: fpgm
|
||||
pruned_ratio: 0.3
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
|
@ -58,7 +58,7 @@ DataLoader:
|
|||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: ./dataset/ILSVRC2012/
|
||||
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
|
||||
cls_label_path: ./dataset/ILSVRC2012/train.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
|
@ -89,7 +89,7 @@ DataLoader:
|
|||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: ./dataset/ILSVRC2012/
|
||||
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
|
||||
cls_label_path: ./dataset/ILSVRC2012/val.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
|
|
|
@ -18,10 +18,14 @@ import os
|
|||
import sys
|
||||
|
||||
import paddle
|
||||
import paddleslim
|
||||
from paddle.jit import to_static
|
||||
from paddleslim.analysis import dygraph_flops as flops
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
|
||||
from paddleslim.dygraph.quant import QAT
|
||||
|
||||
from ppcls.engine.trainer import Trainer
|
||||
from ppcls.utils import config, logger
|
||||
from ppcls.utils.save_load import load_dygraph_pretrain
|
||||
|
@ -53,9 +57,12 @@ quant_config = {
|
|||
class Trainer_slim(Trainer):
|
||||
def __init__(self, config, mode="train"):
|
||||
super().__init__(config, mode)
|
||||
# self.pact = self.config["Slim"].get("pact", False)
|
||||
self.pact = True
|
||||
if self.pact:
|
||||
pact = self.config["Slim"].get("quant", False)
|
||||
self.pact = pact.get("name", False) if pact else pact
|
||||
|
||||
if self.pact and str(self.pact.lower()) != 'pact':
|
||||
raise RuntimeError("The quantization only support 'PACT'!")
|
||||
if pact:
|
||||
quant_config["activation_preprocess_type"] = "PACT"
|
||||
self.quanter = QAT(config=quant_config)
|
||||
self.quanter.quantize(self.model)
|
||||
|
@ -64,6 +71,31 @@ class Trainer_slim(Trainer):
|
|||
else:
|
||||
self.quanter = None
|
||||
|
||||
prune_config = self.config["Slim"].get("prune", False)
|
||||
if prune_config:
|
||||
if prune_config["name"].lower() not in ["fpgm", "l1_norm"]:
|
||||
raise RuntimeError(
|
||||
"The prune methods only support 'fpgm' and 'l1_norm'")
|
||||
else:
|
||||
logger.info("FLOPs before pruning: {}GFLOPs".format(
|
||||
flops(self.model, [1] + self.config["Global"][
|
||||
"image_shape"]) / 1000000))
|
||||
self.model.eval()
|
||||
|
||||
if prune_config["name"].lower() == "fpgm":
|
||||
self.model.eval()
|
||||
self.pruner = paddleslim.dygraph.FPGMFilterPruner(
|
||||
self.model, [1] + self.config["Global"]["image_shape"])
|
||||
else:
|
||||
self.pruner = paddleslim.dygraph.L1NormFilterPruner(
|
||||
self.model, [1] + self.config["Global"]["image_shape"])
|
||||
self.prune_model()
|
||||
else:
|
||||
self.pruner = None
|
||||
|
||||
if self.quanter is None and self.pruner is None:
|
||||
logger.info("Training without slim")
|
||||
|
||||
def train(self):
|
||||
super().train()
|
||||
if self.config["Global"].get("save_inference_dir", None):
|
||||
|
@ -86,17 +118,48 @@ class Trainer_slim(Trainer):
|
|||
raise RuntimeError(
|
||||
"The best_model or pretraine_model should exist to generate inference model"
|
||||
)
|
||||
save_path = os.path.join(self.config["Global"]["save_inference_dir"],
|
||||
"inference")
|
||||
if self.quanter:
|
||||
self.quanter.save_quantized_model(
|
||||
self.model,
|
||||
save_path,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None] + config["Global"]["image_shape"],
|
||||
dtype='float32')
|
||||
])
|
||||
else:
|
||||
model = to_static(
|
||||
self.model,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None] + self.config["Global"]["image_shape"],
|
||||
dtype='float32',
|
||||
name="image")
|
||||
])
|
||||
paddle.jit.save(model, save_path)
|
||||
|
||||
assert self.quanter
|
||||
self.quanter.save_quantized_model(
|
||||
self.model,
|
||||
os.path.join(self.config["Global"]["save_inference_dir"],
|
||||
"inference"),
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None] + config["Global"]["image_shape"],
|
||||
dtype='float32')
|
||||
])
|
||||
def prune_model(self):
|
||||
params = []
|
||||
for sublayer in self.model.sublayers():
|
||||
for param in sublayer.parameters(include_sublayers=False):
|
||||
if isinstance(sublayer, paddle.nn.Conv2D):
|
||||
params.append(param.name)
|
||||
ratios = {}
|
||||
for param in params:
|
||||
ratios[param] = self.config["Slim"]["prune"]["pruned_ratio"]
|
||||
plan = self.pruner.prune_vars(ratios, [0])
|
||||
|
||||
logger.info("FLOPs after pruning: {}GFLOPs; pruned ratio: {}".format(
|
||||
flops(self.model, [1] + self.config["Global"]["image_shape"]) /
|
||||
1000000, plan.pruned_flops))
|
||||
|
||||
for param in self.model.parameters():
|
||||
if "conv2d" in param.name:
|
||||
logger.info("{}\t{}".format(param.name, param.shape))
|
||||
|
||||
self.model.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue