parent
339be96ef5
commit
f91811dab9
|
@ -14,14 +14,12 @@
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import copy
|
import copy
|
||||||
import importlib
|
|
||||||
|
|
||||||
import paddle.nn as nn
|
import paddle.nn as nn
|
||||||
from paddle.jit import to_static
|
from paddle.jit import to_static
|
||||||
from paddle.static import InputSpec
|
from paddle.static import InputSpec
|
||||||
|
|
||||||
from . import backbone
|
from . import backbone
|
||||||
from .backbone import *
|
|
||||||
from .gears import build_gear
|
from .gears import build_gear
|
||||||
from .utils import *
|
from .utils import *
|
||||||
from .backbone.base.theseus_layer import TheseusLayer
|
from .backbone.base.theseus_layer import TheseusLayer
|
||||||
|
@ -38,8 +36,11 @@ def build_model(config, mode="train"):
|
||||||
model_type = arch_config.pop("name")
|
model_type = arch_config.pop("name")
|
||||||
use_sync_bn = arch_config.pop("use_sync_bn", False)
|
use_sync_bn = arch_config.pop("use_sync_bn", False)
|
||||||
|
|
||||||
mod = importlib.import_module(__name__)
|
if hasattr(backbone, model_type):
|
||||||
model = getattr(mod, model_type)(**arch_config)
|
model = ClassModel(model_type, **arch_config)
|
||||||
|
else:
|
||||||
|
model = getattr(sys.modules[__name__], model_type)("ClassModel",
|
||||||
|
**arch_config)
|
||||||
|
|
||||||
if use_sync_bn:
|
if use_sync_bn:
|
||||||
if config["Global"]["device"] == "gpu":
|
if config["Global"]["device"] == "gpu":
|
||||||
|
@ -72,6 +73,23 @@ def apply_to_static(config, model):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(gaotingquan): export model
|
||||||
|
class ClassModel(TheseusLayer):
|
||||||
|
def __init__(self, model_type, **config):
|
||||||
|
super().__init__()
|
||||||
|
if model_type == "ClassModel":
|
||||||
|
backbone_config = config["Backbone"]
|
||||||
|
backbone_name = backbone_config.pop("name")
|
||||||
|
else:
|
||||||
|
backbone_name = model_type
|
||||||
|
backbone_config = config
|
||||||
|
self.backbone = getattr(backbone, backbone_name)(**backbone_config)
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
x, label = batch[0], batch[1]
|
||||||
|
return self.backbone(x)
|
||||||
|
|
||||||
|
|
||||||
class RecModel(TheseusLayer):
|
class RecModel(TheseusLayer):
|
||||||
def __init__(self, **config):
|
def __init__(self, **config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -1,6 +0,0 @@
|
||||||
def clas_forward_decorator(forward_func):
|
|
||||||
def parse_batch_wrapper(model, batch):
|
|
||||||
x, label = batch[0], batch[1]
|
|
||||||
return forward_func(model, x)
|
|
||||||
|
|
||||||
return parse_batch_wrapper
|
|
Loading…
Reference in New Issue