mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Initial AGC impl. Still testing.
This commit is contained in:
parent
5f9aff395c
commit
4f49b94311
@ -31,7 +31,7 @@ from .xception import *
|
||||
from .xception_aligned import *
|
||||
|
||||
from .factory import create_model
|
||||
from .helpers import load_checkpoint, resume_checkpoint
|
||||
from .helpers import load_checkpoint, resume_checkpoint, model_parameters
|
||||
from .layers import TestTimePoolHead, apply_test_time_pool
|
||||
from .layers import convert_splitbn_model
|
||||
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
|
||||
|
@ -113,10 +113,9 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
|
||||
digits of the SHA256 hash of the contents of the file. The hash is used to
|
||||
ensure unique names and to verify the contents of the file. Default: False
|
||||
"""
|
||||
if cfg is None:
|
||||
cfg = getattr(model, 'default_cfg')
|
||||
if cfg is None or 'url' not in cfg or not cfg['url']:
|
||||
_logger.warning("Pretrained model URL does not exist, using random initialization.")
|
||||
cfg = cfg or getattr(model, 'default_cfg')
|
||||
if cfg is None or not cfg.get('url', None):
|
||||
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
|
||||
return
|
||||
url = cfg['url']
|
||||
|
||||
@ -174,9 +173,8 @@ def adapt_input_conv(in_chans, conv_weight):
|
||||
|
||||
|
||||
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
|
||||
if cfg is None:
|
||||
cfg = getattr(model, 'default_cfg')
|
||||
if cfg is None or 'url' not in cfg or not cfg['url']:
|
||||
cfg = cfg or getattr(model, 'default_cfg')
|
||||
if cfg is None or not cfg.get('url', None):
|
||||
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
|
||||
return
|
||||
|
||||
@ -376,3 +374,11 @@ def build_model_with_cfg(
|
||||
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def model_parameters(model, exclude_head=False):
|
||||
if exclude_head:
|
||||
# FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
|
||||
return [p for p in model.parameters()][:-2]
|
||||
else:
|
||||
return model.parameters()
|
||||
|
@ -1,4 +1,6 @@
|
||||
from .agc import adaptive_clip_grad
|
||||
from .checkpoint_saver import CheckpointSaver
|
||||
from .clip_grad import dispatch_clip_grad
|
||||
from .cuda import ApexScaler, NativeScaler
|
||||
from .distributed import distribute_bn, reduce_tensor
|
||||
from .jit import set_jit_legacy
|
||||
|
42
timm/utils/agc.py
Normal file
42
timm/utils/agc.py
Normal file
@ -0,0 +1,42 @@
|
||||
""" Adaptive Gradient Clipping
|
||||
|
||||
An impl of AGC, as per (https://arxiv.org/abs/2102.06171):
|
||||
|
||||
@article{brock2021high,
|
||||
author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
|
||||
title={High-Performance Large-Scale Image Recognition Without Normalization},
|
||||
journal={arXiv preprint arXiv:},
|
||||
year={2021}
|
||||
}
|
||||
|
||||
Code references:
|
||||
* Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets
|
||||
* Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c
|
||||
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
|
||||
|
||||
def unitwise_norm(x, norm_type=2.0):
|
||||
if x.ndim <= 1:
|
||||
return x.norm(norm_type)
|
||||
else:
|
||||
# works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor
|
||||
# might need special cases for other weights (possibly MHA) where this may not be true
|
||||
return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True)
|
||||
|
||||
|
||||
def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0):
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
for p in parameters:
|
||||
if p.grad is None:
|
||||
continue
|
||||
p_data = p.detach()
|
||||
g_data = p.grad.detach()
|
||||
max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor)
|
||||
grad_norm = unitwise_norm(g_data, norm_type=norm_type)
|
||||
clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6))
|
||||
new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad)
|
||||
p.grad.detach().copy_(new_grads)
|
23
timm/utils/clip_grad.py
Normal file
23
timm/utils/clip_grad.py
Normal file
@ -0,0 +1,23 @@
|
||||
import torch
|
||||
|
||||
from timm.utils.agc import adaptive_clip_grad
|
||||
|
||||
|
||||
def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0):
|
||||
""" Dispatch to gradient clipping method
|
||||
|
||||
Args:
|
||||
parameters (Iterable): model parameters to clip
|
||||
value (float): clipping value/factor/norm, mode dependant
|
||||
mode (str): clipping mode, one of 'norm', 'value', 'agc'
|
||||
norm_type (float): p-norm, default 2.0
|
||||
"""
|
||||
if mode == 'norm':
|
||||
torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type)
|
||||
elif mode == 'value':
|
||||
torch.nn.utils.clip_grad_value_(parameters, value)
|
||||
elif mode == 'agc':
|
||||
adaptive_clip_grad(parameters, value, norm_type=norm_type)
|
||||
else:
|
||||
assert False, f"Unknown clip mode ({mode})."
|
||||
|
@ -11,15 +11,17 @@ except ImportError:
|
||||
amp = None
|
||||
has_apex = False
|
||||
|
||||
from .clip_grad import dispatch_clip_grad
|
||||
|
||||
|
||||
class ApexScaler:
|
||||
state_dict_key = "amp"
|
||||
|
||||
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False):
|
||||
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward(create_graph=create_graph)
|
||||
if clip_grad is not None:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), clip_grad)
|
||||
dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode)
|
||||
optimizer.step()
|
||||
|
||||
def state_dict(self):
|
||||
@ -37,12 +39,12 @@ class NativeScaler:
|
||||
def __init__(self):
|
||||
self._scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False):
|
||||
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
|
||||
self._scaler.scale(loss).backward(create_graph=create_graph)
|
||||
if clip_grad is not None:
|
||||
assert parameters is not None
|
||||
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
||||
torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
|
||||
dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
|
||||
self._scaler.step(optimizer)
|
||||
self._scaler.update()
|
||||
|
||||
|
11
train.py
11
train.py
@ -29,7 +29,7 @@ import torchvision.utils
|
||||
from torch.nn.parallel import DistributedDataParallel as NativeDDP
|
||||
|
||||
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
||||
from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model
|
||||
from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model, model_parameters
|
||||
from timm.utils import *
|
||||
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
|
||||
from timm.optim import create_optimizer
|
||||
@ -637,11 +637,16 @@ def train_one_epoch(
|
||||
optimizer.zero_grad()
|
||||
if loss_scaler is not None:
|
||||
loss_scaler(
|
||||
loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)
|
||||
loss, optimizer,
|
||||
clip_grad=args.clip_grad, clip_mode=args.clip_mode,
|
||||
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
|
||||
create_graph=second_order)
|
||||
else:
|
||||
loss.backward(create_graph=second_order)
|
||||
if args.clip_grad is not None:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
|
||||
dispatch_clip_grad(
|
||||
model_parameters(model, exclude_head='agc' in args.clip_mode),
|
||||
value=args.clip_grad, mode=args.clip_mode)
|
||||
optimizer.step()
|
||||
|
||||
if model_ema is not None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user