From dd657320a40f7f69c9d3bb41f6d11d80cea27f75 Mon Sep 17 00:00:00 2001 From: Yinlei Sun Date: Thu, 1 Jun 2023 16:53:32 +0800 Subject: [PATCH] [Fix] Fix calculation errors on ARM chip. (#1592) --- .pre-commit-config.yaml | 4 ++-- mmcls/models/losses/accuracy.py | 9 +++++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 55138ce8..212601d5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,8 +4,8 @@ repos: rev: 4.0.1 hooks: - id: flake8 - - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + - repo: https://github.com/zhouzaida/isort + rev: 5.12.1 hooks: - id: isort - repo: https://github.com/pre-commit/mirrors-yapf diff --git a/mmcls/models/losses/accuracy.py b/mmcls/models/losses/accuracy.py index 1b142bc7..bd14a5a9 100644 --- a/mmcls/models/losses/accuracy.py +++ b/mmcls/models/losses/accuracy.py @@ -1,10 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. +import platform from numbers import Number import numpy as np import torch import torch.nn as nn +from mmcls.utils import auto_select_device + def accuracy_numpy(pred, target, topk=(1, ), thrs=0.): if isinstance(thrs, Number): @@ -112,6 +115,12 @@ def accuracy(pred, target, topk=1, thrs=0.): if isinstance(x, np.ndarray) else x) pred = to_tensor(pred) target = to_tensor(target) + if platform.machine() == 'aarch64': + # ARM chip with low version GCC version may cause calculation errors, + # attempt to calculate on cuda or npu. + # reference: https://github.com/pytorch/pytorch/issues/75411 + pred = pred.to(auto_select_device()) + target = target.to(auto_select_device()) res = accuracy_torch(pred, target, topk, thrs)