mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
fix train without eval bug
This commit is contained in:
parent
bcc88db6e2
commit
583c88b5dc
@ -116,7 +116,8 @@ class Engine(object):
|
|||||||
if self.mode == 'train':
|
if self.mode == 'train':
|
||||||
self.train_dataloader = build_dataloader(
|
self.train_dataloader = build_dataloader(
|
||||||
self.config["DataLoader"], "Train", self.device, self.use_dali)
|
self.config["DataLoader"], "Train", self.device, self.use_dali)
|
||||||
if self.mode in ["train", "eval"]:
|
if self.mode == "eval" or (self.mode == "train" and
|
||||||
|
self.config["Global"]["eval_during_train"]):
|
||||||
if self.eval_mode == "classification":
|
if self.eval_mode == "classification":
|
||||||
self.eval_dataloader = build_dataloader(
|
self.eval_dataloader = build_dataloader(
|
||||||
self.config["DataLoader"], "Eval", self.device,
|
self.config["DataLoader"], "Eval", self.device,
|
||||||
@ -140,7 +141,8 @@ class Engine(object):
|
|||||||
if self.mode == "train":
|
if self.mode == "train":
|
||||||
loss_info = self.config["Loss"]["Train"]
|
loss_info = self.config["Loss"]["Train"]
|
||||||
self.train_loss_func = build_loss(loss_info)
|
self.train_loss_func = build_loss(loss_info)
|
||||||
if self.mode in ["train", "eval"]:
|
if self.mode == "eval" or (self.mode == "train" and
|
||||||
|
self.config["Global"]["eval_during_train"]):
|
||||||
loss_config = self.config.get("Loss", None)
|
loss_config = self.config.get("Loss", None)
|
||||||
if loss_config is not None:
|
if loss_config is not None:
|
||||||
loss_config = loss_config.get("Eval")
|
loss_config = loss_config.get("Eval")
|
||||||
@ -163,7 +165,8 @@ class Engine(object):
|
|||||||
else:
|
else:
|
||||||
self.train_metric_func = None
|
self.train_metric_func = None
|
||||||
|
|
||||||
if self.mode in ["train", "eval"]:
|
if self.mode == "eval" or (self.mode == "train" and
|
||||||
|
self.config["Global"]["eval_during_train"]):
|
||||||
metric_config = self.config.get("Metric")
|
metric_config = self.config.get("Metric")
|
||||||
if self.eval_mode == "classification":
|
if self.eval_mode == "classification":
|
||||||
if metric_config is not None:
|
if metric_config is not None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user