diff --git a/ppcls/arch/__init__.py b/ppcls/arch/__init__.py index d43bc4880..f2c2e412c 100644 --- a/ppcls/arch/__init__.py +++ b/ppcls/arch/__init__.py @@ -26,7 +26,7 @@ from .utils import * from ppcls.arch.backbone.base.theseus_layer import TheseusLayer from ppcls.utils import logger from ppcls.utils.save_load import load_dygraph_pretrain -from ppcls.engine.slim import prune_model, quantize_model +from ppcls.arch.slim import prune_model, quantize_model __all__ = ["build_model", "RecModel", "DistillationModel"] diff --git a/ppcls/engine/slim/__init__.py b/ppcls/arch/slim/__init__.py similarity index 86% rename from ppcls/engine/slim/__init__.py rename to ppcls/arch/slim/__init__.py index de3d857b2..e28424722 100644 --- a/ppcls/engine/slim/__init__.py +++ b/ppcls/arch/slim/__init__.py @@ -12,5 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ppcls.engine.slim.prune import prune_model -from ppcls.engine.slim.quant import quantize_model +from ppcls.arch.slim.prune import prune_model diff --git a/ppcls/engine/slim/prune.py b/ppcls/arch/slim/prune.py similarity index 100% rename from ppcls/engine/slim/prune.py rename to ppcls/arch/slim/prune.py diff --git a/ppcls/engine/slim/quant.py b/ppcls/arch/slim/quant.py similarity index 100% rename from ppcls/engine/slim/quant.py rename to ppcls/arch/slim/quant.py diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 22b3e0589..c53594889 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -186,8 +186,6 @@ class Engine(object): # build model self.model = build_model(self.config) - self.quanted = self.config.get("Slim", {}).get("quant", False) - self.pruned = self.config.get("Slim", {}).get("prune", False) # set @to_static for benchmark, skip this by default. apply_to_static(self.config, self.model) @@ -368,7 +366,7 @@ class Engine(object): model.eval() save_path = os.path.join(self.config["Global"]["save_inference_dir"], "inference") - if self.quanted: + if model.quanter: model.quanter.save_quantized_model( model.base_model, save_path,