[Improve] Enable binary operators on the NPU. (#1714)
* [Improve]Enable binary operators on the NPU. * Update docs link --------- Co-authored-by: mzr1996 <mzr1996@163.com>mmcls-0.x
parent
748ab7aa7d
commit
4ec2509715
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
|
||||
def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs):
|
||||
|
@ -18,8 +19,8 @@ def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs):
|
|||
"""
|
||||
if device == 'npu':
|
||||
from mmcv.device.npu import NPUDataParallel
|
||||
import torch
|
||||
torch.npu.set_device(kwargs['device_ids'][0])
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
model = NPUDataParallel(model.npu(), dim=dim, *args, **kwargs)
|
||||
elif device == 'mlu':
|
||||
from mmcv.device.mlu import MLUDataParallel
|
||||
|
@ -60,6 +61,7 @@ def wrap_distributed_model(model, device='cuda', *args, **kwargs):
|
|||
if device == 'npu':
|
||||
from mmcv.device.npu import NPUDistributedDataParallel
|
||||
from torch.npu import current_device
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
model = NPUDistributedDataParallel(
|
||||
model.npu(), *args, device_ids=[current_device()], **kwargs)
|
||||
elif device == 'mlu':
|
||||
|
|
Loading…
Reference in New Issue