Add loss scale arg, initial distributed loss scale. Maybe fix FX for the model.

pull/2466/head
Ross Wightman 2025-04-08 20:47:57 -07:00
parent 6675590264
commit 13e0f3a4a3
3 changed files with 58 additions and 23 deletions

View File

@ -8,6 +8,8 @@ from .dataset_info import DatasetInfo, CustomDatasetInfo
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
from .loader import create_loader
from .mixup import Mixup, FastCollateMixup
from .naflex_dataset import VariableSeqMapWrapper
from .naflex_loader import create_naflex_loader
from .naflex_transforms import (
ResizeToSequence,
CenterCropToSequence,

View File

@ -356,10 +356,9 @@ def create_attention_mask(
"""
patch_valid = patch_valid.bool()
B = patch_valid.shape[0]
device = patch_valid.device
if num_prefix_tokens > 0:
prefix_valid = torch.ones((B, num_prefix_tokens), device=device, dtype=torch.bool)
prefix_valid = patch_valid.new_ones((B, num_prefix_tokens))
patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
mask_bool = (patch_valid.unsqueeze(-1) & patch_valid.unsqueeze(1)).unsqueeze(1)
@ -390,10 +389,9 @@ def create_attention_mask2(
"""
patch_valid = patch_valid.bool()
B, kv_len = patch_valid.shape
device = patch_valid.device
if num_prefix_tokens > 0:
prefix_valid = torch.ones((B, num_prefix_tokens), device=device, dtype=torch.bool)
prefix_valid = patch_valid.new_ones((B, num_prefix_tokens))
patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
kv_len = patch_valid.shape[1]

View File

@ -33,7 +33,8 @@ import yaml
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm import utils
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.data import create_dataset, create_loader, create_naflex_loader, resolve_data_config, \
Mixup, FastCollateMixup, AugMixDataset
from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters
@ -403,7 +404,8 @@ group.add_argument('--naflex-train-seq-lens', type=int, nargs='+', default=[128,
help='Sequence lengths to use for NaFlex loader')
group.add_argument('--naflex-max-seq-len', type=int, default=576,
help='Fixed maximum sequence length for NaFlex loader (validation)')
group.add_argument('--naflex-loss-scale', default='linear', type=str,
help='Scale loss (gradient) by batch_size ("none", "sqrt", or "linear")')
def _parse_args():
@ -762,11 +764,12 @@ def main():
worker_seeding=args.worker_seeding,
)
naflex_mode = False
if args.naflex_loader:
from timm.data.naflex_loader import create_naflex_loader
if utils.is_primary(args):
_logger.info('Using NaFlex loader')
naflex_mode = True
loader_train = create_naflex_loader(
dataset=dataset_train,
patch_size=16, # Could be derived from model config
@ -804,7 +807,6 @@ def main():
)
if args.naflex_loader:
from timm.data.naflex_loader import create_naflex_loader
# Use largest sequence length for validation
loader_eval = create_naflex_loader(
dataset=dataset_eval,
@ -950,6 +952,7 @@ def main():
model_ema=model_ema,
mixup_fn=mixup_fn,
num_updates_total=num_epochs * updates_per_epoch,
naflex_mode=naflex_mode,
)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
@ -1052,6 +1055,7 @@ def train_one_epoch(
model_ema=None,
mixup_fn=None,
num_updates_total=None,
naflex_mode=False,
):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
@ -1097,10 +1101,10 @@ def train_one_epoch(
def _forward():
with amp_autocast():
output = model(input)
loss = loss_fn(output, target)
_loss = loss_fn(output, target)
if accum_steps > 1:
loss /= accum_steps
return loss
_loss /= accum_steps
return _loss
def _backward(_loss):
if loss_scaler is not None:
@ -1124,18 +1128,48 @@ def train_one_epoch(
)
optimizer.step()
if has_no_sync and not need_update:
with model.no_sync():
loss = _forward()
_backward(loss)
else:
loss = _forward()
_backward(loss)
if isinstance(input, dict):
if naflex_mode:
assert isinstance(input, dict)
batch_size = input['patches'].shape[0]
# scale gradient vs the minimum batch size (for max seq len)
if not args.naflex_loss_scale or args.naflex_loss_scale == 'none':
local_scale = 1.0
else:
local_scale = (batch_size / args.batch_size)
if local_scale == 'sqrt':
local_scale = local_scale ** 0.5
if args.distributed:
# scale gradient btw distributed ranks, each one can have different batch size
global_batch_size = utils.reduce_tensor(torch.tensor(batch_size, device=device), 1) # SUM
dist_scale = args.world_size * batch_size / global_batch_size
else:
dist_scale = None
if has_no_sync and not need_update:
with model.no_sync():
loss = _forward()
scaled_loss = local_scale * loss
if dist_scale is not None:
scaled_loss *= dist_scale
_backward(scaled_loss)
else:
loss = _forward()
scaled_loss = local_scale * loss
if dist_scale is not None:
scaled_loss *= dist_scale
_backward(scaled_loss)
else:
batch_size = input.shape[0]
if has_no_sync and not need_update:
with model.no_sync():
loss = _forward()
_backward(loss)
else:
loss = _forward()
_backward(loss)
losses_m.update(loss.item() * accum_steps, batch_size)
update_sample_count += batch_size
@ -1154,7 +1188,8 @@ def train_one_epoch(
elif device.type == 'npu':
torch.npu.synchronize()
time_now = time.time()
update_time_m.update(time.time() - update_start_time)
update_time_m.update((time.time() - update_start_time) / update_sample_count, update_sample_count)
update_start_time = time_now
if update_idx % args.log_interval == 0:
@ -1173,8 +1208,8 @@ def train_one_epoch(
f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} '
f'({100. * (update_idx + 1) / updates_per_epoch:>3.0f}%)] '
f'Loss: {loss_now:#.3g} ({loss_avg:#.3g}) '
f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s '
f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) '
f'Time: {update_time_m.val:.3f}s, {1 / update_time_m.val:>7.2f}/s '
f'({update_time_m.avg:.3f}s, {1 / update_time_m.avg:>7.2f}/s) '
f'LR: {lr:.3e} '
f'Data: {data_time_m.val:.3f} ({data_time_m.avg:.3f})'
)