[Fix] Fix calculation errors on ARM chip. (#1592)
parent
f2adad2729
commit
dd657320a4
|
@ -4,8 +4,8 @@ repos:
|
|||
rev: 4.0.1
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.10.1
|
||||
- repo: https://github.com/zhouzaida/isort
|
||||
rev: 5.12.1
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import platform
|
||||
from numbers import Number
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmcls.utils import auto_select_device
|
||||
|
||||
|
||||
def accuracy_numpy(pred, target, topk=(1, ), thrs=0.):
|
||||
if isinstance(thrs, Number):
|
||||
|
@ -112,6 +115,12 @@ def accuracy(pred, target, topk=1, thrs=0.):
|
|||
if isinstance(x, np.ndarray) else x)
|
||||
pred = to_tensor(pred)
|
||||
target = to_tensor(target)
|
||||
if platform.machine() == 'aarch64':
|
||||
# ARM chip with low version GCC version may cause calculation errors,
|
||||
# attempt to calculate on cuda or npu.
|
||||
# reference: https://github.com/pytorch/pytorch/issues/75411
|
||||
pred = pred.to(auto_select_device())
|
||||
target = target.to(auto_select_device())
|
||||
|
||||
res = accuracy_torch(pred, target, topk, thrs)
|
||||
|
||||
|
|
Loading…
Reference in New Issue