[Fix] Adapt to PyTorch v2.1 on Ascend (#1332)
parent
762c9a25b6
commit
5671b53bc5
|
@ -190,11 +190,17 @@ class BaseModel(BaseModule):
|
|||
# directly will cause the NPU to not be found.
|
||||
# Here, the input parameters are processed to avoid errors.
|
||||
if args and isinstance(args[0], str) and 'npu' in args[0]:
|
||||
args = tuple(
|
||||
[list(args)[0].replace('npu', torch.npu.native_device)])
|
||||
import torch_npu
|
||||
args = tuple([
|
||||
list(args)[0].replace(
|
||||
'npu', torch_npu.npu.native_device if hasattr(
|
||||
torch_npu.npu, 'native_device') else 'privateuseone')
|
||||
])
|
||||
if kwargs and 'npu' in str(kwargs.get('device', '')):
|
||||
import torch_npu
|
||||
kwargs['device'] = kwargs['device'].replace(
|
||||
'npu', torch.npu.native_device)
|
||||
'npu', torch_npu.npu.native_device if hasattr(
|
||||
torch_npu.npu, 'native_device') else 'privateuseone')
|
||||
|
||||
device = torch._C._nn._parse_to(*args, **kwargs)[0]
|
||||
if device is not None:
|
||||
|
|
Loading…
Reference in New Issue