mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Tweaking tanh scheduler, senet weight init (for BN), transform defaults
This commit is contained in:
parent
48360625f2
commit
b5255960d9
@ -104,6 +104,18 @@ pretrained_config = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _weight_init(m, n='', ll=''):
|
||||||
|
print(m, n, ll)
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
if ll and n == ll:
|
||||||
|
nn.init.constant_(m.weight, 0.)
|
||||||
|
else:
|
||||||
|
nn.init.constant_(m.weight, 1.)
|
||||||
|
nn.init.constant_(m.bias, 0.)
|
||||||
|
|
||||||
|
|
||||||
class SEModule(nn.Module):
|
class SEModule(nn.Module):
|
||||||
|
|
||||||
def __init__(self, channels, reduction):
|
def __init__(self, channels, reduction):
|
||||||
@ -116,6 +128,9 @@ class SEModule(nn.Module):
|
|||||||
channels // reduction, channels, kernel_size=1, padding=0)
|
channels // reduction, channels, kernel_size=1, padding=0)
|
||||||
self.sigmoid = nn.Sigmoid()
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
_weight_init(m)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
module_input = x
|
module_input = x
|
||||||
x = self.avg_pool(x)
|
x = self.avg_pool(x)
|
||||||
@ -176,6 +191,9 @@ class SEBottleneck(Bottleneck):
|
|||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
|
||||||
|
for n, m in self.named_modules():
|
||||||
|
_weight_init(m, n, ll='bn3')
|
||||||
|
|
||||||
|
|
||||||
class SEResNetBottleneck(Bottleneck):
|
class SEResNetBottleneck(Bottleneck):
|
||||||
"""
|
"""
|
||||||
@ -201,6 +219,9 @@ class SEResNetBottleneck(Bottleneck):
|
|||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
|
||||||
|
for n, m in self.named_modules():
|
||||||
|
_weight_init(m, n, ll='bn3')
|
||||||
|
|
||||||
|
|
||||||
class SEResNeXtBottleneck(Bottleneck):
|
class SEResNeXtBottleneck(Bottleneck):
|
||||||
"""
|
"""
|
||||||
@ -225,6 +246,9 @@ class SEResNeXtBottleneck(Bottleneck):
|
|||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
|
||||||
|
for n, m in self.named_modules():
|
||||||
|
_weight_init(m, n, ll='bn3')
|
||||||
|
|
||||||
|
|
||||||
class SEResNetBlock(nn.Module):
|
class SEResNetBlock(nn.Module):
|
||||||
expansion = 1
|
expansion = 1
|
||||||
@ -242,6 +266,9 @@ class SEResNetBlock(nn.Module):
|
|||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
|
||||||
|
for n, m in self.named_modules():
|
||||||
|
_weight_init(m, n, ll='bn2')
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
residual = x
|
residual = x
|
||||||
|
|
||||||
@ -378,6 +405,12 @@ class SENet(nn.Module):
|
|||||||
self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
|
self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
|
||||||
self.last_linear = nn.Linear(512 * block.expansion, num_classes)
|
self.last_linear = nn.Linear(512 * block.expansion, num_classes)
|
||||||
|
|
||||||
|
for n, m in self.named_children():
|
||||||
|
if n == 'layer0':
|
||||||
|
m.apply(_weight_init)
|
||||||
|
else:
|
||||||
|
_weight_init(m)
|
||||||
|
|
||||||
def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
|
def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
|
||||||
downsample_kernel_size=1, downsample_padding=0):
|
downsample_kernel_size=1, downsample_padding=0):
|
||||||
downsample = None
|
downsample = None
|
||||||
|
@ -21,7 +21,7 @@ class LeNormalize(object):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def transforms_imagenet_train(model_name, img_size=224, scale=(0.08, 1.0), color_jitter=(0.3, 0.3, 0.3)):
|
def transforms_imagenet_train(model_name, img_size=224, scale=(0.1, 1.0), color_jitter=(0.333, 0.333, 0.333)):
|
||||||
if 'dpn' in model_name:
|
if 'dpn' in model_name:
|
||||||
normalize = transforms.Normalize(
|
normalize = transforms.Normalize(
|
||||||
mean=IMAGENET_DPN_MEAN,
|
mean=IMAGENET_DPN_MEAN,
|
||||||
|
@ -23,14 +23,20 @@ class TanhLRScheduler(Scheduler):
|
|||||||
t_mul: float = 1.,
|
t_mul: float = 1.,
|
||||||
lr_min: float = 0.,
|
lr_min: float = 0.,
|
||||||
decay_rate: float = 1.,
|
decay_rate: float = 1.,
|
||||||
warmup_updates=0,
|
warmup_t=0,
|
||||||
warmup_lr_init=0,
|
warmup_lr_init=0,
|
||||||
|
warmup_prefix=False,
|
||||||
cycle_limit=0,
|
cycle_limit=0,
|
||||||
|
t_in_epochs=False,
|
||||||
initialize=True) -> None:
|
initialize=True) -> None:
|
||||||
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
|
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
|
||||||
|
|
||||||
assert t_initial > 0
|
assert t_initial > 0
|
||||||
assert lr_min >= 0
|
assert lr_min >= 0
|
||||||
|
assert lb < ub
|
||||||
|
assert cycle_limit >= 0
|
||||||
|
assert warmup_t >= 0
|
||||||
|
assert warmup_lr_init >= 0
|
||||||
self.lb = lb
|
self.lb = lb
|
||||||
self.ub = ub
|
self.ub = ub
|
||||||
self.t_initial = t_initial
|
self.t_initial = t_initial
|
||||||
@ -38,33 +44,33 @@ class TanhLRScheduler(Scheduler):
|
|||||||
self.lr_min = lr_min
|
self.lr_min = lr_min
|
||||||
self.decay_rate = decay_rate
|
self.decay_rate = decay_rate
|
||||||
self.cycle_limit = cycle_limit
|
self.cycle_limit = cycle_limit
|
||||||
self.warmup_updates = warmup_updates
|
self.warmup_t = warmup_t
|
||||||
self.warmup_lr_init = warmup_lr_init
|
self.warmup_lr_init = warmup_lr_init
|
||||||
if self.warmup_updates:
|
self.warmup_prefix = warmup_prefix
|
||||||
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_updates for v in self.base_values]
|
self.t_in_epochs = t_in_epochs
|
||||||
|
if self.warmup_t:
|
||||||
|
t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t)
|
||||||
|
print(t_v)
|
||||||
|
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v]
|
||||||
|
super().update_groups(self.warmup_lr_init)
|
||||||
else:
|
else:
|
||||||
self.warmup_steps = [1 for _ in self.base_values]
|
self.warmup_steps = [1 for _ in self.base_values]
|
||||||
if self.warmup_lr_init:
|
|
||||||
super().update_groups(self.warmup_lr_init)
|
|
||||||
|
|
||||||
def get_epoch_values(self, epoch: int):
|
def _get_lr(self, t):
|
||||||
# this scheduler doesn't update on epoch
|
if t < self.warmup_t:
|
||||||
return None
|
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||||
|
|
||||||
def get_update_values(self, num_updates: int):
|
|
||||||
if num_updates < self.warmup_updates:
|
|
||||||
lrs = [self.warmup_lr_init + num_updates * s for s in self.warmup_steps]
|
|
||||||
else:
|
else:
|
||||||
curr_updates = num_updates - self.warmup_updates
|
if self.warmup_prefix:
|
||||||
|
t = t - self.warmup_t
|
||||||
|
|
||||||
if self.t_mul != 1:
|
if self.t_mul != 1:
|
||||||
i = math.floor(math.log(1 - curr_updates / self.t_initial * (1 - self.t_mul), self.t_mul))
|
i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul))
|
||||||
t_i = self.t_mul ** i * self.t_initial
|
t_i = self.t_mul ** i * self.t_initial
|
||||||
t_curr = curr_updates - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
|
t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
|
||||||
else:
|
else:
|
||||||
i = curr_updates // self.t_initial
|
i = t // self.t_initial
|
||||||
t_i = self.t_initial
|
t_i = self.t_initial
|
||||||
t_curr = curr_updates - (self.t_initial * i)
|
t_curr = t - (self.t_initial * i)
|
||||||
|
|
||||||
if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit):
|
if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit):
|
||||||
gamma = self.decay_rate ** i
|
gamma = self.decay_rate ** i
|
||||||
@ -78,5 +84,16 @@ class TanhLRScheduler(Scheduler):
|
|||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
lrs = [self.lr_min * (self.decay_rate ** self.cycle_limit) for _ in self.base_values]
|
lrs = [self.lr_min * (self.decay_rate ** self.cycle_limit) for _ in self.base_values]
|
||||||
|
|
||||||
return lrs
|
return lrs
|
||||||
|
|
||||||
|
def get_epoch_values(self, epoch: int):
|
||||||
|
if self.t_in_epochs:
|
||||||
|
return self._get_lr(epoch)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_update_values(self, num_updates: int):
|
||||||
|
if not self.t_in_epochs:
|
||||||
|
return self._get_lr(num_updates)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
33
train.py
33
train.py
@ -162,7 +162,7 @@ def main():
|
|||||||
if args.opt.lower() == 'sgd':
|
if args.opt.lower() == 'sgd':
|
||||||
optimizer = optim.SGD(
|
optimizer = optim.SGD(
|
||||||
model.parameters(), lr=args.lr,
|
model.parameters(), lr=args.lr,
|
||||||
momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
|
momentum=args.momentum, weight_decay=args.weight_decay, nesterov=False)
|
||||||
elif args.opt.lower() == 'adam':
|
elif args.opt.lower() == 'adam':
|
||||||
optimizer = optim.Adam(
|
optimizer = optim.Adam(
|
||||||
model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
|
model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
|
||||||
@ -183,15 +183,27 @@ def main():
|
|||||||
if optimizer_state is not None:
|
if optimizer_state is not None:
|
||||||
optimizer.load_state_dict(optimizer_state)
|
optimizer.load_state_dict(optimizer_state)
|
||||||
|
|
||||||
|
updates_per_epoch = len(loader_train)
|
||||||
if args.sched == 'cosine':
|
if args.sched == 'cosine':
|
||||||
lr_scheduler = scheduler.CosineLRScheduler(
|
lr_scheduler = scheduler.CosineLRScheduler(
|
||||||
optimizer,
|
optimizer,
|
||||||
t_initial=13 * len(loader_train),
|
t_initial=100 * updates_per_epoch,
|
||||||
t_mul=2.0,
|
t_mul=1.0,
|
||||||
lr_min=0,
|
lr_min=0,
|
||||||
decay_rate=0.5,
|
decay_rate=0.5,
|
||||||
warmup_lr_init=1e-4,
|
warmup_lr_init=1e-4,
|
||||||
warmup_updates=len(loader_train)
|
warmup_updates=1 * updates_per_epoch
|
||||||
|
)
|
||||||
|
elif args.sched == 'tanh':
|
||||||
|
lr_scheduler = scheduler.TanhLRScheduler(
|
||||||
|
optimizer,
|
||||||
|
t_initial=80 * updates_per_epoch,
|
||||||
|
t_mul=1.0,
|
||||||
|
lr_min=1e-5,
|
||||||
|
decay_rate=0.5,
|
||||||
|
warmup_lr_init=.001,
|
||||||
|
warmup_t=5 * updates_per_epoch,
|
||||||
|
cycle_limit=1
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
lr_scheduler = scheduler.StepLRScheduler(
|
lr_scheduler = scheduler.StepLRScheduler(
|
||||||
@ -354,7 +366,7 @@ def validate(model, loader, loss_fn, args):
|
|||||||
losses_m.update(loss.item(), input.size(0))
|
losses_m.update(loss.item(), input.size(0))
|
||||||
|
|
||||||
# metrics
|
# metrics
|
||||||
prec1, prec5 = accuracy(output, target, topk=(1, 3))
|
prec1, prec5 = accuracy(output, target, topk=(1, 5))
|
||||||
prec1_m.update(prec1.item(), output.size(0))
|
prec1_m.update(prec1.item(), output.size(0))
|
||||||
prec5_m.update(prec5.item(), output.size(0))
|
prec5_m.update(prec5.item(), output.size(0))
|
||||||
|
|
||||||
@ -375,16 +387,5 @@ def validate(model, loader, loss_fn, args):
|
|||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
def update_summary(epoch, train_metrics, eval_metrics, output_dir, write_header=False):
|
|
||||||
rowd = OrderedDict(epoch=epoch)
|
|
||||||
rowd.update(train_metrics)
|
|
||||||
rowd.update(eval_metrics)
|
|
||||||
with open(os.path.join(output_dir, 'summary.csv'), mode='a') as cf:
|
|
||||||
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
|
|
||||||
if write_header: # first iteration (epoch == 1 can't be used)
|
|
||||||
dw.writeheader()
|
|
||||||
dw.writerow(rowd)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user