add pruner and quanter for theseus
parent
0c8a082d35
commit
6c5d1ebc28
ppcls
arch
backbone/base
engine
static
|
@ -26,15 +26,20 @@ from .utils import *
|
|||
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
|
||||
from ppcls.utils import logger
|
||||
from ppcls.utils.save_load import load_dygraph_pretrain
|
||||
from ppcls.engine.slim import prune_model, quantize_model
|
||||
|
||||
|
||||
__all__ = ["build_model", "RecModel", "DistillationModel"]
|
||||
|
||||
|
||||
def build_model(config):
|
||||
config = copy.deepcopy(config)
|
||||
model_type = config.pop("name")
|
||||
arch_config = copy.deepcopy(config["Arch"])
|
||||
model_type = arch_config.pop("name")
|
||||
mod = importlib.import_module(__name__)
|
||||
arch = getattr(mod, model_type)(**config)
|
||||
arch = getattr(mod, model_type)(**arch_config)
|
||||
if isinstance(arch, TheseusLayer):
|
||||
prune_model(config, arch)
|
||||
quantize_model(config, arch)
|
||||
return arch
|
||||
|
||||
|
||||
|
@ -51,7 +56,7 @@ def apply_to_static(config, model):
|
|||
return model
|
||||
|
||||
|
||||
class RecModel(nn.Layer):
|
||||
class RecModel(TheseusLayer):
|
||||
def __init__(self, **config):
|
||||
super().__init__()
|
||||
backbone_config = config["Backbone"]
|
||||
|
|
|
@ -16,6 +16,8 @@ class TheseusLayer(nn.Layer):
|
|||
super(TheseusLayer, self).__init__()
|
||||
self.res_dict = {}
|
||||
self.res_name = self.full_name()
|
||||
self.pruner = None
|
||||
self.quanter = None
|
||||
|
||||
# stop doesn't work when stop layer has a parallel branch.
|
||||
def stop_after(self, stop_layer_name: str):
|
||||
|
|
|
@ -44,7 +44,6 @@ from ppcls.data import create_operators
|
|||
from ppcls.engine.train import train_epoch
|
||||
from ppcls.engine import evaluation
|
||||
from ppcls.arch.gears.identity_head import IdentityHead
|
||||
from ppcls.engine.slim import get_pruner, get_quaner
|
||||
|
||||
|
||||
class Engine(object):
|
||||
|
@ -186,14 +185,12 @@ class Engine(object):
|
|||
self.eval_metric_func = None
|
||||
|
||||
# build model
|
||||
self.model = build_model(self.config["Arch"])
|
||||
self.model = build_model(self.config)
|
||||
self.quanted = self.config.get("Slim", {}).get("quant", False)
|
||||
self.pruned = self.config.get("Slim", {}).get("prune", False)
|
||||
# set @to_static for benchmark, skip this by default.
|
||||
apply_to_static(self.config, self.model)
|
||||
|
||||
# for slim
|
||||
self.pruner = get_pruner(self.config, self.model)
|
||||
self.quanter = get_quaner(self.config, self.model)
|
||||
|
||||
# load_pretrain
|
||||
if self.config["Global"]["pretrained_model"] is not None:
|
||||
if self.config["Global"]["pretrained_model"].startswith("http"):
|
||||
|
@ -371,8 +368,8 @@ class Engine(object):
|
|||
model.eval()
|
||||
save_path = os.path.join(self.config["Global"]["save_inference_dir"],
|
||||
"inference")
|
||||
if self.quanter:
|
||||
self.quanter.save_quantized_model(
|
||||
if self.quanted:
|
||||
model.quanter.save_quantized_model(
|
||||
model.base_model,
|
||||
save_path,
|
||||
input_spec=[
|
||||
|
|
|
@ -12,5 +12,5 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ppcls.engine.slim.prune import get_pruner
|
||||
from ppcls.engine.slim.quant import get_quaner
|
||||
from ppcls.engine.slim.prune import prune_model
|
||||
from ppcls.engine.slim.quant import quantize_model
|
||||
|
|
|
@ -17,7 +17,7 @@ import paddle
|
|||
from ppcls.utils import logger
|
||||
|
||||
|
||||
def get_pruner(config, model):
|
||||
def prune_model(config, model):
|
||||
if config.get("Slim", False) and config["Slim"].get("prune", False):
|
||||
import paddleslim
|
||||
prune_method_name = config["Slim"]["prune"]["name"].lower()
|
||||
|
@ -25,21 +25,20 @@ def get_pruner(config, model):
|
|||
"fpgm", "l1_norm"
|
||||
], "The prune methods only support 'fpgm' and 'l1_norm'"
|
||||
if prune_method_name == "fpgm":
|
||||
pruner = paddleslim.dygraph.FPGMFilterPruner(
|
||||
model.pruner = paddleslim.dygraph.FPGMFilterPruner(
|
||||
model, [1] + config["Global"]["image_shape"])
|
||||
else:
|
||||
pruner = paddleslim.dygraph.L1NormFilterPruner(
|
||||
model.pruner = paddleslim.dygraph.L1NormFilterPruner(
|
||||
model, [1] + config["Global"]["image_shape"])
|
||||
|
||||
# prune model
|
||||
_prune_model(pruner, config, model)
|
||||
_prune_model(config, model)
|
||||
else:
|
||||
pruner = None
|
||||
|
||||
return pruner
|
||||
model.pruner = None
|
||||
|
||||
|
||||
def _prune_model(pruner, config, model):
|
||||
|
||||
def _prune_model(config, model):
|
||||
from paddleslim.analysis import dygraph_flops as flops
|
||||
logger.info("FLOPs before pruning: {}GFLOPs".format(
|
||||
flops(model, [1] + config["Global"]["image_shape"]) / 1e9))
|
||||
|
@ -53,7 +52,7 @@ def _prune_model(pruner, config, model):
|
|||
ratios = {}
|
||||
for param in params:
|
||||
ratios[param] = config["Slim"]["prune"]["pruned_ratio"]
|
||||
plan = pruner.prune_vars(ratios, [0])
|
||||
plan = model.pruner.prune_vars(ratios, [0])
|
||||
|
||||
logger.info("FLOPs after pruning: {}GFLOPs; pruned ratio: {}".format(
|
||||
flops(model, [1] + config["Global"]["image_shape"]) / 1e9,
|
||||
|
|
|
@ -40,16 +40,16 @@ QUANT_CONFIG = {
|
|||
}
|
||||
|
||||
|
||||
def get_quaner(config, model):
|
||||
def quantize_model(config, model):
|
||||
if config.get("Slim", False) and config["Slim"].get("quant", False):
|
||||
from paddleslim.dygraph.quant import QAT
|
||||
assert config["Slim"]["quant"]["name"].lower(
|
||||
) == 'pact', 'Only PACT quantization method is supported now'
|
||||
QUANT_CONFIG["activation_preprocess_type"] = "PACT"
|
||||
quanter = QAT(config=QUANT_CONFIG)
|
||||
quanter.quantize(model)
|
||||
model.quanted = QAT(config=QUANT_CONFIG)
|
||||
model.quanted.quantize_model(model)
|
||||
logger.info("QAT model summary:")
|
||||
paddle.summary(model, (1, 3, 224, 224))
|
||||
else:
|
||||
quanter = None
|
||||
return quanter
|
||||
model.quanted = None
|
||||
return model.quanted
|
||||
|
|
|
@ -259,7 +259,7 @@ def build(config,
|
|||
# data_format should be assigned in arch-dict
|
||||
input_image_channel = config["Global"]["image_shape"][
|
||||
0] # default as [3, 224, 224]
|
||||
model = build_model(config["Arch"])
|
||||
model = build_model(config)
|
||||
out = model(feeds["data"])
|
||||
# end of build model
|
||||
|
||||
|
|
Loading…
Reference in New Issue