[Improve] Enable binary operators on the NPU. (#1714)

* [Improve]Enable binary operators on the NPU.

* Update docs link

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
mmcls-0.x
LRJKD 2023-07-19 18:03:49 +08:00 committed by GitHub
parent 748ab7aa7d
commit 4ec2509715
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 1 deletions

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs):
@ -18,8 +19,8 @@ def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs):
"""
if device == 'npu':
from mmcv.device.npu import NPUDataParallel
import torch
torch.npu.set_device(kwargs['device_ids'][0])
torch.npu.set_compile_mode(jit_compile=False)
model = NPUDataParallel(model.npu(), dim=dim, *args, **kwargs)
elif device == 'mlu':
from mmcv.device.mlu import MLUDataParallel
@ -60,6 +61,7 @@ def wrap_distributed_model(model, device='cuda', *args, **kwargs):
if device == 'npu':
from mmcv.device.npu import NPUDistributedDataParallel
from torch.npu import current_device
torch.npu.set_compile_mode(jit_compile=False)
model = NPUDistributedDataParallel(
model.npu(), *args, device_ids=[current_device()], **kwargs)
elif device == 'mlu':