EasyCV/easycv/toolkit/torchacc/initilization.py

27 lines
642 B
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from distutils.version import LooseVersion
import torch
import torchacc.torch_xla.core.xla_model as xm
from .convert_ops import convert_timm_ops, convert_torch_ops_to_torchacc
def patch_ops():
convert_timm_ops()
convert_torch_ops_to_torchacc()
def torchacc_init():
assert LooseVersion(torch.__version__) >= LooseVersion(
'1.10.0'), 'torchacc only supports torch versions greater than 1.10'
if xm.xrt_world_size() == 1:
os.environ['GPU_NUM_DEVICES'] = '1'
device = xm.xla_device()
xm.set_replication(device, [device])
patch_ops()