mirror of https://github.com/alibaba/EasyCV.git
380 lines
12 KiB
Python
380 lines
12 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import logging
|
|
from collections import namedtuple
|
|
|
|
import torch
|
|
import torch.distributed
|
|
import torchacc.torch_xla.amp as torchacc_amp
|
|
import torchacc.torch_xla.amp.syncfree as xla_optim
|
|
import torchacc.torch_xla.core.xla_model as xm
|
|
from prettytable import PrettyTable
|
|
from torch.distributed import ReduceOp
|
|
|
|
DEFAULT_TAG = 'EasyCV-default-barrier-tag'
|
|
|
|
OpSpec = namedtuple('OpSpec', ['module', 'name', 'value'])
|
|
|
|
|
|
class OpsCollector(object):
|
|
|
|
def __init__(self):
|
|
self._ops = {}
|
|
|
|
def add_op(self, op_name, op_spec):
|
|
self._ops[op_name] = op_spec
|
|
|
|
def get_op(self, op_name):
|
|
return self._ops.get(op_name, None)
|
|
|
|
@property
|
|
def ops(self):
|
|
return self._ops
|
|
|
|
|
|
class OpsCollectorManager(object):
|
|
|
|
def __init__(self):
|
|
self._registries = {}
|
|
|
|
def get_collector(self, name):
|
|
return self._registries[name]
|
|
|
|
@property
|
|
def registies(self):
|
|
return self._registries
|
|
|
|
def register(self, name):
|
|
if name not in self._registries:
|
|
self._registries[name] = OpsCollector()
|
|
|
|
|
|
_ops_manager = OpsCollectorManager()
|
|
|
|
TORCH_MODULE_NAME = 'torch'
|
|
TORCHACC_MODULE_NAME = 'torchacc'
|
|
|
|
|
|
def _register_torch_ops():
|
|
global _ops_manager
|
|
|
|
module_name = TORCH_MODULE_NAME
|
|
_ops_manager.register(module_name)
|
|
collector = _ops_manager.get_collector(module_name)
|
|
|
|
collector.add_op(
|
|
'TO', OpSpec(module=torch.Tensor, name='to', value=torch.Tensor.to))
|
|
collector.add_op(
|
|
'CUDA',
|
|
OpSpec(module=torch.Tensor, name='cuda', value=torch.Tensor.cuda))
|
|
collector.add_op('TENSOR',
|
|
OpSpec(module=torch, name='tensor', value=torch.tensor))
|
|
collector.add_op('ZEROS',
|
|
OpSpec(module=torch, name='zeros', value=torch.zeros))
|
|
collector.add_op(
|
|
'GET_RANK',
|
|
OpSpec(
|
|
module=torch.distributed,
|
|
name='get_rank',
|
|
value=torch.distributed.get_rank))
|
|
collector.add_op(
|
|
'GET_WORLD_SIZE',
|
|
OpSpec(
|
|
module=torch.distributed,
|
|
name='get_world_size',
|
|
value=torch.distributed.get_world_size))
|
|
collector.add_op(
|
|
'BARRIER',
|
|
OpSpec(
|
|
module=torch.distributed,
|
|
name='barrier',
|
|
value=torch.distributed.barrier))
|
|
collector.add_op(
|
|
'ALL_REDUCE',
|
|
OpSpec(
|
|
module=torch.distributed,
|
|
name='all_reduce',
|
|
value=torch.distributed.all_reduce))
|
|
collector.add_op(
|
|
'REDUCE',
|
|
OpSpec(
|
|
module=torch.distributed,
|
|
name='reduce',
|
|
value=torch.distributed.reduce))
|
|
collector.add_op(
|
|
'BROADCAST',
|
|
OpSpec(
|
|
module=torch.distributed,
|
|
name='broadcast',
|
|
value=torch.distributed.broadcast))
|
|
collector.add_op(
|
|
'ALL_GATHER',
|
|
OpSpec(
|
|
module=torch.distributed,
|
|
name='all_gather',
|
|
value=torch.distributed.all_gather))
|
|
collector.add_op(
|
|
'IS_INITIALIZED',
|
|
OpSpec(
|
|
module=torch.distributed,
|
|
name='is_initialized',
|
|
value=torch.distributed.is_initialized))
|
|
collector.add_op(
|
|
'ADAM',
|
|
OpSpec(module=torch.optim, name='Adam', value=torch.optim.Adam))
|
|
collector.add_op(
|
|
'ADAMW',
|
|
OpSpec(module=torch.optim, name='AdamW', value=torch.optim.AdamW))
|
|
collector.add_op(
|
|
'SGD', OpSpec(module=torch.optim, name='SGD', value=torch.optim.SGD))
|
|
collector.add_op(
|
|
'GRADSCALER',
|
|
OpSpec(
|
|
module=torch.cuda.amp,
|
|
name='GradScaler',
|
|
value=torch.cuda.amp.GradScaler))
|
|
|
|
return collector
|
|
|
|
|
|
def _register_torchacc_ops():
|
|
global _ops_manager
|
|
|
|
reduce_op_map = {
|
|
ReduceOp.SUM: xm.REDUCE_SUM,
|
|
ReduceOp.PRODUCT: xm.REDUCE_MUL,
|
|
ReduceOp.MIN: xm.REDUCE_MIN,
|
|
ReduceOp.MAX: xm.REDUCE_MAX,
|
|
ReduceOp.BAND: xm.REDUCE_AND,
|
|
ReduceOp.BOR: xm.REDUCE_OR,
|
|
}
|
|
|
|
module_name = TORCHACC_MODULE_NAME
|
|
_ops_manager.register(module_name)
|
|
collector = _ops_manager.get_collector(module_name)
|
|
|
|
origin_to = torch.Tensor.to
|
|
origin_tensor = torch.tensor
|
|
origin_zeros = torch.zeros
|
|
torchacc_device = xm.xla_device()
|
|
|
|
from typing import Any, Optional, Union
|
|
from torch.types import _int, _bool, _dtype, _device
|
|
|
|
def torcacc_is_initialized():
|
|
# TODO: add initialize attribute
|
|
# keep consistent with torch dist behavior
|
|
return xm.xrt_world_size() > 1
|
|
|
|
def torchacc_to(self,
|
|
device: Optional[Union[_device, str]] = None,
|
|
dtype: Optional[_dtype] = None,
|
|
non_blocking: _bool = False,
|
|
copy: _bool = False) -> torch.Tensor:
|
|
if device is not None and str(device).startswith('cuda'):
|
|
device = torchacc_device
|
|
return origin_to(
|
|
self,
|
|
device=device,
|
|
dtype=dtype,
|
|
non_blocking=non_blocking,
|
|
copy=copy)
|
|
|
|
# must setattr after torchacc_to
|
|
def torchacc_cuda(self,
|
|
device: Optional[Union[_device, _int, str]] = None,
|
|
non_blocking: _bool = False) -> torch.Tensor:
|
|
assert torch.cuda.is_available()
|
|
device = torchacc_device
|
|
return self.to(device=device, non_blocking=non_blocking)
|
|
|
|
def torchacc_tensor(data: Any,
|
|
dtype: Optional[_dtype] = None,
|
|
device: Union[_device, str, None] = None,
|
|
requires_grad: _bool = False) -> torch.Tensor:
|
|
if str(device).startswith('cuda'):
|
|
device = torchacc_device
|
|
return origin_tensor(
|
|
data=data, dtype=dtype, device=device, requires_grad=requires_grad)
|
|
|
|
def torchacc_zeros(*args, device=None, **kwargs):
|
|
if str(device).startswith('cuda'):
|
|
device = torchacc_device
|
|
return origin_zeros(*args, device=device, **kwargs)
|
|
|
|
def torchacc_barrier(tag=None, **kwargs):
|
|
if tag is None:
|
|
tag = DEFAULT_TAG
|
|
return xm.rendezvous(tag, **kwargs)
|
|
|
|
def torcacc_all_reduce(tensor, op=ReduceOp.SUM, **kwargs):
|
|
if not isinstance(tensor, list):
|
|
return xm.all_reduce(
|
|
reduce_type=reduce_op_map[op], inputs=[tensor], **kwargs)
|
|
else:
|
|
return xm.all_reduce(
|
|
reduce_type=reduce_op_map[op], inputs=tensor, **kwargs)
|
|
|
|
def torchacc_reduce(tensor, dst, op=ReduceOp.SUM, **kwargs):
|
|
# if tensor.device.type != 'gpu':
|
|
# return
|
|
if xm.get_ordinal() != dst:
|
|
tensor = tensor.clone()
|
|
xm.all_reduce(
|
|
reduce_type=reduce_op_map[op], inputs=tensor, **kwargs)
|
|
else:
|
|
xm.all_reduce(
|
|
reduce_type=reduce_op_map[op], inputs=[tensor], **kwargs)
|
|
|
|
def torcacc_broadcast(**kwargs):
|
|
raise ValueError('Not support broadcast for torchacc yet!')
|
|
|
|
def torcacc_all_gather(tensor_list, tensor, **kwargs):
|
|
if len(tensor.size()) == 0:
|
|
raise ValueError(
|
|
'Not support ``all_gather`` scaler type for torchacc!')
|
|
|
|
res = xm.all_gather(value=tensor, dim=0, **kwargs)
|
|
splits = torch.tensor_split(res, len(tensor_list))
|
|
|
|
for i in range(len(tensor_list)):
|
|
assert splits[i].size() == tensor.size(
|
|
), 'mismatch size: {}, {}'.format(splits[i].size(), tensor.size())
|
|
tensor_list[i] = splits[i]
|
|
del splits
|
|
|
|
collector.add_op(
|
|
'TO', OpSpec(module=None, name=None,
|
|
value=torchacc_to)) # without `to` function, module=None
|
|
collector.add_op('CUDA', OpSpec(
|
|
module=None, name=None,
|
|
value=torchacc_cuda)) # without `cuda` function, module=None
|
|
collector.add_op('TENSOR',
|
|
OpSpec(module=None, name=None, value=torchacc_tensor))
|
|
collector.add_op('ZEROS',
|
|
OpSpec(module=None, name=None, value=torchacc_zeros))
|
|
collector.add_op(
|
|
'GET_RANK',
|
|
OpSpec(module=xm, name='get_ordinal', value=xm.get_ordinal))
|
|
collector.add_op(
|
|
'GET_WORLD_SIZE',
|
|
OpSpec(module=xm, name='xrt_world_size', value=xm.xrt_world_size))
|
|
collector.add_op(
|
|
'BARRIER',
|
|
OpSpec(module=xm, name='rendezvous', value=torchacc_barrier))
|
|
collector.add_op(
|
|
'ALL_REDUCE',
|
|
OpSpec(module=xm, name='all_reduce', value=torcacc_all_reduce))
|
|
collector.add_op('REDUCE',
|
|
OpSpec(module=None, name=None, value=torchacc_reduce))
|
|
collector.add_op('BROADCAST',
|
|
OpSpec(module=None, name=None, value=torcacc_broadcast)
|
|
) # without `broadcast` function, module=None
|
|
collector.add_op(
|
|
'ALL_GATHER',
|
|
OpSpec(module=xm, name='all_gather', value=torcacc_all_gather))
|
|
collector.add_op(
|
|
'IS_INITIALIZED',
|
|
OpSpec(module=None, name=None, value=torcacc_is_initialized))
|
|
collector.add_op(
|
|
'ADAM', OpSpec(module=xla_optim, name='Adam', value=xla_optim.Adam))
|
|
collector.add_op(
|
|
'ADAMW', OpSpec(module=xla_optim, name='AdamW', value=xla_optim.AdamW))
|
|
collector.add_op('SGD',
|
|
OpSpec(module=xla_optim, name='SGD', value=xla_optim.SGD))
|
|
collector.add_op(
|
|
'GRADSCALER',
|
|
OpSpec(
|
|
module=torchacc_amp,
|
|
name='GradScaler',
|
|
value=torchacc_amp.GradScaler))
|
|
|
|
return collector
|
|
|
|
|
|
def convert_ops(src_collector, target_collector):
|
|
reset_ops_map = {}
|
|
for src_op_key, src_op in src_collector.ops.items():
|
|
target_op = target_collector.get_op(src_op_key)
|
|
if src_op.module is None or target_op is None:
|
|
logging.info(
|
|
'Skip {}, source op module is None or not find target op!')
|
|
continue
|
|
|
|
setattr(src_op.module, src_op.name, target_op.value)
|
|
|
|
reset_ops_map.update({
|
|
'{}.{}'.format(src_op.module.__name__, src_op.name):
|
|
'{}{}'.format(
|
|
target_op.module.__name__ +
|
|
'.' if target_op.module is not None else '', target_op.name
|
|
or target_op.value.__name__)
|
|
})
|
|
|
|
show = PrettyTable()
|
|
show.field_names = ['source ops', 'replaced ops']
|
|
for k, v in reset_ops_map.items():
|
|
show.add_row([k, v])
|
|
|
|
from torch.distributed import get_rank
|
|
if get_rank() == 0:
|
|
print('Replaced ops is as follows:')
|
|
print(show)
|
|
|
|
|
|
def convert_torch_ops_to_torchacc():
|
|
print(
|
|
'For adapt torchacc, we replaced part of torch\'s operators with torchacc\'s operators.'
|
|
)
|
|
|
|
src_collector = _register_torch_ops()
|
|
target_collector = _register_torchacc_ops()
|
|
convert_ops(src_collector, target_collector)
|
|
|
|
|
|
def convert_timm_ops():
|
|
import timm
|
|
|
|
# TODO: remove it, fix torch.tensor to adapt torch.jit
|
|
if hasattr(timm.models.layers, 'anti_aliasing'):
|
|
from timm.models.layers import anti_aliasing
|
|
_ori_DownsampleJIT = anti_aliasing.DownsampleJIT
|
|
|
|
class FixedDownsampleJIT(_ori_DownsampleJIT):
|
|
pass
|
|
|
|
setattr(anti_aliasing, 'DownsampleJIT', FixedDownsampleJIT)
|
|
|
|
# TODO: remove it, fix torch.cat to support multiple types of arguments
|
|
def _fix_forward_features(self, x):
|
|
x = self.patch_embed(x)
|
|
cls_token = self.cls_token.expand(
|
|
x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
|
|
|
# torch.cat does not support multiple types of arguments
|
|
# ========================my add=========================
|
|
if x.dtype == torch.float16:
|
|
cls_token = cls_token.half()
|
|
# ======================== end===========================
|
|
|
|
if self.dist_token is None:
|
|
x = torch.cat((cls_token, x), dim=1)
|
|
else:
|
|
x = torch.cat(
|
|
(cls_token, self.dist_token.expand(x.shape[0], -1, -1), x),
|
|
dim=1)
|
|
x = self.pos_drop(x + self.pos_embed)
|
|
x = self.blocks(x)
|
|
x = self.norm(x)
|
|
if self.dist_token is None:
|
|
return self.pre_logits(x[:, 0])
|
|
else:
|
|
return x[:, 0], x[:, 1]
|
|
|
|
from timm.models.vision_transformer import VisionTransformer
|
|
setattr(VisionTransformer, 'forward_features', _fix_forward_features)
|
|
|
|
print(
|
|
'For adapt to torchacc, we have modified some apis of timm. '
|
|
'Please refer to ``easycv.toolkit.torchacc.convert_timm_ops`` for details.'
|
|
)
|