diff --git a/timm/data/loader.py b/timm/data/loader.py index 3b4a6d0e..31344057 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -77,18 +77,18 @@ class PrefetchLoader: def __init__( self, - loader, - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, - channels=3, - device=torch.device('cuda'), - img_dtype=torch.float32, - fp16=False, - re_prob=0., - re_mode='const', - re_count=1, - re_num_splits=0): - + loader: torch.utils.data.DataLoader, + mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN, + std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, + channels: int = 3, + device: torch.device = torch.device('cuda'), + img_dtype: Optional[torch.dtype] = None, + fp16: bool = False, + re_prob: float = 0., + re_mode: str = 'const', + re_count: int = 1, + re_num_splits: int = 0, + ): mean = adapt_to_chs(mean, channels) std = adapt_to_chs(std, channels) normalization_shape = (1, channels, 1, 1) @@ -98,7 +98,7 @@ class PrefetchLoader: if fp16: # fp16 arg is deprecated, but will override dtype arg if set for bwd compat img_dtype = torch.float16 - self.img_dtype = img_dtype + self.img_dtype = img_dtype or torch.float32 self.mean = torch.tensor( [x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape) self.std = torch.tensor( diff --git a/train.py b/train.py index a9e63f45..7d8e9898 100755 --- a/train.py +++ b/train.py @@ -178,6 +178,8 @@ group.add_argument('--amp-dtype', default='float16', type=str, help='lower precision AMP dtype (default: float16)') group.add_argument('--amp-impl', default='native', type=str, help='AMP impl to use, "native" or "apex" (default: native)') +group.add_argument('--model-dtype', default=None, type=str, + help='Model dtype override (non-AMP) (default: float32)') group.add_argument('--no-ddp-bb', action='store_true', default=False, help='Force broadcast buffers for native DDP to off.') group.add_argument('--synchronize-step', action='store_true', default=False, @@ -434,10 +436,18 @@ def main(): _logger.info(f'Training with a single process on 1 device ({args.device}).') assert args.rank >= 0 + model_dtype = None + if args.model_dtype: + assert args.model_dtype in ('float32', 'float16', 'bfloat16') + model_dtype = getattr(torch, args.model_dtype) + if model_dtype == torch.float16: + _logger.warning('float16 is not recommended for training, for half precision bfloat16 is recommended.') + # resolve AMP arguments based on PyTorch / Apex availability use_amp = None amp_dtype = torch.float16 if args.amp: + assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP' if args.amp_impl == 'apex': assert has_apex, 'AMP impl specified as APEX but APEX is not installed.' use_amp = 'apex' @@ -517,7 +527,7 @@ def main(): model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set - model.to(device=device) + model.to(device=device, dtype=model_dtype) # FIXME move model device & dtype into create_model if args.channels_last: model.to(memory_format=torch.channels_last) @@ -587,7 +597,7 @@ def main(): _logger.info('Using native Torch AMP. Training in mixed precision.') else: if utils.is_primary(args): - _logger.info('AMP not enabled. Training in float32.') + _logger.info(f'AMP not enabled. Training in {model_dtype or torch.float32}.') # optionally resume from a checkpoint resume_epoch = None @@ -732,6 +742,7 @@ def main(): distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, + img_dtype=model_dtype, device=device, use_prefetcher=args.prefetcher, use_multi_epochs_loader=args.use_multi_epochs_loader, @@ -756,6 +767,7 @@ def main(): distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, + img_dtype=model_dtype, device=device, use_prefetcher=args.prefetcher, ) @@ -823,9 +835,13 @@ def main(): if utils.is_primary(args) and args.log_wandb: if has_wandb: assert not args.wandb_resume_id or args.resume - wandb.init(project=args.experiment, config=args, tags=args.wandb_tags, - resume='must' if args.wandb_resume_id else None, - id=args.wandb_resume_id if args.wandb_resume_id else None) + wandb.init( + project=args.experiment, + config=args, + tags=args.wandb_tags, + resume='must' if args.wandb_resume_id else None, + id=args.wandb_resume_id if args.wandb_resume_id else None, + ) else: _logger.warning( "You've requested to log metrics to wandb but package not found. " @@ -879,6 +895,7 @@ def main(): output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, + model_dtype=model_dtype, model_ema=model_ema, mixup_fn=mixup_fn, num_updates_total=num_epochs * updates_per_epoch, @@ -897,6 +914,7 @@ def main(): args, device=device, amp_autocast=amp_autocast, + model_dtype=model_dtype, ) if model_ema is not None and not args.model_ema_force_cpu: @@ -979,6 +997,7 @@ def train_one_epoch( output_dir=None, amp_autocast=suppress, loss_scaler=None, + model_dtype=None, model_ema=None, mixup_fn=None, num_updates_total=None, @@ -1015,7 +1034,7 @@ def train_one_epoch( accum_steps = last_accum_steps if not args.prefetcher: - input, target = input.to(device), target.to(device) + input, target = input.to(device=device, dtype=model_dtype), target.to(device=device) if mixup_fn is not None: input, target = mixup_fn(input, target) if args.channels_last: @@ -1142,6 +1161,7 @@ def validate( args, device=torch.device('cuda'), amp_autocast=suppress, + model_dtype=None, log_suffix='' ): batch_time_m = utils.AverageMeter() @@ -1157,8 +1177,8 @@ def validate( for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx if not args.prefetcher: - input = input.to(device) - target = target.to(device) + input = input.to(device=device, dtype=model_dtype) + target = target.to(device=device) if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) diff --git a/validate.py b/validate.py index 159bd0b1..37fefaa6 100755 --- a/validate.py +++ b/validate.py @@ -123,6 +123,8 @@ parser.add_argument('--amp-dtype', default='float16', type=str, help='lower precision AMP dtype (default: float16)') parser.add_argument('--amp-impl', default='native', type=str, help='AMP impl to use, "native" or "apex" (default: native)') +parser.add_argument('--model-dtype', default=None, type=str, + help='Model dtype override (non-AMP) (default: float32)') parser.add_argument('--tf-preprocessing', action='store_true', default=False, help='Use Tensorflow preprocessing pipeline (require CPU TF installed') parser.add_argument('--use-ema', dest='use_ema', action='store_true', @@ -168,10 +170,16 @@ def validate(args): device = torch.device(args.device) + model_dtype = None + if args.model_dtype: + assert args.model_dtype in ('float32', 'float16', 'bfloat16') + model_dtype = getattr(torch, args.model_dtype) + # resolve AMP arguments based on PyTorch / Apex availability use_amp = None amp_autocast = suppress if args.amp: + assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP' if args.amp_impl == 'apex': assert has_apex, 'AMP impl specified as APEX but APEX is not installed.' assert args.amp_dtype == 'float16' @@ -184,7 +192,7 @@ def validate(args): amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) _logger.info('Validating in mixed precision with native PyTorch AMP.') else: - _logger.info('Validating in float32. AMP not enabled.') + _logger.info(f'Validating in {model_dtype or torch.float32}. AMP not enabled.') if args.fuser: set_jit_fuser(args.fuser) @@ -231,7 +239,7 @@ def validate(args): if args.test_pool: model, test_time_pool = apply_test_time_pool(model, data_config) - model = model.to(device) + model = model.to(device=device, dtype=model_dtype) # FIXME move model device & dtype into create_model if args.channels_last: model = model.to(memory_format=torch.channels_last) @@ -299,6 +307,7 @@ def validate(args): crop_border_pixels=args.crop_border_pixels, pin_memory=args.pin_mem, device=device, + img_dtype=model_dtype, tf_preprocessing=args.tf_preprocessing, ) @@ -310,7 +319,7 @@ def validate(args): model.eval() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non - input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device) + input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device=device, dtype=model_dtype) if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) with amp_autocast(): @@ -319,8 +328,8 @@ def validate(args): end = time.time() for batch_idx, (input, target) in enumerate(loader): if args.no_prefetcher: - target = target.to(device) - input = input.to(device) + target = target.to(device=device) + input = input.to(device=device, dtype=model_dtype) if args.channels_last: input = input.contiguous(memory_format=torch.channels_last)