Merge branch 'main' into caojiaolong-main

caojiaolong-main
Ross Wightman 2025-01-08 09:11:50 -08:00
commit c173886e75
3 changed files with 63 additions and 38 deletions

View File

@ -77,18 +77,18 @@ class PrefetchLoader:
def __init__(
self,
loader,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
channels=3,
device=torch.device('cuda'),
img_dtype=torch.float32,
fp16=False,
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0):
loader: torch.utils.data.DataLoader,
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
channels: int = 3,
device: torch.device = torch.device('cuda'),
img_dtype: Optional[torch.dtype] = None,
fp16: bool = False,
re_prob: float = 0.,
re_mode: str = 'const',
re_count: int = 1,
re_num_splits: int = 0,
):
mean = adapt_to_chs(mean, channels)
std = adapt_to_chs(std, channels)
normalization_shape = (1, channels, 1, 1)
@ -98,7 +98,7 @@ class PrefetchLoader:
if fp16:
# fp16 arg is deprecated, but will override dtype arg if set for bwd compat
img_dtype = torch.float16
self.img_dtype = img_dtype
self.img_dtype = img_dtype or torch.float32
self.mean = torch.tensor(
[x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
self.std = torch.tensor(

View File

@ -178,6 +178,8 @@ group.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)')
group.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--model-dtype', default=None, type=str,
help='Model dtype override (non-AMP) (default: float32)')
group.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.')
group.add_argument('--synchronize-step', action='store_true', default=False,
@ -436,10 +438,18 @@ def main():
_logger.info(f'Training with a single process on 1 device ({args.device}).')
assert args.rank >= 0
model_dtype = None
if args.model_dtype:
assert args.model_dtype in ('float32', 'float16', 'bfloat16')
model_dtype = getattr(torch, args.model_dtype)
if model_dtype == torch.float16:
_logger.warning('float16 is not recommended for training, for half precision bfloat16 is recommended.')
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
amp_dtype = torch.float16
if args.amp:
assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP'
if args.amp_impl == 'apex':
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
use_amp = 'apex'
@ -519,7 +529,7 @@ def main():
model = convert_splitbn_model(model, max(num_aug_splits, 2))
# move model to GPU, enable channels last layout if set
model.to(device=device)
model.to(device=device, dtype=model_dtype) # FIXME move model device & dtype into create_model
if args.channels_last:
model.to(memory_format=torch.channels_last)
@ -589,7 +599,7 @@ def main():
_logger.info('Using native Torch AMP. Training in mixed precision.')
else:
if utils.is_primary(args):
_logger.info('AMP not enabled. Training in float32.')
_logger.info(f'AMP not enabled. Training in {model_dtype or torch.float32}.')
# optionally resume from a checkpoint
resume_epoch = None
@ -734,6 +744,7 @@ def main():
distributed=args.distributed,
collate_fn=collate_fn,
pin_memory=args.pin_mem,
img_dtype=model_dtype,
device=device,
use_prefetcher=args.prefetcher,
use_multi_epochs_loader=args.use_multi_epochs_loader,
@ -758,6 +769,7 @@ def main():
distributed=args.distributed,
crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem,
img_dtype=model_dtype,
device=device,
use_prefetcher=args.prefetcher,
)
@ -822,21 +834,21 @@ def main():
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
if utils.is_primary(args) and args.log_wandb:
if has_wandb:
assert not args.wandb_resume_id or args.resume
wandb.init(
project=args.wandb_project,
name=args.experiment,
config=args,
tags=args.wandb_tags,
resume="must" if args.wandb_resume_id else None,
id=args.wandb_resume_id if args.wandb_resume_id else None,
)
else:
_logger.warning(
"You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
if args.log_wandb:
if has_wandb:
assert not args.wandb_resume_id or args.resume
wandb.init(
project=args.wandb_project,
name=exp_name,
config=args,
tags=args.wandb_tags,
resume="must" if args.wandb_resume_id else None,
id=args.wandb_resume_id if args.wandb_resume_id else None,
)
else:
_logger.warning(
"You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
# setup learning rate schedule and starting epoch
updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps
@ -886,6 +898,7 @@ def main():
output_dir=output_dir,
amp_autocast=amp_autocast,
loss_scaler=loss_scaler,
model_dtype=model_dtype,
model_ema=model_ema,
mixup_fn=mixup_fn,
num_updates_total=num_epochs * updates_per_epoch,
@ -904,6 +917,7 @@ def main():
args,
device=device,
amp_autocast=amp_autocast,
model_dtype=model_dtype,
)
if model_ema is not None and not args.model_ema_force_cpu:
@ -986,6 +1000,7 @@ def train_one_epoch(
output_dir=None,
amp_autocast=suppress,
loss_scaler=None,
model_dtype=None,
model_ema=None,
mixup_fn=None,
num_updates_total=None,
@ -1022,7 +1037,7 @@ def train_one_epoch(
accum_steps = last_accum_steps
if not args.prefetcher:
input, target = input.to(device), target.to(device)
input, target = input.to(device=device, dtype=model_dtype), target.to(device=device)
if mixup_fn is not None:
input, target = mixup_fn(input, target)
if args.channels_last:
@ -1149,6 +1164,7 @@ def validate(
args,
device=torch.device('cuda'),
amp_autocast=suppress,
model_dtype=None,
log_suffix=''
):
batch_time_m = utils.AverageMeter()
@ -1164,8 +1180,8 @@ def validate(
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
if not args.prefetcher:
input = input.to(device)
target = target.to(device)
input = input.to(device=device, dtype=model_dtype)
target = target.to(device=device)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)

View File

@ -123,6 +123,8 @@ parser.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)')
parser.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)')
parser.add_argument('--model-dtype', default=None, type=str,
help='Model dtype override (non-AMP) (default: float32)')
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
@ -168,10 +170,16 @@ def validate(args):
device = torch.device(args.device)
model_dtype = None
if args.model_dtype:
assert args.model_dtype in ('float32', 'float16', 'bfloat16')
model_dtype = getattr(torch, args.model_dtype)
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
amp_autocast = suppress
if args.amp:
assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP'
if args.amp_impl == 'apex':
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
assert args.amp_dtype == 'float16'
@ -184,7 +192,7 @@ def validate(args):
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
_logger.info('Validating in mixed precision with native PyTorch AMP.')
else:
_logger.info('Validating in float32. AMP not enabled.')
_logger.info(f'Validating in {model_dtype or torch.float32}. AMP not enabled.')
if args.fuser:
set_jit_fuser(args.fuser)
@ -231,7 +239,7 @@ def validate(args):
if args.test_pool:
model, test_time_pool = apply_test_time_pool(model, data_config)
model = model.to(device)
model = model.to(device=device, dtype=model_dtype) # FIXME move model device & dtype into create_model
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
@ -299,6 +307,7 @@ def validate(args):
crop_border_pixels=args.crop_border_pixels,
pin_memory=args.pin_mem,
device=device,
img_dtype=model_dtype,
tf_preprocessing=args.tf_preprocessing,
)
@ -310,7 +319,7 @@ def validate(args):
model.eval()
with torch.no_grad():
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device)
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device=device, dtype=model_dtype)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
@ -319,8 +328,8 @@ def validate(args):
end = time.time()
for batch_idx, (input, target) in enumerate(loader):
if args.no_prefetcher:
target = target.to(device)
input = input.to(device)
target = target.to(device=device)
input = input.to(device=device, dtype=model_dtype)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)