parent
0c26c4e831
commit
d8f5fcfe87
|
@ -192,8 +192,8 @@ class Model(nn.Module):
|
|||
copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
|
||||
return m
|
||||
|
||||
def info(self, verbose=False): # print model information
|
||||
model_info(self, verbose)
|
||||
def info(self, verbose=False, img_size=640): # print model information
|
||||
model_info(self, verbose, img_size)
|
||||
|
||||
|
||||
def parse_model(d, ch): # model_dict, input_channels(3)
|
||||
|
|
|
@ -139,8 +139,8 @@ def fuse_conv_and_bn(conv, bn):
|
|||
return fusedconv
|
||||
|
||||
|
||||
def model_info(model, verbose=False):
|
||||
# Plots a line-by-line description of a PyTorch model
|
||||
def model_info(model, verbose=False, img_size=640):
|
||||
# Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
|
||||
n_p = sum(x.numel() for x in model.parameters()) # number parameters
|
||||
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
|
||||
if verbose:
|
||||
|
@ -152,8 +152,10 @@ def model_info(model, verbose=False):
|
|||
|
||||
try: # FLOPS
|
||||
from thop import profile
|
||||
flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, 64, 64),), verbose=False)[0] / 1E9 * 2
|
||||
fs = ', %.1f GFLOPS' % (flops * 100) # 640x640 FLOPS
|
||||
stride = int(model.stride.max())
|
||||
flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, stride, stride),), verbose=False)[0] / 1E9 * 2
|
||||
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
|
||||
fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 FLOPS
|
||||
except ImportError:
|
||||
fs = ''
|
||||
|
||||
|
|
Loading…
Reference in New Issue