deep-person-reid/torchreid/utils/torchtools.py

28 lines
940 B
Python
Raw Normal View History

2018-07-04 17:32:32 +08:00
from __future__ import absolute_import
from __future__ import division
import torch
2018-07-06 18:03:54 +08:00
import torch.nn as nn
2018-07-04 17:32:32 +08:00
def adjust_learning_rate(optimizer, base_lr, epoch, stepsize, gamma=0.1):
2018-07-06 18:03:54 +08:00
# decay learning rate by 'gamma' for every 'stepsize'
2018-07-04 17:32:32 +08:00
lr = base_lr * (gamma ** (epoch // stepsize))
for param_group in optimizer.param_groups:
2018-07-04 22:53:50 +08:00
param_group['lr'] = lr
def set_bn_to_eval(m):
2018-07-06 18:03:54 +08:00
# 1. no update for running mean and var
# 2. scale and shift parameters are still trainable
2018-07-04 22:53:50 +08:00
classname = m.__class__.__name__
if classname.find('BatchNorm') != -1:
m.eval()
2018-07-06 18:03:54 +08:00
def count_num_param(model):
num_param = sum(p.numel() for p in model.parameters()) / 1e+06
if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Module):
# we ignore the classifier because it is unused at test time
num_param -= sum(p.numel() for p in model.classifier.parameters()) / 1e+06
return num_param