Merge branch 'main' into caojiaolong-main
commit
c173886e75
|
@ -77,18 +77,18 @@ class PrefetchLoader:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
loader,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
channels=3,
|
||||
device=torch.device('cuda'),
|
||||
img_dtype=torch.float32,
|
||||
fp16=False,
|
||||
re_prob=0.,
|
||||
re_mode='const',
|
||||
re_count=1,
|
||||
re_num_splits=0):
|
||||
|
||||
loader: torch.utils.data.DataLoader,
|
||||
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
|
||||
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
||||
channels: int = 3,
|
||||
device: torch.device = torch.device('cuda'),
|
||||
img_dtype: Optional[torch.dtype] = None,
|
||||
fp16: bool = False,
|
||||
re_prob: float = 0.,
|
||||
re_mode: str = 'const',
|
||||
re_count: int = 1,
|
||||
re_num_splits: int = 0,
|
||||
):
|
||||
mean = adapt_to_chs(mean, channels)
|
||||
std = adapt_to_chs(std, channels)
|
||||
normalization_shape = (1, channels, 1, 1)
|
||||
|
@ -98,7 +98,7 @@ class PrefetchLoader:
|
|||
if fp16:
|
||||
# fp16 arg is deprecated, but will override dtype arg if set for bwd compat
|
||||
img_dtype = torch.float16
|
||||
self.img_dtype = img_dtype
|
||||
self.img_dtype = img_dtype or torch.float32
|
||||
self.mean = torch.tensor(
|
||||
[x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
|
||||
self.std = torch.tensor(
|
||||
|
|
56
train.py
56
train.py
|
@ -178,6 +178,8 @@ group.add_argument('--amp-dtype', default='float16', type=str,
|
|||
help='lower precision AMP dtype (default: float16)')
|
||||
group.add_argument('--amp-impl', default='native', type=str,
|
||||
help='AMP impl to use, "native" or "apex" (default: native)')
|
||||
group.add_argument('--model-dtype', default=None, type=str,
|
||||
help='Model dtype override (non-AMP) (default: float32)')
|
||||
group.add_argument('--no-ddp-bb', action='store_true', default=False,
|
||||
help='Force broadcast buffers for native DDP to off.')
|
||||
group.add_argument('--synchronize-step', action='store_true', default=False,
|
||||
|
@ -436,10 +438,18 @@ def main():
|
|||
_logger.info(f'Training with a single process on 1 device ({args.device}).')
|
||||
assert args.rank >= 0
|
||||
|
||||
model_dtype = None
|
||||
if args.model_dtype:
|
||||
assert args.model_dtype in ('float32', 'float16', 'bfloat16')
|
||||
model_dtype = getattr(torch, args.model_dtype)
|
||||
if model_dtype == torch.float16:
|
||||
_logger.warning('float16 is not recommended for training, for half precision bfloat16 is recommended.')
|
||||
|
||||
# resolve AMP arguments based on PyTorch / Apex availability
|
||||
use_amp = None
|
||||
amp_dtype = torch.float16
|
||||
if args.amp:
|
||||
assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP'
|
||||
if args.amp_impl == 'apex':
|
||||
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
|
||||
use_amp = 'apex'
|
||||
|
@ -519,7 +529,7 @@ def main():
|
|||
model = convert_splitbn_model(model, max(num_aug_splits, 2))
|
||||
|
||||
# move model to GPU, enable channels last layout if set
|
||||
model.to(device=device)
|
||||
model.to(device=device, dtype=model_dtype) # FIXME move model device & dtype into create_model
|
||||
if args.channels_last:
|
||||
model.to(memory_format=torch.channels_last)
|
||||
|
||||
|
@ -589,7 +599,7 @@ def main():
|
|||
_logger.info('Using native Torch AMP. Training in mixed precision.')
|
||||
else:
|
||||
if utils.is_primary(args):
|
||||
_logger.info('AMP not enabled. Training in float32.')
|
||||
_logger.info(f'AMP not enabled. Training in {model_dtype or torch.float32}.')
|
||||
|
||||
# optionally resume from a checkpoint
|
||||
resume_epoch = None
|
||||
|
@ -734,6 +744,7 @@ def main():
|
|||
distributed=args.distributed,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=args.pin_mem,
|
||||
img_dtype=model_dtype,
|
||||
device=device,
|
||||
use_prefetcher=args.prefetcher,
|
||||
use_multi_epochs_loader=args.use_multi_epochs_loader,
|
||||
|
@ -758,6 +769,7 @@ def main():
|
|||
distributed=args.distributed,
|
||||
crop_pct=data_config['crop_pct'],
|
||||
pin_memory=args.pin_mem,
|
||||
img_dtype=model_dtype,
|
||||
device=device,
|
||||
use_prefetcher=args.prefetcher,
|
||||
)
|
||||
|
@ -822,21 +834,21 @@ def main():
|
|||
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
|
||||
f.write(args_text)
|
||||
|
||||
if utils.is_primary(args) and args.log_wandb:
|
||||
if has_wandb:
|
||||
assert not args.wandb_resume_id or args.resume
|
||||
wandb.init(
|
||||
project=args.wandb_project,
|
||||
name=args.experiment,
|
||||
config=args,
|
||||
tags=args.wandb_tags,
|
||||
resume="must" if args.wandb_resume_id else None,
|
||||
id=args.wandb_resume_id if args.wandb_resume_id else None,
|
||||
)
|
||||
else:
|
||||
_logger.warning(
|
||||
"You've requested to log metrics to wandb but package not found. "
|
||||
"Metrics not being logged to wandb, try `pip install wandb`")
|
||||
if args.log_wandb:
|
||||
if has_wandb:
|
||||
assert not args.wandb_resume_id or args.resume
|
||||
wandb.init(
|
||||
project=args.wandb_project,
|
||||
name=exp_name,
|
||||
config=args,
|
||||
tags=args.wandb_tags,
|
||||
resume="must" if args.wandb_resume_id else None,
|
||||
id=args.wandb_resume_id if args.wandb_resume_id else None,
|
||||
)
|
||||
else:
|
||||
_logger.warning(
|
||||
"You've requested to log metrics to wandb but package not found. "
|
||||
"Metrics not being logged to wandb, try `pip install wandb`")
|
||||
|
||||
# setup learning rate schedule and starting epoch
|
||||
updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps
|
||||
|
@ -886,6 +898,7 @@ def main():
|
|||
output_dir=output_dir,
|
||||
amp_autocast=amp_autocast,
|
||||
loss_scaler=loss_scaler,
|
||||
model_dtype=model_dtype,
|
||||
model_ema=model_ema,
|
||||
mixup_fn=mixup_fn,
|
||||
num_updates_total=num_epochs * updates_per_epoch,
|
||||
|
@ -904,6 +917,7 @@ def main():
|
|||
args,
|
||||
device=device,
|
||||
amp_autocast=amp_autocast,
|
||||
model_dtype=model_dtype,
|
||||
)
|
||||
|
||||
if model_ema is not None and not args.model_ema_force_cpu:
|
||||
|
@ -986,6 +1000,7 @@ def train_one_epoch(
|
|||
output_dir=None,
|
||||
amp_autocast=suppress,
|
||||
loss_scaler=None,
|
||||
model_dtype=None,
|
||||
model_ema=None,
|
||||
mixup_fn=None,
|
||||
num_updates_total=None,
|
||||
|
@ -1022,7 +1037,7 @@ def train_one_epoch(
|
|||
accum_steps = last_accum_steps
|
||||
|
||||
if not args.prefetcher:
|
||||
input, target = input.to(device), target.to(device)
|
||||
input, target = input.to(device=device, dtype=model_dtype), target.to(device=device)
|
||||
if mixup_fn is not None:
|
||||
input, target = mixup_fn(input, target)
|
||||
if args.channels_last:
|
||||
|
@ -1149,6 +1164,7 @@ def validate(
|
|||
args,
|
||||
device=torch.device('cuda'),
|
||||
amp_autocast=suppress,
|
||||
model_dtype=None,
|
||||
log_suffix=''
|
||||
):
|
||||
batch_time_m = utils.AverageMeter()
|
||||
|
@ -1164,8 +1180,8 @@ def validate(
|
|||
for batch_idx, (input, target) in enumerate(loader):
|
||||
last_batch = batch_idx == last_idx
|
||||
if not args.prefetcher:
|
||||
input = input.to(device)
|
||||
target = target.to(device)
|
||||
input = input.to(device=device, dtype=model_dtype)
|
||||
target = target.to(device=device)
|
||||
if args.channels_last:
|
||||
input = input.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
|
|
19
validate.py
19
validate.py
|
@ -123,6 +123,8 @@ parser.add_argument('--amp-dtype', default='float16', type=str,
|
|||
help='lower precision AMP dtype (default: float16)')
|
||||
parser.add_argument('--amp-impl', default='native', type=str,
|
||||
help='AMP impl to use, "native" or "apex" (default: native)')
|
||||
parser.add_argument('--model-dtype', default=None, type=str,
|
||||
help='Model dtype override (non-AMP) (default: float32)')
|
||||
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
|
||||
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
|
||||
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
||||
|
@ -168,10 +170,16 @@ def validate(args):
|
|||
|
||||
device = torch.device(args.device)
|
||||
|
||||
model_dtype = None
|
||||
if args.model_dtype:
|
||||
assert args.model_dtype in ('float32', 'float16', 'bfloat16')
|
||||
model_dtype = getattr(torch, args.model_dtype)
|
||||
|
||||
# resolve AMP arguments based on PyTorch / Apex availability
|
||||
use_amp = None
|
||||
amp_autocast = suppress
|
||||
if args.amp:
|
||||
assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP'
|
||||
if args.amp_impl == 'apex':
|
||||
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
|
||||
assert args.amp_dtype == 'float16'
|
||||
|
@ -184,7 +192,7 @@ def validate(args):
|
|||
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
|
||||
_logger.info('Validating in mixed precision with native PyTorch AMP.')
|
||||
else:
|
||||
_logger.info('Validating in float32. AMP not enabled.')
|
||||
_logger.info(f'Validating in {model_dtype or torch.float32}. AMP not enabled.')
|
||||
|
||||
if args.fuser:
|
||||
set_jit_fuser(args.fuser)
|
||||
|
@ -231,7 +239,7 @@ def validate(args):
|
|||
if args.test_pool:
|
||||
model, test_time_pool = apply_test_time_pool(model, data_config)
|
||||
|
||||
model = model.to(device)
|
||||
model = model.to(device=device, dtype=model_dtype) # FIXME move model device & dtype into create_model
|
||||
if args.channels_last:
|
||||
model = model.to(memory_format=torch.channels_last)
|
||||
|
||||
|
@ -299,6 +307,7 @@ def validate(args):
|
|||
crop_border_pixels=args.crop_border_pixels,
|
||||
pin_memory=args.pin_mem,
|
||||
device=device,
|
||||
img_dtype=model_dtype,
|
||||
tf_preprocessing=args.tf_preprocessing,
|
||||
)
|
||||
|
||||
|
@ -310,7 +319,7 @@ def validate(args):
|
|||
model.eval()
|
||||
with torch.no_grad():
|
||||
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
|
||||
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device)
|
||||
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device=device, dtype=model_dtype)
|
||||
if args.channels_last:
|
||||
input = input.contiguous(memory_format=torch.channels_last)
|
||||
with amp_autocast():
|
||||
|
@ -319,8 +328,8 @@ def validate(args):
|
|||
end = time.time()
|
||||
for batch_idx, (input, target) in enumerate(loader):
|
||||
if args.no_prefetcher:
|
||||
target = target.to(device)
|
||||
input = input.to(device)
|
||||
target = target.to(device=device)
|
||||
input = input.to(device=device, dtype=model_dtype)
|
||||
if args.channels_last:
|
||||
input = input.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
|
|
Loading…
Reference in New Issue