fix loss reduce from dict to list (#679)
* fix loss reduce from dict to list * remove notepull/677/head^2
parent
42d2962d90
commit
2e62e2e25e
tools
|
@ -143,8 +143,8 @@ def create_metric(out,
|
|||
out = out[1]
|
||||
softmax_out = F.softmax(out)
|
||||
|
||||
fetchs = OrderedDict()
|
||||
metric_names = set()
|
||||
fetch_list = []
|
||||
metric_names = []
|
||||
if not multilabel:
|
||||
softmax_out = F.softmax(out)
|
||||
|
||||
|
@ -154,12 +154,11 @@ def create_metric(out,
|
|||
k = min(topk, classes_num)
|
||||
topk = paddle.metric.accuracy(softmax_out, label=label, k=k)
|
||||
|
||||
metric_names.add("top1")
|
||||
metric_names.add("top{}".format(k))
|
||||
metric_names.append("top1")
|
||||
metric_names.append("top{}".format(k))
|
||||
|
||||
fetchs['top1'] = top1
|
||||
topk_name = "top{}".format(k)
|
||||
fetchs[topk_name] = topk
|
||||
fetch_list.append(top1)
|
||||
fetch_list.append(topk)
|
||||
else:
|
||||
out = F.sigmoid(out)
|
||||
preds = multi_hot_encode(out.numpy())
|
||||
|
@ -169,19 +168,22 @@ def create_metric(out,
|
|||
|
||||
ham_dist_name = "hamming_distance"
|
||||
accuracy_name = "multilabel_accuracy"
|
||||
metric_names.add(ham_dist_name)
|
||||
metric_names.add(accuracy_name)
|
||||
metric_names.append(ham_dist_name)
|
||||
metric_names.append(accuracy_name)
|
||||
|
||||
fetchs[accuracy_name] = accuracy
|
||||
fetchs[ham_dist_name] = ham_dist
|
||||
fetch_list.append(accuracy)
|
||||
fetch_list.append(ham_dist)
|
||||
|
||||
# multi cards' eval
|
||||
if mode != "train" and paddle.distributed.get_world_size() > 1:
|
||||
for metric_name in metric_names:
|
||||
fetchs[metric_name] = paddle.distributed.all_reduce(
|
||||
fetchs[metric_name], op=paddle.distributed.ReduceOp.
|
||||
for idx, fetch in enumerate(fetch_list):
|
||||
fetch_list[idx] = paddle.distributed.all_reduce(
|
||||
fetch, op=paddle.distributed.ReduceOp.
|
||||
SUM) / paddle.distributed.get_world_size()
|
||||
|
||||
fetchs = OrderedDict()
|
||||
for idx, name in enumerate(metric_names):
|
||||
fetchs[name] = fetch_list[idx]
|
||||
return fetchs
|
||||
|
||||
|
||||
|
@ -282,7 +284,8 @@ def create_feeds(batch, use_mix, num_classes, multilabel=False):
|
|||
if not multilabel:
|
||||
label = to_tensor(batch[1].numpy().astype("int64").reshape(-1, 1))
|
||||
else:
|
||||
label = to_tensor(batch[1].numpy().astype('float32').reshape(-1, num_classes))
|
||||
label = to_tensor(batch[1].numpy().astype('float32').reshape(
|
||||
-1, num_classes))
|
||||
feeds = {"image": image, "label": label}
|
||||
return feeds
|
||||
|
||||
|
@ -336,10 +339,12 @@ def run(dataloader,
|
|||
0, ("top1", AverageMeter(
|
||||
"top1", '.5f', postfix=",")))
|
||||
else:
|
||||
metric_list.insert(0, ("multilabel_accuracy", AverageMeter(
|
||||
"multilabel_accuracy", '.5f', postfix=",")))
|
||||
metric_list.insert(0, ("hamming_distance", AverageMeter(
|
||||
"hamming_distance", '.5f', postfix=",")))
|
||||
metric_list.insert(
|
||||
0, ("multilabel_accuracy", AverageMeter(
|
||||
"multilabel_accuracy", '.5f', postfix=",")))
|
||||
metric_list.insert(
|
||||
0, ("hamming_distance", AverageMeter(
|
||||
"hamming_distance", '.5f', postfix=",")))
|
||||
|
||||
metric_list = OrderedDict(metric_list)
|
||||
|
||||
|
|
Loading…
Reference in New Issue