[Fix] Fix calculation errors on ARM chip. (#1592)

master
Yinlei Sun 2023-06-01 16:53:32 +08:00 committed by GitHub
parent f2adad2729
commit dd657320a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 2 deletions

View File

@ -4,8 +4,8 @@ repos:
rev: 4.0.1
hooks:
- id: flake8
- repo: https://github.com/PyCQA/isort
rev: 5.10.1
- repo: https://github.com/zhouzaida/isort
rev: 5.12.1
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-yapf

View File

@ -1,10 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
from numbers import Number
import numpy as np
import torch
import torch.nn as nn
from mmcls.utils import auto_select_device
def accuracy_numpy(pred, target, topk=(1, ), thrs=0.):
if isinstance(thrs, Number):
@ -112,6 +115,12 @@ def accuracy(pred, target, topk=1, thrs=0.):
if isinstance(x, np.ndarray) else x)
pred = to_tensor(pred)
target = to_tensor(target)
if platform.machine() == 'aarch64':
# ARM chip with low version GCC version may cause calculation errors,
# attempt to calculate on cuda or npu.
# reference: https://github.com/pytorch/pytorch/issues/75411
pred = pred.to(auto_select_device())
target = target.to(auto_select_device())
res = accuracy_torch(pred, target, topk, thrs)