diff --git a/mmcls/utils/distribution.py b/mmcls/utils/distribution.py index b74761bf..22a6f74e 100644 --- a/mmcls/utils/distribution.py +++ b/mmcls/utils/distribution.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import torch def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs): @@ -18,8 +19,8 @@ def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs): """ if device == 'npu': from mmcv.device.npu import NPUDataParallel - import torch torch.npu.set_device(kwargs['device_ids'][0]) + torch.npu.set_compile_mode(jit_compile=False) model = NPUDataParallel(model.npu(), dim=dim, *args, **kwargs) elif device == 'mlu': from mmcv.device.mlu import MLUDataParallel @@ -60,6 +61,7 @@ def wrap_distributed_model(model, device='cuda', *args, **kwargs): if device == 'npu': from mmcv.device.npu import NPUDistributedDataParallel from torch.npu import current_device + torch.npu.set_compile_mode(jit_compile=False) model = NPUDistributedDataParallel( model.npu(), *args, device_ids=[current_device()], **kwargs) elif device == 'mlu':