mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #2308 from huggingface/device_amp_cleanup
Cleanup some amp related behaviour to better support different (non-cuda) devices
This commit is contained in:
commit
5081b53e48
@ -32,13 +32,6 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
has_native_amp = False
|
|
||||||
try:
|
|
||||||
if getattr(torch.cuda.amp, 'autocast') is not None:
|
|
||||||
has_native_amp = True
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from deepspeed.profiling.flops_profiler import get_model_profile
|
from deepspeed.profiling.flops_profiler import get_model_profile
|
||||||
has_deepspeed_profiling = True
|
has_deepspeed_profiling = True
|
||||||
@ -242,7 +235,7 @@ class BenchmarkRunner:
|
|||||||
self.amp_dtype, self.model_dtype, self.data_dtype = resolve_precision(precision)
|
self.amp_dtype, self.model_dtype, self.data_dtype = resolve_precision(precision)
|
||||||
self.channels_last = kwargs.pop('channels_last', False)
|
self.channels_last = kwargs.pop('channels_last', False)
|
||||||
if self.amp_dtype is not None:
|
if self.amp_dtype is not None:
|
||||||
self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=self.amp_dtype)
|
self.amp_autocast = partial(torch.amp.autocast, device_type=device, dtype=self.amp_dtype)
|
||||||
else:
|
else:
|
||||||
self.amp_autocast = suppress
|
self.amp_autocast = suppress
|
||||||
|
|
||||||
|
@ -28,13 +28,6 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
has_apex = False
|
has_apex = False
|
||||||
|
|
||||||
has_native_amp = False
|
|
||||||
try:
|
|
||||||
if getattr(torch.cuda.amp, 'autocast') is not None:
|
|
||||||
has_native_amp = True
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from functorch.compile import memory_efficient_fusion
|
from functorch.compile import memory_efficient_fusion
|
||||||
has_functorch = True
|
has_functorch = True
|
||||||
@ -170,7 +163,6 @@ def main():
|
|||||||
# resolve AMP arguments based on PyTorch / Apex availability
|
# resolve AMP arguments based on PyTorch / Apex availability
|
||||||
amp_autocast = suppress
|
amp_autocast = suppress
|
||||||
if args.amp:
|
if args.amp:
|
||||||
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
|
|
||||||
assert args.amp_dtype in ('float16', 'bfloat16')
|
assert args.amp_dtype in ('float16', 'bfloat16')
|
||||||
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
|
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
|
||||||
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
|
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
|
||||||
|
@ -113,13 +113,17 @@ class PrefetchLoader:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.random_erasing = None
|
self.random_erasing = None
|
||||||
self.is_cuda = torch.cuda.is_available() and device.type == 'cuda'
|
self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
|
||||||
|
self.is_npu = device.type == 'npu' and torch.npu.is_available()
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
first = True
|
first = True
|
||||||
if self.is_cuda:
|
if self.is_cuda:
|
||||||
stream = torch.cuda.Stream()
|
stream = torch.cuda.Stream()
|
||||||
stream_context = partial(torch.cuda.stream, stream=stream)
|
stream_context = partial(torch.cuda.stream, stream=stream)
|
||||||
|
elif self.is_npu:
|
||||||
|
stream = torch.npu.Stream()
|
||||||
|
stream_context = partial(torch.npu.stream, stream=stream)
|
||||||
else:
|
else:
|
||||||
stream = None
|
stream = None
|
||||||
stream_context = suppress
|
stream_context = suppress
|
||||||
@ -139,7 +143,10 @@ class PrefetchLoader:
|
|||||||
first = False
|
first = False
|
||||||
|
|
||||||
if stream is not None:
|
if stream is not None:
|
||||||
torch.cuda.current_stream().wait_stream(stream)
|
if self.is_cuda:
|
||||||
|
torch.cuda.current_stream().wait_stream(stream)
|
||||||
|
elif self.is_npu:
|
||||||
|
torch.npu.current_stream().wait_stream(stream)
|
||||||
|
|
||||||
input = next_input
|
input = next_input
|
||||||
target = next_target
|
target = next_target
|
||||||
|
@ -28,6 +28,30 @@ except ImportError:
|
|||||||
_USE_FAST_NORM = False # defaulting to False for now
|
_USE_FAST_NORM = False # defaulting to False for now
|
||||||
|
|
||||||
|
|
||||||
|
def get_autocast_dtype(device: str = 'cuda'):
|
||||||
|
try:
|
||||||
|
return torch.get_autocast_dtype(device)
|
||||||
|
except (AttributeError, TypeError):
|
||||||
|
# dispatch to older device specific fns, only covering cuda/cpu devices here
|
||||||
|
if device == 'cpu':
|
||||||
|
return torch.get_autocast_cpu_dtype()
|
||||||
|
else:
|
||||||
|
assert device == 'cuda'
|
||||||
|
return torch.get_autocast_gpu_dtype()
|
||||||
|
|
||||||
|
|
||||||
|
def is_autocast_enabled(device: str = 'cuda'):
|
||||||
|
try:
|
||||||
|
return torch.is_autocast_enabled(device)
|
||||||
|
except TypeError:
|
||||||
|
# dispatch to older device specific fns, only covering cuda/cpu devices here
|
||||||
|
if device == 'cpu':
|
||||||
|
return torch.is_autocast_cpu_enabled()
|
||||||
|
else:
|
||||||
|
assert device == 'cuda'
|
||||||
|
return torch.is_autocast_enabled() # defaults cuda (only cuda on older pytorch)
|
||||||
|
|
||||||
|
|
||||||
def is_fast_norm():
|
def is_fast_norm():
|
||||||
return _USE_FAST_NORM
|
return _USE_FAST_NORM
|
||||||
|
|
||||||
@ -48,14 +72,14 @@ def fast_group_norm(
|
|||||||
# currently cannot use is_autocast_enabled within torchscript
|
# currently cannot use is_autocast_enabled within torchscript
|
||||||
return F.group_norm(x, num_groups, weight, bias, eps)
|
return F.group_norm(x, num_groups, weight, bias, eps)
|
||||||
|
|
||||||
if torch.is_autocast_enabled():
|
if is_autocast_enabled(x.device.type):
|
||||||
# normally native AMP casts GN inputs to float32
|
# normally native AMP casts GN inputs to float32
|
||||||
# here we use the low precision autocast dtype
|
# here we use the low precision autocast dtype
|
||||||
# FIXME what to do re CPU autocast?
|
# FIXME what to do re CPU autocast?
|
||||||
dt = torch.get_autocast_gpu_dtype()
|
dt = get_autocast_dtype(x.device.type)
|
||||||
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
|
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.amp.autocast(device_type=x.device.type, enabled=False):
|
||||||
return F.group_norm(x, num_groups, weight, bias, eps)
|
return F.group_norm(x, num_groups, weight, bias, eps)
|
||||||
|
|
||||||
|
|
||||||
@ -73,14 +97,14 @@ def fast_layer_norm(
|
|||||||
if has_apex:
|
if has_apex:
|
||||||
return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps)
|
return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps)
|
||||||
|
|
||||||
if torch.is_autocast_enabled():
|
if is_autocast_enabled(x.device.type):
|
||||||
# normally native AMP casts LN inputs to float32
|
# normally native AMP casts LN inputs to float32
|
||||||
# apex LN does not, this is behaving like Apex
|
# apex LN does not, this is behaving like Apex
|
||||||
dt = torch.get_autocast_gpu_dtype()
|
dt = get_autocast_dtype(x.device.type)
|
||||||
# FIXME what to do re CPU autocast?
|
# FIXME what to do re CPU autocast?
|
||||||
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
|
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.amp.autocast(device_type=x.device.type, enabled=False):
|
||||||
return F.layer_norm(x, normalized_shape, weight, bias, eps)
|
return F.layer_norm(x, normalized_shape, weight, bias, eps)
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,8 +46,11 @@ class ApexScaler:
|
|||||||
class NativeScaler:
|
class NativeScaler:
|
||||||
state_dict_key = "amp_scaler"
|
state_dict_key = "amp_scaler"
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, device='cuda'):
|
||||||
self._scaler = torch.cuda.amp.GradScaler()
|
try:
|
||||||
|
self._scaler = torch.amp.GradScaler(device=device)
|
||||||
|
except (AttributeError, TypeError) as e:
|
||||||
|
self._scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -116,6 +116,7 @@ def init_distributed_device_so(
|
|||||||
"xpu": "ccl",
|
"xpu": "ccl",
|
||||||
"hpu": "hccl",
|
"hpu": "hccl",
|
||||||
"cuda": "nccl",
|
"cuda": "nccl",
|
||||||
|
"npu": "hccl",
|
||||||
}
|
}
|
||||||
dist_backend = dist_backends.get(device_type, 'gloo')
|
dist_backend = dist_backends.get(device_type, 'gloo')
|
||||||
dist_url = dist_url or 'env://'
|
dist_url = dist_url or 'env://'
|
||||||
@ -159,6 +160,8 @@ def init_distributed_device_so(
|
|||||||
|
|
||||||
if device_type == 'cuda':
|
if device_type == 'cuda':
|
||||||
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
|
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
|
||||||
|
if device_type == 'npu':
|
||||||
|
assert torch.npu.is_available(), f'Ascend NPU is not available but {device} was specified.'
|
||||||
|
|
||||||
if distributed and device != 'cpu':
|
if distributed and device != 'cpu':
|
||||||
# Ignore manually specified device index in distributed mode and
|
# Ignore manually specified device index in distributed mode and
|
||||||
|
27
train.py
27
train.py
@ -48,12 +48,6 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
has_apex = False
|
has_apex = False
|
||||||
|
|
||||||
has_native_amp = False
|
|
||||||
try:
|
|
||||||
if getattr(torch.cuda.amp, 'autocast') is not None:
|
|
||||||
has_native_amp = True
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import wandb
|
import wandb
|
||||||
@ -442,7 +436,6 @@ def main():
|
|||||||
use_amp = 'apex'
|
use_amp = 'apex'
|
||||||
assert args.amp_dtype == 'float16'
|
assert args.amp_dtype == 'float16'
|
||||||
else:
|
else:
|
||||||
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
|
|
||||||
use_amp = 'native'
|
use_amp = 'native'
|
||||||
assert args.amp_dtype in ('float16', 'bfloat16')
|
assert args.amp_dtype in ('float16', 'bfloat16')
|
||||||
if args.amp_dtype == 'bfloat16':
|
if args.amp_dtype == 'bfloat16':
|
||||||
@ -572,15 +565,10 @@ def main():
|
|||||||
if utils.is_primary(args):
|
if utils.is_primary(args):
|
||||||
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
|
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
|
||||||
elif use_amp == 'native':
|
elif use_amp == 'native':
|
||||||
try:
|
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
|
||||||
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
|
if device.type in ('cuda',) and amp_dtype == torch.float16:
|
||||||
except (AttributeError, TypeError):
|
|
||||||
# fallback to CUDA only AMP for PyTorch < 1.10
|
|
||||||
assert device.type == 'cuda'
|
|
||||||
amp_autocast = torch.cuda.amp.autocast
|
|
||||||
if device.type == 'cuda' and amp_dtype == torch.float16:
|
|
||||||
# loss scaler only used for float16 (half) dtype, bfloat16 does not need it
|
# loss scaler only used for float16 (half) dtype, bfloat16 does not need it
|
||||||
loss_scaler = NativeScaler()
|
loss_scaler = NativeScaler(device=device.type)
|
||||||
if utils.is_primary(args):
|
if utils.is_primary(args):
|
||||||
_logger.info('Using native Torch AMP. Training in mixed precision.')
|
_logger.info('Using native Torch AMP. Training in mixed precision.')
|
||||||
else:
|
else:
|
||||||
@ -1054,8 +1042,11 @@ def train_one_epoch(
|
|||||||
if model_ema is not None:
|
if model_ema is not None:
|
||||||
model_ema.update(model, step=num_updates)
|
model_ema.update(model, step=num_updates)
|
||||||
|
|
||||||
if args.synchronize_step and device.type == 'cuda':
|
if args.synchronize_step:
|
||||||
torch.cuda.synchronize()
|
if device.type == 'cuda':
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
elif device.type == 'npu':
|
||||||
|
torch.npu.synchronize()
|
||||||
time_now = time.time()
|
time_now = time.time()
|
||||||
update_time_m.update(time.time() - update_start_time)
|
update_time_m.update(time.time() - update_start_time)
|
||||||
update_start_time = time_now
|
update_start_time = time_now
|
||||||
@ -1155,6 +1146,8 @@ def validate(
|
|||||||
|
|
||||||
if device.type == 'cuda':
|
if device.type == 'cuda':
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
elif device.type == "npu":
|
||||||
|
torch.npu.synchronize()
|
||||||
|
|
||||||
losses_m.update(reduced_loss.item(), input.size(0))
|
losses_m.update(reduced_loss.item(), input.size(0))
|
||||||
top1_m.update(acc1.item(), output.size(0))
|
top1_m.update(acc1.item(), output.size(0))
|
||||||
|
12
validate.py
12
validate.py
@ -34,13 +34,6 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
has_apex = False
|
has_apex = False
|
||||||
|
|
||||||
has_native_amp = False
|
|
||||||
try:
|
|
||||||
if getattr(torch.cuda.amp, 'autocast') is not None:
|
|
||||||
has_native_amp = True
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from functorch.compile import memory_efficient_fusion
|
from functorch.compile import memory_efficient_fusion
|
||||||
has_functorch = True
|
has_functorch = True
|
||||||
@ -183,7 +176,6 @@ def validate(args):
|
|||||||
use_amp = 'apex'
|
use_amp = 'apex'
|
||||||
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
|
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
|
||||||
else:
|
else:
|
||||||
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
|
|
||||||
assert args.amp_dtype in ('float16', 'bfloat16')
|
assert args.amp_dtype in ('float16', 'bfloat16')
|
||||||
use_amp = 'native'
|
use_amp = 'native'
|
||||||
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
|
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
|
||||||
@ -395,8 +387,10 @@ def _try_run(args, initial_batch_size):
|
|||||||
while batch_size:
|
while batch_size:
|
||||||
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
|
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
|
||||||
try:
|
try:
|
||||||
if torch.cuda.is_available() and 'cuda' in args.device:
|
if 'cuda' in args.device and torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
elif "npu" in args.device and torch.npu.is_available():
|
||||||
|
torch.npu.empty_cache()
|
||||||
results = validate(args)
|
results = validate(args)
|
||||||
return results
|
return results
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user