diff --git a/ppcls/arch/__init__.py b/ppcls/arch/__init__.py index 77b66c421..1ec112de5 100644 --- a/ppcls/arch/__init__.py +++ b/ppcls/arch/__init__.py @@ -14,14 +14,12 @@ import sys import copy -import importlib import paddle.nn as nn from paddle.jit import to_static from paddle.static import InputSpec from . import backbone -from .backbone import * from .gears import build_gear from .utils import * from .backbone.base.theseus_layer import TheseusLayer @@ -38,8 +36,11 @@ def build_model(config, mode="train"): model_type = arch_config.pop("name") use_sync_bn = arch_config.pop("use_sync_bn", False) - mod = importlib.import_module(__name__) - model = getattr(mod, model_type)(**arch_config) + if hasattr(backbone, model_type): + model = ClassModel(model_type, **arch_config) + else: + model = getattr(sys.modules[__name__], model_type)("ClassModel", + **arch_config) if use_sync_bn: if config["Global"]["device"] == "gpu": @@ -72,6 +73,23 @@ def apply_to_static(config, model): return model +# TODO(gaotingquan): export model +class ClassModel(TheseusLayer): + def __init__(self, model_type, **config): + super().__init__() + if model_type == "ClassModel": + backbone_config = config["Backbone"] + backbone_name = backbone_config.pop("name") + else: + backbone_name = model_type + backbone_config = config + self.backbone = getattr(backbone, backbone_name)(**backbone_config) + + def forward(self, batch): + x, label = batch[0], batch[1] + return self.backbone(x) + + class RecModel(TheseusLayer): def __init__(self, **config): super().__init__() diff --git a/ppcls/arch/backbone/base/__init__.py b/ppcls/arch/backbone/base/__init__.py index 7a1fec9b8..e69de29bb 100644 --- a/ppcls/arch/backbone/base/__init__.py +++ b/ppcls/arch/backbone/base/__init__.py @@ -1,6 +0,0 @@ -def clas_forward_decorator(forward_func): - def parse_batch_wrapper(model, batch): - x, label = batch[0], batch[1] - return forward_func(model, x) - - return parse_batch_wrapper