Add loss scale arg, initial distributed loss scale. Maybe fix FX for the model.
parent
6675590264
commit
13e0f3a4a3
timm
data
|
@ -8,6 +8,8 @@ from .dataset_info import DatasetInfo, CustomDatasetInfo
|
|||
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
|
||||
from .loader import create_loader
|
||||
from .mixup import Mixup, FastCollateMixup
|
||||
from .naflex_dataset import VariableSeqMapWrapper
|
||||
from .naflex_loader import create_naflex_loader
|
||||
from .naflex_transforms import (
|
||||
ResizeToSequence,
|
||||
CenterCropToSequence,
|
||||
|
|
|
@ -356,10 +356,9 @@ def create_attention_mask(
|
|||
"""
|
||||
patch_valid = patch_valid.bool()
|
||||
B = patch_valid.shape[0]
|
||||
device = patch_valid.device
|
||||
|
||||
if num_prefix_tokens > 0:
|
||||
prefix_valid = torch.ones((B, num_prefix_tokens), device=device, dtype=torch.bool)
|
||||
prefix_valid = patch_valid.new_ones((B, num_prefix_tokens))
|
||||
patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
|
||||
|
||||
mask_bool = (patch_valid.unsqueeze(-1) & patch_valid.unsqueeze(1)).unsqueeze(1)
|
||||
|
@ -390,10 +389,9 @@ def create_attention_mask2(
|
|||
"""
|
||||
patch_valid = patch_valid.bool()
|
||||
B, kv_len = patch_valid.shape
|
||||
device = patch_valid.device
|
||||
|
||||
if num_prefix_tokens > 0:
|
||||
prefix_valid = torch.ones((B, num_prefix_tokens), device=device, dtype=torch.bool)
|
||||
prefix_valid = patch_valid.new_ones((B, num_prefix_tokens))
|
||||
patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
|
||||
kv_len = patch_valid.shape[1]
|
||||
|
||||
|
|
73
train.py
73
train.py
|
@ -33,7 +33,8 @@ import yaml
|
|||
from torch.nn.parallel import DistributedDataParallel as NativeDDP
|
||||
|
||||
from timm import utils
|
||||
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
||||
from timm.data import create_dataset, create_loader, create_naflex_loader, resolve_data_config, \
|
||||
Mixup, FastCollateMixup, AugMixDataset
|
||||
from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm
|
||||
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy
|
||||
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters
|
||||
|
@ -403,7 +404,8 @@ group.add_argument('--naflex-train-seq-lens', type=int, nargs='+', default=[128,
|
|||
help='Sequence lengths to use for NaFlex loader')
|
||||
group.add_argument('--naflex-max-seq-len', type=int, default=576,
|
||||
help='Fixed maximum sequence length for NaFlex loader (validation)')
|
||||
|
||||
group.add_argument('--naflex-loss-scale', default='linear', type=str,
|
||||
help='Scale loss (gradient) by batch_size ("none", "sqrt", or "linear")')
|
||||
|
||||
|
||||
def _parse_args():
|
||||
|
@ -762,11 +764,12 @@ def main():
|
|||
worker_seeding=args.worker_seeding,
|
||||
)
|
||||
|
||||
naflex_mode = False
|
||||
if args.naflex_loader:
|
||||
from timm.data.naflex_loader import create_naflex_loader
|
||||
if utils.is_primary(args):
|
||||
_logger.info('Using NaFlex loader')
|
||||
|
||||
naflex_mode = True
|
||||
loader_train = create_naflex_loader(
|
||||
dataset=dataset_train,
|
||||
patch_size=16, # Could be derived from model config
|
||||
|
@ -804,7 +807,6 @@ def main():
|
|||
)
|
||||
|
||||
if args.naflex_loader:
|
||||
from timm.data.naflex_loader import create_naflex_loader
|
||||
# Use largest sequence length for validation
|
||||
loader_eval = create_naflex_loader(
|
||||
dataset=dataset_eval,
|
||||
|
@ -950,6 +952,7 @@ def main():
|
|||
model_ema=model_ema,
|
||||
mixup_fn=mixup_fn,
|
||||
num_updates_total=num_epochs * updates_per_epoch,
|
||||
naflex_mode=naflex_mode,
|
||||
)
|
||||
|
||||
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
||||
|
@ -1052,6 +1055,7 @@ def train_one_epoch(
|
|||
model_ema=None,
|
||||
mixup_fn=None,
|
||||
num_updates_total=None,
|
||||
naflex_mode=False,
|
||||
):
|
||||
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
|
||||
if args.prefetcher and loader.mixup_enabled:
|
||||
|
@ -1097,10 +1101,10 @@ def train_one_epoch(
|
|||
def _forward():
|
||||
with amp_autocast():
|
||||
output = model(input)
|
||||
loss = loss_fn(output, target)
|
||||
_loss = loss_fn(output, target)
|
||||
if accum_steps > 1:
|
||||
loss /= accum_steps
|
||||
return loss
|
||||
_loss /= accum_steps
|
||||
return _loss
|
||||
|
||||
def _backward(_loss):
|
||||
if loss_scaler is not None:
|
||||
|
@ -1124,18 +1128,48 @@ def train_one_epoch(
|
|||
)
|
||||
optimizer.step()
|
||||
|
||||
if has_no_sync and not need_update:
|
||||
with model.no_sync():
|
||||
loss = _forward()
|
||||
_backward(loss)
|
||||
else:
|
||||
loss = _forward()
|
||||
_backward(loss)
|
||||
|
||||
if isinstance(input, dict):
|
||||
if naflex_mode:
|
||||
assert isinstance(input, dict)
|
||||
batch_size = input['patches'].shape[0]
|
||||
|
||||
# scale gradient vs the minimum batch size (for max seq len)
|
||||
if not args.naflex_loss_scale or args.naflex_loss_scale == 'none':
|
||||
local_scale = 1.0
|
||||
else:
|
||||
local_scale = (batch_size / args.batch_size)
|
||||
if local_scale == 'sqrt':
|
||||
local_scale = local_scale ** 0.5
|
||||
|
||||
if args.distributed:
|
||||
# scale gradient btw distributed ranks, each one can have different batch size
|
||||
global_batch_size = utils.reduce_tensor(torch.tensor(batch_size, device=device), 1) # SUM
|
||||
dist_scale = args.world_size * batch_size / global_batch_size
|
||||
else:
|
||||
dist_scale = None
|
||||
|
||||
if has_no_sync and not need_update:
|
||||
with model.no_sync():
|
||||
loss = _forward()
|
||||
scaled_loss = local_scale * loss
|
||||
if dist_scale is not None:
|
||||
scaled_loss *= dist_scale
|
||||
_backward(scaled_loss)
|
||||
else:
|
||||
loss = _forward()
|
||||
scaled_loss = local_scale * loss
|
||||
if dist_scale is not None:
|
||||
scaled_loss *= dist_scale
|
||||
_backward(scaled_loss)
|
||||
else:
|
||||
batch_size = input.shape[0]
|
||||
if has_no_sync and not need_update:
|
||||
with model.no_sync():
|
||||
loss = _forward()
|
||||
_backward(loss)
|
||||
else:
|
||||
loss = _forward()
|
||||
_backward(loss)
|
||||
|
||||
losses_m.update(loss.item() * accum_steps, batch_size)
|
||||
update_sample_count += batch_size
|
||||
|
||||
|
@ -1154,7 +1188,8 @@ def train_one_epoch(
|
|||
elif device.type == 'npu':
|
||||
torch.npu.synchronize()
|
||||
time_now = time.time()
|
||||
update_time_m.update(time.time() - update_start_time)
|
||||
|
||||
update_time_m.update((time.time() - update_start_time) / update_sample_count, update_sample_count)
|
||||
update_start_time = time_now
|
||||
|
||||
if update_idx % args.log_interval == 0:
|
||||
|
@ -1173,8 +1208,8 @@ def train_one_epoch(
|
|||
f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} '
|
||||
f'({100. * (update_idx + 1) / updates_per_epoch:>3.0f}%)] '
|
||||
f'Loss: {loss_now:#.3g} ({loss_avg:#.3g}) '
|
||||
f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s '
|
||||
f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) '
|
||||
f'Time: {update_time_m.val:.3f}s, {1 / update_time_m.val:>7.2f}/s '
|
||||
f'({update_time_m.avg:.3f}s, {1 / update_time_m.avg:>7.2f}/s) '
|
||||
f'LR: {lr:.3e} '
|
||||
f'Data: {data_time_m.val:.3f} ({data_time_m.avg:.3f})'
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue