fix eval_interval bugs
parent
3b4f5f4dfc
commit
fcc1b857d4
|
@ -35,13 +35,13 @@ from ppcls.data.preprocess import transform
|
|||
|
||||
|
||||
def build_dataloader(config, mode, device, seed=None):
|
||||
assert mode in ['Train', 'Eval', 'Test', 'Gallery', 'Query'
|
||||
], "Mode should be Train, Eval, Test, Gallery or Query"
|
||||
assert mode in ['Train', 'Eval', 'Test',
|
||||
], "Mode should be Train, Eval, Test"
|
||||
# build dataset
|
||||
config_dataset = config[mode]['dataset']
|
||||
config_dataset = copy.deepcopy(config_dataset)
|
||||
dataset_name = config_dataset.pop('name')
|
||||
if 'batch_transform_ops' in config_dataset:
|
||||
if 'batch_transform_ops' in config_dataset:h
|
||||
batch_transform = config_dataset.pop('batch_transform_ops')
|
||||
else:
|
||||
batch_transform = None
|
||||
|
|
|
@ -223,7 +223,7 @@ class Trainer(object):
|
|||
# eval model and save model if possible
|
||||
if self.config["Global"][
|
||||
"eval_during_train"] and epoch_id % self.config["Global"][
|
||||
"eval_during_train"] == 0:
|
||||
"eval_interval"] == 0:
|
||||
acc = self.eval(epoch_id)
|
||||
if acc > best_metric["metric"]:
|
||||
best_metric["metric"] = acc
|
||||
|
|
Loading…
Reference in New Issue