parent
915dde176a
commit
339be96ef5
|
@ -110,7 +110,7 @@ def build(config, mode, use_dali=False, seed=None):
|
||||||
config_dataset = copy.deepcopy(config_dataset)
|
config_dataset = copy.deepcopy(config_dataset)
|
||||||
dataset_name = config_dataset.pop('name')
|
dataset_name = config_dataset.pop('name')
|
||||||
if 'batch_transform_ops' in config_dataset:
|
if 'batch_transform_ops' in config_dataset:
|
||||||
batch_transform = config_dataset['batch_transform_ops']
|
batch_transform = config_dataset.pop('batch_transform_ops')
|
||||||
else:
|
else:
|
||||||
batch_transform = None
|
batch_transform = None
|
||||||
|
|
||||||
|
@ -254,11 +254,10 @@ def build_dataloader(config, mode):
|
||||||
|
|
||||||
if mode == "eval" or (mode == "train" and
|
if mode == "eval" or (mode == "train" and
|
||||||
config["Global"]["eval_during_train"]):
|
config["Global"]["eval_during_train"]):
|
||||||
task = config["Global"].get("task", "classification")
|
if config["Global"]["eval_mode"] in ["classification", "adaface"]:
|
||||||
if task in ["classification", "adaface"]:
|
|
||||||
dataloader_dict["Eval"] = build(
|
dataloader_dict["Eval"] = build(
|
||||||
config["DataLoader"], "Eval", use_dali, seed=None)
|
config["DataLoader"], "Eval", use_dali, seed=None)
|
||||||
elif task == "retrieval":
|
elif config["Global"]["eval_mode"] == "retrieval":
|
||||||
if len(config["DataLoader"]["Eval"].keys()) == 1:
|
if len(config["DataLoader"]["Eval"].keys()) == 1:
|
||||||
key = list(config["DataLoader"]["Eval"].keys())[0]
|
key = list(config["DataLoader"]["Eval"].keys())[0]
|
||||||
dataloader_dict["GalleryQuery"] = build(
|
dataloader_dict["GalleryQuery"] = build(
|
||||||
|
|
|
@ -42,7 +42,7 @@ from ppcls.data.preprocess.ops.dali_operators import RandomRot90
|
||||||
from ppcls.data.preprocess.ops.dali_operators import RandomRotation
|
from ppcls.data.preprocess.ops.dali_operators import RandomRotation
|
||||||
from ppcls.data.preprocess.ops.dali_operators import ResizeImage
|
from ppcls.data.preprocess.ops.dali_operators import ResizeImage
|
||||||
from ppcls.data.preprocess.ops.dali_operators import ToCHWImage
|
from ppcls.data.preprocess.ops.dali_operators import ToCHWImage
|
||||||
from ppcls.utils import type_name
|
from ppcls.engine.train.utils import type_name
|
||||||
from ppcls.utils import logger
|
from ppcls.utils import logger
|
||||||
|
|
||||||
INTERP_MAP = {
|
INTERP_MAP = {
|
||||||
|
|
|
@ -14,7 +14,8 @@
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
from ppcls.data import build_dataloader
|
from ppcls.data import build_dataloader
|
||||||
from ppcls.utils import logger, type_name
|
from ppcls.engine.train.utils import type_name
|
||||||
|
from ppcls.utils import logger
|
||||||
from .regular_train_epoch import regular_train_epoch
|
from .regular_train_epoch import regular_train_epoch
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
from ppcls.utils import logger, type_name
|
from ppcls.utils import logger
|
||||||
from ppcls.utils.misc import AverageMeter
|
from ppcls.utils.misc import AverageMeter
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,3 +75,8 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
|
||||||
value=trainer.output_info[key].avg,
|
value=trainer.output_info[key].avg,
|
||||||
step=trainer.global_step,
|
step=trainer.global_step,
|
||||||
writer=trainer.vdl_writer)
|
writer=trainer.vdl_writer)
|
||||||
|
|
||||||
|
|
||||||
|
def type_name(object: object) -> str:
|
||||||
|
"""get class name of an object"""
|
||||||
|
return object.__class__.__name__
|
||||||
|
|
|
@ -94,7 +94,6 @@ def build_loss(config, mode="train"):
|
||||||
if unlabel_loss_info:
|
if unlabel_loss_info:
|
||||||
unlabel_train_loss_func = CombinedLoss(
|
unlabel_train_loss_func = CombinedLoss(
|
||||||
copy.deepcopy(unlabel_loss_info))
|
copy.deepcopy(unlabel_loss_info))
|
||||||
return train_loss_func, unlabel_train_loss_func
|
|
||||||
if mode == "eval" or (mode == "train" and
|
if mode == "eval" or (mode == "train" and
|
||||||
config["Global"]["eval_during_train"]):
|
config["Global"]["eval_during_train"]):
|
||||||
loss_config = config.get("Loss", None)
|
loss_config = config.get("Loss", None)
|
||||||
|
@ -102,4 +101,5 @@ def build_loss(config, mode="train"):
|
||||||
loss_config = loss_config.get("Eval")
|
loss_config = loss_config.get("Eval")
|
||||||
if loss_config is not None:
|
if loss_config is not None:
|
||||||
eval_loss_func = CombinedLoss(copy.deepcopy(loss_config))
|
eval_loss_func = CombinedLoss(copy.deepcopy(loss_config))
|
||||||
return eval_loss_func
|
|
||||||
|
return train_loss_func, unlabel_train_loss_func, eval_loss_func
|
||||||
|
|
|
@ -65,19 +65,22 @@ class CombinedMetrics(AvgMetrics):
|
||||||
metric.reset()
|
metric.reset()
|
||||||
|
|
||||||
|
|
||||||
def build_metrics(config, mode):
|
def build_metrics(engine):
|
||||||
|
config, mode = engine.config, engine.mode
|
||||||
if mode == 'train' and "Metric" in config and "Train" in config[
|
if mode == 'train' and "Metric" in config and "Train" in config[
|
||||||
"Metric"] and config["Metric"]["Train"]:
|
"Metric"] and config["Metric"]["Train"]:
|
||||||
metric_config = config["Metric"]["Train"]
|
metric_config = config["Metric"]["Train"]
|
||||||
if config["DataLoader"]["Train"]["dataset"].get("batch_transform_ops",
|
if hasattr(engine.dataloader_dict["Train"],
|
||||||
None):
|
"collate_fn") and engine.dataloader_dict[
|
||||||
|
"Train"].collate_fn is not None:
|
||||||
for m_idx, m in enumerate(metric_config):
|
for m_idx, m in enumerate(metric_config):
|
||||||
if "TopkAcc" in m:
|
if "TopkAcc" in m:
|
||||||
msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed."
|
msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed."
|
||||||
logger.warning(msg)
|
logger.warning(msg)
|
||||||
metric_config.pop(m_idx)
|
metric_config.pop(m_idx)
|
||||||
train_metric_func = CombinedMetrics(copy.deepcopy(metric_config))
|
train_metric_func = CombinedMetrics(copy.deepcopy(metric_config))
|
||||||
return train_metric_func
|
else:
|
||||||
|
train_metric_func = None
|
||||||
|
|
||||||
if mode == "eval" or (mode == "train" and
|
if mode == "eval" or (mode == "train" and
|
||||||
config["Global"]["eval_during_train"]):
|
config["Global"]["eval_during_train"]):
|
||||||
|
@ -94,4 +97,7 @@ def build_metrics(config, mode):
|
||||||
else:
|
else:
|
||||||
metric_config = [{"name": "Recallk", "topk": (1, 5)}]
|
metric_config = [{"name": "Recallk", "topk": (1, 5)}]
|
||||||
eval_metric_func = CombinedMetrics(copy.deepcopy(metric_config))
|
eval_metric_func = CombinedMetrics(copy.deepcopy(metric_config))
|
||||||
return eval_metric_func
|
else:
|
||||||
|
eval_metric_func = None
|
||||||
|
|
||||||
|
return train_metric_func, eval_metric_func
|
||||||
|
|
|
@ -21,7 +21,8 @@ import copy
|
||||||
import paddle
|
import paddle
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from ..utils import logger, type_name
|
from ppcls.engine.train.utils import type_name
|
||||||
|
from ppcls.utils import logger
|
||||||
|
|
||||||
from . import optimizer
|
from . import optimizer
|
||||||
|
|
||||||
|
@ -44,10 +45,14 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
|
||||||
|
|
||||||
|
|
||||||
# model_list is None in static graph
|
# model_list is None in static graph
|
||||||
def build_optimizer(config, max_iter, model_list, update_freq):
|
def build_optimizer(engine):
|
||||||
|
if engine.mode != "train":
|
||||||
|
return None, None
|
||||||
|
config, max_iter, model_list = engine.config, engine.dataloader_dict[
|
||||||
|
"Train"].max_iter, [engine.model, engine.train_loss_func]
|
||||||
optim_config = copy.deepcopy(config["Optimizer"])
|
optim_config = copy.deepcopy(config["Optimizer"])
|
||||||
epochs = config["Global"]["epochs"]
|
epochs = config["Global"]["epochs"]
|
||||||
update_freq = config["Global"].get("update_freq", 1)
|
update_freq = engine.update_freq
|
||||||
step_each_epoch = max_iter // update_freq
|
step_each_epoch = max_iter // update_freq
|
||||||
if isinstance(optim_config, dict):
|
if isinstance(optim_config, dict):
|
||||||
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
|
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
|
||||||
|
|
|
@ -26,8 +26,3 @@ from .metrics import multi_hot_encode
|
||||||
from .metrics import precision_recall_fscore
|
from .metrics import precision_recall_fscore
|
||||||
from .misc import AverageMeter
|
from .misc import AverageMeter
|
||||||
from .save_load import init_model
|
from .save_load import init_model
|
||||||
|
|
||||||
|
|
||||||
def type_name(object: object) -> str:
|
|
||||||
"""get class name of an object"""
|
|
||||||
return object.__class__.__name__
|
|
||||||
|
|
Loading…
Reference in New Issue