diff --git a/benchmark.py b/benchmark.py index 422b9a35..d31395b8 100755 --- a/benchmark.py +++ b/benchmark.py @@ -32,13 +32,6 @@ try: except ImportError: pass -has_native_amp = False -try: - if getattr(torch.cuda.amp, 'autocast') is not None: - has_native_amp = True -except AttributeError: - pass - try: from deepspeed.profiling.flops_profiler import get_model_profile has_deepspeed_profiling = True @@ -242,7 +235,7 @@ class BenchmarkRunner: self.amp_dtype, self.model_dtype, self.data_dtype = resolve_precision(precision) self.channels_last = kwargs.pop('channels_last', False) if self.amp_dtype is not None: - self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=self.amp_dtype) + self.amp_autocast = partial(torch.amp.autocast, device_type=device, dtype=self.amp_dtype) else: self.amp_autocast = suppress diff --git a/inference.py b/inference.py index e6bd4ae1..60581978 100755 --- a/inference.py +++ b/inference.py @@ -28,13 +28,6 @@ try: except ImportError: has_apex = False -has_native_amp = False -try: - if getattr(torch.cuda.amp, 'autocast') is not None: - has_native_amp = True -except AttributeError: - pass - try: from functorch.compile import memory_efficient_fusion has_functorch = True @@ -170,7 +163,6 @@ def main(): # resolve AMP arguments based on PyTorch / Apex availability amp_autocast = suppress if args.amp: - assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).' assert args.amp_dtype in ('float16', 'bfloat16') amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16 amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) diff --git a/timm/layers/fast_norm.py b/timm/layers/fast_norm.py index 8bcf1de1..3bbb0b4f 100644 --- a/timm/layers/fast_norm.py +++ b/timm/layers/fast_norm.py @@ -28,6 +28,30 @@ except ImportError: _USE_FAST_NORM = False # defaulting to False for now +def get_autocast_dtype(device: str = 'cuda'): + try: + return torch.get_autocast_dtype(device) + except (AttributeError, TypeError): + # dispatch to older device specific fns, only covering cuda/cpu devices here + if device == 'cpu': + return torch.get_autocast_cpu_dtype() + else: + assert device == 'cuda' + return torch.get_autocast_gpu_dtype() + + +def is_autocast_enabled(device: str = 'cuda'): + try: + return torch.is_autocast_enabled(device) + except TypeError: + # dispatch to older device specific fns, only covering cuda/cpu devices here + if device == 'cpu': + return torch.is_autocast_cpu_enabled() + else: + assert device == 'cuda' + return torch.is_autocast_enabled() # defaults cuda (only cuda on older pytorch) + + def is_fast_norm(): return _USE_FAST_NORM @@ -48,14 +72,14 @@ def fast_group_norm( # currently cannot use is_autocast_enabled within torchscript return F.group_norm(x, num_groups, weight, bias, eps) - if torch.is_autocast_enabled(): + if is_autocast_enabled(x.device.type): # normally native AMP casts GN inputs to float32 # here we use the low precision autocast dtype # FIXME what to do re CPU autocast? - dt = torch.get_autocast_gpu_dtype() + dt = get_autocast_dtype(x.device.type) x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(device_type=x.device.type, enabled=False): return F.group_norm(x, num_groups, weight, bias, eps) @@ -73,14 +97,14 @@ def fast_layer_norm( if has_apex: return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps) - if torch.is_autocast_enabled(): + if is_autocast_enabled(x.device.type): # normally native AMP casts LN inputs to float32 # apex LN does not, this is behaving like Apex - dt = torch.get_autocast_gpu_dtype() + dt = get_autocast_dtype(x.device.type) # FIXME what to do re CPU autocast? x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(device_type=x.device.type, enabled=False): return F.layer_norm(x, normalized_shape, weight, bias, eps) diff --git a/timm/utils/cuda.py b/timm/utils/cuda.py index de0b881c..a0a2770c 100644 --- a/timm/utils/cuda.py +++ b/timm/utils/cuda.py @@ -46,8 +46,11 @@ class ApexScaler: class NativeScaler: state_dict_key = "amp_scaler" - def __init__(self): - self._scaler = torch.cuda.amp.GradScaler() + def __init__(self, device='cuda'): + try: + self._scaler = torch.amp.GradScaler(device=device) + except (AttributeError, TypeError) as e: + self._scaler = torch.cuda.amp.GradScaler() def __call__( self, diff --git a/train.py b/train.py index ebd9bc80..bbcb0cac 100755 --- a/train.py +++ b/train.py @@ -48,12 +48,6 @@ try: except ImportError: has_apex = False -has_native_amp = False -try: - if getattr(torch.cuda.amp, 'autocast') is not None: - has_native_amp = True -except AttributeError: - pass try: import wandb @@ -442,7 +436,6 @@ def main(): use_amp = 'apex' assert args.amp_dtype == 'float16' else: - assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).' use_amp = 'native' assert args.amp_dtype in ('float16', 'bfloat16') if args.amp_dtype == 'bfloat16': @@ -572,15 +565,10 @@ def main(): if utils.is_primary(args): _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': - try: - amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) - except (AttributeError, TypeError): - # fallback to CUDA only AMP for PyTorch < 1.10 - assert device.type == 'cuda' - amp_autocast = torch.cuda.amp.autocast - if device.type == 'cuda' and amp_dtype == torch.float16: + amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) + if device.type in ('cuda',) and amp_dtype == torch.float16: # loss scaler only used for float16 (half) dtype, bfloat16 does not need it - loss_scaler = NativeScaler() + loss_scaler = NativeScaler(device=device.type) if utils.is_primary(args): _logger.info('Using native Torch AMP. Training in mixed precision.') else: diff --git a/validate.py b/validate.py index 6115de7a..d8a4283e 100755 --- a/validate.py +++ b/validate.py @@ -34,13 +34,6 @@ try: except ImportError: has_apex = False -has_native_amp = False -try: - if getattr(torch.cuda.amp, 'autocast') is not None: - has_native_amp = True -except AttributeError: - pass - try: from functorch.compile import memory_efficient_fusion has_functorch = True @@ -183,7 +176,6 @@ def validate(args): use_amp = 'apex' _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') else: - assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).' assert args.amp_dtype in ('float16', 'bfloat16') use_amp = 'native' amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16