From d8f5fcfe87e781714375329d565d9e4923fdd017 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 14 Nov 2020 14:39:46 +0100 Subject: [PATCH] Improved FLOPS computation (#1398) * Improved FLOPS computation * update comment --- models/yolo.py | 4 ++-- utils/torch_utils.py | 10 ++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/models/yolo.py b/models/yolo.py index 0080056a4..2ef1574a8 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -192,8 +192,8 @@ class Model(nn.Module): copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes return m - def info(self, verbose=False): # print model information - model_info(self, verbose) + def info(self, verbose=False, img_size=640): # print model information + model_info(self, verbose, img_size) def parse_model(d, ch): # model_dict, input_channels(3) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index cdd21b519..e5ef2607a 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -139,8 +139,8 @@ def fuse_conv_and_bn(conv, bn): return fusedconv -def model_info(model, verbose=False): - # Plots a line-by-line description of a PyTorch model +def model_info(model, verbose=False, img_size=640): + # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320] n_p = sum(x.numel() for x in model.parameters()) # number parameters n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients if verbose: @@ -152,8 +152,10 @@ def model_info(model, verbose=False): try: # FLOPS from thop import profile - flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, 64, 64),), verbose=False)[0] / 1E9 * 2 - fs = ', %.1f GFLOPS' % (flops * 100) # 640x640 FLOPS + stride = int(model.stride.max()) + flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, stride, stride),), verbose=False)[0] / 1E9 * 2 + img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float + fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 FLOPS except ImportError: fs = ''