[Fix] Fix calculation errors on ARM chip. (#1592)
parent
f2adad2729
commit
dd657320a4
|
@ -4,8 +4,8 @@ repos:
|
||||||
rev: 4.0.1
|
rev: 4.0.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
- repo: https://github.com/PyCQA/isort
|
- repo: https://github.com/zhouzaida/isort
|
||||||
rev: 5.10.1
|
rev: 5.12.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import platform
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from mmcls.utils import auto_select_device
|
||||||
|
|
||||||
|
|
||||||
def accuracy_numpy(pred, target, topk=(1, ), thrs=0.):
|
def accuracy_numpy(pred, target, topk=(1, ), thrs=0.):
|
||||||
if isinstance(thrs, Number):
|
if isinstance(thrs, Number):
|
||||||
|
@ -112,6 +115,12 @@ def accuracy(pred, target, topk=1, thrs=0.):
|
||||||
if isinstance(x, np.ndarray) else x)
|
if isinstance(x, np.ndarray) else x)
|
||||||
pred = to_tensor(pred)
|
pred = to_tensor(pred)
|
||||||
target = to_tensor(target)
|
target = to_tensor(target)
|
||||||
|
if platform.machine() == 'aarch64':
|
||||||
|
# ARM chip with low version GCC version may cause calculation errors,
|
||||||
|
# attempt to calculate on cuda or npu.
|
||||||
|
# reference: https://github.com/pytorch/pytorch/issues/75411
|
||||||
|
pred = pred.to(auto_select_device())
|
||||||
|
target = target.to(auto_select_device())
|
||||||
|
|
||||||
res = accuracy_torch(pred, target, topk, thrs)
|
res = accuracy_torch(pred, target, topk, thrs)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue