[Fix] Fix calculation errors on ARM chip. (#1592)

master
Yinlei Sun 2023-06-01 16:53:32 +08:00 committed by GitHub
parent f2adad2729
commit dd657320a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 2 deletions

View File

@ -4,8 +4,8 @@ repos:
rev: 4.0.1 rev: 4.0.1
hooks: hooks:
- id: flake8 - id: flake8
- repo: https://github.com/PyCQA/isort - repo: https://github.com/zhouzaida/isort
rev: 5.10.1 rev: 5.12.1
hooks: hooks:
- id: isort - id: isort
- repo: https://github.com/pre-commit/mirrors-yapf - repo: https://github.com/pre-commit/mirrors-yapf

View File

@ -1,10 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import platform
from numbers import Number from numbers import Number
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcls.utils import auto_select_device
def accuracy_numpy(pred, target, topk=(1, ), thrs=0.): def accuracy_numpy(pred, target, topk=(1, ), thrs=0.):
if isinstance(thrs, Number): if isinstance(thrs, Number):
@ -112,6 +115,12 @@ def accuracy(pred, target, topk=1, thrs=0.):
if isinstance(x, np.ndarray) else x) if isinstance(x, np.ndarray) else x)
pred = to_tensor(pred) pred = to_tensor(pred)
target = to_tensor(target) target = to_tensor(target)
if platform.machine() == 'aarch64':
# ARM chip with low version GCC version may cause calculation errors,
# attempt to calculate on cuda or npu.
# reference: https://github.com/pytorch/pytorch/issues/75411
pred = pred.to(auto_select_device())
target = target.to(auto_select_device())
res = accuracy_torch(pred, target, topk, thrs) res = accuracy_torch(pred, target, topk, thrs)