fix static training bugs
parent
7bf9b40bf2
commit
33a15cfdae
|
@ -46,7 +46,7 @@ class TopkAcc(AvgMetrics):
|
|||
for k in self.topk:
|
||||
metric_dict["top{}".format(k)] = paddle.metric.accuracy(
|
||||
x, label, k=k)
|
||||
self.avg_meters["top{}".format(k)].update(metric_dict["top{}".format(k)].numpy()[0], x.shape[0])
|
||||
self.avg_meters["top{}".format(k)].update(metric_dict["top{}".format(k)], x.shape[0])
|
||||
return metric_dict
|
||||
|
||||
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
|
||||
__all__ = ['AverageMeter']
|
||||
|
||||
|
||||
|
@ -44,6 +46,8 @@ class AverageMeter(object):
|
|||
|
||||
@property
|
||||
def avg_info(self):
|
||||
if isinstance(self.avg, paddle.Tensor):
|
||||
self.avg = self.avg.numpy()[0]
|
||||
return "{}: {:.5f}".format(self.name, self.avg)
|
||||
|
||||
@property
|
||||
|
|
Loading…
Reference in New Issue