Revert "refactor: add ClassModel to unify model forward interface"

This reverts commit 75a20ba5574340fa5742eba8e41aebe4de6c5eb8.
This commit is contained in:
Tingquan Gao 2023-03-14 16:16:40 +08:00
parent e7e4f68b5c
commit 6aabb94d8c
3 changed files with 36 additions and 47 deletions

View File

@ -12,14 +12,14 @@
#See the License for the specific language governing permissions and #See the License for the specific language governing permissions and
#limitations under the License. #limitations under the License.
import sys
import copy import copy
import importlib
import paddle.nn as nn import paddle.nn as nn
from paddle.jit import to_static from paddle.jit import to_static
from paddle.static import InputSpec from paddle.static import InputSpec
from . import backbone as backbone_zoo from . import backbone, gears
from .backbone import *
from .gears import build_gear from .gears import build_gear
from .utils import * from .utils import *
from .backbone.base.theseus_layer import TheseusLayer from .backbone.base.theseus_layer import TheseusLayer
@ -35,28 +35,20 @@ def build_model(config, mode="train"):
arch_config = copy.deepcopy(config["Arch"]) arch_config = copy.deepcopy(config["Arch"])
model_type = arch_config.pop("name") model_type = arch_config.pop("name")
use_sync_bn = arch_config.pop("use_sync_bn", False) use_sync_bn = arch_config.pop("use_sync_bn", False)
mod = importlib.import_module(__name__)
if hasattr(backbone_zoo, model_type): arch = getattr(mod, model_type)(**arch_config)
model = ClassModel(model_type, **arch_config)
else:
model = getattr(sys.modules[__name__], model_type)("ClassModel",
**arch_config)
if use_sync_bn: if use_sync_bn:
if config["Global"]["device"] == "gpu": if config["Global"]["device"] == "gpu":
model = nn.SyncBatchNorm.convert_sync_batchnorm(model) arch = nn.SyncBatchNorm.convert_sync_batchnorm(arch)
else: else:
msg = "SyncBatchNorm can only be used on GPU device. The releated setting has been ignored." msg = "SyncBatchNorm can only be used on GPU device. The releated setting has been ignored."
logger.warning(msg) logger.warning(msg)
if isinstance(model, TheseusLayer): if isinstance(arch, TheseusLayer):
prune_model(config, model) prune_model(config, arch)
quantize_model(config, model, mode) quantize_model(config, arch, mode)
# set @to_static for benchmark, skip this by default. return arch
model = apply_to_static(config, model)
return model
def apply_to_static(config, model): def apply_to_static(config, model):
@ -73,29 +65,12 @@ def apply_to_static(config, model):
return model return model
# TODO(gaotingquan): export model
class ClassModel(TheseusLayer):
def __init__(self, model_type, **config):
super().__init__()
if model_type == "ClassModel":
backbone_config = config["Backbone"]
backbone_name = backbone_config.pop("name")
else:
backbone_name = model_type
backbone_config = config
self.backbone = getattr(backbone_zoo, backbone_name)(**backbone_config)
def forward(self, batch):
x, label = batch[0], batch[1]
return self.backbone(x)
class RecModel(TheseusLayer): class RecModel(TheseusLayer):
def __init__(self, **config): def __init__(self, **config):
super().__init__() super().__init__()
backbone_config = config["Backbone"] backbone_config = config["Backbone"]
backbone_name = backbone_config.pop("name") backbone_name = backbone_config.pop("name")
self.backbone = getattr(backbone_zoo, backbone_name)(**backbone_config) self.backbone = eval(backbone_name)(**backbone_config)
self.head_feature_from = config.get('head_feature_from', 'neck') self.head_feature_from = config.get('head_feature_from', 'neck')
if "BackboneStopLayer" in config: if "BackboneStopLayer" in config:
@ -112,8 +87,8 @@ class RecModel(TheseusLayer):
else: else:
self.head = None self.head = None
def forward(self, batch): def forward(self, x, label=None):
x, label = batch[0], batch[1]
out = dict() out = dict()
x = self.backbone(x) x = self.backbone(x)
out["backbone"] = x out["backbone"] = x
@ -165,8 +140,7 @@ class DistillationModel(nn.Layer):
load_dygraph_pretrain( load_dygraph_pretrain(
self.model_name_list[idx], path=pretrained) self.model_name_list[idx], path=pretrained)
def forward(self, batch): def forward(self, x, label=None):
x, label = batch[0], batch[1]
result_dict = dict() result_dict = dict()
for idx, model_name in enumerate(self.model_name_list): for idx, model_name in enumerate(self.model_name_list):
if label is None: if label is None:
@ -184,8 +158,7 @@ class AttentionModel(DistillationModel):
**kargs): **kargs):
super().__init__(models, pretrained_list, freeze_params_list, **kargs) super().__init__(models, pretrained_list, freeze_params_list, **kargs)
def forward(self, batch): def forward(self, x, label=None):
x, label = batch[0], batch[1]
result_dict = dict() result_dict = dict()
out = x out = x
for idx, model_name in enumerate(self.model_name_list): for idx, model_name in enumerate(self.model_name_list):

View File

@ -28,6 +28,7 @@ from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config from ppcls.utils.config import print_config
from ppcls.data import build_dataloader from ppcls.data import build_dataloader
from ppcls.arch import build_model, RecModel, DistillationModel, TheseusLayer from ppcls.arch import build_model, RecModel, DistillationModel, TheseusLayer
from ppcls.arch import apply_to_static
from ppcls.loss import build_loss from ppcls.loss import build_loss
from ppcls.metric import build_metrics from ppcls.metric import build_metrics
from ppcls.optimizer import build_optimizer from ppcls.optimizer import build_optimizer
@ -56,10 +57,18 @@ class Engine(object):
# init logger # init logger
init_logger(self.config, mode=mode) init_logger(self.config, mode=mode)
print_config(config)
# for visualdl # for visualdl
self.vdl_writer = self._init_vdl() self.vdl_writer = self._init_vdl()
# is_rec
if "Head" in self.config["Arch"] or self.config["Arch"].get("is_rec",
False):
self.is_rec = True
else:
self.is_rec = False
# init train_func and eval_func # init train_func and eval_func
self.train_mode = self.config["Global"].get("train_mode", None) self.train_mode = self.config["Global"].get("train_mode", None)
if self.train_mode is None: if self.train_mode is None:
@ -99,6 +108,8 @@ class Engine(object):
# build model # build model
self.model = build_model(self.config, self.mode) self.model = build_model(self.config, self.mode)
# set @to_static for benchmark, skip this by default.
apply_to_static(self.config, self.model)
# load_pretrain # load_pretrain
self._init_pretrained() self._init_pretrained()
@ -114,8 +125,6 @@ class Engine(object):
# for distributed # for distributed
self._init_dist() self._init_dist()
print_config(config)
def train(self): def train(self):
assert self.mode == "train" assert self.mode == "train"
print_batch_step = self.config['Global']['print_batch_step'] print_batch_step = self.config['Global']['print_batch_step']

View File

@ -55,10 +55,10 @@ def train_epoch(engine, epoch_id, print_batch_step):
"flatten_contiguous_range", "greater_than" "flatten_contiguous_range", "greater_than"
}, },
level=amp_level): level=amp_level):
out = engine.model(batch) out = forward(engine, batch)
loss_dict = engine.train_loss_func(out, batch[1]) loss_dict = engine.train_loss_func(out, batch[1])
else: else:
out = engine.model(batch) out = forward(engine, batch)
loss_dict = engine.train_loss_func(out, batch[1]) loss_dict = engine.train_loss_func(out, batch[1])
# loss # loss
@ -104,3 +104,10 @@ def train_epoch(engine, epoch_id, print_batch_step):
if getattr(engine.lr_sch[i], "by_epoch", False) and \ if getattr(engine.lr_sch[i], "by_epoch", False) and \
type_name(engine.lr_sch[i]) != "ReduceOnPlateau": type_name(engine.lr_sch[i]) != "ReduceOnPlateau":
engine.lr_sch[i].step() engine.lr_sch[i].step()
def forward(engine, batch):
if not engine.is_rec:
return engine.model(batch[0])
else:
return engine.model(batch[0], batch[1])