Cleanup some amp related behaviour to better support different (non-cuda) devices

This commit is contained in:
Ross Wightman 2024-10-18 13:54:16 -07:00
parent a852318b63
commit 1766a01f96
6 changed files with 39 additions and 47 deletions

View File

@ -32,13 +32,6 @@ try:
except ImportError:
pass
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
try:
from deepspeed.profiling.flops_profiler import get_model_profile
has_deepspeed_profiling = True
@ -242,7 +235,7 @@ class BenchmarkRunner:
self.amp_dtype, self.model_dtype, self.data_dtype = resolve_precision(precision)
self.channels_last = kwargs.pop('channels_last', False)
if self.amp_dtype is not None:
self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=self.amp_dtype)
self.amp_autocast = partial(torch.amp.autocast, device_type=device, dtype=self.amp_dtype)
else:
self.amp_autocast = suppress

View File

@ -28,13 +28,6 @@ try:
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
try:
from functorch.compile import memory_efficient_fusion
has_functorch = True
@ -170,7 +163,6 @@ def main():
# resolve AMP arguments based on PyTorch / Apex availability
amp_autocast = suppress
if args.amp:
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
assert args.amp_dtype in ('float16', 'bfloat16')
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)

View File

@ -28,6 +28,30 @@ except ImportError:
_USE_FAST_NORM = False # defaulting to False for now
def get_autocast_dtype(device: str = 'cuda'):
try:
return torch.get_autocast_dtype(device)
except (AttributeError, TypeError):
# dispatch to older device specific fns, only covering cuda/cpu devices here
if device == 'cpu':
return torch.get_autocast_cpu_dtype()
else:
assert device == 'cuda'
return torch.get_autocast_gpu_dtype()
def is_autocast_enabled(device: str = 'cuda'):
try:
return torch.is_autocast_enabled(device)
except TypeError:
# dispatch to older device specific fns, only covering cuda/cpu devices here
if device == 'cpu':
return torch.is_autocast_cpu_enabled()
else:
assert device == 'cuda'
return torch.is_autocast_enabled() # defaults cuda (only cuda on older pytorch)
def is_fast_norm():
return _USE_FAST_NORM
@ -48,14 +72,14 @@ def fast_group_norm(
# currently cannot use is_autocast_enabled within torchscript
return F.group_norm(x, num_groups, weight, bias, eps)
if torch.is_autocast_enabled():
if is_autocast_enabled(x.device.type):
# normally native AMP casts GN inputs to float32
# here we use the low precision autocast dtype
# FIXME what to do re CPU autocast?
dt = torch.get_autocast_gpu_dtype()
dt = get_autocast_dtype(x.device.type)
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
with torch.cuda.amp.autocast(enabled=False):
with torch.amp.autocast(device_type=x.device.type, enabled=False):
return F.group_norm(x, num_groups, weight, bias, eps)
@ -73,14 +97,14 @@ def fast_layer_norm(
if has_apex:
return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps)
if torch.is_autocast_enabled():
if is_autocast_enabled(x.device.type):
# normally native AMP casts LN inputs to float32
# apex LN does not, this is behaving like Apex
dt = torch.get_autocast_gpu_dtype()
dt = get_autocast_dtype(x.device.type)
# FIXME what to do re CPU autocast?
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
with torch.cuda.amp.autocast(enabled=False):
with torch.amp.autocast(device_type=x.device.type, enabled=False):
return F.layer_norm(x, normalized_shape, weight, bias, eps)

View File

@ -46,8 +46,11 @@ class ApexScaler:
class NativeScaler:
state_dict_key = "amp_scaler"
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __init__(self, device='cuda'):
try:
self._scaler = torch.amp.GradScaler(device=device)
except (AttributeError, TypeError) as e:
self._scaler = torch.cuda.amp.GradScaler()
def __call__(
self,

View File

@ -48,12 +48,6 @@ try:
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
try:
import wandb
@ -442,7 +436,6 @@ def main():
use_amp = 'apex'
assert args.amp_dtype == 'float16'
else:
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
use_amp = 'native'
assert args.amp_dtype in ('float16', 'bfloat16')
if args.amp_dtype == 'bfloat16':
@ -572,15 +565,10 @@ def main():
if utils.is_primary(args):
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
elif use_amp == 'native':
try:
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
except (AttributeError, TypeError):
# fallback to CUDA only AMP for PyTorch < 1.10
assert device.type == 'cuda'
amp_autocast = torch.cuda.amp.autocast
if device.type == 'cuda' and amp_dtype == torch.float16:
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
if device.type in ('cuda',) and amp_dtype == torch.float16:
# loss scaler only used for float16 (half) dtype, bfloat16 does not need it
loss_scaler = NativeScaler()
loss_scaler = NativeScaler(device=device.type)
if utils.is_primary(args):
_logger.info('Using native Torch AMP. Training in mixed precision.')
else:

View File

@ -34,13 +34,6 @@ try:
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
try:
from functorch.compile import memory_efficient_fusion
has_functorch = True
@ -183,7 +176,6 @@ def validate(args):
use_amp = 'apex'
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
else:
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
assert args.amp_dtype in ('float16', 'bfloat16')
use_amp = 'native'
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16