mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
Revert "refactor: add ClassModel to unify model forward interface"
This reverts commit 75a20ba5574340fa5742eba8e41aebe4de6c5eb8.
This commit is contained in:
parent
e7e4f68b5c
commit
6aabb94d8c
@ -12,14 +12,14 @@
|
|||||||
#See the License for the specific language governing permissions and
|
#See the License for the specific language governing permissions and
|
||||||
#limitations under the License.
|
#limitations under the License.
|
||||||
|
|
||||||
import sys
|
|
||||||
import copy
|
import copy
|
||||||
|
import importlib
|
||||||
import paddle.nn as nn
|
import paddle.nn as nn
|
||||||
from paddle.jit import to_static
|
from paddle.jit import to_static
|
||||||
from paddle.static import InputSpec
|
from paddle.static import InputSpec
|
||||||
|
|
||||||
from . import backbone as backbone_zoo
|
from . import backbone, gears
|
||||||
|
from .backbone import *
|
||||||
from .gears import build_gear
|
from .gears import build_gear
|
||||||
from .utils import *
|
from .utils import *
|
||||||
from .backbone.base.theseus_layer import TheseusLayer
|
from .backbone.base.theseus_layer import TheseusLayer
|
||||||
@ -35,28 +35,20 @@ def build_model(config, mode="train"):
|
|||||||
arch_config = copy.deepcopy(config["Arch"])
|
arch_config = copy.deepcopy(config["Arch"])
|
||||||
model_type = arch_config.pop("name")
|
model_type = arch_config.pop("name")
|
||||||
use_sync_bn = arch_config.pop("use_sync_bn", False)
|
use_sync_bn = arch_config.pop("use_sync_bn", False)
|
||||||
|
mod = importlib.import_module(__name__)
|
||||||
if hasattr(backbone_zoo, model_type):
|
arch = getattr(mod, model_type)(**arch_config)
|
||||||
model = ClassModel(model_type, **arch_config)
|
|
||||||
else:
|
|
||||||
model = getattr(sys.modules[__name__], model_type)("ClassModel",
|
|
||||||
**arch_config)
|
|
||||||
|
|
||||||
if use_sync_bn:
|
if use_sync_bn:
|
||||||
if config["Global"]["device"] == "gpu":
|
if config["Global"]["device"] == "gpu":
|
||||||
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
arch = nn.SyncBatchNorm.convert_sync_batchnorm(arch)
|
||||||
else:
|
else:
|
||||||
msg = "SyncBatchNorm can only be used on GPU device. The releated setting has been ignored."
|
msg = "SyncBatchNorm can only be used on GPU device. The releated setting has been ignored."
|
||||||
logger.warning(msg)
|
logger.warning(msg)
|
||||||
|
|
||||||
if isinstance(model, TheseusLayer):
|
if isinstance(arch, TheseusLayer):
|
||||||
prune_model(config, model)
|
prune_model(config, arch)
|
||||||
quantize_model(config, model, mode)
|
quantize_model(config, arch, mode)
|
||||||
|
|
||||||
# set @to_static for benchmark, skip this by default.
|
return arch
|
||||||
model = apply_to_static(config, model)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def apply_to_static(config, model):
|
def apply_to_static(config, model):
|
||||||
@ -73,29 +65,12 @@ def apply_to_static(config, model):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
# TODO(gaotingquan): export model
|
|
||||||
class ClassModel(TheseusLayer):
|
|
||||||
def __init__(self, model_type, **config):
|
|
||||||
super().__init__()
|
|
||||||
if model_type == "ClassModel":
|
|
||||||
backbone_config = config["Backbone"]
|
|
||||||
backbone_name = backbone_config.pop("name")
|
|
||||||
else:
|
|
||||||
backbone_name = model_type
|
|
||||||
backbone_config = config
|
|
||||||
self.backbone = getattr(backbone_zoo, backbone_name)(**backbone_config)
|
|
||||||
|
|
||||||
def forward(self, batch):
|
|
||||||
x, label = batch[0], batch[1]
|
|
||||||
return self.backbone(x)
|
|
||||||
|
|
||||||
|
|
||||||
class RecModel(TheseusLayer):
|
class RecModel(TheseusLayer):
|
||||||
def __init__(self, **config):
|
def __init__(self, **config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
backbone_config = config["Backbone"]
|
backbone_config = config["Backbone"]
|
||||||
backbone_name = backbone_config.pop("name")
|
backbone_name = backbone_config.pop("name")
|
||||||
self.backbone = getattr(backbone_zoo, backbone_name)(**backbone_config)
|
self.backbone = eval(backbone_name)(**backbone_config)
|
||||||
self.head_feature_from = config.get('head_feature_from', 'neck')
|
self.head_feature_from = config.get('head_feature_from', 'neck')
|
||||||
|
|
||||||
if "BackboneStopLayer" in config:
|
if "BackboneStopLayer" in config:
|
||||||
@ -112,8 +87,8 @@ class RecModel(TheseusLayer):
|
|||||||
else:
|
else:
|
||||||
self.head = None
|
self.head = None
|
||||||
|
|
||||||
def forward(self, batch):
|
def forward(self, x, label=None):
|
||||||
x, label = batch[0], batch[1]
|
|
||||||
out = dict()
|
out = dict()
|
||||||
x = self.backbone(x)
|
x = self.backbone(x)
|
||||||
out["backbone"] = x
|
out["backbone"] = x
|
||||||
@ -165,8 +140,7 @@ class DistillationModel(nn.Layer):
|
|||||||
load_dygraph_pretrain(
|
load_dygraph_pretrain(
|
||||||
self.model_name_list[idx], path=pretrained)
|
self.model_name_list[idx], path=pretrained)
|
||||||
|
|
||||||
def forward(self, batch):
|
def forward(self, x, label=None):
|
||||||
x, label = batch[0], batch[1]
|
|
||||||
result_dict = dict()
|
result_dict = dict()
|
||||||
for idx, model_name in enumerate(self.model_name_list):
|
for idx, model_name in enumerate(self.model_name_list):
|
||||||
if label is None:
|
if label is None:
|
||||||
@ -184,8 +158,7 @@ class AttentionModel(DistillationModel):
|
|||||||
**kargs):
|
**kargs):
|
||||||
super().__init__(models, pretrained_list, freeze_params_list, **kargs)
|
super().__init__(models, pretrained_list, freeze_params_list, **kargs)
|
||||||
|
|
||||||
def forward(self, batch):
|
def forward(self, x, label=None):
|
||||||
x, label = batch[0], batch[1]
|
|
||||||
result_dict = dict()
|
result_dict = dict()
|
||||||
out = x
|
out = x
|
||||||
for idx, model_name in enumerate(self.model_name_list):
|
for idx, model_name in enumerate(self.model_name_list):
|
||||||
|
@ -28,6 +28,7 @@ from ppcls.utils.logger import init_logger
|
|||||||
from ppcls.utils.config import print_config
|
from ppcls.utils.config import print_config
|
||||||
from ppcls.data import build_dataloader
|
from ppcls.data import build_dataloader
|
||||||
from ppcls.arch import build_model, RecModel, DistillationModel, TheseusLayer
|
from ppcls.arch import build_model, RecModel, DistillationModel, TheseusLayer
|
||||||
|
from ppcls.arch import apply_to_static
|
||||||
from ppcls.loss import build_loss
|
from ppcls.loss import build_loss
|
||||||
from ppcls.metric import build_metrics
|
from ppcls.metric import build_metrics
|
||||||
from ppcls.optimizer import build_optimizer
|
from ppcls.optimizer import build_optimizer
|
||||||
@ -56,10 +57,18 @@ class Engine(object):
|
|||||||
|
|
||||||
# init logger
|
# init logger
|
||||||
init_logger(self.config, mode=mode)
|
init_logger(self.config, mode=mode)
|
||||||
|
print_config(config)
|
||||||
|
|
||||||
# for visualdl
|
# for visualdl
|
||||||
self.vdl_writer = self._init_vdl()
|
self.vdl_writer = self._init_vdl()
|
||||||
|
|
||||||
|
# is_rec
|
||||||
|
if "Head" in self.config["Arch"] or self.config["Arch"].get("is_rec",
|
||||||
|
False):
|
||||||
|
self.is_rec = True
|
||||||
|
else:
|
||||||
|
self.is_rec = False
|
||||||
|
|
||||||
# init train_func and eval_func
|
# init train_func and eval_func
|
||||||
self.train_mode = self.config["Global"].get("train_mode", None)
|
self.train_mode = self.config["Global"].get("train_mode", None)
|
||||||
if self.train_mode is None:
|
if self.train_mode is None:
|
||||||
@ -99,6 +108,8 @@ class Engine(object):
|
|||||||
|
|
||||||
# build model
|
# build model
|
||||||
self.model = build_model(self.config, self.mode)
|
self.model = build_model(self.config, self.mode)
|
||||||
|
# set @to_static for benchmark, skip this by default.
|
||||||
|
apply_to_static(self.config, self.model)
|
||||||
|
|
||||||
# load_pretrain
|
# load_pretrain
|
||||||
self._init_pretrained()
|
self._init_pretrained()
|
||||||
@ -114,8 +125,6 @@ class Engine(object):
|
|||||||
# for distributed
|
# for distributed
|
||||||
self._init_dist()
|
self._init_dist()
|
||||||
|
|
||||||
print_config(config)
|
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
assert self.mode == "train"
|
assert self.mode == "train"
|
||||||
print_batch_step = self.config['Global']['print_batch_step']
|
print_batch_step = self.config['Global']['print_batch_step']
|
||||||
|
@ -55,10 +55,10 @@ def train_epoch(engine, epoch_id, print_batch_step):
|
|||||||
"flatten_contiguous_range", "greater_than"
|
"flatten_contiguous_range", "greater_than"
|
||||||
},
|
},
|
||||||
level=amp_level):
|
level=amp_level):
|
||||||
out = engine.model(batch)
|
out = forward(engine, batch)
|
||||||
loss_dict = engine.train_loss_func(out, batch[1])
|
loss_dict = engine.train_loss_func(out, batch[1])
|
||||||
else:
|
else:
|
||||||
out = engine.model(batch)
|
out = forward(engine, batch)
|
||||||
loss_dict = engine.train_loss_func(out, batch[1])
|
loss_dict = engine.train_loss_func(out, batch[1])
|
||||||
|
|
||||||
# loss
|
# loss
|
||||||
@ -104,3 +104,10 @@ def train_epoch(engine, epoch_id, print_batch_step):
|
|||||||
if getattr(engine.lr_sch[i], "by_epoch", False) and \
|
if getattr(engine.lr_sch[i], "by_epoch", False) and \
|
||||||
type_name(engine.lr_sch[i]) != "ReduceOnPlateau":
|
type_name(engine.lr_sch[i]) != "ReduceOnPlateau":
|
||||||
engine.lr_sch[i].step()
|
engine.lr_sch[i].step()
|
||||||
|
|
||||||
|
|
||||||
|
def forward(engine, batch):
|
||||||
|
if not engine.is_rec:
|
||||||
|
return engine.model(batch[0])
|
||||||
|
else:
|
||||||
|
return engine.model(batch[0], batch[1])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user