From 4ec25097152bd7c21b1fd087695df6fea57ddc23 Mon Sep 17 00:00:00 2001
From: LRJKD <101466907+LRJKD@users.noreply.github.com>
Date: Wed, 19 Jul 2023 18:03:49 +0800
Subject: [PATCH] [Improve] Enable binary operators on the NPU. (#1714)

* [Improve]Enable binary operators on the NPU.

* Update docs link

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
---
 mmcls/utils/distribution.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/mmcls/utils/distribution.py b/mmcls/utils/distribution.py
index b74761bf..22a6f74e 100644
--- a/mmcls/utils/distribution.py
+++ b/mmcls/utils/distribution.py
@@ -1,4 +1,5 @@
 # Copyright (c) OpenMMLab. All rights reserved.
+import torch
 
 
 def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs):
@@ -18,8 +19,8 @@ def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs):
     """
     if device == 'npu':
         from mmcv.device.npu import NPUDataParallel
-        import torch
         torch.npu.set_device(kwargs['device_ids'][0])
+        torch.npu.set_compile_mode(jit_compile=False)
         model = NPUDataParallel(model.npu(), dim=dim, *args, **kwargs)
     elif device == 'mlu':
         from mmcv.device.mlu import MLUDataParallel
@@ -60,6 +61,7 @@ def wrap_distributed_model(model, device='cuda', *args, **kwargs):
     if device == 'npu':
         from mmcv.device.npu import NPUDistributedDataParallel
         from torch.npu import current_device
+        torch.npu.set_compile_mode(jit_compile=False)
         model = NPUDistributedDataParallel(
             model.npu(), *args, device_ids=[current_device()], **kwargs)
     elif device == 'mlu':