add default metrics
parent
5dc93ac013
commit
3e4af0448e
|
@ -93,37 +93,17 @@ class Trainer(object):
|
|||
self.train_metric_func = None
|
||||
self.eval_metric_func = None
|
||||
|
||||
def _build_metric_info(self, metric_config, mode="train"):
|
||||
"""
|
||||
_build_metric_info: build metrics according to current mode
|
||||
Return:
|
||||
metric: dict of the metrics info
|
||||
"""
|
||||
metric = None
|
||||
mode = mode.capitalize()
|
||||
if mode in metric_config and metric_config[mode] is not None:
|
||||
metric = build_metrics(metric_config[mode])
|
||||
return metric
|
||||
|
||||
def _build_loss_info(self, loss_config, mode="train"):
|
||||
"""
|
||||
_build_loss_info: build loss according to current mode
|
||||
Return:
|
||||
loss_dict: dict of the loss info
|
||||
"""
|
||||
loss = None
|
||||
mode = mode.capitalize()
|
||||
if mode in loss_config and loss_config[mode] is not None:
|
||||
loss = build_loss(loss_config[mode])
|
||||
return loss
|
||||
|
||||
def train(self):
|
||||
# build train loss and metric info
|
||||
if self.train_loss_func is None:
|
||||
self.train_loss_func = self._build_loss_info(self.config["Loss"])
|
||||
if "Metric" in self.config and self.train_metric_func is None:
|
||||
self.train_metric_func = self._build_metric_info(self.config[
|
||||
"Metric"])
|
||||
self.train_loss_func = build_loss(self.config["Loss"])
|
||||
if self.train_metric_func is None:
|
||||
metric_config = self.config.get("Metric", None)
|
||||
if metric_config is None:
|
||||
metric_config = [{"name": "TopkAcc", "topk": (1, 5)}]
|
||||
else:
|
||||
metric_config = metric_config["Train"]
|
||||
self.train_metric_func = build_metrics(metric_config)
|
||||
|
||||
if self.train_dataloader is None:
|
||||
self.train_dataloader = build_dataloader(self.config["DataLoader"],
|
||||
|
@ -241,10 +221,26 @@ class Trainer(object):
|
|||
@paddle.no_grad()
|
||||
def eval(self, epoch_id=0):
|
||||
self.model.eval()
|
||||
if self.eval_loss_func is None:
|
||||
loss_info = self.config.get("Loss", None)
|
||||
if loss_info is None:
|
||||
loss_info = [{"CELoss": {"weight": 1.0}}]
|
||||
else:
|
||||
loss_info = loss_info["Eval"]
|
||||
self.eval_loss_func = build_loss(loss_info)
|
||||
if self.eval_mode == "classification":
|
||||
if self.eval_dataloader is None:
|
||||
self.eval_dataloader = build_dataloader(
|
||||
self.config["DataLoader"], "Eval", self.device)
|
||||
|
||||
if self.eval_metric_func is None:
|
||||
metric_config = self.config.get("Metric", None)
|
||||
if metric_config is None:
|
||||
metric_config = [{"name": "TopkAcc", "topk": (1, 5)}]
|
||||
else:
|
||||
metric_config = metric_config["Eval"]
|
||||
self.eval_metric_func = build_metrics(metric_config)
|
||||
|
||||
eval_result = self.eval_cls(epoch_id)
|
||||
|
||||
elif self.eval_mode == "retrieval":
|
||||
|
@ -255,13 +251,14 @@ class Trainer(object):
|
|||
if self.query_dataloader is None:
|
||||
self.query_dataloader = build_dataloader(
|
||||
self.config["DataLoader"], "Query", self.device)
|
||||
# build train loss and metric info
|
||||
if self.eval_loss_func is None:
|
||||
self.eval_loss_func = self._build_loss_info(
|
||||
self.config["Loss"], "eval")
|
||||
# build metric info
|
||||
if self.eval_metric_func is None:
|
||||
self.eval_metric_func = self._build_metric_info(
|
||||
self.config["Metric"], "eval")
|
||||
metric_config = self.config.get("Metric", None)
|
||||
if metric_config is None:
|
||||
metric_config = [{"name": "Recallk", "topk": (1, 5)}]
|
||||
else:
|
||||
metric_config = metric_config["Eval"]
|
||||
self.eval_metric_func = build_metrics(metric_config)
|
||||
eval_result = self.eval_retrieval(epoch_id)
|
||||
else:
|
||||
logger.warning("Invalid eval mode: {}".format(self.eval_mode))
|
||||
|
|
|
@ -16,7 +16,7 @@ from paddle import nn
|
|||
import copy
|
||||
from collections import OrderedDict
|
||||
|
||||
from .metrics import Topk, mAP, mINP, Recallk
|
||||
from .metrics import TopkAcc, mAP, mINP, Recallk
|
||||
|
||||
|
||||
class CombinedMetrics(nn.Layer):
|
||||
|
|
|
@ -18,7 +18,7 @@ import paddle.nn as nn
|
|||
|
||||
|
||||
# TODO: fix the format
|
||||
class Topk(nn.Layer):
|
||||
class TopkAcc(nn.Layer):
|
||||
def __init__(self, topk=(1, 5)):
|
||||
super().__init__()
|
||||
assert isinstance(topk, (int, list, tuple))
|
||||
|
|
Loading…
Reference in New Issue