Update GoogLeNetLoss
parent
4e154aed4f
commit
707e01aed5
|
@ -122,8 +122,8 @@ Infer:
|
|||
|
||||
Metric:
|
||||
Train:
|
||||
- TopkAcc:
|
||||
- GoogLeNetTopkAcc:
|
||||
topk: [1, 5]
|
||||
Eval:
|
||||
- TopkAcc:
|
||||
- GoogLeNetTopkAcc:
|
||||
topk: [1, 5]
|
||||
|
|
|
@ -18,6 +18,7 @@ from collections import OrderedDict
|
|||
|
||||
from .metrics import TopkAcc, mAP, mINP, Recallk
|
||||
from .metrics import DistillationTopkAcc
|
||||
from .metrics import GoogLeNetTopkAcc
|
||||
|
||||
class CombinedMetrics(nn.Layer):
|
||||
def __init__(self, config_list):
|
||||
|
|
|
@ -25,8 +25,6 @@ class TopkAcc(nn.Layer):
|
|||
self.topk = topk
|
||||
|
||||
def forward(self, x, label):
|
||||
if isinstance(x, list):
|
||||
x = x[0]
|
||||
if isinstance(x, dict):
|
||||
x = x["logits"]
|
||||
|
||||
|
@ -122,3 +120,16 @@ class DistillationTopkAcc(TopkAcc):
|
|||
if self.feature_key is not None:
|
||||
x = x[self.feature_key]
|
||||
return super().forward(x, label)
|
||||
|
||||
|
||||
class GoogLeNetTopkAcc(TopkAcc):
|
||||
def __init__(self, topk=(1, 5)):
|
||||
super().__init__()
|
||||
assert isinstance(topk, (int, list, tuple))
|
||||
if isinstance(topk, int):
|
||||
topk = [topk]
|
||||
self.topk = topk
|
||||
|
||||
def forward(self, x, label):
|
||||
return super().forward(x[0], label)
|
||||
|
||||
|
|
Loading…
Reference in New Issue