yolov5/utils/torch_utils.py

287 lines
11 KiB
Python
Raw Normal View History

# PyTorch utils
2020-08-21 12:17:40 +08:00
import logging
import math
2020-05-30 08:04:54 +08:00
import os
import time
from contextlib import contextmanager
2020-05-30 08:04:54 +08:00
from copy import deepcopy
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
2020-10-06 21:09:24 +08:00
import torchvision
2020-05-30 08:04:54 +08:00
try:
import thop # for FLOPS computation
except ImportError:
thop = None
logger = logging.getLogger(__name__)
2020-05-30 08:04:54 +08:00
2020-08-15 02:53:44 +08:00
@contextmanager
def torch_distributed_zero_first(local_rank: int):
"""
Decorator to make all processes in distributed training wait for each local_master to do something.
"""
if local_rank not in [-1, 0]:
torch.distributed.barrier()
yield
if local_rank == 0:
torch.distributed.barrier()
2020-05-30 08:04:54 +08:00
def init_torch_seeds(seed=0):
2020-06-06 04:07:09 +08:00
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
torch.manual_seed(seed)
2020-06-06 04:07:09 +08:00
if seed == 0: # slower, more reproducible
cudnn.deterministic = True
cudnn.benchmark = False
else: # faster, less reproducible
2020-05-30 08:04:54 +08:00
cudnn.deterministic = False
cudnn.benchmark = True
def select_device(device='', batch_size=None):
2020-05-30 08:04:54 +08:00
# device = 'cpu' or '0' or '0,1,2,3'
cpu_request = device.lower() == 'cpu'
if device and not cpu_request: # if device requested other than 'cpu'
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availablity
2020-05-30 08:04:54 +08:00
cuda = False if cpu_request else torch.cuda.is_available()
if cuda:
c = 1024 ** 2 # bytes to MB
ng = torch.cuda.device_count()
if ng > 1 and batch_size: # check that batch_size is compatible with device_count
assert batch_size % ng == 0, f'batch-size {batch_size} not multiple of GPU count {ng}'
2020-05-30 08:04:54 +08:00
x = [torch.cuda.get_device_properties(i) for i in range(ng)]
s = f'Using torch {torch.__version__} '
for i, d in enumerate((device or '0').split(',')):
2020-05-30 08:04:54 +08:00
if i == 1:
s = ' ' * len(s)
logger.info(f"{s}CUDA:{d} ({x[i].name}, {x[i].total_memory / c}MB)")
2020-05-30 08:04:54 +08:00
else:
logger.info(f'Using torch {torch.__version__} CPU')
2020-05-30 08:04:54 +08:00
logger.info('') # skip a line
2020-05-30 08:04:54 +08:00
return torch.device('cuda:0' if cuda else 'cpu')
def time_synchronized():
# pytorch-accurate time
2020-05-30 08:04:54 +08:00
torch.cuda.synchronize() if torch.cuda.is_available() else None
return time.time()
def profile(x, ops, n=100, device=None):
# profile a pytorch module or list of modules. Example usage:
# x = torch.randn(16, 3, 640, 640) # input
# m1 = lambda x: x * torch.sigmoid(x)
# m2 = nn.SiLU()
# profile(x, [m1, m2], n=100) # profile speed over 100 iterations
device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
x = x.to(device)
x.requires_grad = True
print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
2020-12-22 05:29:52 +08:00
print(f"\n{'Params':>12s}{'GFLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
for m in ops if isinstance(ops, list) else [ops]:
2020-12-22 07:20:33 +08:00
m = m.to(device) if hasattr(m, 'to') else m # device
m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type
dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
try:
flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPS
except:
flops = 0
for _ in range(n):
t[0] = time_synchronized()
y = m(x)
t[1] = time_synchronized()
try:
_ = y.sum().backward()
t[2] = time_synchronized()
except: # no backward method
t[2] = float('nan')
dtf += (t[1] - t[0]) * 1000 / n # ms per op forward
dtb += (t[2] - t[1]) * 1000 / n # ms per op backward
s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
print(f'{p:12.4g}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}')
2020-07-03 03:03:45 +08:00
def is_parallel(model):
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
def intersect_dicts(da, db, exclude=()):
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
2020-05-30 08:04:54 +08:00
def initialize_weights(model):
for m in model.modules():
t = type(m)
if t is nn.Conv2d:
pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif t is nn.BatchNorm2d:
m.eps = 1e-3
2020-05-30 08:04:54 +08:00
m.momentum = 0.03
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
2020-05-30 08:04:54 +08:00
m.inplace = True
def find_modules(model, mclass=nn.Conv2d):
# Finds layer indices matching module class 'mclass'
2020-05-30 08:04:54 +08:00
return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
2020-07-06 04:41:21 +08:00
def sparsity(model):
# Return global model sparsity
a, b = 0., 0.
for p in model.parameters():
a += p.numel()
b += (p == 0).sum()
return b / a
def prune(model, amount=0.3):
# Prune model to requested global sparsity
import torch.nn.utils.prune as prune
print('Pruning model... ', end='')
for name, m in model.named_modules():
2020-07-07 02:46:10 +08:00
if isinstance(m, nn.Conv2d):
2020-07-06 04:41:21 +08:00
prune.l1_unstructured(m, name='weight', amount=amount) # prune
prune.remove(m, 'weight') # make permanent
print(' %.3g global sparsity' % sparsity(model))
2020-05-30 08:04:54 +08:00
def fuse_conv_and_bn(conv, bn):
2020-09-21 02:57:19 +08:00
# Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
fusedconv = nn.Conv2d(conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
groups=conv.groups,
bias=True).requires_grad_(False).to(conv.weight.device)
# prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
# prepare spatial bias
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
return fusedconv
2020-05-30 08:04:54 +08:00
def model_info(model, verbose=False, img_size=640):
# Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
2020-05-30 08:04:54 +08:00
n_p = sum(x.numel() for x in model.parameters()) # number parameters
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
if verbose:
print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
for i, (name, p) in enumerate(model.named_parameters()):
name = name.replace('module_list.', '')
print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
try: # FLOPS
from thop import profile
2020-12-05 18:41:34 +08:00
stride = int(model.stride.max()) if hasattr(model, 'stride') else 32
img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
2020-12-22 05:29:52 +08:00
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPS
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
2020-12-22 05:29:52 +08:00
fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPS
2020-11-14 21:48:55 +08:00
except (ImportError, Exception):
2020-05-30 08:04:54 +08:00
fs = ''
logger.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
2020-05-30 08:04:54 +08:00
def load_classifier(name='resnet101', n=2):
# Loads a pretrained model reshaped to n-class output
2020-10-06 20:54:02 +08:00
model = torchvision.models.__dict__[name](pretrained=True)
# ResNet model properties
# input_size = [3, 224, 224]
# input_space = 'RGB'
# input_range = [0, 1]
# mean = [0.485, 0.456, 0.406]
# std = [0.229, 0.224, 0.225]
2020-05-30 08:04:54 +08:00
# Reshape output to n classes
filters = model.fc.weight.shape[1]
2020-07-07 02:46:10 +08:00
model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
model.fc.out_features = n
2020-05-30 08:04:54 +08:00
return model
2020-06-06 12:14:15 +08:00
def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio
2020-05-30 08:04:54 +08:00
# scales img(bs,3,y,x) by ratio
if ratio == 1.0:
return img
else:
h, w = img.shape[2:]
s = (int(h * ratio), int(w * ratio)) # new size
img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
if not same_shape: # pad/crop img
gs = 32 # (pixels) grid size
h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
2020-05-30 08:04:54 +08:00
2020-07-12 03:35:21 +08:00
def copy_attr(a, b, include=(), exclude=()):
# Copy attributes from b to a, options to only include [...] and to exclude [...]
for k, v in b.__dict__.items():
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
continue
else:
setattr(a, k, v)
2020-05-30 08:04:54 +08:00
class ModelEMA:
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
Keep a moving average of everything in the model state_dict (parameters and buffers).
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
A smoothed version of the weights is necessary for some training schemes to perform well.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
"""
2020-07-10 06:09:06 +08:00
def __init__(self, model, decay=0.9999, updates=0):
2020-07-04 02:56:14 +08:00
# Create EMA
2020-07-10 06:09:06 +08:00
self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
2020-07-14 06:47:46 +08:00
# if next(model.parameters()).device.type != 'cpu':
# self.ema.half() # FP16 EMA
2020-07-10 06:09:06 +08:00
self.updates = updates # number of EMA updates
2020-05-30 08:04:54 +08:00
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
for p in self.ema.parameters():
p.requires_grad_(False)
def update(self, model):
2020-07-04 02:56:14 +08:00
# Update EMA parameters
2020-05-30 08:04:54 +08:00
with torch.no_grad():
2020-07-04 02:56:14 +08:00
self.updates += 1
d = self.decay(self.updates)
2020-05-30 08:04:54 +08:00
2020-07-04 02:56:14 +08:00
msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
for k, v in self.ema.state_dict().items():
2020-05-30 08:04:54 +08:00
if v.dtype.is_floating_point:
v *= d
v += (1. - d) * msd[k].detach()
2020-07-12 03:35:21 +08:00
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
2020-07-04 02:56:14 +08:00
# Update EMA attributes
2020-07-12 03:35:21 +08:00
copy_attr(self.ema, model, include, exclude)