mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
Merge pull request #820 from cuicheng01/develop_reg
fix eval_interval bugs
This commit is contained in:
commit
4905424f31
@ -35,8 +35,8 @@ from ppcls.data.preprocess import transform
|
|||||||
|
|
||||||
|
|
||||||
def build_dataloader(config, mode, device, seed=None):
|
def build_dataloader(config, mode, device, seed=None):
|
||||||
assert mode in ['Train', 'Eval', 'Test', 'Gallery', 'Query'
|
assert mode in ['Train', 'Eval', 'Test',
|
||||||
], "Mode should be Train, Eval, Test, Gallery or Query"
|
], "Mode should be Train, Eval, Test"
|
||||||
# build dataset
|
# build dataset
|
||||||
config_dataset = config[mode]['dataset']
|
config_dataset = config[mode]['dataset']
|
||||||
config_dataset = copy.deepcopy(config_dataset)
|
config_dataset = copy.deepcopy(config_dataset)
|
||||||
|
@ -223,7 +223,7 @@ class Trainer(object):
|
|||||||
# eval model and save model if possible
|
# eval model and save model if possible
|
||||||
if self.config["Global"][
|
if self.config["Global"][
|
||||||
"eval_during_train"] and epoch_id % self.config["Global"][
|
"eval_during_train"] and epoch_id % self.config["Global"][
|
||||||
"eval_during_train"] == 0:
|
"eval_interval"] == 0:
|
||||||
acc = self.eval(epoch_id)
|
acc = self.eval(epoch_id)
|
||||||
if acc > best_metric["metric"]:
|
if acc > best_metric["metric"]:
|
||||||
best_metric["metric"] = acc
|
best_metric["metric"] = acc
|
||||||
|
Loading…
x
Reference in New Issue
Block a user