Revert "refactor: add ClassModel to unify model forward interface"
This reverts commit 75a20ba557
.
pull/2701/head
parent
e7e4f68b5c
commit
6aabb94d8c
|
@ -12,14 +12,14 @@
|
|||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
import sys
|
||||
import copy
|
||||
|
||||
import importlib
|
||||
import paddle.nn as nn
|
||||
from paddle.jit import to_static
|
||||
from paddle.static import InputSpec
|
||||
|
||||
from . import backbone as backbone_zoo
|
||||
from . import backbone, gears
|
||||
from .backbone import *
|
||||
from .gears import build_gear
|
||||
from .utils import *
|
||||
from .backbone.base.theseus_layer import TheseusLayer
|
||||
|
@ -35,28 +35,20 @@ def build_model(config, mode="train"):
|
|||
arch_config = copy.deepcopy(config["Arch"])
|
||||
model_type = arch_config.pop("name")
|
||||
use_sync_bn = arch_config.pop("use_sync_bn", False)
|
||||
|
||||
if hasattr(backbone_zoo, model_type):
|
||||
model = ClassModel(model_type, **arch_config)
|
||||
else:
|
||||
model = getattr(sys.modules[__name__], model_type)("ClassModel",
|
||||
**arch_config)
|
||||
|
||||
mod = importlib.import_module(__name__)
|
||||
arch = getattr(mod, model_type)(**arch_config)
|
||||
if use_sync_bn:
|
||||
if config["Global"]["device"] == "gpu":
|
||||
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
arch = nn.SyncBatchNorm.convert_sync_batchnorm(arch)
|
||||
else:
|
||||
msg = "SyncBatchNorm can only be used on GPU device. The releated setting has been ignored."
|
||||
logger.warning(msg)
|
||||
|
||||
if isinstance(model, TheseusLayer):
|
||||
prune_model(config, model)
|
||||
quantize_model(config, model, mode)
|
||||
if isinstance(arch, TheseusLayer):
|
||||
prune_model(config, arch)
|
||||
quantize_model(config, arch, mode)
|
||||
|
||||
# set @to_static for benchmark, skip this by default.
|
||||
model = apply_to_static(config, model)
|
||||
|
||||
return model
|
||||
return arch
|
||||
|
||||
|
||||
def apply_to_static(config, model):
|
||||
|
@ -73,29 +65,12 @@ def apply_to_static(config, model):
|
|||
return model
|
||||
|
||||
|
||||
# TODO(gaotingquan): export model
|
||||
class ClassModel(TheseusLayer):
|
||||
def __init__(self, model_type, **config):
|
||||
super().__init__()
|
||||
if model_type == "ClassModel":
|
||||
backbone_config = config["Backbone"]
|
||||
backbone_name = backbone_config.pop("name")
|
||||
else:
|
||||
backbone_name = model_type
|
||||
backbone_config = config
|
||||
self.backbone = getattr(backbone_zoo, backbone_name)(**backbone_config)
|
||||
|
||||
def forward(self, batch):
|
||||
x, label = batch[0], batch[1]
|
||||
return self.backbone(x)
|
||||
|
||||
|
||||
class RecModel(TheseusLayer):
|
||||
def __init__(self, **config):
|
||||
super().__init__()
|
||||
backbone_config = config["Backbone"]
|
||||
backbone_name = backbone_config.pop("name")
|
||||
self.backbone = getattr(backbone_zoo, backbone_name)(**backbone_config)
|
||||
self.backbone = eval(backbone_name)(**backbone_config)
|
||||
self.head_feature_from = config.get('head_feature_from', 'neck')
|
||||
|
||||
if "BackboneStopLayer" in config:
|
||||
|
@ -112,8 +87,8 @@ class RecModel(TheseusLayer):
|
|||
else:
|
||||
self.head = None
|
||||
|
||||
def forward(self, batch):
|
||||
x, label = batch[0], batch[1]
|
||||
def forward(self, x, label=None):
|
||||
|
||||
out = dict()
|
||||
x = self.backbone(x)
|
||||
out["backbone"] = x
|
||||
|
@ -165,8 +140,7 @@ class DistillationModel(nn.Layer):
|
|||
load_dygraph_pretrain(
|
||||
self.model_name_list[idx], path=pretrained)
|
||||
|
||||
def forward(self, batch):
|
||||
x, label = batch[0], batch[1]
|
||||
def forward(self, x, label=None):
|
||||
result_dict = dict()
|
||||
for idx, model_name in enumerate(self.model_name_list):
|
||||
if label is None:
|
||||
|
@ -184,8 +158,7 @@ class AttentionModel(DistillationModel):
|
|||
**kargs):
|
||||
super().__init__(models, pretrained_list, freeze_params_list, **kargs)
|
||||
|
||||
def forward(self, batch):
|
||||
x, label = batch[0], batch[1]
|
||||
def forward(self, x, label=None):
|
||||
result_dict = dict()
|
||||
out = x
|
||||
for idx, model_name in enumerate(self.model_name_list):
|
||||
|
|
|
@ -28,6 +28,7 @@ from ppcls.utils.logger import init_logger
|
|||
from ppcls.utils.config import print_config
|
||||
from ppcls.data import build_dataloader
|
||||
from ppcls.arch import build_model, RecModel, DistillationModel, TheseusLayer
|
||||
from ppcls.arch import apply_to_static
|
||||
from ppcls.loss import build_loss
|
||||
from ppcls.metric import build_metrics
|
||||
from ppcls.optimizer import build_optimizer
|
||||
|
@ -56,10 +57,18 @@ class Engine(object):
|
|||
|
||||
# init logger
|
||||
init_logger(self.config, mode=mode)
|
||||
print_config(config)
|
||||
|
||||
# for visualdl
|
||||
self.vdl_writer = self._init_vdl()
|
||||
|
||||
# is_rec
|
||||
if "Head" in self.config["Arch"] or self.config["Arch"].get("is_rec",
|
||||
False):
|
||||
self.is_rec = True
|
||||
else:
|
||||
self.is_rec = False
|
||||
|
||||
# init train_func and eval_func
|
||||
self.train_mode = self.config["Global"].get("train_mode", None)
|
||||
if self.train_mode is None:
|
||||
|
@ -99,6 +108,8 @@ class Engine(object):
|
|||
|
||||
# build model
|
||||
self.model = build_model(self.config, self.mode)
|
||||
# set @to_static for benchmark, skip this by default.
|
||||
apply_to_static(self.config, self.model)
|
||||
|
||||
# load_pretrain
|
||||
self._init_pretrained()
|
||||
|
@ -114,8 +125,6 @@ class Engine(object):
|
|||
# for distributed
|
||||
self._init_dist()
|
||||
|
||||
print_config(config)
|
||||
|
||||
def train(self):
|
||||
assert self.mode == "train"
|
||||
print_batch_step = self.config['Global']['print_batch_step']
|
||||
|
|
|
@ -55,10 +55,10 @@ def train_epoch(engine, epoch_id, print_batch_step):
|
|||
"flatten_contiguous_range", "greater_than"
|
||||
},
|
||||
level=amp_level):
|
||||
out = engine.model(batch)
|
||||
out = forward(engine, batch)
|
||||
loss_dict = engine.train_loss_func(out, batch[1])
|
||||
else:
|
||||
out = engine.model(batch)
|
||||
out = forward(engine, batch)
|
||||
loss_dict = engine.train_loss_func(out, batch[1])
|
||||
|
||||
# loss
|
||||
|
@ -104,3 +104,10 @@ def train_epoch(engine, epoch_id, print_batch_step):
|
|||
if getattr(engine.lr_sch[i], "by_epoch", False) and \
|
||||
type_name(engine.lr_sch[i]) != "ReduceOnPlateau":
|
||||
engine.lr_sch[i].step()
|
||||
|
||||
|
||||
def forward(engine, batch):
|
||||
if not engine.is_rec:
|
||||
return engine.model(batch[0])
|
||||
else:
|
||||
return engine.model(batch[0], batch[1])
|
||||
|
|
Loading…
Reference in New Issue