[Fix] Adapt to PyTorch v2.1 on Ascend (#1332)

pull/1334/head
LRJKD 2023-09-01 16:55:45 +08:00 committed by GitHub
parent 762c9a25b6
commit 5671b53bc5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 3 deletions

View File

@ -190,11 +190,17 @@ class BaseModel(BaseModule):
# directly will cause the NPU to not be found.
# Here, the input parameters are processed to avoid errors.
if args and isinstance(args[0], str) and 'npu' in args[0]:
args = tuple(
[list(args)[0].replace('npu', torch.npu.native_device)])
import torch_npu
args = tuple([
list(args)[0].replace(
'npu', torch_npu.npu.native_device if hasattr(
torch_npu.npu, 'native_device') else 'privateuseone')
])
if kwargs and 'npu' in str(kwargs.get('device', '')):
import torch_npu
kwargs['device'] = kwargs['device'].replace(
'npu', torch.npu.native_device)
'npu', torch_npu.npu.native_device if hasattr(
torch_npu.npu, 'native_device') else 'privateuseone')
device = torch._C._nn._parse_to(*args, **kwargs)[0]
if device is not None: