GFLOPs computation fix for classification models (#8954)
* GFLOPs computation fix for classification models Improved robustness in reading input channel count * Update torch_utils.py * Update torch_utils.pypull/8957/head
parent
f1214f237d
commit
6aed0a7c00
|
@ -199,12 +199,11 @@ def sparsity(model):
|
|||
def prune(model, amount=0.3):
|
||||
# Prune model to requested global sparsity
|
||||
import torch.nn.utils.prune as prune
|
||||
print('Pruning model... ', end='')
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
prune.l1_unstructured(m, name='weight', amount=amount) # prune
|
||||
prune.remove(m, 'weight') # make permanent
|
||||
print(' %.3g global sparsity' % sparsity(model))
|
||||
LOGGER.info(f'Model pruned to {sparsity(model):.3g} global sparsity')
|
||||
|
||||
|
||||
def fuse_conv_and_bn(conv, bn):
|
||||
|
@ -230,7 +229,7 @@ def fuse_conv_and_bn(conv, bn):
|
|||
return fusedconv
|
||||
|
||||
|
||||
def model_info(model, verbose=False, img_size=640):
|
||||
def model_info(model, verbose=False, imgsz=640):
|
||||
# Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
|
||||
n_p = sum(x.numel() for x in model.parameters()) # number parameters
|
||||
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
|
||||
|
@ -242,12 +241,12 @@ def model_info(model, verbose=False, img_size=640):
|
|||
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
|
||||
|
||||
try: # FLOPs
|
||||
from thop import profile
|
||||
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
|
||||
img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
|
||||
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
|
||||
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
|
||||
fs = ', %.1f GFLOPs' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPs
|
||||
p = next(model.parameters())
|
||||
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
|
||||
im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
||||
flops = thop.profile(deepcopy(model), inputs=(im,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
|
||||
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
|
||||
fs = f', {flops * imgsz[0] / stride * imgsz[1] / stride:.1f} GFLOPs' # 640x640 GFLOPs
|
||||
except Exception:
|
||||
fs = ''
|
||||
|
||||
|
|
Loading…
Reference in New Issue