mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
Profile() feature addition (#1673)
* Profile() feature addition * cleanup
This commit is contained in:
parent
94a7f55c4e
commit
ada90e3901
@ -1,18 +1,22 @@
|
|||||||
# PyTorch utils
|
# PyTorch utils
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
import math
|
|
||||||
import torch
|
import torch
|
||||||
import torch.backends.cudnn as cudnn
|
import torch.backends.cudnn as cudnn
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision
|
import torchvision
|
||||||
|
|
||||||
|
try:
|
||||||
|
import thop # for FLOPS computation
|
||||||
|
except ImportError:
|
||||||
|
thop = None
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -66,10 +70,45 @@ def select_device(device='', batch_size=None):
|
|||||||
|
|
||||||
|
|
||||||
def time_synchronized():
|
def time_synchronized():
|
||||||
|
# pytorch-accurate time
|
||||||
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
||||||
return time.time()
|
return time.time()
|
||||||
|
|
||||||
|
|
||||||
|
def profile(x, ops, n=100, device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')):
|
||||||
|
# profile a pytorch module or list of modules. Example usage:
|
||||||
|
# x = torch.randn(16, 3, 640, 640) # input
|
||||||
|
# m1 = lambda x: x * torch.sigmoid(x)
|
||||||
|
# m2 = nn.SiLU()
|
||||||
|
# profile(x, [m1, m2], n=100) # profile speed over 100 iterations
|
||||||
|
|
||||||
|
x = x.to(device)
|
||||||
|
x.requires_grad = True
|
||||||
|
print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
|
||||||
|
print(f"\n{'Params':>12s}{'FLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
|
||||||
|
for m in ops if isinstance(ops, list) else [ops]:
|
||||||
|
m = m.to(device) if hasattr(m, 'to') else m
|
||||||
|
dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
|
||||||
|
try:
|
||||||
|
flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPS
|
||||||
|
except:
|
||||||
|
flops = 0
|
||||||
|
|
||||||
|
for _ in range(n):
|
||||||
|
t[0] = time_synchronized()
|
||||||
|
y = m(x)
|
||||||
|
t[1] = time_synchronized()
|
||||||
|
_ = y.sum().backward()
|
||||||
|
t[2] = time_synchronized()
|
||||||
|
dtf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
||||||
|
dtb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
||||||
|
|
||||||
|
s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
|
||||||
|
s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
|
||||||
|
p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
|
||||||
|
print(f'{p:12.4g}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}')
|
||||||
|
|
||||||
|
|
||||||
def is_parallel(model):
|
def is_parallel(model):
|
||||||
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user