Update GoogLeNetLoss

pull/944/head
cuicheng01 2021-06-22 01:58:03 +00:00
parent 4e154aed4f
commit 707e01aed5
3 changed files with 16 additions and 4 deletions

View File

@ -122,8 +122,8 @@ Infer:
Metric:
Train:
- TopkAcc:
- GoogLeNetTopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
- GoogLeNetTopkAcc:
topk: [1, 5]

View File

@ -18,6 +18,7 @@ from collections import OrderedDict
from .metrics import TopkAcc, mAP, mINP, Recallk
from .metrics import DistillationTopkAcc
from .metrics import GoogLeNetTopkAcc
class CombinedMetrics(nn.Layer):
def __init__(self, config_list):

View File

@ -25,8 +25,6 @@ class TopkAcc(nn.Layer):
self.topk = topk
def forward(self, x, label):
if isinstance(x, list):
x = x[0]
if isinstance(x, dict):
x = x["logits"]
@ -122,3 +120,16 @@ class DistillationTopkAcc(TopkAcc):
if self.feature_key is not None:
x = x[self.feature_key]
return super().forward(x, label)
class GoogLeNetTopkAcc(TopkAcc):
def __init__(self, topk=(1, 5)):
super().__init__()
assert isinstance(topk, (int, list, tuple))
if isinstance(topk, int):
topk = [topk]
self.topk = topk
def forward(self, x, label):
return super().forward(x[0], label)