mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix memory_efficient mode for DenseNets. Add AntiAliasing (Blur) support for DenseNets and create one test model. Add lr cycle/mul params to train args.
This commit is contained in:
parent
7df83258c9
commit
6441e9cc1b
@ -14,7 +14,7 @@ from torch.jit.annotations import List
|
|||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act
|
from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act, BlurPool2d
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
|
|
||||||
__all__ = ['DenseNet']
|
__all__ = ['DenseNet']
|
||||||
@ -71,9 +71,9 @@ class DenseLayer(nn.Module):
|
|||||||
def call_checkpoint_bottleneck(self, x):
|
def call_checkpoint_bottleneck(self, x):
|
||||||
# type: (List[torch.Tensor]) -> torch.Tensor
|
# type: (List[torch.Tensor]) -> torch.Tensor
|
||||||
def closure(*xs):
|
def closure(*xs):
|
||||||
return self.bottleneck_fn(*xs)
|
return self.bottleneck_fn(xs)
|
||||||
|
|
||||||
return cp.checkpoint(closure, x)
|
return cp.checkpoint(closure, *x)
|
||||||
|
|
||||||
@torch.jit._overload_method # noqa: F811
|
@torch.jit._overload_method # noqa: F811
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -132,12 +132,15 @@ class DenseBlock(nn.ModuleDict):
|
|||||||
|
|
||||||
|
|
||||||
class DenseTransition(nn.Sequential):
|
class DenseTransition(nn.Sequential):
|
||||||
def __init__(self, num_input_features, num_output_features, norm_act_layer=nn.BatchNorm2d):
|
def __init__(self, num_input_features, num_output_features, norm_act_layer=nn.BatchNorm2d, aa_layer=None):
|
||||||
super(DenseTransition, self).__init__()
|
super(DenseTransition, self).__init__()
|
||||||
self.add_module('norm', norm_act_layer(num_input_features))
|
self.add_module('norm', norm_act_layer(num_input_features))
|
||||||
self.add_module('conv', nn.Conv2d(
|
self.add_module('conv', nn.Conv2d(
|
||||||
num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
|
num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
|
||||||
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
|
if aa_layer is not None:
|
||||||
|
self.add_module('pool', aa_layer(num_output_features, stride=2))
|
||||||
|
else:
|
||||||
|
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
|
||||||
|
|
||||||
|
|
||||||
class DenseNet(nn.Module):
|
class DenseNet(nn.Module):
|
||||||
@ -301,6 +304,17 @@ def densenet121(pretrained=False, **kwargs):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def densenetblur121d(pretrained=False, **kwargs):
|
||||||
|
r"""Densenet-121 model from
|
||||||
|
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||||
|
"""
|
||||||
|
model = _densenet(
|
||||||
|
'densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, stem_type='deep',
|
||||||
|
aa_layer=BlurPool2d, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def densenet121d(pretrained=False, **kwargs):
|
def densenet121d(pretrained=False, **kwargs):
|
||||||
r"""Densenet-121 model from
|
r"""Densenet-121 model from
|
||||||
|
@ -23,12 +23,12 @@ def create_scheduler(args, optimizer):
|
|||||||
lr_scheduler = CosineLRScheduler(
|
lr_scheduler = CosineLRScheduler(
|
||||||
optimizer,
|
optimizer,
|
||||||
t_initial=num_epochs,
|
t_initial=num_epochs,
|
||||||
t_mul=1.0,
|
t_mul=args.lr_cycle_mul,
|
||||||
lr_min=args.min_lr,
|
lr_min=args.min_lr,
|
||||||
decay_rate=args.decay_rate,
|
decay_rate=args.decay_rate,
|
||||||
warmup_lr_init=args.warmup_lr,
|
warmup_lr_init=args.warmup_lr,
|
||||||
warmup_t=args.warmup_epochs,
|
warmup_t=args.warmup_epochs,
|
||||||
cycle_limit=1,
|
cycle_limit=args.lr_cycle_limit,
|
||||||
t_in_epochs=True,
|
t_in_epochs=True,
|
||||||
noise_range_t=noise_range,
|
noise_range_t=noise_range,
|
||||||
noise_pct=args.lr_noise_pct,
|
noise_pct=args.lr_noise_pct,
|
||||||
@ -40,11 +40,11 @@ def create_scheduler(args, optimizer):
|
|||||||
lr_scheduler = TanhLRScheduler(
|
lr_scheduler = TanhLRScheduler(
|
||||||
optimizer,
|
optimizer,
|
||||||
t_initial=num_epochs,
|
t_initial=num_epochs,
|
||||||
t_mul=1.0,
|
t_mul=args.lr_cycle_mul,
|
||||||
lr_min=args.min_lr,
|
lr_min=args.min_lr,
|
||||||
warmup_lr_init=args.warmup_lr,
|
warmup_lr_init=args.warmup_lr,
|
||||||
warmup_t=args.warmup_epochs,
|
warmup_t=args.warmup_epochs,
|
||||||
cycle_limit=1,
|
cycle_limit=args.lr_cycle_limit,
|
||||||
t_in_epochs=True,
|
t_in_epochs=True,
|
||||||
noise_range_t=noise_range,
|
noise_range_t=noise_range,
|
||||||
noise_pct=args.lr_noise_pct,
|
noise_pct=args.lr_noise_pct,
|
||||||
|
4
train.py
4
train.py
@ -111,6 +111,10 @@ parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT
|
|||||||
help='learning rate noise limit percent (default: 0.67)')
|
help='learning rate noise limit percent (default: 0.67)')
|
||||||
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
|
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
|
||||||
help='learning rate noise std-dev (default: 1.0)')
|
help='learning rate noise std-dev (default: 1.0)')
|
||||||
|
parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
|
||||||
|
help='learning rate cycle len multiplier (default: 1.0)')
|
||||||
|
parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
|
||||||
|
help='learning rate cycle limit')
|
||||||
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
|
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
|
||||||
help='warmup learning rate (default: 0.0001)')
|
help='warmup learning rate (default: 0.0001)')
|
||||||
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
|
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
|
||||||
|
Loading…
x
Reference in New Issue
Block a user