Merge pull request #2308 from huggingface/device_amp_cleanup

Cleanup some amp related behaviour to better support different (non-cuda) devices
This commit is contained in:
Ross Wightman 2024-10-19 08:19:27 -07:00 committed by GitHub
commit 5081b53e48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 61 additions and 52 deletions

View File

@ -32,13 +32,6 @@ try:
except ImportError: except ImportError:
pass pass
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
try: try:
from deepspeed.profiling.flops_profiler import get_model_profile from deepspeed.profiling.flops_profiler import get_model_profile
has_deepspeed_profiling = True has_deepspeed_profiling = True
@ -242,7 +235,7 @@ class BenchmarkRunner:
self.amp_dtype, self.model_dtype, self.data_dtype = resolve_precision(precision) self.amp_dtype, self.model_dtype, self.data_dtype = resolve_precision(precision)
self.channels_last = kwargs.pop('channels_last', False) self.channels_last = kwargs.pop('channels_last', False)
if self.amp_dtype is not None: if self.amp_dtype is not None:
self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=self.amp_dtype) self.amp_autocast = partial(torch.amp.autocast, device_type=device, dtype=self.amp_dtype)
else: else:
self.amp_autocast = suppress self.amp_autocast = suppress

View File

@ -28,13 +28,6 @@ try:
except ImportError: except ImportError:
has_apex = False has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
try: try:
from functorch.compile import memory_efficient_fusion from functorch.compile import memory_efficient_fusion
has_functorch = True has_functorch = True
@ -170,7 +163,6 @@ def main():
# resolve AMP arguments based on PyTorch / Apex availability # resolve AMP arguments based on PyTorch / Apex availability
amp_autocast = suppress amp_autocast = suppress
if args.amp: if args.amp:
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
assert args.amp_dtype in ('float16', 'bfloat16') assert args.amp_dtype in ('float16', 'bfloat16')
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16 amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)

View File

@ -113,13 +113,17 @@ class PrefetchLoader:
) )
else: else:
self.random_erasing = None self.random_erasing = None
self.is_cuda = torch.cuda.is_available() and device.type == 'cuda' self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
self.is_npu = device.type == 'npu' and torch.npu.is_available()
def __iter__(self): def __iter__(self):
first = True first = True
if self.is_cuda: if self.is_cuda:
stream = torch.cuda.Stream() stream = torch.cuda.Stream()
stream_context = partial(torch.cuda.stream, stream=stream) stream_context = partial(torch.cuda.stream, stream=stream)
elif self.is_npu:
stream = torch.npu.Stream()
stream_context = partial(torch.npu.stream, stream=stream)
else: else:
stream = None stream = None
stream_context = suppress stream_context = suppress
@ -139,7 +143,10 @@ class PrefetchLoader:
first = False first = False
if stream is not None: if stream is not None:
torch.cuda.current_stream().wait_stream(stream) if self.is_cuda:
torch.cuda.current_stream().wait_stream(stream)
elif self.is_npu:
torch.npu.current_stream().wait_stream(stream)
input = next_input input = next_input
target = next_target target = next_target

View File

@ -28,6 +28,30 @@ except ImportError:
_USE_FAST_NORM = False # defaulting to False for now _USE_FAST_NORM = False # defaulting to False for now
def get_autocast_dtype(device: str = 'cuda'):
try:
return torch.get_autocast_dtype(device)
except (AttributeError, TypeError):
# dispatch to older device specific fns, only covering cuda/cpu devices here
if device == 'cpu':
return torch.get_autocast_cpu_dtype()
else:
assert device == 'cuda'
return torch.get_autocast_gpu_dtype()
def is_autocast_enabled(device: str = 'cuda'):
try:
return torch.is_autocast_enabled(device)
except TypeError:
# dispatch to older device specific fns, only covering cuda/cpu devices here
if device == 'cpu':
return torch.is_autocast_cpu_enabled()
else:
assert device == 'cuda'
return torch.is_autocast_enabled() # defaults cuda (only cuda on older pytorch)
def is_fast_norm(): def is_fast_norm():
return _USE_FAST_NORM return _USE_FAST_NORM
@ -48,14 +72,14 @@ def fast_group_norm(
# currently cannot use is_autocast_enabled within torchscript # currently cannot use is_autocast_enabled within torchscript
return F.group_norm(x, num_groups, weight, bias, eps) return F.group_norm(x, num_groups, weight, bias, eps)
if torch.is_autocast_enabled(): if is_autocast_enabled(x.device.type):
# normally native AMP casts GN inputs to float32 # normally native AMP casts GN inputs to float32
# here we use the low precision autocast dtype # here we use the low precision autocast dtype
# FIXME what to do re CPU autocast? # FIXME what to do re CPU autocast?
dt = torch.get_autocast_gpu_dtype() dt = get_autocast_dtype(x.device.type)
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast(device_type=x.device.type, enabled=False):
return F.group_norm(x, num_groups, weight, bias, eps) return F.group_norm(x, num_groups, weight, bias, eps)
@ -73,14 +97,14 @@ def fast_layer_norm(
if has_apex: if has_apex:
return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps) return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps)
if torch.is_autocast_enabled(): if is_autocast_enabled(x.device.type):
# normally native AMP casts LN inputs to float32 # normally native AMP casts LN inputs to float32
# apex LN does not, this is behaving like Apex # apex LN does not, this is behaving like Apex
dt = torch.get_autocast_gpu_dtype() dt = get_autocast_dtype(x.device.type)
# FIXME what to do re CPU autocast? # FIXME what to do re CPU autocast?
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast(device_type=x.device.type, enabled=False):
return F.layer_norm(x, normalized_shape, weight, bias, eps) return F.layer_norm(x, normalized_shape, weight, bias, eps)

View File

@ -46,8 +46,11 @@ class ApexScaler:
class NativeScaler: class NativeScaler:
state_dict_key = "amp_scaler" state_dict_key = "amp_scaler"
def __init__(self): def __init__(self, device='cuda'):
self._scaler = torch.cuda.amp.GradScaler() try:
self._scaler = torch.amp.GradScaler(device=device)
except (AttributeError, TypeError) as e:
self._scaler = torch.cuda.amp.GradScaler()
def __call__( def __call__(
self, self,

View File

@ -116,6 +116,7 @@ def init_distributed_device_so(
"xpu": "ccl", "xpu": "ccl",
"hpu": "hccl", "hpu": "hccl",
"cuda": "nccl", "cuda": "nccl",
"npu": "hccl",
} }
dist_backend = dist_backends.get(device_type, 'gloo') dist_backend = dist_backends.get(device_type, 'gloo')
dist_url = dist_url or 'env://' dist_url = dist_url or 'env://'
@ -159,6 +160,8 @@ def init_distributed_device_so(
if device_type == 'cuda': if device_type == 'cuda':
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.' assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
if device_type == 'npu':
assert torch.npu.is_available(), f'Ascend NPU is not available but {device} was specified.'
if distributed and device != 'cpu': if distributed and device != 'cpu':
# Ignore manually specified device index in distributed mode and # Ignore manually specified device index in distributed mode and

View File

@ -48,12 +48,6 @@ try:
except ImportError: except ImportError:
has_apex = False has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
try: try:
import wandb import wandb
@ -442,7 +436,6 @@ def main():
use_amp = 'apex' use_amp = 'apex'
assert args.amp_dtype == 'float16' assert args.amp_dtype == 'float16'
else: else:
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
use_amp = 'native' use_amp = 'native'
assert args.amp_dtype in ('float16', 'bfloat16') assert args.amp_dtype in ('float16', 'bfloat16')
if args.amp_dtype == 'bfloat16': if args.amp_dtype == 'bfloat16':
@ -572,15 +565,10 @@ def main():
if utils.is_primary(args): if utils.is_primary(args):
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
elif use_amp == 'native': elif use_amp == 'native':
try: amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) if device.type in ('cuda',) and amp_dtype == torch.float16:
except (AttributeError, TypeError):
# fallback to CUDA only AMP for PyTorch < 1.10
assert device.type == 'cuda'
amp_autocast = torch.cuda.amp.autocast
if device.type == 'cuda' and amp_dtype == torch.float16:
# loss scaler only used for float16 (half) dtype, bfloat16 does not need it # loss scaler only used for float16 (half) dtype, bfloat16 does not need it
loss_scaler = NativeScaler() loss_scaler = NativeScaler(device=device.type)
if utils.is_primary(args): if utils.is_primary(args):
_logger.info('Using native Torch AMP. Training in mixed precision.') _logger.info('Using native Torch AMP. Training in mixed precision.')
else: else:
@ -1054,8 +1042,11 @@ def train_one_epoch(
if model_ema is not None: if model_ema is not None:
model_ema.update(model, step=num_updates) model_ema.update(model, step=num_updates)
if args.synchronize_step and device.type == 'cuda': if args.synchronize_step:
torch.cuda.synchronize() if device.type == 'cuda':
torch.cuda.synchronize()
elif device.type == 'npu':
torch.npu.synchronize()
time_now = time.time() time_now = time.time()
update_time_m.update(time.time() - update_start_time) update_time_m.update(time.time() - update_start_time)
update_start_time = time_now update_start_time = time_now
@ -1155,6 +1146,8 @@ def validate(
if device.type == 'cuda': if device.type == 'cuda':
torch.cuda.synchronize() torch.cuda.synchronize()
elif device.type == "npu":
torch.npu.synchronize()
losses_m.update(reduced_loss.item(), input.size(0)) losses_m.update(reduced_loss.item(), input.size(0))
top1_m.update(acc1.item(), output.size(0)) top1_m.update(acc1.item(), output.size(0))

View File

@ -34,13 +34,6 @@ try:
except ImportError: except ImportError:
has_apex = False has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
try: try:
from functorch.compile import memory_efficient_fusion from functorch.compile import memory_efficient_fusion
has_functorch = True has_functorch = True
@ -183,7 +176,6 @@ def validate(args):
use_amp = 'apex' use_amp = 'apex'
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.') _logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
else: else:
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
assert args.amp_dtype in ('float16', 'bfloat16') assert args.amp_dtype in ('float16', 'bfloat16')
use_amp = 'native' use_amp = 'native'
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16 amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
@ -395,8 +387,10 @@ def _try_run(args, initial_batch_size):
while batch_size: while batch_size:
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
try: try:
if torch.cuda.is_available() and 'cuda' in args.device: if 'cuda' in args.device and torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
elif "npu" in args.device and torch.npu.is_available():
torch.npu.empty_cache()
results = validate(args) results = validate(args)
return results return results
except RuntimeError as e: except RuntimeError as e: