Tweaking tanh scheduler, senet weight init (for BN), transform defaults
parent
48360625f2
commit
b5255960d9
|
@ -104,6 +104,18 @@ pretrained_config = {
|
|||
}
|
||||
|
||||
|
||||
def _weight_init(m, n='', ll=''):
|
||||
print(m, n, ll)
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
if ll and n == ll:
|
||||
nn.init.constant_(m.weight, 0.)
|
||||
else:
|
||||
nn.init.constant_(m.weight, 1.)
|
||||
nn.init.constant_(m.bias, 0.)
|
||||
|
||||
|
||||
class SEModule(nn.Module):
|
||||
|
||||
def __init__(self, channels, reduction):
|
||||
|
@ -116,6 +128,9 @@ class SEModule(nn.Module):
|
|||
channels // reduction, channels, kernel_size=1, padding=0)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
for m in self.modules():
|
||||
_weight_init(m)
|
||||
|
||||
def forward(self, x):
|
||||
module_input = x
|
||||
x = self.avg_pool(x)
|
||||
|
@ -176,6 +191,9 @@ class SEBottleneck(Bottleneck):
|
|||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
for n, m in self.named_modules():
|
||||
_weight_init(m, n, ll='bn3')
|
||||
|
||||
|
||||
class SEResNetBottleneck(Bottleneck):
|
||||
"""
|
||||
|
@ -201,6 +219,9 @@ class SEResNetBottleneck(Bottleneck):
|
|||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
for n, m in self.named_modules():
|
||||
_weight_init(m, n, ll='bn3')
|
||||
|
||||
|
||||
class SEResNeXtBottleneck(Bottleneck):
|
||||
"""
|
||||
|
@ -225,6 +246,9 @@ class SEResNeXtBottleneck(Bottleneck):
|
|||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
for n, m in self.named_modules():
|
||||
_weight_init(m, n, ll='bn3')
|
||||
|
||||
|
||||
class SEResNetBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
@ -242,6 +266,9 @@ class SEResNetBlock(nn.Module):
|
|||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
for n, m in self.named_modules():
|
||||
_weight_init(m, n, ll='bn2')
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
|
@ -378,6 +405,12 @@ class SENet(nn.Module):
|
|||
self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
|
||||
self.last_linear = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for n, m in self.named_children():
|
||||
if n == 'layer0':
|
||||
m.apply(_weight_init)
|
||||
else:
|
||||
_weight_init(m)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
|
||||
downsample_kernel_size=1, downsample_padding=0):
|
||||
downsample = None
|
||||
|
|
|
@ -21,7 +21,7 @@ class LeNormalize(object):
|
|||
return tensor
|
||||
|
||||
|
||||
def transforms_imagenet_train(model_name, img_size=224, scale=(0.08, 1.0), color_jitter=(0.3, 0.3, 0.3)):
|
||||
def transforms_imagenet_train(model_name, img_size=224, scale=(0.1, 1.0), color_jitter=(0.333, 0.333, 0.333)):
|
||||
if 'dpn' in model_name:
|
||||
normalize = transforms.Normalize(
|
||||
mean=IMAGENET_DPN_MEAN,
|
||||
|
|
|
@ -23,14 +23,20 @@ class TanhLRScheduler(Scheduler):
|
|||
t_mul: float = 1.,
|
||||
lr_min: float = 0.,
|
||||
decay_rate: float = 1.,
|
||||
warmup_updates=0,
|
||||
warmup_t=0,
|
||||
warmup_lr_init=0,
|
||||
warmup_prefix=False,
|
||||
cycle_limit=0,
|
||||
t_in_epochs=False,
|
||||
initialize=True) -> None:
|
||||
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
|
||||
|
||||
assert t_initial > 0
|
||||
assert lr_min >= 0
|
||||
assert lb < ub
|
||||
assert cycle_limit >= 0
|
||||
assert warmup_t >= 0
|
||||
assert warmup_lr_init >= 0
|
||||
self.lb = lb
|
||||
self.ub = ub
|
||||
self.t_initial = t_initial
|
||||
|
@ -38,33 +44,33 @@ class TanhLRScheduler(Scheduler):
|
|||
self.lr_min = lr_min
|
||||
self.decay_rate = decay_rate
|
||||
self.cycle_limit = cycle_limit
|
||||
self.warmup_updates = warmup_updates
|
||||
self.warmup_t = warmup_t
|
||||
self.warmup_lr_init = warmup_lr_init
|
||||
if self.warmup_updates:
|
||||
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_updates for v in self.base_values]
|
||||
self.warmup_prefix = warmup_prefix
|
||||
self.t_in_epochs = t_in_epochs
|
||||
if self.warmup_t:
|
||||
t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t)
|
||||
print(t_v)
|
||||
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v]
|
||||
super().update_groups(self.warmup_lr_init)
|
||||
else:
|
||||
self.warmup_steps = [1 for _ in self.base_values]
|
||||
if self.warmup_lr_init:
|
||||
super().update_groups(self.warmup_lr_init)
|
||||
|
||||
def get_epoch_values(self, epoch: int):
|
||||
# this scheduler doesn't update on epoch
|
||||
return None
|
||||
|
||||
def get_update_values(self, num_updates: int):
|
||||
if num_updates < self.warmup_updates:
|
||||
lrs = [self.warmup_lr_init + num_updates * s for s in self.warmup_steps]
|
||||
def _get_lr(self, t):
|
||||
if t < self.warmup_t:
|
||||
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||
else:
|
||||
curr_updates = num_updates - self.warmup_updates
|
||||
if self.warmup_prefix:
|
||||
t = t - self.warmup_t
|
||||
|
||||
if self.t_mul != 1:
|
||||
i = math.floor(math.log(1 - curr_updates / self.t_initial * (1 - self.t_mul), self.t_mul))
|
||||
i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul))
|
||||
t_i = self.t_mul ** i * self.t_initial
|
||||
t_curr = curr_updates - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
|
||||
t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
|
||||
else:
|
||||
i = curr_updates // self.t_initial
|
||||
i = t // self.t_initial
|
||||
t_i = self.t_initial
|
||||
t_curr = curr_updates - (self.t_initial * i)
|
||||
t_curr = t - (self.t_initial * i)
|
||||
|
||||
if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit):
|
||||
gamma = self.decay_rate ** i
|
||||
|
@ -78,5 +84,16 @@ class TanhLRScheduler(Scheduler):
|
|||
]
|
||||
else:
|
||||
lrs = [self.lr_min * (self.decay_rate ** self.cycle_limit) for _ in self.base_values]
|
||||
|
||||
return lrs
|
||||
|
||||
def get_epoch_values(self, epoch: int):
|
||||
if self.t_in_epochs:
|
||||
return self._get_lr(epoch)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_update_values(self, num_updates: int):
|
||||
if not self.t_in_epochs:
|
||||
return self._get_lr(num_updates)
|
||||
else:
|
||||
return None
|
||||
|
|
33
train.py
33
train.py
|
@ -162,7 +162,7 @@ def main():
|
|||
if args.opt.lower() == 'sgd':
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(), lr=args.lr,
|
||||
momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
|
||||
momentum=args.momentum, weight_decay=args.weight_decay, nesterov=False)
|
||||
elif args.opt.lower() == 'adam':
|
||||
optimizer = optim.Adam(
|
||||
model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
|
||||
|
@ -183,15 +183,27 @@ def main():
|
|||
if optimizer_state is not None:
|
||||
optimizer.load_state_dict(optimizer_state)
|
||||
|
||||
updates_per_epoch = len(loader_train)
|
||||
if args.sched == 'cosine':
|
||||
lr_scheduler = scheduler.CosineLRScheduler(
|
||||
optimizer,
|
||||
t_initial=13 * len(loader_train),
|
||||
t_mul=2.0,
|
||||
t_initial=100 * updates_per_epoch,
|
||||
t_mul=1.0,
|
||||
lr_min=0,
|
||||
decay_rate=0.5,
|
||||
warmup_lr_init=1e-4,
|
||||
warmup_updates=len(loader_train)
|
||||
warmup_updates=1 * updates_per_epoch
|
||||
)
|
||||
elif args.sched == 'tanh':
|
||||
lr_scheduler = scheduler.TanhLRScheduler(
|
||||
optimizer,
|
||||
t_initial=80 * updates_per_epoch,
|
||||
t_mul=1.0,
|
||||
lr_min=1e-5,
|
||||
decay_rate=0.5,
|
||||
warmup_lr_init=.001,
|
||||
warmup_t=5 * updates_per_epoch,
|
||||
cycle_limit=1
|
||||
)
|
||||
else:
|
||||
lr_scheduler = scheduler.StepLRScheduler(
|
||||
|
@ -354,7 +366,7 @@ def validate(model, loader, loss_fn, args):
|
|||
losses_m.update(loss.item(), input.size(0))
|
||||
|
||||
# metrics
|
||||
prec1, prec5 = accuracy(output, target, topk=(1, 3))
|
||||
prec1, prec5 = accuracy(output, target, topk=(1, 5))
|
||||
prec1_m.update(prec1.item(), output.size(0))
|
||||
prec5_m.update(prec5.item(), output.size(0))
|
||||
|
||||
|
@ -375,16 +387,5 @@ def validate(model, loader, loss_fn, args):
|
|||
return metrics
|
||||
|
||||
|
||||
def update_summary(epoch, train_metrics, eval_metrics, output_dir, write_header=False):
|
||||
rowd = OrderedDict(epoch=epoch)
|
||||
rowd.update(train_metrics)
|
||||
rowd.update(eval_metrics)
|
||||
with open(os.path.join(output_dir, 'summary.csv'), mode='a') as cf:
|
||||
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
|
||||
if write_header: # first iteration (epoch == 1 can't be used)
|
||||
dw.writeheader()
|
||||
dw.writerow(rowd)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
Loading…
Reference in New Issue