Revert "use decorator to parse batch"

This reverts commit 97935164fe.
pull/2701/head
Tingquan Gao 2023-03-14 16:16:40 +08:00
parent 339be96ef5
commit f91811dab9
2 changed files with 22 additions and 10 deletions

View File

@ -14,14 +14,12 @@
import sys
import copy
import importlib
import paddle.nn as nn
from paddle.jit import to_static
from paddle.static import InputSpec
from . import backbone
from .backbone import *
from .gears import build_gear
from .utils import *
from .backbone.base.theseus_layer import TheseusLayer
@ -38,8 +36,11 @@ def build_model(config, mode="train"):
model_type = arch_config.pop("name")
use_sync_bn = arch_config.pop("use_sync_bn", False)
mod = importlib.import_module(__name__)
model = getattr(mod, model_type)(**arch_config)
if hasattr(backbone, model_type):
model = ClassModel(model_type, **arch_config)
else:
model = getattr(sys.modules[__name__], model_type)("ClassModel",
**arch_config)
if use_sync_bn:
if config["Global"]["device"] == "gpu":
@ -72,6 +73,23 @@ def apply_to_static(config, model):
return model
# TODO(gaotingquan): export model
class ClassModel(TheseusLayer):
def __init__(self, model_type, **config):
super().__init__()
if model_type == "ClassModel":
backbone_config = config["Backbone"]
backbone_name = backbone_config.pop("name")
else:
backbone_name = model_type
backbone_config = config
self.backbone = getattr(backbone, backbone_name)(**backbone_config)
def forward(self, batch):
x, label = batch[0], batch[1]
return self.backbone(x)
class RecModel(TheseusLayer):
def __init__(self, **config):
super().__init__()

View File

@ -1,6 +0,0 @@
def clas_forward_decorator(forward_func):
def parse_batch_wrapper(model, batch):
x, label = batch[0], batch[1]
return forward_func(model, x)
return parse_batch_wrapper