mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #2397 from huggingface/half_prec_trainval
Add half-precision (bfloat16, float16) support to train & validate scripts
This commit is contained in:
commit
2d0ac6f567
@ -77,18 +77,18 @@ class PrefetchLoader:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
loader,
|
loader: torch.utils.data.DataLoader,
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
|
||||||
std=IMAGENET_DEFAULT_STD,
|
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
||||||
channels=3,
|
channels: int = 3,
|
||||||
device=torch.device('cuda'),
|
device: torch.device = torch.device('cuda'),
|
||||||
img_dtype=torch.float32,
|
img_dtype: Optional[torch.dtype] = None,
|
||||||
fp16=False,
|
fp16: bool = False,
|
||||||
re_prob=0.,
|
re_prob: float = 0.,
|
||||||
re_mode='const',
|
re_mode: str = 'const',
|
||||||
re_count=1,
|
re_count: int = 1,
|
||||||
re_num_splits=0):
|
re_num_splits: int = 0,
|
||||||
|
):
|
||||||
mean = adapt_to_chs(mean, channels)
|
mean = adapt_to_chs(mean, channels)
|
||||||
std = adapt_to_chs(std, channels)
|
std = adapt_to_chs(std, channels)
|
||||||
normalization_shape = (1, channels, 1, 1)
|
normalization_shape = (1, channels, 1, 1)
|
||||||
@ -98,7 +98,7 @@ class PrefetchLoader:
|
|||||||
if fp16:
|
if fp16:
|
||||||
# fp16 arg is deprecated, but will override dtype arg if set for bwd compat
|
# fp16 arg is deprecated, but will override dtype arg if set for bwd compat
|
||||||
img_dtype = torch.float16
|
img_dtype = torch.float16
|
||||||
self.img_dtype = img_dtype
|
self.img_dtype = img_dtype or torch.float32
|
||||||
self.mean = torch.tensor(
|
self.mean = torch.tensor(
|
||||||
[x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
|
[x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
|
||||||
self.std = torch.tensor(
|
self.std = torch.tensor(
|
||||||
|
36
train.py
36
train.py
@ -178,6 +178,8 @@ group.add_argument('--amp-dtype', default='float16', type=str,
|
|||||||
help='lower precision AMP dtype (default: float16)')
|
help='lower precision AMP dtype (default: float16)')
|
||||||
group.add_argument('--amp-impl', default='native', type=str,
|
group.add_argument('--amp-impl', default='native', type=str,
|
||||||
help='AMP impl to use, "native" or "apex" (default: native)')
|
help='AMP impl to use, "native" or "apex" (default: native)')
|
||||||
|
group.add_argument('--model-dtype', default=None, type=str,
|
||||||
|
help='Model dtype override (non-AMP) (default: float32)')
|
||||||
group.add_argument('--no-ddp-bb', action='store_true', default=False,
|
group.add_argument('--no-ddp-bb', action='store_true', default=False,
|
||||||
help='Force broadcast buffers for native DDP to off.')
|
help='Force broadcast buffers for native DDP to off.')
|
||||||
group.add_argument('--synchronize-step', action='store_true', default=False,
|
group.add_argument('--synchronize-step', action='store_true', default=False,
|
||||||
@ -434,10 +436,18 @@ def main():
|
|||||||
_logger.info(f'Training with a single process on 1 device ({args.device}).')
|
_logger.info(f'Training with a single process on 1 device ({args.device}).')
|
||||||
assert args.rank >= 0
|
assert args.rank >= 0
|
||||||
|
|
||||||
|
model_dtype = None
|
||||||
|
if args.model_dtype:
|
||||||
|
assert args.model_dtype in ('float32', 'float16', 'bfloat16')
|
||||||
|
model_dtype = getattr(torch, args.model_dtype)
|
||||||
|
if model_dtype == torch.float16:
|
||||||
|
_logger.warning('float16 is not recommended for training, for half precision bfloat16 is recommended.')
|
||||||
|
|
||||||
# resolve AMP arguments based on PyTorch / Apex availability
|
# resolve AMP arguments based on PyTorch / Apex availability
|
||||||
use_amp = None
|
use_amp = None
|
||||||
amp_dtype = torch.float16
|
amp_dtype = torch.float16
|
||||||
if args.amp:
|
if args.amp:
|
||||||
|
assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP'
|
||||||
if args.amp_impl == 'apex':
|
if args.amp_impl == 'apex':
|
||||||
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
|
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
|
||||||
use_amp = 'apex'
|
use_amp = 'apex'
|
||||||
@ -517,7 +527,7 @@ def main():
|
|||||||
model = convert_splitbn_model(model, max(num_aug_splits, 2))
|
model = convert_splitbn_model(model, max(num_aug_splits, 2))
|
||||||
|
|
||||||
# move model to GPU, enable channels last layout if set
|
# move model to GPU, enable channels last layout if set
|
||||||
model.to(device=device)
|
model.to(device=device, dtype=model_dtype) # FIXME move model device & dtype into create_model
|
||||||
if args.channels_last:
|
if args.channels_last:
|
||||||
model.to(memory_format=torch.channels_last)
|
model.to(memory_format=torch.channels_last)
|
||||||
|
|
||||||
@ -587,7 +597,7 @@ def main():
|
|||||||
_logger.info('Using native Torch AMP. Training in mixed precision.')
|
_logger.info('Using native Torch AMP. Training in mixed precision.')
|
||||||
else:
|
else:
|
||||||
if utils.is_primary(args):
|
if utils.is_primary(args):
|
||||||
_logger.info('AMP not enabled. Training in float32.')
|
_logger.info(f'AMP not enabled. Training in {model_dtype or torch.float32}.')
|
||||||
|
|
||||||
# optionally resume from a checkpoint
|
# optionally resume from a checkpoint
|
||||||
resume_epoch = None
|
resume_epoch = None
|
||||||
@ -732,6 +742,7 @@ def main():
|
|||||||
distributed=args.distributed,
|
distributed=args.distributed,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
pin_memory=args.pin_mem,
|
pin_memory=args.pin_mem,
|
||||||
|
img_dtype=model_dtype,
|
||||||
device=device,
|
device=device,
|
||||||
use_prefetcher=args.prefetcher,
|
use_prefetcher=args.prefetcher,
|
||||||
use_multi_epochs_loader=args.use_multi_epochs_loader,
|
use_multi_epochs_loader=args.use_multi_epochs_loader,
|
||||||
@ -756,6 +767,7 @@ def main():
|
|||||||
distributed=args.distributed,
|
distributed=args.distributed,
|
||||||
crop_pct=data_config['crop_pct'],
|
crop_pct=data_config['crop_pct'],
|
||||||
pin_memory=args.pin_mem,
|
pin_memory=args.pin_mem,
|
||||||
|
img_dtype=model_dtype,
|
||||||
device=device,
|
device=device,
|
||||||
use_prefetcher=args.prefetcher,
|
use_prefetcher=args.prefetcher,
|
||||||
)
|
)
|
||||||
@ -823,9 +835,13 @@ def main():
|
|||||||
if utils.is_primary(args) and args.log_wandb:
|
if utils.is_primary(args) and args.log_wandb:
|
||||||
if has_wandb:
|
if has_wandb:
|
||||||
assert not args.wandb_resume_id or args.resume
|
assert not args.wandb_resume_id or args.resume
|
||||||
wandb.init(project=args.experiment, config=args, tags=args.wandb_tags,
|
wandb.init(
|
||||||
resume='must' if args.wandb_resume_id else None,
|
project=args.experiment,
|
||||||
id=args.wandb_resume_id if args.wandb_resume_id else None)
|
config=args,
|
||||||
|
tags=args.wandb_tags,
|
||||||
|
resume='must' if args.wandb_resume_id else None,
|
||||||
|
id=args.wandb_resume_id if args.wandb_resume_id else None,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
_logger.warning(
|
_logger.warning(
|
||||||
"You've requested to log metrics to wandb but package not found. "
|
"You've requested to log metrics to wandb but package not found. "
|
||||||
@ -879,6 +895,7 @@ def main():
|
|||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
amp_autocast=amp_autocast,
|
amp_autocast=amp_autocast,
|
||||||
loss_scaler=loss_scaler,
|
loss_scaler=loss_scaler,
|
||||||
|
model_dtype=model_dtype,
|
||||||
model_ema=model_ema,
|
model_ema=model_ema,
|
||||||
mixup_fn=mixup_fn,
|
mixup_fn=mixup_fn,
|
||||||
num_updates_total=num_epochs * updates_per_epoch,
|
num_updates_total=num_epochs * updates_per_epoch,
|
||||||
@ -897,6 +914,7 @@ def main():
|
|||||||
args,
|
args,
|
||||||
device=device,
|
device=device,
|
||||||
amp_autocast=amp_autocast,
|
amp_autocast=amp_autocast,
|
||||||
|
model_dtype=model_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_ema is not None and not args.model_ema_force_cpu:
|
if model_ema is not None and not args.model_ema_force_cpu:
|
||||||
@ -979,6 +997,7 @@ def train_one_epoch(
|
|||||||
output_dir=None,
|
output_dir=None,
|
||||||
amp_autocast=suppress,
|
amp_autocast=suppress,
|
||||||
loss_scaler=None,
|
loss_scaler=None,
|
||||||
|
model_dtype=None,
|
||||||
model_ema=None,
|
model_ema=None,
|
||||||
mixup_fn=None,
|
mixup_fn=None,
|
||||||
num_updates_total=None,
|
num_updates_total=None,
|
||||||
@ -1015,7 +1034,7 @@ def train_one_epoch(
|
|||||||
accum_steps = last_accum_steps
|
accum_steps = last_accum_steps
|
||||||
|
|
||||||
if not args.prefetcher:
|
if not args.prefetcher:
|
||||||
input, target = input.to(device), target.to(device)
|
input, target = input.to(device=device, dtype=model_dtype), target.to(device=device)
|
||||||
if mixup_fn is not None:
|
if mixup_fn is not None:
|
||||||
input, target = mixup_fn(input, target)
|
input, target = mixup_fn(input, target)
|
||||||
if args.channels_last:
|
if args.channels_last:
|
||||||
@ -1142,6 +1161,7 @@ def validate(
|
|||||||
args,
|
args,
|
||||||
device=torch.device('cuda'),
|
device=torch.device('cuda'),
|
||||||
amp_autocast=suppress,
|
amp_autocast=suppress,
|
||||||
|
model_dtype=None,
|
||||||
log_suffix=''
|
log_suffix=''
|
||||||
):
|
):
|
||||||
batch_time_m = utils.AverageMeter()
|
batch_time_m = utils.AverageMeter()
|
||||||
@ -1157,8 +1177,8 @@ def validate(
|
|||||||
for batch_idx, (input, target) in enumerate(loader):
|
for batch_idx, (input, target) in enumerate(loader):
|
||||||
last_batch = batch_idx == last_idx
|
last_batch = batch_idx == last_idx
|
||||||
if not args.prefetcher:
|
if not args.prefetcher:
|
||||||
input = input.to(device)
|
input = input.to(device=device, dtype=model_dtype)
|
||||||
target = target.to(device)
|
target = target.to(device=device)
|
||||||
if args.channels_last:
|
if args.channels_last:
|
||||||
input = input.contiguous(memory_format=torch.channels_last)
|
input = input.contiguous(memory_format=torch.channels_last)
|
||||||
|
|
||||||
|
19
validate.py
19
validate.py
@ -123,6 +123,8 @@ parser.add_argument('--amp-dtype', default='float16', type=str,
|
|||||||
help='lower precision AMP dtype (default: float16)')
|
help='lower precision AMP dtype (default: float16)')
|
||||||
parser.add_argument('--amp-impl', default='native', type=str,
|
parser.add_argument('--amp-impl', default='native', type=str,
|
||||||
help='AMP impl to use, "native" or "apex" (default: native)')
|
help='AMP impl to use, "native" or "apex" (default: native)')
|
||||||
|
parser.add_argument('--model-dtype', default=None, type=str,
|
||||||
|
help='Model dtype override (non-AMP) (default: float32)')
|
||||||
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
|
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
|
||||||
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
|
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
|
||||||
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
||||||
@ -168,10 +170,16 @@ def validate(args):
|
|||||||
|
|
||||||
device = torch.device(args.device)
|
device = torch.device(args.device)
|
||||||
|
|
||||||
|
model_dtype = None
|
||||||
|
if args.model_dtype:
|
||||||
|
assert args.model_dtype in ('float32', 'float16', 'bfloat16')
|
||||||
|
model_dtype = getattr(torch, args.model_dtype)
|
||||||
|
|
||||||
# resolve AMP arguments based on PyTorch / Apex availability
|
# resolve AMP arguments based on PyTorch / Apex availability
|
||||||
use_amp = None
|
use_amp = None
|
||||||
amp_autocast = suppress
|
amp_autocast = suppress
|
||||||
if args.amp:
|
if args.amp:
|
||||||
|
assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP'
|
||||||
if args.amp_impl == 'apex':
|
if args.amp_impl == 'apex':
|
||||||
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
|
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
|
||||||
assert args.amp_dtype == 'float16'
|
assert args.amp_dtype == 'float16'
|
||||||
@ -184,7 +192,7 @@ def validate(args):
|
|||||||
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
|
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
|
||||||
_logger.info('Validating in mixed precision with native PyTorch AMP.')
|
_logger.info('Validating in mixed precision with native PyTorch AMP.')
|
||||||
else:
|
else:
|
||||||
_logger.info('Validating in float32. AMP not enabled.')
|
_logger.info(f'Validating in {model_dtype or torch.float32}. AMP not enabled.')
|
||||||
|
|
||||||
if args.fuser:
|
if args.fuser:
|
||||||
set_jit_fuser(args.fuser)
|
set_jit_fuser(args.fuser)
|
||||||
@ -231,7 +239,7 @@ def validate(args):
|
|||||||
if args.test_pool:
|
if args.test_pool:
|
||||||
model, test_time_pool = apply_test_time_pool(model, data_config)
|
model, test_time_pool = apply_test_time_pool(model, data_config)
|
||||||
|
|
||||||
model = model.to(device)
|
model = model.to(device=device, dtype=model_dtype) # FIXME move model device & dtype into create_model
|
||||||
if args.channels_last:
|
if args.channels_last:
|
||||||
model = model.to(memory_format=torch.channels_last)
|
model = model.to(memory_format=torch.channels_last)
|
||||||
|
|
||||||
@ -299,6 +307,7 @@ def validate(args):
|
|||||||
crop_border_pixels=args.crop_border_pixels,
|
crop_border_pixels=args.crop_border_pixels,
|
||||||
pin_memory=args.pin_mem,
|
pin_memory=args.pin_mem,
|
||||||
device=device,
|
device=device,
|
||||||
|
img_dtype=model_dtype,
|
||||||
tf_preprocessing=args.tf_preprocessing,
|
tf_preprocessing=args.tf_preprocessing,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -310,7 +319,7 @@ def validate(args):
|
|||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
|
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
|
||||||
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device)
|
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device=device, dtype=model_dtype)
|
||||||
if args.channels_last:
|
if args.channels_last:
|
||||||
input = input.contiguous(memory_format=torch.channels_last)
|
input = input.contiguous(memory_format=torch.channels_last)
|
||||||
with amp_autocast():
|
with amp_autocast():
|
||||||
@ -319,8 +328,8 @@ def validate(args):
|
|||||||
end = time.time()
|
end = time.time()
|
||||||
for batch_idx, (input, target) in enumerate(loader):
|
for batch_idx, (input, target) in enumerate(loader):
|
||||||
if args.no_prefetcher:
|
if args.no_prefetcher:
|
||||||
target = target.to(device)
|
target = target.to(device=device)
|
||||||
input = input.to(device)
|
input = input.to(device=device, dtype=model_dtype)
|
||||||
if args.channels_last:
|
if args.channels_last:
|
||||||
input = input.contiguous(memory_format=torch.channels_last)
|
input = input.contiguous(memory_format=torch.channels_last)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user