Improved FLOPS computation (#1398)

* Improved FLOPS computation

* update comment
pull/1399/head
Glenn Jocher 2020-11-14 14:39:46 +01:00 committed by GitHub
parent 0c26c4e831
commit d8f5fcfe87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 6 deletions

View File

@ -192,8 +192,8 @@ class Model(nn.Module):
copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
return m
def info(self, verbose=False): # print model information
model_info(self, verbose)
def info(self, verbose=False, img_size=640): # print model information
model_info(self, verbose, img_size)
def parse_model(d, ch): # model_dict, input_channels(3)

View File

@ -139,8 +139,8 @@ def fuse_conv_and_bn(conv, bn):
return fusedconv
def model_info(model, verbose=False):
# Plots a line-by-line description of a PyTorch model
def model_info(model, verbose=False, img_size=640):
# Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
n_p = sum(x.numel() for x in model.parameters()) # number parameters
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
if verbose:
@ -152,8 +152,10 @@ def model_info(model, verbose=False):
try: # FLOPS
from thop import profile
flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, 64, 64),), verbose=False)[0] / 1E9 * 2
fs = ', %.1f GFLOPS' % (flops * 100) # 640x640 FLOPS
stride = int(model.stride.max())
flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, stride, stride),), verbose=False)[0] / 1E9 * 2
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 FLOPS
except ImportError:
fs = ''