Initial AGC impl. Still testing.

This commit is contained in:
Ross Wightman 2021-02-15 23:22:44 -08:00
parent 5f9aff395c
commit 4f49b94311
7 changed files with 95 additions and 15 deletions

View File

@ -31,7 +31,7 @@ from .xception import *
from .xception_aligned import * from .xception_aligned import *
from .factory import create_model from .factory import create_model
from .helpers import load_checkpoint, resume_checkpoint from .helpers import load_checkpoint, resume_checkpoint, model_parameters
from .layers import TestTimePoolHead, apply_test_time_pool from .layers import TestTimePoolHead, apply_test_time_pool
from .layers import convert_splitbn_model from .layers import convert_splitbn_model
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit

View File

@ -113,10 +113,9 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
digits of the SHA256 hash of the contents of the file. The hash is used to digits of the SHA256 hash of the contents of the file. The hash is used to
ensure unique names and to verify the contents of the file. Default: False ensure unique names and to verify the contents of the file. Default: False
""" """
if cfg is None: cfg = cfg or getattr(model, 'default_cfg')
cfg = getattr(model, 'default_cfg') if cfg is None or not cfg.get('url', None):
if cfg is None or 'url' not in cfg or not cfg['url']: _logger.warning("No pretrained weights exist for this model. Using random initialization.")
_logger.warning("Pretrained model URL does not exist, using random initialization.")
return return
url = cfg['url'] url = cfg['url']
@ -174,9 +173,8 @@ def adapt_input_conv(in_chans, conv_weight):
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
if cfg is None: cfg = cfg or getattr(model, 'default_cfg')
cfg = getattr(model, 'default_cfg') if cfg is None or not cfg.get('url', None):
if cfg is None or 'url' not in cfg or not cfg['url']:
_logger.warning("No pretrained weights exist for this model. Using random initialization.") _logger.warning("No pretrained weights exist for this model. Using random initialization.")
return return
@ -376,3 +374,11 @@ def build_model_with_cfg(
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
return model return model
def model_parameters(model, exclude_head=False):
if exclude_head:
# FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
return [p for p in model.parameters()][:-2]
else:
return model.parameters()

View File

@ -1,4 +1,6 @@
from .agc import adaptive_clip_grad
from .checkpoint_saver import CheckpointSaver from .checkpoint_saver import CheckpointSaver
from .clip_grad import dispatch_clip_grad
from .cuda import ApexScaler, NativeScaler from .cuda import ApexScaler, NativeScaler
from .distributed import distribute_bn, reduce_tensor from .distributed import distribute_bn, reduce_tensor
from .jit import set_jit_legacy from .jit import set_jit_legacy

42
timm/utils/agc.py Normal file
View File

@ -0,0 +1,42 @@
""" Adaptive Gradient Clipping
An impl of AGC, as per (https://arxiv.org/abs/2102.06171):
@article{brock2021high,
author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
title={High-Performance Large-Scale Image Recognition Without Normalization},
journal={arXiv preprint arXiv:},
year={2021}
}
Code references:
* Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets
* Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c
Hacked together by / Copyright 2021 Ross Wightman
"""
import torch
def unitwise_norm(x, norm_type=2.0):
if x.ndim <= 1:
return x.norm(norm_type)
else:
# works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor
# might need special cases for other weights (possibly MHA) where this may not be true
return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True)
def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
for p in parameters:
if p.grad is None:
continue
p_data = p.detach()
g_data = p.grad.detach()
max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor)
grad_norm = unitwise_norm(g_data, norm_type=norm_type)
clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6))
new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad)
p.grad.detach().copy_(new_grads)

23
timm/utils/clip_grad.py Normal file
View File

@ -0,0 +1,23 @@
import torch
from timm.utils.agc import adaptive_clip_grad
def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0):
""" Dispatch to gradient clipping method
Args:
parameters (Iterable): model parameters to clip
value (float): clipping value/factor/norm, mode dependant
mode (str): clipping mode, one of 'norm', 'value', 'agc'
norm_type (float): p-norm, default 2.0
"""
if mode == 'norm':
torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type)
elif mode == 'value':
torch.nn.utils.clip_grad_value_(parameters, value)
elif mode == 'agc':
adaptive_clip_grad(parameters, value, norm_type=norm_type)
else:
assert False, f"Unknown clip mode ({mode})."

View File

@ -11,15 +11,17 @@ except ImportError:
amp = None amp = None
has_apex = False has_apex = False
from .clip_grad import dispatch_clip_grad
class ApexScaler: class ApexScaler:
state_dict_key = "amp" state_dict_key = "amp"
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False): def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
with amp.scale_loss(loss, optimizer) as scaled_loss: with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward(create_graph=create_graph) scaled_loss.backward(create_graph=create_graph)
if clip_grad is not None: if clip_grad is not None:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), clip_grad) dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode)
optimizer.step() optimizer.step()
def state_dict(self): def state_dict(self):
@ -37,12 +39,12 @@ class NativeScaler:
def __init__(self): def __init__(self):
self._scaler = torch.cuda.amp.GradScaler() self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False): def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
self._scaler.scale(loss).backward(create_graph=create_graph) self._scaler.scale(loss).backward(create_graph=create_graph)
if clip_grad is not None: if clip_grad is not None:
assert parameters is not None assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
torch.nn.utils.clip_grad_norm_(parameters, clip_grad) dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
self._scaler.step(optimizer) self._scaler.step(optimizer)
self._scaler.update() self._scaler.update()

View File

@ -29,7 +29,7 @@ import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model, model_parameters
from timm.utils import * from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer from timm.optim import create_optimizer
@ -637,11 +637,16 @@ def train_one_epoch(
optimizer.zero_grad() optimizer.zero_grad()
if loss_scaler is not None: if loss_scaler is not None:
loss_scaler( loss_scaler(
loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order) loss, optimizer,
clip_grad=args.clip_grad, clip_mode=args.clip_mode,
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
create_graph=second_order)
else: else:
loss.backward(create_graph=second_order) loss.backward(create_graph=second_order)
if args.clip_grad is not None: if args.clip_grad is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) dispatch_clip_grad(
model_parameters(model, exclude_head='agc' in args.clip_mode),
value=args.clip_grad, mode=args.clip_mode)
optimizer.step() optimizer.step()
if model_ema is not None: if model_ema is not None: