mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add loss scale arg, initial distributed loss scale. Maybe fix FX for the model.
This commit is contained in:
parent
6675590264
commit
13e0f3a4a3
@ -8,6 +8,8 @@ from .dataset_info import DatasetInfo, CustomDatasetInfo
|
|||||||
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
|
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
|
||||||
from .loader import create_loader
|
from .loader import create_loader
|
||||||
from .mixup import Mixup, FastCollateMixup
|
from .mixup import Mixup, FastCollateMixup
|
||||||
|
from .naflex_dataset import VariableSeqMapWrapper
|
||||||
|
from .naflex_loader import create_naflex_loader
|
||||||
from .naflex_transforms import (
|
from .naflex_transforms import (
|
||||||
ResizeToSequence,
|
ResizeToSequence,
|
||||||
CenterCropToSequence,
|
CenterCropToSequence,
|
||||||
|
@ -356,10 +356,9 @@ def create_attention_mask(
|
|||||||
"""
|
"""
|
||||||
patch_valid = patch_valid.bool()
|
patch_valid = patch_valid.bool()
|
||||||
B = patch_valid.shape[0]
|
B = patch_valid.shape[0]
|
||||||
device = patch_valid.device
|
|
||||||
|
|
||||||
if num_prefix_tokens > 0:
|
if num_prefix_tokens > 0:
|
||||||
prefix_valid = torch.ones((B, num_prefix_tokens), device=device, dtype=torch.bool)
|
prefix_valid = patch_valid.new_ones((B, num_prefix_tokens))
|
||||||
patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
|
patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
|
||||||
|
|
||||||
mask_bool = (patch_valid.unsqueeze(-1) & patch_valid.unsqueeze(1)).unsqueeze(1)
|
mask_bool = (patch_valid.unsqueeze(-1) & patch_valid.unsqueeze(1)).unsqueeze(1)
|
||||||
@ -390,10 +389,9 @@ def create_attention_mask2(
|
|||||||
"""
|
"""
|
||||||
patch_valid = patch_valid.bool()
|
patch_valid = patch_valid.bool()
|
||||||
B, kv_len = patch_valid.shape
|
B, kv_len = patch_valid.shape
|
||||||
device = patch_valid.device
|
|
||||||
|
|
||||||
if num_prefix_tokens > 0:
|
if num_prefix_tokens > 0:
|
||||||
prefix_valid = torch.ones((B, num_prefix_tokens), device=device, dtype=torch.bool)
|
prefix_valid = patch_valid.new_ones((B, num_prefix_tokens))
|
||||||
patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
|
patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
|
||||||
kv_len = patch_valid.shape[1]
|
kv_len = patch_valid.shape[1]
|
||||||
|
|
||||||
|
73
train.py
73
train.py
@ -33,7 +33,8 @@ import yaml
|
|||||||
from torch.nn.parallel import DistributedDataParallel as NativeDDP
|
from torch.nn.parallel import DistributedDataParallel as NativeDDP
|
||||||
|
|
||||||
from timm import utils
|
from timm import utils
|
||||||
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
from timm.data import create_dataset, create_loader, create_naflex_loader, resolve_data_config, \
|
||||||
|
Mixup, FastCollateMixup, AugMixDataset
|
||||||
from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm
|
from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm
|
||||||
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy
|
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy
|
||||||
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters
|
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters
|
||||||
@ -403,7 +404,8 @@ group.add_argument('--naflex-train-seq-lens', type=int, nargs='+', default=[128,
|
|||||||
help='Sequence lengths to use for NaFlex loader')
|
help='Sequence lengths to use for NaFlex loader')
|
||||||
group.add_argument('--naflex-max-seq-len', type=int, default=576,
|
group.add_argument('--naflex-max-seq-len', type=int, default=576,
|
||||||
help='Fixed maximum sequence length for NaFlex loader (validation)')
|
help='Fixed maximum sequence length for NaFlex loader (validation)')
|
||||||
|
group.add_argument('--naflex-loss-scale', default='linear', type=str,
|
||||||
|
help='Scale loss (gradient) by batch_size ("none", "sqrt", or "linear")')
|
||||||
|
|
||||||
|
|
||||||
def _parse_args():
|
def _parse_args():
|
||||||
@ -762,11 +764,12 @@ def main():
|
|||||||
worker_seeding=args.worker_seeding,
|
worker_seeding=args.worker_seeding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
naflex_mode = False
|
||||||
if args.naflex_loader:
|
if args.naflex_loader:
|
||||||
from timm.data.naflex_loader import create_naflex_loader
|
|
||||||
if utils.is_primary(args):
|
if utils.is_primary(args):
|
||||||
_logger.info('Using NaFlex loader')
|
_logger.info('Using NaFlex loader')
|
||||||
|
|
||||||
|
naflex_mode = True
|
||||||
loader_train = create_naflex_loader(
|
loader_train = create_naflex_loader(
|
||||||
dataset=dataset_train,
|
dataset=dataset_train,
|
||||||
patch_size=16, # Could be derived from model config
|
patch_size=16, # Could be derived from model config
|
||||||
@ -804,7 +807,6 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
if args.naflex_loader:
|
if args.naflex_loader:
|
||||||
from timm.data.naflex_loader import create_naflex_loader
|
|
||||||
# Use largest sequence length for validation
|
# Use largest sequence length for validation
|
||||||
loader_eval = create_naflex_loader(
|
loader_eval = create_naflex_loader(
|
||||||
dataset=dataset_eval,
|
dataset=dataset_eval,
|
||||||
@ -950,6 +952,7 @@ def main():
|
|||||||
model_ema=model_ema,
|
model_ema=model_ema,
|
||||||
mixup_fn=mixup_fn,
|
mixup_fn=mixup_fn,
|
||||||
num_updates_total=num_epochs * updates_per_epoch,
|
num_updates_total=num_epochs * updates_per_epoch,
|
||||||
|
naflex_mode=naflex_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
||||||
@ -1052,6 +1055,7 @@ def train_one_epoch(
|
|||||||
model_ema=None,
|
model_ema=None,
|
||||||
mixup_fn=None,
|
mixup_fn=None,
|
||||||
num_updates_total=None,
|
num_updates_total=None,
|
||||||
|
naflex_mode=False,
|
||||||
):
|
):
|
||||||
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
|
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
|
||||||
if args.prefetcher and loader.mixup_enabled:
|
if args.prefetcher and loader.mixup_enabled:
|
||||||
@ -1097,10 +1101,10 @@ def train_one_epoch(
|
|||||||
def _forward():
|
def _forward():
|
||||||
with amp_autocast():
|
with amp_autocast():
|
||||||
output = model(input)
|
output = model(input)
|
||||||
loss = loss_fn(output, target)
|
_loss = loss_fn(output, target)
|
||||||
if accum_steps > 1:
|
if accum_steps > 1:
|
||||||
loss /= accum_steps
|
_loss /= accum_steps
|
||||||
return loss
|
return _loss
|
||||||
|
|
||||||
def _backward(_loss):
|
def _backward(_loss):
|
||||||
if loss_scaler is not None:
|
if loss_scaler is not None:
|
||||||
@ -1124,18 +1128,48 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
if has_no_sync and not need_update:
|
if naflex_mode:
|
||||||
with model.no_sync():
|
assert isinstance(input, dict)
|
||||||
loss = _forward()
|
|
||||||
_backward(loss)
|
|
||||||
else:
|
|
||||||
loss = _forward()
|
|
||||||
_backward(loss)
|
|
||||||
|
|
||||||
if isinstance(input, dict):
|
|
||||||
batch_size = input['patches'].shape[0]
|
batch_size = input['patches'].shape[0]
|
||||||
|
|
||||||
|
# scale gradient vs the minimum batch size (for max seq len)
|
||||||
|
if not args.naflex_loss_scale or args.naflex_loss_scale == 'none':
|
||||||
|
local_scale = 1.0
|
||||||
|
else:
|
||||||
|
local_scale = (batch_size / args.batch_size)
|
||||||
|
if local_scale == 'sqrt':
|
||||||
|
local_scale = local_scale ** 0.5
|
||||||
|
|
||||||
|
if args.distributed:
|
||||||
|
# scale gradient btw distributed ranks, each one can have different batch size
|
||||||
|
global_batch_size = utils.reduce_tensor(torch.tensor(batch_size, device=device), 1) # SUM
|
||||||
|
dist_scale = args.world_size * batch_size / global_batch_size
|
||||||
|
else:
|
||||||
|
dist_scale = None
|
||||||
|
|
||||||
|
if has_no_sync and not need_update:
|
||||||
|
with model.no_sync():
|
||||||
|
loss = _forward()
|
||||||
|
scaled_loss = local_scale * loss
|
||||||
|
if dist_scale is not None:
|
||||||
|
scaled_loss *= dist_scale
|
||||||
|
_backward(scaled_loss)
|
||||||
|
else:
|
||||||
|
loss = _forward()
|
||||||
|
scaled_loss = local_scale * loss
|
||||||
|
if dist_scale is not None:
|
||||||
|
scaled_loss *= dist_scale
|
||||||
|
_backward(scaled_loss)
|
||||||
else:
|
else:
|
||||||
batch_size = input.shape[0]
|
batch_size = input.shape[0]
|
||||||
|
if has_no_sync and not need_update:
|
||||||
|
with model.no_sync():
|
||||||
|
loss = _forward()
|
||||||
|
_backward(loss)
|
||||||
|
else:
|
||||||
|
loss = _forward()
|
||||||
|
_backward(loss)
|
||||||
|
|
||||||
losses_m.update(loss.item() * accum_steps, batch_size)
|
losses_m.update(loss.item() * accum_steps, batch_size)
|
||||||
update_sample_count += batch_size
|
update_sample_count += batch_size
|
||||||
|
|
||||||
@ -1154,7 +1188,8 @@ def train_one_epoch(
|
|||||||
elif device.type == 'npu':
|
elif device.type == 'npu':
|
||||||
torch.npu.synchronize()
|
torch.npu.synchronize()
|
||||||
time_now = time.time()
|
time_now = time.time()
|
||||||
update_time_m.update(time.time() - update_start_time)
|
|
||||||
|
update_time_m.update((time.time() - update_start_time) / update_sample_count, update_sample_count)
|
||||||
update_start_time = time_now
|
update_start_time = time_now
|
||||||
|
|
||||||
if update_idx % args.log_interval == 0:
|
if update_idx % args.log_interval == 0:
|
||||||
@ -1173,8 +1208,8 @@ def train_one_epoch(
|
|||||||
f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} '
|
f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} '
|
||||||
f'({100. * (update_idx + 1) / updates_per_epoch:>3.0f}%)] '
|
f'({100. * (update_idx + 1) / updates_per_epoch:>3.0f}%)] '
|
||||||
f'Loss: {loss_now:#.3g} ({loss_avg:#.3g}) '
|
f'Loss: {loss_now:#.3g} ({loss_avg:#.3g}) '
|
||||||
f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s '
|
f'Time: {update_time_m.val:.3f}s, {1 / update_time_m.val:>7.2f}/s '
|
||||||
f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) '
|
f'({update_time_m.avg:.3f}s, {1 / update_time_m.avg:>7.2f}/s) '
|
||||||
f'LR: {lr:.3e} '
|
f'LR: {lr:.3e} '
|
||||||
f'Data: {data_time_m.val:.3f} ({data_time_m.avg:.3f})'
|
f'Data: {data_time_m.val:.3f} ({data_time_m.avg:.3f})'
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user