parent
915dde176a
commit
339be96ef5
|
@ -110,7 +110,7 @@ def build(config, mode, use_dali=False, seed=None):
|
|||
config_dataset = copy.deepcopy(config_dataset)
|
||||
dataset_name = config_dataset.pop('name')
|
||||
if 'batch_transform_ops' in config_dataset:
|
||||
batch_transform = config_dataset['batch_transform_ops']
|
||||
batch_transform = config_dataset.pop('batch_transform_ops')
|
||||
else:
|
||||
batch_transform = None
|
||||
|
||||
|
@ -254,11 +254,10 @@ def build_dataloader(config, mode):
|
|||
|
||||
if mode == "eval" or (mode == "train" and
|
||||
config["Global"]["eval_during_train"]):
|
||||
task = config["Global"].get("task", "classification")
|
||||
if task in ["classification", "adaface"]:
|
||||
if config["Global"]["eval_mode"] in ["classification", "adaface"]:
|
||||
dataloader_dict["Eval"] = build(
|
||||
config["DataLoader"], "Eval", use_dali, seed=None)
|
||||
elif task == "retrieval":
|
||||
elif config["Global"]["eval_mode"] == "retrieval":
|
||||
if len(config["DataLoader"]["Eval"].keys()) == 1:
|
||||
key = list(config["DataLoader"]["Eval"].keys())[0]
|
||||
dataloader_dict["GalleryQuery"] = build(
|
||||
|
|
|
@ -42,7 +42,7 @@ from ppcls.data.preprocess.ops.dali_operators import RandomRot90
|
|||
from ppcls.data.preprocess.ops.dali_operators import RandomRotation
|
||||
from ppcls.data.preprocess.ops.dali_operators import ResizeImage
|
||||
from ppcls.data.preprocess.ops.dali_operators import ToCHWImage
|
||||
from ppcls.utils import type_name
|
||||
from ppcls.engine.train.utils import type_name
|
||||
from ppcls.utils import logger
|
||||
|
||||
INTERP_MAP = {
|
||||
|
|
|
@ -14,7 +14,8 @@
|
|||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from ppcls.data import build_dataloader
|
||||
from ppcls.utils import logger, type_name
|
||||
from ppcls.engine.train.utils import type_name
|
||||
from ppcls.utils import logger
|
||||
from .regular_train_epoch import regular_train_epoch
|
||||
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import datetime
|
||||
from ppcls.utils import logger, type_name
|
||||
from ppcls.utils import logger
|
||||
from ppcls.utils.misc import AverageMeter
|
||||
|
||||
|
||||
|
@ -75,3 +75,8 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
|
|||
value=trainer.output_info[key].avg,
|
||||
step=trainer.global_step,
|
||||
writer=trainer.vdl_writer)
|
||||
|
||||
|
||||
def type_name(object: object) -> str:
|
||||
"""get class name of an object"""
|
||||
return object.__class__.__name__
|
||||
|
|
|
@ -94,7 +94,6 @@ def build_loss(config, mode="train"):
|
|||
if unlabel_loss_info:
|
||||
unlabel_train_loss_func = CombinedLoss(
|
||||
copy.deepcopy(unlabel_loss_info))
|
||||
return train_loss_func, unlabel_train_loss_func
|
||||
if mode == "eval" or (mode == "train" and
|
||||
config["Global"]["eval_during_train"]):
|
||||
loss_config = config.get("Loss", None)
|
||||
|
@ -102,4 +101,5 @@ def build_loss(config, mode="train"):
|
|||
loss_config = loss_config.get("Eval")
|
||||
if loss_config is not None:
|
||||
eval_loss_func = CombinedLoss(copy.deepcopy(loss_config))
|
||||
return eval_loss_func
|
||||
|
||||
return train_loss_func, unlabel_train_loss_func, eval_loss_func
|
||||
|
|
|
@ -65,19 +65,22 @@ class CombinedMetrics(AvgMetrics):
|
|||
metric.reset()
|
||||
|
||||
|
||||
def build_metrics(config, mode):
|
||||
def build_metrics(engine):
|
||||
config, mode = engine.config, engine.mode
|
||||
if mode == 'train' and "Metric" in config and "Train" in config[
|
||||
"Metric"] and config["Metric"]["Train"]:
|
||||
metric_config = config["Metric"]["Train"]
|
||||
if config["DataLoader"]["Train"]["dataset"].get("batch_transform_ops",
|
||||
None):
|
||||
if hasattr(engine.dataloader_dict["Train"],
|
||||
"collate_fn") and engine.dataloader_dict[
|
||||
"Train"].collate_fn is not None:
|
||||
for m_idx, m in enumerate(metric_config):
|
||||
if "TopkAcc" in m:
|
||||
msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed."
|
||||
logger.warning(msg)
|
||||
metric_config.pop(m_idx)
|
||||
train_metric_func = CombinedMetrics(copy.deepcopy(metric_config))
|
||||
return train_metric_func
|
||||
else:
|
||||
train_metric_func = None
|
||||
|
||||
if mode == "eval" or (mode == "train" and
|
||||
config["Global"]["eval_during_train"]):
|
||||
|
@ -94,4 +97,7 @@ def build_metrics(config, mode):
|
|||
else:
|
||||
metric_config = [{"name": "Recallk", "topk": (1, 5)}]
|
||||
eval_metric_func = CombinedMetrics(copy.deepcopy(metric_config))
|
||||
return eval_metric_func
|
||||
else:
|
||||
eval_metric_func = None
|
||||
|
||||
return train_metric_func, eval_metric_func
|
||||
|
|
|
@ -21,7 +21,8 @@ import copy
|
|||
import paddle
|
||||
from typing import Dict, List
|
||||
|
||||
from ..utils import logger, type_name
|
||||
from ppcls.engine.train.utils import type_name
|
||||
from ppcls.utils import logger
|
||||
|
||||
from . import optimizer
|
||||
|
||||
|
@ -44,10 +45,14 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
|
|||
|
||||
|
||||
# model_list is None in static graph
|
||||
def build_optimizer(config, max_iter, model_list, update_freq):
|
||||
def build_optimizer(engine):
|
||||
if engine.mode != "train":
|
||||
return None, None
|
||||
config, max_iter, model_list = engine.config, engine.dataloader_dict[
|
||||
"Train"].max_iter, [engine.model, engine.train_loss_func]
|
||||
optim_config = copy.deepcopy(config["Optimizer"])
|
||||
epochs = config["Global"]["epochs"]
|
||||
update_freq = config["Global"].get("update_freq", 1)
|
||||
update_freq = engine.update_freq
|
||||
step_each_epoch = max_iter // update_freq
|
||||
if isinstance(optim_config, dict):
|
||||
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
|
||||
|
|
|
@ -26,8 +26,3 @@ from .metrics import multi_hot_encode
|
|||
from .metrics import precision_recall_fscore
|
||||
from .misc import AverageMeter
|
||||
from .save_load import init_model
|
||||
|
||||
|
||||
def type_name(object: object) -> str:
|
||||
"""get class name of an object"""
|
||||
return object.__class__.__name__
|
||||
|
|
Loading…
Reference in New Issue