85 lines
3.0 KiB
Python
85 lines
3.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
|
|
|
|
def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs):
|
|
"""Wrap module in non-distributed environment by device type.
|
|
|
|
- For CUDA, wrap as :obj:`mmcv.parallel.MMDataParallel`.
|
|
- For MPS, wrap as :obj:`mmcv.device.mps.MPSDataParallel`.
|
|
- For CPU & IPU, not wrap the model.
|
|
|
|
Args:
|
|
model(:class:`nn.Module`): model to be parallelized.
|
|
device(str): device type, cuda, cpu or mlu. Defaults to cuda.
|
|
dim(int): Dimension used to scatter the data. Defaults to 0.
|
|
|
|
Returns:
|
|
model(nn.Module): the model to be parallelized.
|
|
"""
|
|
if device == 'npu':
|
|
from mmcv.device.npu import NPUDataParallel
|
|
torch.npu.set_device(kwargs['device_ids'][0])
|
|
torch.npu.set_compile_mode(jit_compile=False)
|
|
model = NPUDataParallel(model.npu(), dim=dim, *args, **kwargs)
|
|
elif device == 'mlu':
|
|
from mmcv.device.mlu import MLUDataParallel
|
|
model = MLUDataParallel(model.mlu(), dim=dim, *args, **kwargs)
|
|
elif device == 'cuda':
|
|
from mmcv.parallel import MMDataParallel
|
|
model = MMDataParallel(model.cuda(), dim=dim, *args, **kwargs)
|
|
elif device == 'cpu':
|
|
model = model.cpu()
|
|
elif device == 'ipu':
|
|
model = model.cpu()
|
|
elif device == 'mps':
|
|
from mmcv.device import mps
|
|
model = mps.MPSDataParallel(model.to('mps'), dim=dim, *args, **kwargs)
|
|
else:
|
|
raise RuntimeError(f'Unavailable device "{device}"')
|
|
|
|
return model
|
|
|
|
|
|
def wrap_distributed_model(model, device='cuda', *args, **kwargs):
|
|
"""Build DistributedDataParallel module by device type.
|
|
|
|
- For CUDA, wrap as :obj:`mmcv.parallel.MMDistributedDataParallel`.
|
|
- Other device types are not supported by now.
|
|
|
|
Args:
|
|
model(:class:`nn.Module`): module to be parallelized.
|
|
device(str): device type, mlu or cuda.
|
|
|
|
Returns:
|
|
model(:class:`nn.Module`): the module to be parallelized
|
|
|
|
References:
|
|
.. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
|
|
DistributedDataParallel.html
|
|
"""
|
|
if device == 'npu':
|
|
from mmcv.device.npu import NPUDistributedDataParallel
|
|
from torch.npu import current_device
|
|
torch.npu.set_compile_mode(jit_compile=False)
|
|
model = NPUDistributedDataParallel(
|
|
model.npu(), *args, device_ids=[current_device()], **kwargs)
|
|
elif device == 'mlu':
|
|
import os
|
|
|
|
from mmcv.device.mlu import MLUDistributedDataParallel
|
|
model = MLUDistributedDataParallel(
|
|
model.mlu(),
|
|
*args,
|
|
device_ids=[int(os.environ['LOCAL_RANK'])],
|
|
**kwargs)
|
|
elif device == 'cuda':
|
|
from mmcv.parallel import MMDistributedDataParallel
|
|
from torch.cuda import current_device
|
|
model = MMDistributedDataParallel(
|
|
model.cuda(), *args, device_ids=[current_device()], **kwargs)
|
|
else:
|
|
raise RuntimeError(f'Unavailable device "{device}"')
|
|
|
|
return model
|