From 4ec25097152bd7c21b1fd087695df6fea57ddc23 Mon Sep 17 00:00:00 2001 From: LRJKD <101466907+LRJKD@users.noreply.github.com> Date: Wed, 19 Jul 2023 18:03:49 +0800 Subject: [PATCH] [Improve] Enable binary operators on the NPU. (#1714) * [Improve]Enable binary operators on the NPU. * Update docs link --------- Co-authored-by: mzr1996 <mzr1996@163.com> --- mmcls/utils/distribution.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mmcls/utils/distribution.py b/mmcls/utils/distribution.py index b74761bf..22a6f74e 100644 --- a/mmcls/utils/distribution.py +++ b/mmcls/utils/distribution.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import torch def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs): @@ -18,8 +19,8 @@ def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs): """ if device == 'npu': from mmcv.device.npu import NPUDataParallel - import torch torch.npu.set_device(kwargs['device_ids'][0]) + torch.npu.set_compile_mode(jit_compile=False) model = NPUDataParallel(model.npu(), dim=dim, *args, **kwargs) elif device == 'mlu': from mmcv.device.mlu import MLUDataParallel @@ -60,6 +61,7 @@ def wrap_distributed_model(model, device='cuda', *args, **kwargs): if device == 'npu': from mmcv.device.npu import NPUDistributedDataParallel from torch.npu import current_device + torch.npu.set_compile_mode(jit_compile=False) model = NPUDistributedDataParallel( model.npu(), *args, device_ids=[current_device()], **kwargs) elif device == 'mlu':