fix loss reduce from dict to list (#679)
* fix loss reduce from dict to list * remove notepull/677/head^2
parent
42d2962d90
commit
2e62e2e25e
|
@ -143,8 +143,8 @@ def create_metric(out,
|
||||||
out = out[1]
|
out = out[1]
|
||||||
softmax_out = F.softmax(out)
|
softmax_out = F.softmax(out)
|
||||||
|
|
||||||
fetchs = OrderedDict()
|
fetch_list = []
|
||||||
metric_names = set()
|
metric_names = []
|
||||||
if not multilabel:
|
if not multilabel:
|
||||||
softmax_out = F.softmax(out)
|
softmax_out = F.softmax(out)
|
||||||
|
|
||||||
|
@ -154,12 +154,11 @@ def create_metric(out,
|
||||||
k = min(topk, classes_num)
|
k = min(topk, classes_num)
|
||||||
topk = paddle.metric.accuracy(softmax_out, label=label, k=k)
|
topk = paddle.metric.accuracy(softmax_out, label=label, k=k)
|
||||||
|
|
||||||
metric_names.add("top1")
|
metric_names.append("top1")
|
||||||
metric_names.add("top{}".format(k))
|
metric_names.append("top{}".format(k))
|
||||||
|
|
||||||
fetchs['top1'] = top1
|
fetch_list.append(top1)
|
||||||
topk_name = "top{}".format(k)
|
fetch_list.append(topk)
|
||||||
fetchs[topk_name] = topk
|
|
||||||
else:
|
else:
|
||||||
out = F.sigmoid(out)
|
out = F.sigmoid(out)
|
||||||
preds = multi_hot_encode(out.numpy())
|
preds = multi_hot_encode(out.numpy())
|
||||||
|
@ -169,19 +168,22 @@ def create_metric(out,
|
||||||
|
|
||||||
ham_dist_name = "hamming_distance"
|
ham_dist_name = "hamming_distance"
|
||||||
accuracy_name = "multilabel_accuracy"
|
accuracy_name = "multilabel_accuracy"
|
||||||
metric_names.add(ham_dist_name)
|
metric_names.append(ham_dist_name)
|
||||||
metric_names.add(accuracy_name)
|
metric_names.append(accuracy_name)
|
||||||
|
|
||||||
fetchs[accuracy_name] = accuracy
|
fetch_list.append(accuracy)
|
||||||
fetchs[ham_dist_name] = ham_dist
|
fetch_list.append(ham_dist)
|
||||||
|
|
||||||
# multi cards' eval
|
# multi cards' eval
|
||||||
if mode != "train" and paddle.distributed.get_world_size() > 1:
|
if mode != "train" and paddle.distributed.get_world_size() > 1:
|
||||||
for metric_name in metric_names:
|
for idx, fetch in enumerate(fetch_list):
|
||||||
fetchs[metric_name] = paddle.distributed.all_reduce(
|
fetch_list[idx] = paddle.distributed.all_reduce(
|
||||||
fetchs[metric_name], op=paddle.distributed.ReduceOp.
|
fetch, op=paddle.distributed.ReduceOp.
|
||||||
SUM) / paddle.distributed.get_world_size()
|
SUM) / paddle.distributed.get_world_size()
|
||||||
|
|
||||||
|
fetchs = OrderedDict()
|
||||||
|
for idx, name in enumerate(metric_names):
|
||||||
|
fetchs[name] = fetch_list[idx]
|
||||||
return fetchs
|
return fetchs
|
||||||
|
|
||||||
|
|
||||||
|
@ -282,7 +284,8 @@ def create_feeds(batch, use_mix, num_classes, multilabel=False):
|
||||||
if not multilabel:
|
if not multilabel:
|
||||||
label = to_tensor(batch[1].numpy().astype("int64").reshape(-1, 1))
|
label = to_tensor(batch[1].numpy().astype("int64").reshape(-1, 1))
|
||||||
else:
|
else:
|
||||||
label = to_tensor(batch[1].numpy().astype('float32').reshape(-1, num_classes))
|
label = to_tensor(batch[1].numpy().astype('float32').reshape(
|
||||||
|
-1, num_classes))
|
||||||
feeds = {"image": image, "label": label}
|
feeds = {"image": image, "label": label}
|
||||||
return feeds
|
return feeds
|
||||||
|
|
||||||
|
@ -336,10 +339,12 @@ def run(dataloader,
|
||||||
0, ("top1", AverageMeter(
|
0, ("top1", AverageMeter(
|
||||||
"top1", '.5f', postfix=",")))
|
"top1", '.5f', postfix=",")))
|
||||||
else:
|
else:
|
||||||
metric_list.insert(0, ("multilabel_accuracy", AverageMeter(
|
metric_list.insert(
|
||||||
"multilabel_accuracy", '.5f', postfix=",")))
|
0, ("multilabel_accuracy", AverageMeter(
|
||||||
metric_list.insert(0, ("hamming_distance", AverageMeter(
|
"multilabel_accuracy", '.5f', postfix=",")))
|
||||||
"hamming_distance", '.5f', postfix=",")))
|
metric_list.insert(
|
||||||
|
0, ("hamming_distance", AverageMeter(
|
||||||
|
"hamming_distance", '.5f', postfix=",")))
|
||||||
|
|
||||||
metric_list = OrderedDict(metric_list)
|
metric_list = OrderedDict(metric_list)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue