diff --git a/benchmark.py b/benchmark.py index 422b9a35..d31395b8 100755 --- a/benchmark.py +++ b/benchmark.py @@ -32,13 +32,6 @@ try: except ImportError: pass -has_native_amp = False -try: - if getattr(torch.cuda.amp, 'autocast') is not None: - has_native_amp = True -except AttributeError: - pass - try: from deepspeed.profiling.flops_profiler import get_model_profile has_deepspeed_profiling = True @@ -242,7 +235,7 @@ class BenchmarkRunner: self.amp_dtype, self.model_dtype, self.data_dtype = resolve_precision(precision) self.channels_last = kwargs.pop('channels_last', False) if self.amp_dtype is not None: - self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=self.amp_dtype) + self.amp_autocast = partial(torch.amp.autocast, device_type=device, dtype=self.amp_dtype) else: self.amp_autocast = suppress diff --git a/inference.py b/inference.py index e6bd4ae1..60581978 100755 --- a/inference.py +++ b/inference.py @@ -28,13 +28,6 @@ try: except ImportError: has_apex = False -has_native_amp = False -try: - if getattr(torch.cuda.amp, 'autocast') is not None: - has_native_amp = True -except AttributeError: - pass - try: from functorch.compile import memory_efficient_fusion has_functorch = True @@ -170,7 +163,6 @@ def main(): # resolve AMP arguments based on PyTorch / Apex availability amp_autocast = suppress if args.amp: - assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).' assert args.amp_dtype in ('float16', 'bfloat16') amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16 amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) diff --git a/timm/data/loader.py b/timm/data/loader.py index ff61ad56..3b4a6d0e 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -113,13 +113,17 @@ class PrefetchLoader: ) else: self.random_erasing = None - self.is_cuda = torch.cuda.is_available() and device.type == 'cuda' + self.is_cuda = device.type == 'cuda' and torch.cuda.is_available() + self.is_npu = device.type == 'npu' and torch.npu.is_available() def __iter__(self): first = True if self.is_cuda: stream = torch.cuda.Stream() stream_context = partial(torch.cuda.stream, stream=stream) + elif self.is_npu: + stream = torch.npu.Stream() + stream_context = partial(torch.npu.stream, stream=stream) else: stream = None stream_context = suppress @@ -139,7 +143,10 @@ class PrefetchLoader: first = False if stream is not None: - torch.cuda.current_stream().wait_stream(stream) + if self.is_cuda: + torch.cuda.current_stream().wait_stream(stream) + elif self.is_npu: + torch.npu.current_stream().wait_stream(stream) input = next_input target = next_target diff --git a/timm/layers/fast_norm.py b/timm/layers/fast_norm.py index 8bcf1de1..3bbb0b4f 100644 --- a/timm/layers/fast_norm.py +++ b/timm/layers/fast_norm.py @@ -28,6 +28,30 @@ except ImportError: _USE_FAST_NORM = False # defaulting to False for now +def get_autocast_dtype(device: str = 'cuda'): + try: + return torch.get_autocast_dtype(device) + except (AttributeError, TypeError): + # dispatch to older device specific fns, only covering cuda/cpu devices here + if device == 'cpu': + return torch.get_autocast_cpu_dtype() + else: + assert device == 'cuda' + return torch.get_autocast_gpu_dtype() + + +def is_autocast_enabled(device: str = 'cuda'): + try: + return torch.is_autocast_enabled(device) + except TypeError: + # dispatch to older device specific fns, only covering cuda/cpu devices here + if device == 'cpu': + return torch.is_autocast_cpu_enabled() + else: + assert device == 'cuda' + return torch.is_autocast_enabled() # defaults cuda (only cuda on older pytorch) + + def is_fast_norm(): return _USE_FAST_NORM @@ -48,14 +72,14 @@ def fast_group_norm( # currently cannot use is_autocast_enabled within torchscript return F.group_norm(x, num_groups, weight, bias, eps) - if torch.is_autocast_enabled(): + if is_autocast_enabled(x.device.type): # normally native AMP casts GN inputs to float32 # here we use the low precision autocast dtype # FIXME what to do re CPU autocast? - dt = torch.get_autocast_gpu_dtype() + dt = get_autocast_dtype(x.device.type) x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(device_type=x.device.type, enabled=False): return F.group_norm(x, num_groups, weight, bias, eps) @@ -73,14 +97,14 @@ def fast_layer_norm( if has_apex: return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps) - if torch.is_autocast_enabled(): + if is_autocast_enabled(x.device.type): # normally native AMP casts LN inputs to float32 # apex LN does not, this is behaving like Apex - dt = torch.get_autocast_gpu_dtype() + dt = get_autocast_dtype(x.device.type) # FIXME what to do re CPU autocast? x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(device_type=x.device.type, enabled=False): return F.layer_norm(x, normalized_shape, weight, bias, eps) diff --git a/timm/utils/cuda.py b/timm/utils/cuda.py index de0b881c..a0a2770c 100644 --- a/timm/utils/cuda.py +++ b/timm/utils/cuda.py @@ -46,8 +46,11 @@ class ApexScaler: class NativeScaler: state_dict_key = "amp_scaler" - def __init__(self): - self._scaler = torch.cuda.amp.GradScaler() + def __init__(self, device='cuda'): + try: + self._scaler = torch.amp.GradScaler(device=device) + except (AttributeError, TypeError) as e: + self._scaler = torch.cuda.amp.GradScaler() def __call__( self, diff --git a/timm/utils/distributed.py b/timm/utils/distributed.py index 18f526bb..cca2cdbb 100644 --- a/timm/utils/distributed.py +++ b/timm/utils/distributed.py @@ -116,6 +116,7 @@ def init_distributed_device_so( "xpu": "ccl", "hpu": "hccl", "cuda": "nccl", + "npu": "hccl", } dist_backend = dist_backends.get(device_type, 'gloo') dist_url = dist_url or 'env://' @@ -159,6 +160,8 @@ def init_distributed_device_so( if device_type == 'cuda': assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.' + if device_type == 'npu': + assert torch.npu.is_available(), f'Ascend NPU is not available but {device} was specified.' if distributed and device != 'cpu': # Ignore manually specified device index in distributed mode and diff --git a/train.py b/train.py index ebd9bc80..a82fa0a8 100755 --- a/train.py +++ b/train.py @@ -48,12 +48,6 @@ try: except ImportError: has_apex = False -has_native_amp = False -try: - if getattr(torch.cuda.amp, 'autocast') is not None: - has_native_amp = True -except AttributeError: - pass try: import wandb @@ -442,7 +436,6 @@ def main(): use_amp = 'apex' assert args.amp_dtype == 'float16' else: - assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).' use_amp = 'native' assert args.amp_dtype in ('float16', 'bfloat16') if args.amp_dtype == 'bfloat16': @@ -572,15 +565,10 @@ def main(): if utils.is_primary(args): _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': - try: - amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) - except (AttributeError, TypeError): - # fallback to CUDA only AMP for PyTorch < 1.10 - assert device.type == 'cuda' - amp_autocast = torch.cuda.amp.autocast - if device.type == 'cuda' and amp_dtype == torch.float16: + amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) + if device.type in ('cuda',) and amp_dtype == torch.float16: # loss scaler only used for float16 (half) dtype, bfloat16 does not need it - loss_scaler = NativeScaler() + loss_scaler = NativeScaler(device=device.type) if utils.is_primary(args): _logger.info('Using native Torch AMP. Training in mixed precision.') else: @@ -1054,8 +1042,11 @@ def train_one_epoch( if model_ema is not None: model_ema.update(model, step=num_updates) - if args.synchronize_step and device.type == 'cuda': - torch.cuda.synchronize() + if args.synchronize_step: + if device.type == 'cuda': + torch.cuda.synchronize() + elif device.type == 'npu': + torch.npu.synchronize() time_now = time.time() update_time_m.update(time.time() - update_start_time) update_start_time = time_now @@ -1155,6 +1146,8 @@ def validate( if device.type == 'cuda': torch.cuda.synchronize() + elif device.type == "npu": + torch.npu.synchronize() losses_m.update(reduced_loss.item(), input.size(0)) top1_m.update(acc1.item(), output.size(0)) diff --git a/validate.py b/validate.py index 6115de7a..602111bb 100755 --- a/validate.py +++ b/validate.py @@ -34,13 +34,6 @@ try: except ImportError: has_apex = False -has_native_amp = False -try: - if getattr(torch.cuda.amp, 'autocast') is not None: - has_native_amp = True -except AttributeError: - pass - try: from functorch.compile import memory_efficient_fusion has_functorch = True @@ -183,7 +176,6 @@ def validate(args): use_amp = 'apex' _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') else: - assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).' assert args.amp_dtype in ('float16', 'bfloat16') use_amp = 'native' amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16 @@ -395,8 +387,10 @@ def _try_run(args, initial_batch_size): while batch_size: args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case try: - if torch.cuda.is_available() and 'cuda' in args.device: + if 'cuda' in args.device and torch.cuda.is_available(): torch.cuda.empty_cache() + elif "npu" in args.device and torch.npu.is_available(): + torch.npu.empty_cache() results = validate(args) return results except RuntimeError as e: