fix static training bugs

pull/1954/head
cuicheng01 2022-05-24 17:02:44 +00:00
parent 7bf9b40bf2
commit 33a15cfdae
2 changed files with 5 additions and 1 deletions

View File

@ -46,7 +46,7 @@ class TopkAcc(AvgMetrics):
for k in self.topk:
metric_dict["top{}".format(k)] = paddle.metric.accuracy(
x, label, k=k)
self.avg_meters["top{}".format(k)].update(metric_dict["top{}".format(k)].numpy()[0], x.shape[0])
self.avg_meters["top{}".format(k)].update(metric_dict["top{}".format(k)], x.shape[0])
return metric_dict

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
__all__ = ['AverageMeter']
@ -44,6 +46,8 @@ class AverageMeter(object):
@property
def avg_info(self):
if isinstance(self.avg, paddle.Tensor):
self.avg = self.avg.numpy()[0]
return "{}: {:.5f}".format(self.name, self.avg)
@property