add pruner and quanter for theseus
parent
0c8a082d35
commit
6c5d1ebc28
|
@ -26,15 +26,20 @@ from .utils import *
|
||||||
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
|
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
|
||||||
from ppcls.utils import logger
|
from ppcls.utils import logger
|
||||||
from ppcls.utils.save_load import load_dygraph_pretrain
|
from ppcls.utils.save_load import load_dygraph_pretrain
|
||||||
|
from ppcls.engine.slim import prune_model, quantize_model
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["build_model", "RecModel", "DistillationModel"]
|
__all__ = ["build_model", "RecModel", "DistillationModel"]
|
||||||
|
|
||||||
|
|
||||||
def build_model(config):
|
def build_model(config):
|
||||||
config = copy.deepcopy(config)
|
arch_config = copy.deepcopy(config["Arch"])
|
||||||
model_type = config.pop("name")
|
model_type = arch_config.pop("name")
|
||||||
mod = importlib.import_module(__name__)
|
mod = importlib.import_module(__name__)
|
||||||
arch = getattr(mod, model_type)(**config)
|
arch = getattr(mod, model_type)(**arch_config)
|
||||||
|
if isinstance(arch, TheseusLayer):
|
||||||
|
prune_model(config, arch)
|
||||||
|
quantize_model(config, arch)
|
||||||
return arch
|
return arch
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,7 +56,7 @@ def apply_to_static(config, model):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class RecModel(nn.Layer):
|
class RecModel(TheseusLayer):
|
||||||
def __init__(self, **config):
|
def __init__(self, **config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
backbone_config = config["Backbone"]
|
backbone_config = config["Backbone"]
|
||||||
|
|
|
@ -16,6 +16,8 @@ class TheseusLayer(nn.Layer):
|
||||||
super(TheseusLayer, self).__init__()
|
super(TheseusLayer, self).__init__()
|
||||||
self.res_dict = {}
|
self.res_dict = {}
|
||||||
self.res_name = self.full_name()
|
self.res_name = self.full_name()
|
||||||
|
self.pruner = None
|
||||||
|
self.quanter = None
|
||||||
|
|
||||||
# stop doesn't work when stop layer has a parallel branch.
|
# stop doesn't work when stop layer has a parallel branch.
|
||||||
def stop_after(self, stop_layer_name: str):
|
def stop_after(self, stop_layer_name: str):
|
||||||
|
|
|
@ -44,7 +44,6 @@ from ppcls.data import create_operators
|
||||||
from ppcls.engine.train import train_epoch
|
from ppcls.engine.train import train_epoch
|
||||||
from ppcls.engine import evaluation
|
from ppcls.engine import evaluation
|
||||||
from ppcls.arch.gears.identity_head import IdentityHead
|
from ppcls.arch.gears.identity_head import IdentityHead
|
||||||
from ppcls.engine.slim import get_pruner, get_quaner
|
|
||||||
|
|
||||||
|
|
||||||
class Engine(object):
|
class Engine(object):
|
||||||
|
@ -186,14 +185,12 @@ class Engine(object):
|
||||||
self.eval_metric_func = None
|
self.eval_metric_func = None
|
||||||
|
|
||||||
# build model
|
# build model
|
||||||
self.model = build_model(self.config["Arch"])
|
self.model = build_model(self.config)
|
||||||
|
self.quanted = self.config.get("Slim", {}).get("quant", False)
|
||||||
|
self.pruned = self.config.get("Slim", {}).get("prune", False)
|
||||||
# set @to_static for benchmark, skip this by default.
|
# set @to_static for benchmark, skip this by default.
|
||||||
apply_to_static(self.config, self.model)
|
apply_to_static(self.config, self.model)
|
||||||
|
|
||||||
# for slim
|
|
||||||
self.pruner = get_pruner(self.config, self.model)
|
|
||||||
self.quanter = get_quaner(self.config, self.model)
|
|
||||||
|
|
||||||
# load_pretrain
|
# load_pretrain
|
||||||
if self.config["Global"]["pretrained_model"] is not None:
|
if self.config["Global"]["pretrained_model"] is not None:
|
||||||
if self.config["Global"]["pretrained_model"].startswith("http"):
|
if self.config["Global"]["pretrained_model"].startswith("http"):
|
||||||
|
@ -371,8 +368,8 @@ class Engine(object):
|
||||||
model.eval()
|
model.eval()
|
||||||
save_path = os.path.join(self.config["Global"]["save_inference_dir"],
|
save_path = os.path.join(self.config["Global"]["save_inference_dir"],
|
||||||
"inference")
|
"inference")
|
||||||
if self.quanter:
|
if self.quanted:
|
||||||
self.quanter.save_quantized_model(
|
model.quanter.save_quantized_model(
|
||||||
model.base_model,
|
model.base_model,
|
||||||
save_path,
|
save_path,
|
||||||
input_spec=[
|
input_spec=[
|
||||||
|
|
|
@ -12,5 +12,5 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ppcls.engine.slim.prune import get_pruner
|
from ppcls.engine.slim.prune import prune_model
|
||||||
from ppcls.engine.slim.quant import get_quaner
|
from ppcls.engine.slim.quant import quantize_model
|
||||||
|
|
|
@ -17,7 +17,7 @@ import paddle
|
||||||
from ppcls.utils import logger
|
from ppcls.utils import logger
|
||||||
|
|
||||||
|
|
||||||
def get_pruner(config, model):
|
def prune_model(config, model):
|
||||||
if config.get("Slim", False) and config["Slim"].get("prune", False):
|
if config.get("Slim", False) and config["Slim"].get("prune", False):
|
||||||
import paddleslim
|
import paddleslim
|
||||||
prune_method_name = config["Slim"]["prune"]["name"].lower()
|
prune_method_name = config["Slim"]["prune"]["name"].lower()
|
||||||
|
@ -25,21 +25,20 @@ def get_pruner(config, model):
|
||||||
"fpgm", "l1_norm"
|
"fpgm", "l1_norm"
|
||||||
], "The prune methods only support 'fpgm' and 'l1_norm'"
|
], "The prune methods only support 'fpgm' and 'l1_norm'"
|
||||||
if prune_method_name == "fpgm":
|
if prune_method_name == "fpgm":
|
||||||
pruner = paddleslim.dygraph.FPGMFilterPruner(
|
model.pruner = paddleslim.dygraph.FPGMFilterPruner(
|
||||||
model, [1] + config["Global"]["image_shape"])
|
model, [1] + config["Global"]["image_shape"])
|
||||||
else:
|
else:
|
||||||
pruner = paddleslim.dygraph.L1NormFilterPruner(
|
model.pruner = paddleslim.dygraph.L1NormFilterPruner(
|
||||||
model, [1] + config["Global"]["image_shape"])
|
model, [1] + config["Global"]["image_shape"])
|
||||||
|
|
||||||
# prune model
|
# prune model
|
||||||
_prune_model(pruner, config, model)
|
_prune_model(config, model)
|
||||||
else:
|
else:
|
||||||
pruner = None
|
model.pruner = None
|
||||||
|
|
||||||
return pruner
|
|
||||||
|
|
||||||
|
|
||||||
def _prune_model(pruner, config, model):
|
|
||||||
|
def _prune_model(config, model):
|
||||||
from paddleslim.analysis import dygraph_flops as flops
|
from paddleslim.analysis import dygraph_flops as flops
|
||||||
logger.info("FLOPs before pruning: {}GFLOPs".format(
|
logger.info("FLOPs before pruning: {}GFLOPs".format(
|
||||||
flops(model, [1] + config["Global"]["image_shape"]) / 1e9))
|
flops(model, [1] + config["Global"]["image_shape"]) / 1e9))
|
||||||
|
@ -53,7 +52,7 @@ def _prune_model(pruner, config, model):
|
||||||
ratios = {}
|
ratios = {}
|
||||||
for param in params:
|
for param in params:
|
||||||
ratios[param] = config["Slim"]["prune"]["pruned_ratio"]
|
ratios[param] = config["Slim"]["prune"]["pruned_ratio"]
|
||||||
plan = pruner.prune_vars(ratios, [0])
|
plan = model.pruner.prune_vars(ratios, [0])
|
||||||
|
|
||||||
logger.info("FLOPs after pruning: {}GFLOPs; pruned ratio: {}".format(
|
logger.info("FLOPs after pruning: {}GFLOPs; pruned ratio: {}".format(
|
||||||
flops(model, [1] + config["Global"]["image_shape"]) / 1e9,
|
flops(model, [1] + config["Global"]["image_shape"]) / 1e9,
|
||||||
|
|
|
@ -40,16 +40,16 @@ QUANT_CONFIG = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_quaner(config, model):
|
def quantize_model(config, model):
|
||||||
if config.get("Slim", False) and config["Slim"].get("quant", False):
|
if config.get("Slim", False) and config["Slim"].get("quant", False):
|
||||||
from paddleslim.dygraph.quant import QAT
|
from paddleslim.dygraph.quant import QAT
|
||||||
assert config["Slim"]["quant"]["name"].lower(
|
assert config["Slim"]["quant"]["name"].lower(
|
||||||
) == 'pact', 'Only PACT quantization method is supported now'
|
) == 'pact', 'Only PACT quantization method is supported now'
|
||||||
QUANT_CONFIG["activation_preprocess_type"] = "PACT"
|
QUANT_CONFIG["activation_preprocess_type"] = "PACT"
|
||||||
quanter = QAT(config=QUANT_CONFIG)
|
model.quanted = QAT(config=QUANT_CONFIG)
|
||||||
quanter.quantize(model)
|
model.quanted.quantize_model(model)
|
||||||
logger.info("QAT model summary:")
|
logger.info("QAT model summary:")
|
||||||
paddle.summary(model, (1, 3, 224, 224))
|
paddle.summary(model, (1, 3, 224, 224))
|
||||||
else:
|
else:
|
||||||
quanter = None
|
model.quanted = None
|
||||||
return quanter
|
return model.quanted
|
||||||
|
|
|
@ -259,7 +259,7 @@ def build(config,
|
||||||
# data_format should be assigned in arch-dict
|
# data_format should be assigned in arch-dict
|
||||||
input_image_channel = config["Global"]["image_shape"][
|
input_image_channel = config["Global"]["image_shape"][
|
||||||
0] # default as [3, 224, 224]
|
0] # default as [3, 224, 224]
|
||||||
model = build_model(config["Arch"])
|
model = build_model(config)
|
||||||
out = model(feeds["data"])
|
out = model(feeds["data"])
|
||||||
# end of build model
|
# end of build model
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue