mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Change reduce_bn to distribute_bn, add ability to choose between broadcast and reduce (mean). Add crop_pct arg to allow selecting validation crop while training.
This commit is contained in:
parent
3bff2b21dc
commit
a435ea1327
@ -210,12 +210,17 @@ def reduce_tensor(tensor, n):
|
|||||||
return rt
|
return rt
|
||||||
|
|
||||||
|
|
||||||
def reduce_bn(model, world_size):
|
def distribute_bn(model, world_size, reduce=False):
|
||||||
# ensure every node has the same running bn stats
|
# ensure every node has the same running bn stats
|
||||||
for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True):
|
for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True):
|
||||||
if ('running_mean' in bn_name) or ('running_var' in bn_name):
|
if ('running_mean' in bn_name) or ('running_var' in bn_name):
|
||||||
torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
|
if reduce:
|
||||||
bn_buf /= float(world_size)
|
# average bn stats across whole group
|
||||||
|
torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
|
||||||
|
bn_buf /= float(world_size)
|
||||||
|
else:
|
||||||
|
# broadcast bn stats from rank 0 to whole group
|
||||||
|
torch.distributed.broadcast(bn_buf, 0)
|
||||||
|
|
||||||
|
|
||||||
class ModelEma:
|
class ModelEma:
|
||||||
|
19
train.py
19
train.py
@ -55,6 +55,8 @@ parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
|
|||||||
help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")')
|
help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")')
|
||||||
parser.add_argument('--img-size', type=int, default=None, metavar='N',
|
parser.add_argument('--img-size', type=int, default=None, metavar='N',
|
||||||
help='Image patch size (default: None => model default)')
|
help='Image patch size (default: None => model default)')
|
||||||
|
parser.add_argument('--crop-pct', default=None, type=float,
|
||||||
|
metavar='N', help='Input image center crop percent (for validation only)')
|
||||||
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||||
help='Override mean pixel value of dataset')
|
help='Override mean pixel value of dataset')
|
||||||
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
||||||
@ -121,6 +123,10 @@ parser.add_argument('--bn-momentum', type=float, default=None,
|
|||||||
help='BatchNorm momentum override (if not None)')
|
help='BatchNorm momentum override (if not None)')
|
||||||
parser.add_argument('--bn-eps', type=float, default=None,
|
parser.add_argument('--bn-eps', type=float, default=None,
|
||||||
help='BatchNorm epsilon override (if not None)')
|
help='BatchNorm epsilon override (if not None)')
|
||||||
|
parser.add_argument('--sync-bn', action='store_true',
|
||||||
|
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
|
||||||
|
parser.add_argument('--dist-bn', type=str, default='',
|
||||||
|
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
|
||||||
# Model Exponential Moving Average
|
# Model Exponential Moving Average
|
||||||
parser.add_argument('--model-ema', action='store_true', default=False,
|
parser.add_argument('--model-ema', action='store_true', default=False,
|
||||||
help='Enable tracking moving average of model weights')
|
help='Enable tracking moving average of model weights')
|
||||||
@ -143,10 +149,6 @@ parser.add_argument('--save-images', action='store_true', default=False,
|
|||||||
help='save images of input bathes every log interval for debugging')
|
help='save images of input bathes every log interval for debugging')
|
||||||
parser.add_argument('--amp', action='store_true', default=False,
|
parser.add_argument('--amp', action='store_true', default=False,
|
||||||
help='use NVIDIA amp for mixed precision training')
|
help='use NVIDIA amp for mixed precision training')
|
||||||
parser.add_argument('--sync-bn', action='store_true',
|
|
||||||
help='enabling apex sync BN.')
|
|
||||||
parser.add_argument('--reduce-bn', action='store_true',
|
|
||||||
help='average BN running stats across all distributed nodes between train and validation.')
|
|
||||||
parser.add_argument('--no-prefetcher', action='store_true', default=False,
|
parser.add_argument('--no-prefetcher', action='store_true', default=False,
|
||||||
help='disable fast prefetcher')
|
help='disable fast prefetcher')
|
||||||
parser.add_argument('--output', default='', type=str, metavar='PATH',
|
parser.add_argument('--output', default='', type=str, metavar='PATH',
|
||||||
@ -349,6 +351,7 @@ def main():
|
|||||||
std=data_config['std'],
|
std=data_config['std'],
|
||||||
num_workers=args.workers,
|
num_workers=args.workers,
|
||||||
distributed=args.distributed,
|
distributed=args.distributed,
|
||||||
|
crop_pct=data_config['crop_pct'],
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.mixup > 0.:
|
if args.mixup > 0.:
|
||||||
@ -390,16 +393,16 @@ def main():
|
|||||||
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
||||||
use_amp=use_amp, model_ema=model_ema)
|
use_amp=use_amp, model_ema=model_ema)
|
||||||
|
|
||||||
if args.distributed and args.reduce_bn:
|
if args.distributed and args.dist_bn and args.dist_bn in ('broadcast', 'reduce'):
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
logging.info("Averaging bn running means and vars")
|
logging.info("Distributing BatchNorm running means and vars")
|
||||||
reduce_bn(model, args.world_size)
|
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
|
||||||
|
|
||||||
eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
|
eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
|
||||||
|
|
||||||
if model_ema is not None and not args.model_ema_force_cpu:
|
if model_ema is not None and not args.model_ema_force_cpu:
|
||||||
if args.distributed and args.reduce_bn:
|
if args.distributed and args.reduce_bn:
|
||||||
reduce_bn(model_ema, args.world_size)
|
distribute_bn(model_ema, args.world_size)
|
||||||
|
|
||||||
ema_eval_metrics = validate(
|
ema_eval_metrics = validate(
|
||||||
model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')
|
model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user