polish googlenet
parent
7d5006055f
commit
10005abed6
|
@ -157,7 +157,11 @@ def create_loss(out,
|
|||
return loss(out, target)
|
||||
|
||||
|
||||
def create_metric(out, feeds, topk=5, classes_num=1000,
|
||||
def create_metric(out,
|
||||
feeds,
|
||||
architecture,
|
||||
topk=5,
|
||||
classes_num=1000,
|
||||
use_distillation=False):
|
||||
"""
|
||||
Create measures of model accuracy, such as top1 and top5
|
||||
|
@ -171,16 +175,22 @@ def create_metric(out, feeds, topk=5, classes_num=1000,
|
|||
Returns:
|
||||
fetchs(dict): dict of measures
|
||||
"""
|
||||
# just need student label to get metrics
|
||||
if use_distillation:
|
||||
out = out[1]
|
||||
if architecture["name"] == "GoogLeNet":
|
||||
assert len(out) == 3, "GoogLeNet should have 3 outputs"
|
||||
softmax_out = out[0]
|
||||
else:
|
||||
# just need student label to get metrics
|
||||
if use_distillation:
|
||||
out = out[1]
|
||||
softmax_out = fluid.layers.softmax(out, use_cudnn=False)
|
||||
|
||||
fetchs = OrderedDict()
|
||||
label = feeds['label']
|
||||
softmax_out = fluid.layers.softmax(out, use_cudnn=False)
|
||||
top1 = fluid.layers.accuracy(softmax_out, label=label, k=1)
|
||||
# set top1 to fetchs
|
||||
top1 = fluid.layers.accuracy(softmax_out, label=feeds['label'], k=1)
|
||||
fetchs['top1'] = (top1, AverageMeter('top1', '.4f', need_avg=True))
|
||||
# set topk to fetchs
|
||||
k = min(topk, classes_num)
|
||||
topk = fluid.layers.accuracy(softmax_out, label=label, k=k)
|
||||
topk = fluid.layers.accuracy(softmax_out, label=feeds['label'], k=k)
|
||||
topk_name = 'top{}'.format(k)
|
||||
fetchs[topk_name] = (topk, AverageMeter(topk_name, '.4f', need_avg=True))
|
||||
|
||||
|
@ -201,7 +211,8 @@ def create_fetchs(out,
|
|||
|
||||
Args:
|
||||
out(variable): model output variable
|
||||
feeds(dict): dict of model input variables(included label)
|
||||
feeds(dict): dict of model input variables.
|
||||
If use mix_up, it will not include label.
|
||||
architecture(dict): architecture information,
|
||||
name(such as ResNet50) is needed
|
||||
topk(int): usually top5
|
||||
|
@ -217,7 +228,8 @@ def create_fetchs(out,
|
|||
use_distillation)
|
||||
fetchs['loss'] = (loss, AverageMeter('loss', '7.4f', need_avg=True))
|
||||
if not use_mix:
|
||||
metric = create_metric(out, feeds, topk, classes_num, use_distillation)
|
||||
metric = create_metric(out, feeds, architecture, topk, classes_num,
|
||||
use_distillation)
|
||||
fetchs.update(metric)
|
||||
|
||||
return fetchs
|
||||
|
|
Loading…
Reference in New Issue