mirror of https://github.com/alibaba/EasyCV.git
27 lines
642 B
Python
27 lines
642 B
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import os
|
|
from distutils.version import LooseVersion
|
|
|
|
import torch
|
|
import torchacc.torch_xla.core.xla_model as xm
|
|
|
|
from .convert_ops import convert_timm_ops, convert_torch_ops_to_torchacc
|
|
|
|
|
|
def patch_ops():
|
|
convert_timm_ops()
|
|
convert_torch_ops_to_torchacc()
|
|
|
|
|
|
def torchacc_init():
|
|
assert LooseVersion(torch.__version__) >= LooseVersion(
|
|
'1.10.0'), 'torchacc only supports torch versions greater than 1.10'
|
|
|
|
if xm.xrt_world_size() == 1:
|
|
os.environ['GPU_NUM_DEVICES'] = '1'
|
|
|
|
device = xm.xla_device()
|
|
xm.set_replication(device, [device])
|
|
|
|
patch_ops()
|